API proposal: fast difference-of-products

Motivation
Expressions of the form a*b - c*d or a*b + c*d are extremely common and notorious sources of cancellation error in floating-point arithmetic. Some well-known examples include 2x2 matrix determinants and the discriminant term in the quadratic formula. It's quite easy to construct ill-posed examples (for determinants, when the rows/columns are nearly collinear; for discriminants when the roots of the quadratic are close together) for which essentially all accuracy is lost if a naive computation via two multiplies and an add or a multiply and an FMA is used.

Kahan's lecture note goes into excessive detail on these motivating examples, for anyone interested in a more complete treatment of the problem.

Proposed Solution
There is a simple algorithm from Kahan that eliminates the problem entirely with almost zero performance cost. Kahan was particularly interested in 2x2 determinants for the purposes of finding singular values, through he wrote his note in terms of quadratics, but the algorithms is more generally useful; Matt Pharr calls it "difference of products," which I find to be a pretty good name (even though it's also applicable to sums of products). In pseudocode, the algorithm is as follows:

// given a, b, c, d, compute ab - cd
let (head, tail) = Augmented.product(-c, d)
let result = head.addingProduct(a, b) + tail

Instead of two multiplies and an add, this requires two FMAs and a multiply, so the performance impact of adopting it is negligible on modern targets, and it completely solves the problem; instead of essentially arbitrary relative error, this achieves an error bound of 1.5ulp for binary floating-point types.Âą

I propose to make this algorithm available in the Augmented enum of RealModule, under the binding: Rough sketch of fast difference-of-products. by stephentyrone · Pull Request #329 · apple/swift-numerics · GitHub. As a strawman API for development purposes, I have bound it as:

extension Augmented {
  /// ab - cd
  static func fastDifferenceOfProducts<T: FloatingPoint>(
    _ a: T, _ b: T, _ c: T, _d: T
  ) -> T
}

Alternatives

  • Naming: I generally like Matt Pharr's name vs "determinant" or similar, because it suggests that it can be used for any difference of products, but I'm open to other suggestions. In particular, "difference of products" suggests the argument ordering (a, b, c, d) for (ab - cd), but a "determinant" name might suggest (ad - bc), and other names would suggest something else. This is primarily of use as a building block for higher-level numerical algorithms, so I do not expect "normal" users to be faced with it often, but it's still worth getting right.
  • Naming (pt 2): Why "fast difference of products" instead of just "difference of products"? I have in mind leaving space for a more expensive sub-ulp accurate (or correctly-rounded) implementation, or an implementation that produces both head and tail. But maybe those would best be placed under a different name, and this should have the shorter name.
  • This algorithm does not handle overflow/underflow. That's OK, neither does Augmented.product. But it's worth keeping in mind with regard to API design.
  • Should there also be an API for ab + cd (or should that be the API)? There isn't one for head-tail a-b; users are expected to write sum(a, -b), but maybe that should be revisited.
  • Maybe Augmented should be reserved for algorithms that have head-tail results or parameters, and algorithms like this that merely use augmented arithmetic belong elsewhere.

Âą this is not obvious, and for a decade or so after publication people only remarked that the error was surprisingly good most of the time. Claude-Pierre Jeannerod, Nicolas Louvet, and Jean-Michel Muller published the clearest proof that I'm aware of.

12 Likes
Bikeshed color

Augmented.differenceOfProducts((a, b), (c, d)). Exact same ABI (I think), follows the more mathy conventions rather than using labeled arguments (though you could consider labeling the second tuple minus), but much quicker to guess the meaning for someone who isn’t familiar with this operation than the fully flat placeholder form.

9 Likes

I do think this should have the “best” name and one that doesn’t mention how it’s implemented with augmented arithmetic: while the underlying algorithm is very interesting to me, I suspect most people will want to use this API for an ideal tradeoff of performance and accuracy regardless of how it’s implemented.

It seems to me that the most consistent spelling with stdlib precedent and that users will seek out is:

(a, b).subtractingProduct(c, d)

It is too bad that we can’t currently extend tuples, but something along those lines would be nice.

1 Like

For the sake of discussion… perhaps this optimization correctness fix could be made at the compiler level?

No, sometimes users really want ab - cd and the compiler has to respect that, or else engineering is impossible.

4 Likes

For comparison purposes, do we know how computationally expensive each of those would be?

• • •

My first impression is that the name you’ve chosen is good, because the bare Augmented.differenceOfProducts sounds like it should produce both the head and tail.

I do like Jordan’s idea of passing the parameters as pairs though.

• • •

(One could also imagine a fanciful world filled with custom types and operator overloads where Augmented(a*b - c*d) would produce the correct head-tail result, and Augmented.fast(a*b - c*d) would use the algorithm proposed here. But that would defy common expectations and thus not belong in the Numerics library.)

In the past, I implemented the Kahan algorithm in my own code to improve numerical accuracy. It would be great to have this available as a built-in feature.

1 Like

Interesting feature. Am I using it right though?

let n = 1.0e20
let a = n + 10
let b = n
let c = n
let d = n

let det1 = a * b - c * d
print("Normal: \(det1)") // Normal: 0.0
let det2 = fma(a, b, -c*d) // fma: -3.037860284270037e+23
print("fma: \(det2)")

The true result here is non zero, so "normal" calculation is wrong.
OTOH the true result here should be positive obviously, so the FMA result (being a giant negative number) is "more wrong". What am I doing wrong?

Nope.

First, add a line print(a == n) and you’ll see it’s true.

Change it to print(n.nextUp - n) and you’ll see that the next representable value above n is n + 16384.

Now update your code to use a = n + 16384 (or alternatively, use n = 1.0e16 so that a = n + 10 is representable).

But more importantly, your code does not match the proposed semantics here. Instead you need this:

let head = -c * d
let tail = fma(-c, d, -head)
let det3 = fma(a, b, head) + tail
print("Kahan: \(det3)")

(Technically you could remove both minus signs on line 2, and change the plus to a minus on line 3, and it would work the same. But I wrote it this way to match the proposed implementation.)

And even still, the numbers you chose have lots of trailing zeros, so they don’t actually demonstrate the benefits of this approach. Instead, try these values:

n = 1.0e16
a = n
b = n
c = n + 10
d = n - 10

Since c*d is a difference of squares, the mathematically correct value for a*b-c*d is exactly 100. If you run the code you’ll see these results:

det1 = 0.0
det2 = -5366162204393572.0
det3 = 100.0
4 Likes

Thank you!

On something you wrote above

I am thinking of having something like:

var r: Real = a*b - c*d + ...
var d: Double = Double(r) // does the truncation once

that works with a wrapper Real type and produces "correct" result (working with a higher intermediate precision and truncating result only at the end).

If you were to do something like that (outside of standard / numeric library) what would you use as an underlying "currency type" to do intermediate operations with a higher precision so that full bits of simple expressions are preserved (considering expressions like a*b but not necessarily higher bit count expressions like a*b*c)? As I understand Double mantissa is 52 (or 53?) bits, so to preserve all bits of two numbers multiplication we'd need 106 bits (plus extra bits for power components). Is the following currency type looking reasonable and appropriate for the task (albeit being quite inefficient)?

struct Real {
    var mantissa: Int128
    var exponent: Int32
}

You should look at compensated summation, which is also by Kahan. The idea is to keep track of the missing “tail” and re-add it at each step.

There are other tricks that can help, like matching magnitudes at each step, and trying to cause cancellation as early as possible.

Another thing that can help reduce errors, is to pair up the sums like a tournament rather than zipping down the line into a single accumulator. (This can be understood as recursively summing the two halves of the list separately, then adding the results.)

If you value accuracy over speed, you could even do something like sort the terms by magnitude (perhaps in a priority queue) and repeatedly combine the two smallest terms.

In your use case, that might look like precomputing the (head, tail) pair for each product, and then going through the list either in arbitrary order, or in ascending order, or with a priority queue, and doing something like compensated summation along the way.

On the other hand, if you still want speed, you could just adapt the compensated summation method to incorporate the missing part of the tail from the fma into that from the sum.

It would be nice if computer systems provided a very wide floating point accumulator (eg. 256 or 512 bits) but I’m not aware of any.

2 Likes

There have been a bunch of proposals to do something along these lines over the years, but basically no one has been willing to pay the performance cost it would entail in practice. Most algorithms are happy with a backwards-stable error bound, and to get reproducible results, defining an accumulation order suffices; you do not need correct rounding.

If you’re interested in the academic literature on the subject, “Kulish accumulator” is a good search term to start with for hardware, and the reproblas is a pretty good representative of the state of the art for software.

5 Likes

Turns out to be quite elegant.

MVP
indirect enum Real: ExpressibleByFloatLiteral, ExpressibleByIntegerLiteral, CustomStringConvertible {
    case number(Double)
    case mul(Self, Self)
    case sub(Self, Self)
    
    var description: String {
        switch self {
            case .number(let value): "\(value)"
            case .mul(let lhs, let rhs): "\(lhs) * \(rhs)"
            case .sub(let lhs, let rhs): "(\(lhs) - \(rhs))"
        }
    }
    
    init(floatLiteral: Double) { self = .number(floatLiteral) }
    init(integerLiteral: Int) { self = .number(Double(integerLiteral)) }
    init(_ v: Double) { self = .number(v) }
    
    static func * (lhs: Self, rhs: Self) -> Self {
        mul(lhs, rhs)
    }
    static func - (lhs: Self, rhs: Self) -> Self {
        sub(lhs, rhs)
    }
    
    var normalValue: Double {
        switch self {
            case .number(let value): value
            case .mul(let lhs, let rhs): lhs.normalValue * rhs.normalValue
            case .sub(let lhs, let rhs): lhs.normalValue - rhs.normalValue
        }
    }
    
    var fmaValue: Double {
        switch self {
            case .number(let value):
                return value
            case .mul(let lhs, let rhs):
                return lhs.fmaValue * rhs.fmaValue
            case .sub(let lhs, let rhs):
                switch (lhs, rhs) {
                    case (.mul(let a, let b), .mul(let c, let d)):
                        let tail = -c.fmaValue * d.fmaValue
                        return fma(a.fmaValue, b.fmaValue, tail) - fma(c.fmaValue, d.fmaValue, tail)
                    default:
                        return lhs.fmaValue - rhs.fmaValue
                }
        }
    }
}

extension Double {
    init(_ v: Real) { self = v.fmaValue }
}

Looks and feels (almost) like working with ordinary float values:

let a: Real = 10000000000
let b: Real = 10000000000
let c: Real = 10000010000
let d: Real = 09999990000
let r = a*b - c*d
print(r)  // (10000000000.0 * 10000000000.0 - 10000010000.0 * 9999990000.0)
print(r.normalValue) // 100007936.0
print(r.fmaValue)    // 100000000.0
let x = Double(a*b) - Double(c*d)    // 100007936.0
let y = Double(a*b - c*d)            // 100000000.0