Hello SIL experts,
I'm working on automatic differentiation and I have some questions about implementing retroactive differentiability. Scroll to the bottom for my question in bold, and a tentative solution.
Automatic differentiation (AD) is a technique for automatically computing derivatives of functions. My team is implementing an expressive and flexible AD system for Swift on the tensorflow
branch. You can read more from this slightly outdated manifesto. I'll try to keep things concise in this post and explain just enough to frame the problem.
Here's an example of AD using the gradient(of:)
differential operator:
// cubed = x^3
func cubed(_ x: Float) -> Float {
return x * x * x
}
// Differentiate `cubed` with respect to parameter `x`.
// dcubed = 3x^2
let dcubed = gradient(of: cubed)
// dcubed(3) = 3(2^2) = 3*4 = 12.
print(dcubed(2)) // 12
The way this works is that primitive differentiable functions declare their derivatives (either vector-Jacobian product functions or Jacobian-vector functions, vjp
s or jvp
s) via the @differentiable
attribute. The compiler identifies these derivatives and knows how to chain them together to differentiate functions composed of primitive differentiable functions.
For example, here's the definition of Float.*
, used in cubed(_:)
above:
extension Float {
// `_vjpMultiply` is registered as the derivative of `Float.*`.
@differentiable(vjp: _vjpMultiply(lhs:rhs:))
public static func * (lhs: Float, rhs: Float) -> Float {
...
}
// `_vjpMultiply` takes original arguments and returns:
// (Float, (Float) -> (Float, Float))
// ^~~~~ ^~~~~ ^~~~~~~~~~~~
// orig. result vector vector-Jacobian products
static func _vjpMultiply(
lhs: Float, rhs: Float
) -> (Float, (Float) -> (Float, Float)) {
return (lhs * rhs, { v in (rhs * v, lhs * v) })
}
}
@differentiable
is a straightforward way for functions to declare their derivatives; it is good and sufficient for many AD use cases. However, with just @differentiable
, it is not possible to retroactively register derivatives for existing functions.
Retroactive derivative registration enables registering derivatives for a function without changing the original function's declaration, similar to how protocol extensions enable extending types with functionality without changing the original type declaration.
We've scoped out a design for retroactive differentiability using the @differentiating
attribute, which retroactively registers functions as derivatives of other functions. Here's an example:
import Darwin
// Even though `sin` is defined in another module,
// I can register a derivative for it.
@differentiating(sin(_:))
func _vjpSin(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (value: sin(x), pullback: { v in v * cos(x) })
}
// `gradient(of: sin)` should look up `_vjpSin` and work.
let dsin = gradient(of: sin)
The syntax/semantics of @differentiating
are mostly scoped out. However, the SIL-level implementation is not clear. We want the implementation to be general enough to support cross-module use cases like:
- Module A defines function
foo
. - Module B imports module A and defines
@differentiating(foo) func vjpFoo
. - Module C imports modules A and B and calls
gradient(of: foo)
, which should look upvjpFoo
defined in module B.
How should we implement retroactivity differentiability at the SIL level? We feel there's a lot of potential overlap with the design of protocol extensions/conformances in SIL and we want to be informed by that design and reuse code if appropriate.
Here's a tentative solution, involving name mangling (a new mangling scheme for retroactive derivatives) and lookup:
- When SILGen'ing a function with
@differentiating
attribute, we lower it normally but also create a "redirection thunk" with a specially mangled, module-prefixed name, that simply calls the normally-lowered function.- Example: during SILGen for
@differentiating(foo) func vjpFoo
, we lower it normally (fake mangled name:s10ModuleB$vjpFoo
), then create a "redirection thunk" (fake mangled name:s10ModuleB$retroactive_jvp_for_ModuleA$foo
).
- Example: during SILGen for
- During SILGen, differential operators like
gradient(of:)
are lowered to SILAutoDiffFunctionInst
. The differentiation transform currently processes theseAutoDiffFunctionInst
s and tries to look up their associated derivative functions. If lookup fails, then look for retroactive derivatives from each imported module by mangling the module name into the retroactive derivative mangling scheme.- Example: let's say derivative lookup for
foo
fails in the differentiation transform. Then, for every imported module, look up the function$retroactive_jvp_for_ModuleA$foo
prepended with the module name. Thus,s10ModuleB$retroactive_jvp_for_ModuleA$foo
will be found and used as the derivative. - Question: are there more efficient lookup strategies? Performing lookup in every imported module doesn't seem terribly efficient.
- Example: let's say derivative lookup for
Alternative solutions:
- Lower AST
@differentiating
attribute into SIL[differentiating]
attribute, which stores the mangled name of the function-to-differentiate. Then, perform similar lookup in the differentiation pass.
Feedback about the tentative solution would be greatly appreciated! An explanation of how protocol extensions are represented in SIL would also be highly appreciated.