Should we add a fused multiply-add?

When trying to write another numeric type, I drifted into thinking how to make a bignum type using FixedWidthInteger. There are master methods to do full addition and multiplication. To do a bignum multiplication, we need to cascade carries while in the multiplication loop. But for now I have to do a full multiplication then a full addition. I remember that computers sometimes have a fused multiply-add instruction, so maybe we should do something similar.

My example code uses extension methods, but the real deal should be new requirements on Numeric with default implementations given. The names need a lot of bikeshedding:

extension Numeric {

    /// Returns the sum of the first operand with the product of the other
    /// operands.
    ///
    /// The operation should implement lesser rounding or other precision loss
    /// compared to a distinct multiplication then addition steps.
    static func fused(addingOf augend: Self, withProductOf multiplicand: Self, and multiplier: Self) -> Self {
        return augend + multiplicand * multiplier
    }

    /// Returns the difference between the first operand and the product of the
    /// other operands.
    ///
    /// The operation should implement lesser rounding or other precision loss
    /// compared to a distinct multiplication then subtraction steps.
    static func fused(subtractionOf minuend: Self, againstProductOf multiplicand: Self, and multiplier: Self) -> Self {
        return minuend - multiplicand * multiplier
    }

    /// Returns the difference between the product of the first two operands and
    /// the third operand.
    ///
    /// The operation should implement lesser rounding or other precision loss
    /// compared to a distinct multiplication then subtraction steps.
    static func fused(takingProductOf multiplicand: Self, and multiplier: Self, subtracting subtrahend: Self) -> Self {
        return multiplicand * multiplier - subtrahend
    }

    /// Increases this value by the product of the given operands.
    mutating func addProduct(of multiplicand: Self, and multiplier: Self) {
        self += multiplicand * multiplier
    }

    /// Decreases this value by the product of the given operands.
    mutating func subtractProduct(of multiplicand: Self, and multiplier: Self) {
        self -= multiplicand * multiplier
    }

    /// Amplifies this value by the first argument, then increases it by the
    /// second.
    mutating func multiply(by multiplier: Self, thenAdd addend: Self) {
        self *= multiplier
        self += addend
    }

    /// Amplifies this value by the first argument, then decreases it by the
    /// second.
    mutating func multiply(by multiplier: Self, thenSubtract subtrahend: Self) {
        self *= multiplier
        self -= subtrahend
    }

    /// Reset this value to the difference of subtracting the product of this
    /// value with the first argument from the second argument.
    mutating func subtractProduct(with multiplier: Self, from minuend: Self) {
        self = minuend - self * multiplier
    }

}

We should probably do similar functions to FixedWidthInteger:

extension FixedWidthInteger {

    /// Returns a tuple containing the high and low parts of the result of
    /// multiplying this value by the first given value then adding the second
    /// given value to the product.
    func multipliedFullWidth(by multiplier: Self, thenAdd addend: Self) -> (high: Self, low: Self.Magnitude) {
        // I think this works....

        let partialProduct = multipliedFullWidth(by: multiplier)
        let newLow: Magnitude, newHigh: Self, overflow: Bool
        if addend >= 0 {
            let carry: Bool
            (newLow, carry) = partialProduct.low.addingReportingOverflow(addend.magnitude)
            (newHigh, overflow) = partialProduct.high.addingReportingOverflow(carry ? 1 : 0)
        } else {
            let borrow: Bool
            (newLow, borrow) = partialProduct.low.subtractingReportingOverflow(addend.magnitude)
            (newHigh, overflow) = partialProduct.high.subtractingReportingOverflow(borrow ? 1 : 0)
        }
        assert(!overflow)
        return (high: newHigh, low: newLow)
    }

    /// Returns a tuple containing the sign, high word, and low word of the
    /// result of multiplying this value by the first given value then
    /// subtracting the second given value to the product.
    func multipliedFullWidth(by multiplier: Self, thenSubtract subtrahend: Self) -> (negate: Bool, high: Self.Magnitude, low: Self.Magnitude) {
        // Figure out this one yourself.
    }

    /// Returns a tuple containing the sign, high word, and low word of the
    /// result of multiplying this value by the first given value then
    /// subtracting that product from the second given value.
    func multipliedFullWidth(by multiplier: Self, thenSubtractFrom minuend: Self) -> (negate: Bool, high: Self.Magnitude, low: Self.Magnitude) {
        // Figure out this one yourself.
    }

}

The purpose of making these protocol requirements is so the default integer types can secretly call assembly level instructions when needed.

This seems like quite a micro-optimization. Is there a reason we can’t expect LLVM’s optimizer to generate this instruction automatically whenever it sees a multiply and add next to each other? Or is there some semantic difference between this and a standard multiply/add?

2 Likes

Minor terminology correction: that's not what a fused multiply-add is. Fused multiply-add (FMA) is strictly a floating-point operation, and refers to performing a multiply and an add without intermediate rounding. We already have that operation in the standard library, as c.addingProduct(a, b).

A more general multiply-add is a perfectly reasonable thing to add, with a couple notes:

  • it should probably go on FixedWidthInteger, rather than Numeric[1].
  • the actual workhorse operation you want is what bignum libraries call muladd2 (a*b + c + d), with a default value of d = 0. This is the most basic building-block that you need to built a multi-word multiplication by assembling operations that you already have, and the result never overflows a double-width type.
  1. It's not clear that you even need it on FixedWidthInteger--it's possible that you only really need to have it on the natural word type, but you can certainly define it for FixedWidthInteger.

One more note: you don't need any magic builtin support for this, the optimizer can handle the following primitives just fine:

// MARK: - Single-word primitives
@usableFromInline @_transparent
internal func add(_ a: inout Word, _ b: Word, carry c: Bool = false) -> Bool {
  a = a &+ b &+ (c ? 1 : 0)
  return a < b || a == b && c
}

@usableFromInline @_transparent
internal func muladd(_ a: Word, _ b: Word, _ c: Word = 0, _ d: Word = 0) -> (lo: Word, hi: Word) {
  var (hi, lo) = a.multipliedFullWidth(by: b)
  hi &+= add(&lo, c) ? 1 : 0
  hi &+= add(&lo, d) ? 1 : 0
  return (lo, hi)
}

These get you most of what you need to build fast bignum arithmetic in release builds. The standard library could use some tricks to get better debug performance, though.

7 Likes

Would those optimizations be possible if this was tested out in the standard library preview package?

Doing significantly better requires guaranteeing codegen to the appropriate LLVM builtins, which can be accomplished either by always-inline C header shims or by using the stdlib Builtin module. It only really matters for debug builds, so it's not a huge deal.

1 Like