Help needed with retroactive differentiability

help-needed

(Dan Zheng) #1

Hello SIL experts,

I'm working on automatic differentiation and I have some questions about implementing retroactive differentiability. Scroll to the bottom for my question in bold, and a tentative solution.


Automatic differentiation (AD) is a technique for automatically computing derivatives of functions. My team is implementing an expressive and flexible AD system for Swift on the tensorflow branch. You can read more from this slightly outdated manifesto. I'll try to keep things concise in this post and explain just enough to frame the problem.


Here's an example of AD using the gradient(of:) differential operator:

// cubed = x^3
func cubed(_ x: Float) -> Float {
  return x * x * x
}
// Differentiate `cubed` with respect to parameter `x`.
// dcubed = 3x^2
let dcubed = gradient(of: cubed)
// dcubed(3) = 3(2^2) = 3*4 = 12.
print(dcubed(2)) // 12

The way this works is that primitive differentiable functions declare their derivatives (either vector-Jacobian product functions or Jacobian-vector functions, vjps or jvps) via the @differentiable attribute. The compiler identifies these derivatives and knows how to chain them together to differentiate functions composed of primitive differentiable functions.

For example, here's the definition of Float.*, used in cubed(_:) above:

extension Float {
  // `_vjpMultiply` is registered as the derivative of `Float.*`.
  @differentiable(vjp: _vjpMultiply(lhs:rhs:))
  public static func * (lhs: Float, rhs: Float) -> Float {
    ...
  }

  // `_vjpMultiply` takes original arguments and returns:
  // (Float,       (Float) -> (Float, Float))
  //  ^~~~~         ^~~~~      ^~~~~~~~~~~~
  //  orig. result  vector     vector-Jacobian products
  static func _vjpMultiply(
    lhs: Float, rhs: Float
  ) -> (Float, (Float) -> (Float, Float)) {
    return (lhs * rhs, { v in (rhs * v, lhs * v) })
  }
}

@differentiable is a straightforward way for functions to declare their derivatives; it is good and sufficient for many AD use cases. However, with just @differentiable, it is not possible to retroactively register derivatives for existing functions.

Retroactive derivative registration enables registering derivatives for a function without changing the original function's declaration, similar to how protocol extensions enable extending types with functionality without changing the original type declaration.

We've scoped out a design for retroactive differentiability using the @differentiating attribute, which retroactively registers functions as derivatives of other functions. Here's an example:

import Darwin
// Even though `sin` is defined in another module,
// I can register a derivative for it.
@differentiating(sin(_:))
func _vjpSin(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  return (value: sin(x), pullback: { v in v * cos(x) })
}

// `gradient(of: sin)` should look up `_vjpSin` and work.
let dsin = gradient(of: sin)

The syntax/semantics of @differentiating are mostly scoped out. However, the SIL-level implementation is not clear. We want the implementation to be general enough to support cross-module use cases like:

  • Module A defines function foo.
  • Module B imports module A and defines @differentiating(foo) func vjpFoo.
  • Module C imports modules A and B and calls gradient(of: foo), which should look up vjpFoo defined in module B.

How should we implement retroactivity differentiability at the SIL level? We feel there's a lot of potential overlap with the design of protocol extensions/conformances in SIL and we want to be informed by that design and reuse code if appropriate.

Here's a tentative solution, involving name mangling (a new mangling scheme for retroactive derivatives) and lookup:

  • When SILGen'ing a function with @differentiating attribute, we lower it normally but also create a "redirection thunk" with a specially mangled, module-prefixed name, that simply calls the normally-lowered function.
    • Example: during SILGen for @differentiating(foo) func vjpFoo, we lower it normally (fake mangled name: s10ModuleB$vjpFoo), then create a "redirection thunk" (fake mangled name: s10ModuleB$retroactive_jvp_for_ModuleA$foo).
  • During SILGen, differential operators like gradient(of:) are lowered to SIL AutoDiffFunctionInst. The differentiation transform currently processes these AutoDiffFunctionInsts and tries to look up their associated derivative functions. If lookup fails, then look for retroactive derivatives from each imported module by mangling the module name into the retroactive derivative mangling scheme.
    • Example: let's say derivative lookup for foo fails in the differentiation transform. Then, for every imported module, look up the function $retroactive_jvp_for_ModuleA$foo prepended with the module name. Thus, s10ModuleB$retroactive_jvp_for_ModuleA$foo will be found and used as the derivative.
    • Question: are there more efficient lookup strategies? Performing lookup in every imported module doesn't seem terribly efficient.

Alternative solutions:

  • Lower AST @differentiating attribute into SIL [differentiating] attribute, which stores the mangled name of the function-to-differentiate. Then, perform similar lookup in the differentiation pass.

Feedback about the tentative solution would be greatly appreciated! An explanation of how protocol extensions are represented in SIL would also be highly appreciated.


cc @rxwei @Douglas_Gregor


(Richard Wei) #2

@Joe_Groff @John_McCall @Slava_Pestov

Any help would be appreciated!


(John McCall) #3

Feels like you might want a new top-level entity, like sil_vtable, that associates the original function with the differentials, and it would be keyed by the original function name.


(Dan Zheng) #4

Creating a new top-level entity like sil_vtable seems quite heavyweight, do you feel that's it the right approach?
Since type extensions don't require a top-level entity in SIL, I wonder if retroactive differentiability can be implemented similarly (like the tentative solution above)?


(Jordan Rose) #5

What you're doing sounds like a kind of conformance for a function, and conformances have identity. Extensions don't require a top-level entity, but conformances do. (sil_witness_table rather than sil_vtable, but same idea.)


(Richard Wei) #6

In the following example:

Module A:

func foo(x: Float) -> Float

Module B:

import A
@differentiating(A.foo(x:))
func _(x: Float) -> (value: Float, differential: (Float) -> Float)

Module C:

import A
import B
@differentiating(A.foo(x:))
func _(x: Float) -> (value: Float, differential: (Float) -> Float)

We'd expect differentiation of A.foo(x:) in module C to use the derivative defined in C. It feels like a witness table for functions, but can be overriden.


(Richard Wei) #7

This could result in the following tables.

In module B:

differentiability_witness @A.foo {
  jvp: @B.foo_jvp
  vjp: @B.foo_vjp
}

In module C:

differentiability_witness @A.foo {
  jvp: @C.foo_jvp
  vjp: @C.foo_vjp
}

And in the differentiation transform we can always look up differentiability_witness tables in the current module, then in imported modules if it doesn't exist. This also makes us able to get rid of the [differentiable jvp @... vjp @...] attribute on SIL function declarations because it's not as powerful as lookup tables.


(Richard Wei) #8

Ok, we are about to add this top-level construct to SIL if you folks agree it's the right thing to do! We'll land things in the tensorflow branch and let you review, and will also send PRs to master in the next few months.