Operator overloading with generics

Hi guys, my first question here, hopefully, I'll structure it correctly.

As part of a series of homework assignments for Numerical mathematics, I'm developing a Linear algebra library and I just startled on a problem.

I abstractly defined matrices as

public protocol MatrixProtocol:
  ExpressibleByArrayLiteral,
  Equatable, 
  BidirectionalCollection
where Element == Vector
{
  associatedtype Value: BinaryFloatingPoint
  ...
}

As a design decision, I provided default implementations for core operators, like multiplication, so that every conforming type automatically has these capabilities.

public func *<M: MatrixProtocol>(_ m1: M, _ m2: M) -> Matrix<M.Value> {
  // slow O(n^3) basic implementation
}

The core data structure is Matrix, which is an n*m dense matrix, backed by an array:

public struct Matrix<T: BinaryFloatingPoint>: MatrixProtocol {
  public typealias Value = T
  
  var buffer: [T]
  ...
}

Now I can provide optimized multiplication overloads using Accelerate:

import Accelerate

public func *(_ m1: Matrix<Double>, _ m2: Matrix<Double>) -> Matrix<Double> {
  vDSP_mmulD(...)
}

public func *(_ m1: Matrix<Float>, _ m2: Matrix<Float>) -> Matrix<Float> {
  vDSP_mmul(...)
}

Now, let me illustrate the problem:

Let's say I'm working on an algorithm, which extensively uses matrix multiplication, for instance, QR decomposition.

func QRDecompose<M: MatrixProtocol>(_ A: M) {
  // dummy implementation
  A * A
}

let A: Matrix<Float> = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
QRDecompose(A)

I would expect that, since I know that A is Matrix<Float>, the optimized multiplication overload would be selected by the compiler, but instead, the default generic one is.

Is there anything I can do, besides having two additional implementations for each algorithm, when the input matrix is Matrix<Float> / Matrix<Double>, which would lead to a lot of code duplication?

I hope I've stated my problem clearly.
Thanks

You can check inside *

public func *<M: MatrixProtocol>(_ m1: M, _ m2: M) -> Matrix<M.Value> {
  if M.self == Double.self {
    vDSP_mmulD(...)
  } else if M.self == Float.self {
    vDSP_mmul(...)
  } else {
        // slow O(n^3) basic implementation
  }
}

Thanks for answering.

It does solve the problem, though I still need to cast m1 and m2 to Matrix<Float/Double>, right? I think this shouldn't be a performance concern.

From an engineering perspective though, I'm not sure how much I like this approach because it requires modifying the function for every new matrix-like type that I add.
But I guess it will be good enough for now.

That's not how generics work. When compiling QRDecompose<M: MatrixProtocol>, the compiler only knows that M conforms to MatrixProtocol, so the generic implementation is used. If you want to get runtime dispatch based on the scalar type, another option is to create a protocol to bind the scalar type, and add the implementation hooks as protocol requirements:

// warning: I'm coding this in the forums text editor, so no guarantees =)
protocol MatrixScalar: BinaryFloatingPoint {
  static func gemm(_ m1: Matrix<Self>, _ m2: Matrix<Self>) -> Matrix<Self>
}

extension Float: MatrixScalar {
  static func gemm(_ m1: Matrix<Float>, _ m2: Matrix<Float>) -> Matrix<Float> {
    cblas_sgemm( ... ) // use this instead of vDSP_mmul
  }
}

extension Double: MatrixScalar {
  static func gemm(_ m1: Matrix<Double>, _ m2: Matrix<Double>) -> Matrix<Double> {
    cblas_dgemm( ... ) // use this instead of vDSP_mmulD
  }
}

public struct Matrix<Scalar: MatrixScalar> {
  // ...
  static func *(_ m1: Matrix, _ m2: Matrix) -> Matrix {
    Scalar.gemm(m1, m2)
  }
}
2 Likes

Is this a general recommendation that the BLAS version of matrix multiplication is preferable to the vDSP version?

Yes. vDSP_mmul wraps the blas implementation on recent OS versions, but on older OSes, cblas_*gemm is quite a bit faster (and you get to avoid a layer of indirection). It's the single most valuable function to optimize for most computation-heavy workloads, so it gets the bulk of the attention.

3 Likes

Is this a fundamental limitation of swift semantics or can it be sidestepped with a JIT compiler?

A more accurate phrasing would be “at the point of the call, the arguments are only constrained by conformance to Matrix, so the generic implementation is used.”

I.e. this isn’t a compiler limitation—it’s not a limitation at all, it’s the semantics of the language. The escape hatch is to make the thing you want to customize a protocol requirement, so that it’s dispatched through the witness table.

2 Likes