AutoDiff: introduce a @differentiable function type attribute

(Richard Wei) #1

In the Automatic Differentiation feature I've been building under the tensorflow branch, I need to extend function types to encode differentiability. Without going deep into AD (I'll send an AD manifesto in ~2 weeks), I'll briefly explain the purpose of such an extension.

If a function's body is available at compilation, differentiability can be determined from its data flow.

func foo(x: Float) -> Float {
    return sin(x) * cos(x)

_ = #gradient(foo) // (Float) -> Float

However, determining differentiability simply by looking at its body won't work for cross-module cases, and requiring serialization is definitely not a long-term option. Here are two important use cases of AD that motivate the inclusion of differentiability in Swift function types:

  1. Differentiating opaque closures

    func foo(f: (Float) -> Float) {
        _ = #gradient(f) // error: cannot differentiate opaque closures

    If we have a type attribute called @differentiable, problem solved.

    func foo(f: @differentiable (Float) -> Float) {
        _ = #gradient(f) // okay!
  2. Differentiating protocol requirements

    protocol P {
        func f(_ x: Float) -> Float
    extension P {
        func g() {
            _ = #gradient(f) // error: cannot differentiate opaque closures

    Similarly, when we have a @differentiable attribute

    protocol P {
        // A declaration attribute, which forces implementors of this function
        // to provide a differentiable body, so that the compiler emits a binary
        // representation that contains function pointer(s) for its Jacobian.
        func f(_ x: Float) -> Float
    extension P {
        func g() {
            _ = #gradient(f) // ok! `f` has type `@differentiable (Float) -> Float`

In the simple model illustrated above, @differentiable can just be stored in ExtInfo as one bit. However, it's complicated: Differentiability does not always apply to all function parameters.

A function (T0, T1) -> U can be differentiable with respect to only a subset of parameters, e.g. the first parameter but not the second. In that case, we need to store, in the function type, a bit mask of arguments that are diff'able-wrt (or non-diff'able-wrt). The type syntax of that can look like:

let f: @differentiable (T0, @nodiff T1) -> U

As such, differentiability is not a traditional function representation that can be stored in ExtInfo as one bit: Its information is more complex. What would be the best representation for it?

Decl attribute and type attribute having the same name

Possibly unrelated question: Should differentiability be an attribute of a parameter rather than a function? This

let f: @differentiable (T0, @nodiff T1) -> U

would then be roughly

let f: (@differentiable T0, T1) -> U

though the attribute probably needs a different name, then. Maybe the common case is functions that are differentiable with respect to all parameters, though.

(Richard Wei) #3

Differentiability is a property on a function, not on parameter types. Other than its mathematical definition, it is because a differentiable function has a different layout.

Say, a normal "thin" function is a function pointer. A (reverse-mode) differentiable function is a product of these:

  • The original function pointer
  • The primal function pointer
  • The adjoint function pointer

In most cases, a differentiable function is differentiable with respect to all parameters.

In forward-mode differentiation, a function can be differentiable from all (or a subset of) results with respect to a single parameter. In that case, even results can be selected. Swift doesn't formally model multiple results as a result list, so result selection syntax can be trickier than parameter selection.

A full-blown version of parameter/result selection can look like the following, but it's kind of off-topic.

// f(x0, x1, x2) = (y0, y1, y2)
let f: @differentiable (T0, @nodiff T1, T2) -> (@nodiff U0, U1, U2)
// Jacobian
#jacobian(f)(x0, x1, x2) // ((∂y1/∂x0, ∂y1/∂x2), (∂y2/∂x0, ∂y2/∂x2))
// Forward-mode differentiation
#derivatives(f, wrt: .0)(x0, x1, x2) // (∂y1/∂x0, ∂y2/∂x0)
#derivatives(f, wrt: .1)(x0, x1, x2) // error: parameter #1 is not differentiable-with-respect-to
#derivatives(f, wrt: .2)(x0, x1, x2) // (∂y1/∂x2, ∂y2/∂x2)

(Richard Wei) #4

I took a stab at it and added differentiability into ExtInfo. I think @nodiff can either be stored as a bit vector in AnyFunctionType and SILFunctionType, or stored in each Param.

    //   |representation|isAutoClosure|noEscape|throws|differentiable|
    //   |    0 .. 3    |      4      |    5   |   6  |      7       |
    enum : unsigned {
      RepresentationMask     = 0xF << 0,
      AutoClosureMask        = 1 << 4,
      NoEscapeMask           = 1 << 5,
      ThrowsMask             = 1 << 6,
      DifferentiableMask     = 1 << 7,
      NumMaskBits            = 8

Here's a very interesting implementation question: Swift doesn't seem to have a declaration attribute and a type attribute both having the same name. Here, @differentiable is both a type attribute and a decl attribute and it defaults to a type attribute in lexing. Should I add my custom AST node?