Thank you for the responses!
Iām sorry, please let me provide more context. The core issue is not reabstraction, but implementing derivative function type calculation in a way that is robust against function type transformations like LoadableByAddress.
Please see "The crux of the issue" for a summary of the core problem.
Context
Here are the typing rules for derivative functions:
// Given a function type `(T0, ...) -> U`,
// (where `T0`, ..., `U` all conform to the `Differentiable` protocol)
//
// The derivative function type is:
//
// (T0, ...) -> (U, (T0.TangentVector, ...) -> (U.TangentVector))
// ^ ^ ^~~~~~~~~~~~~~~~~~~~~ ^~~~~~~~~~~~~~~
// original args result derivative wrt args derivative wrt result
//
// The derivative function returns a tuple of:
// - The original result of type `U`.
// - A "linear approximation" function takes that derivatives with respect to
// arguments and returns the derivative with respect to the result.
In SIL, there are also parameter/result conventions: parameters and results may be direct or indirect.
Currently, the derivative function type calculation logic uses the parameter/result conventions of the original function type to compute the parameter/result conventions of the returned linear approximation function:
Original type: (@guaranteed Foo) -> @owned Foo
Derivative type: (@guaranteed Foo) -> (
@owned Foo,
@owned (@guaranteed Foo.TangentVector) -> @owned Foo.TangentVector
^~~~~~~~~~~ ^~~~~~
// `@guaranteed Foo.TangentVector` because original argument `Foo` is `@guaranteed`.
// `@owned Foo.TangentVector` because original result `Foo` is `@owned`.
// However, `Foo` and `Foo.TangentVector` are unrelated types.
// Copying the convention may not make sense - it also tightly couples the
// original and derivative types.
)
To support differentiable functions in SIL, we added the autodiff_function
SIL instruction: it takes SIL values representing an original function and a derivative function and returns a @differentiable
function-typed value. The actual type of the derivative function value must match the expected derivative function type (which is computed from the original function type as shown above).
// Slightly simplified SIL for brevity.
%orig_fn = function_ref @orig_fn : $(@guaranteed T) -> @owned T
%jvp_fn = function_ref @jvp_fn : $(@guaranteed T) -> (@owned T, (@guaranteed T.TangentVector) -> @owned T.TangentVector)
// %diff_fn: $@differentiable (T) -> T
%diff_fn = autodiff_function [wrt 0] %orig_fn : $(@guaranteed T) -> @owned T
with {%jvp_fn : $(@guaranteed T) -> (@owned T, (@guaranteed T.TangentVector) -> @owned T.TangentVector)}
Transformations like LoadableByAddress rewrite function types; this breaks invariants currently required by autodiff_function
instruction.
For example, Foo
may be a large loadable type while Foo.TangentVector
is not. LoadableByAddress rewrites Foo
-typed arguments to be become @in_constant Foo
while leaving Foo.TangentVector
-typed arguments alone.
This leads to a verification error: autodiff_function
expects the derivative function to return a linear approximation function taking @in_constant Foo.TangentVector
, but the actual returned linear approximation function has type Foo.TangentVector
.
SIL verification failed: JVP type does not match expected JVP type
$@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float)
$@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (@in_constant Foo.TangentVector) -> Float)
Verifying instruction:
%21 = thin_to_thick_function %20 : $@convention(thin) (@in_constant Foo) -> Float to $@callee_guaranteed (@in_constant Foo) -> Float // user: %26
%23 = thin_to_thick_function %22 : $@convention(thin) (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float) to $@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float) // user: %26
-> %26 = autodiff_function [wrt 0] [order 1] %21 : $@callee_guaranteed (@in_constant Foo) -> Float with {%23 : $@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float)}
The crux of the issue
The crux of the issue is that derivative function types are tightly coupled with original function types, specifically the original parameter/result conventions.
autodiff_function
instruction requires the type of the derivative function operand to exactly match the
expected derivative function type computed from the original function type. Currently, this requirement is broken easily by LoadableByAddress, which does not ensure that original/derivative function operands to autodiff_function
remain compatible after transformation.
Solution ideas
I think there are a few different flavors of solutions:
-
Cop-out: disable LoadableByAddress for functions that become arguments to autodiff_function
instructions.
- There may be implementation challenges with selectively disabling LoadableByAddress. This approach doesn't scale to other transformations.
-
Heavyweight: create a dedicated @differentiable
function type that stores derivative function type information.
- Currently,
@differentiable
is simply represented as a bit in SILFunctionType::ExtInfo
. It's possible to create a new class SILDifferentiableFunctionType
that stores derivative function types along with the original function type. However, this is heavyweight and affects much of the codebase - @rxwei started prototyping SILDifferentiableFunctionType
but chose not to pursue it further.
-
Standardization/simplification: change derivative function type calculation to not depend on the parameter/result conventions of the original function type.
- The idea is to decouple the types of "returned linear approximation functions" from original function parameter/result conventions. Returned linear approximation functions are standardized to be maximally indirect: all parameters are
@in_guaranteed
and all results are @out
.
- We investigated this approach - it required some ad-hoc SILGen thunking and fixed LoadableByAddress +
autodiff_function
for most cases. However, one unhandled case (the only known unhandled case) is when the original function returns a large loadable type:
// The returned linear approximation function types match (they are maximally indirect).
// The only difference is the result convention of the original result.
SIL verification failed: JVP type does not match expected JVP type
$@convention(method) (...) -> (@owned RNNCellOutput<State>, @owned @callee_guaranteed (@in_guaranteed LSTMCell<Ļ_0_0>.TangentVector) -> @out RNNCellOutput<State>.TangentVector)
$@convention(method) (...) -> (@out RNNCellOutput<State>, @owned @callee_guaranteed (@in_guaranteed LSTMCell<Ļ_0_0>.TangentVector) -> @out RNNCellOutput<State>.TangentVector)
- It seems that LBA does not rewrite functions that return more than one result; i.e. it does not peer through tuple results to rewrite large loadable types. This means that LBA transforms the result of the original function to be indirect, but does not transform the "original function result" of the derivative function.
- I'm not sure if "LBA not handling multiple results" is intentional, or if fundamental limitations prevent peering through tuple results - I don't believe there are fundamental limitations, but I think there's a design decision of how to handle tuples of large loadable types: make LBA transform
(Large1, Large2)
result to (@out Large1, @out Large2)
or @out (Large1, Large2)
? Changing LBA to peer through tuple results should fix LBA + autodiff_function
for all known cases, though the implementation difficulty is not clear (LBA seems to assume "single result" in many places).
Best solution?
If someone has ideas for a best final solution, please reply.
The standardization/simplification approach above seems ideal to me. To handle the last unhandled case (when original function returns a large loadable type), I think there are two options:
- Further decouple derivative type from original type until original parameter/result conventions do not matter at all. This approach makes
autodiff_function
robust against any function type transformation. There are a few options for standardizing derivative function types, differing in indirectness:// Make derivative functions:
// - Always return an indirect original result.
// This is the minimal necessary change.
(T) -> (@out U, (@in_guaranteed T.TangentVector) -> @out U)
// - Return a fully indirect tuple.
// I don't think there's any upside to the approach above.
(T) -> @out (U, (@in_guaranteed T.TangentVector) -> @out U)
// - Fully opaque, including original arguments.
//
// This might simplify SILGen (enabling us to use SILGen reabstraction
// thunking infrastructure with an opaque output abstraction pattern),
// though SILGen may be simplify-able via other means, as Joe hinted at
// above.
//
// However, this complicates the SIL differentiation transform
// considerably more than above approaches.
(@in_guaranteed T) -> (@out U, (@in_guaranteed T.TangentVector) -> @out U)
- Change LBA to peer through tuple results, transforming
(T) -> (Large, ...)
to (T) -> (@out Large, ...)
. This fixes autodiff_function
instruction in an ad-hoc way: original functions and derivative functions will be type-transformed consistently. This approach seems less ideal because derivative function types are still coupled with original function result convention; as long as this coupling exists, all function type transformations are required to have the ability to transform original and derivative functions in a consistent way so that autodiff_function
is not broken.
Any ideas and suggestions appreciated! We're far from familiar with all parts of the codebase so it's not clear which solution holistically fits best into Swift infrastructure today.