In the Automatic Differentiation feature I've been building under the tensorflow branch, I need to extend function types to encode differentiability. Without going deep into AD (I'll send an AD manifesto in ~2 weeks), I'll briefly explain the purpose of such an extension.
If a function's body is available at compilation, differentiability can be determined from its data flow.
func foo(x: Float) -> Float {
return sin(x) * cos(x)
}
_ = #gradient(foo) // (Float) -> Float
However, determining differentiability simply by looking at its body won't work for cross-module cases, and requiring serialization is definitely not a long-term option. Here are two important use cases of AD that motivate the inclusion of differentiability in Swift function types:
-
Differentiating opaque closures
func foo(f: (Float) -> Float) { _ = #gradient(f) // error: cannot differentiate opaque closures }
If we have a type attribute called
@differentiable
, problem solved.func foo(f: @differentiable (Float) -> Float) { _ = #gradient(f) // okay! }
-
Differentiating protocol requirements
protocol P { func f(_ x: Float) -> Float } extension P { func g() { _ = #gradient(f) // error: cannot differentiate opaque closures } }
Similarly, when we have a
@differentiable
attributeprotocol P { // A declaration attribute, which forces implementors of this function // to provide a differentiable body, so that the compiler emits a binary // representation that contains function pointer(s) for its Jacobian. @differentiable func f(_ x: Float) -> Float } extension P { func g() { _ = #gradient(f) // ok! `f` has type `@differentiable (Float) -> Float` } }
In the simple model illustrated above, @differentiable
can just be stored in ExtInfo
as one bit. However, it's complicated: Differentiability does not always apply to all function parameters.
A function (T0, T1) -> U
can be differentiable with respect to only a subset of parameters, e.g. the first parameter but not the second. In that case, we need to store, in the function type, a bit mask of arguments that are diff'able-wrt (or non-diff'able-wrt). The type syntax of that can look like:
let f: @differentiable (T0, @nodiff T1) -> U
As such, differentiability is not a traditional function representation that can be stored in ExtInfo as one bit: Its information is more complex. What would be the best representation for it?