Trouble with LoadableByAddress pass and Autodiff

[SR-9849] LoadableByAddress fails to handle functions that return a closure in a tuple · Issue #52260 · apple/swift · GitHub illustrates a problem that we ran into while working on figuring out how the autodiff feature inter-operates with this LoadableByAddress IRGen pass.

My understanding of this pass is that it maps function types to pass large values by address.

For context, a differentiable function (original) also binds additional associated functions (like vjp) that can be extracted at runtime.

original: (T) -> R
vjp: (T) -> (R, (R.CotangentVector) -> T.CotangentVector)

Thus far, we've been updating the autodiff functions identically whenever the original function goes through some transformation (like abstraction thunks). However, because the types do not match, original and vjp can be transformed differently in the LoadableByAddress IRGen pass. This leads to our tensorflow branch version of LoadableByAddress to try and cast between functions that pass via different conventions.

Any thoughts on what would be the proper resolution to this would be appreciated.

1 Like

To give a bit more background, here's a pointer to autodiff_function and autodiff_function_extract instructions in the SIL language manual: swift/SIL.rst at tensorflow · apple/swift · GitHub.

At first glance, autodiff_function instruction turns a normal function (T) -> U to a differentiable function @differentiable (T) -> U. At the SIL level, it's achieved by forming a bundle of the original function and two derivative functions, which are called "JVP" and "VJP".

autodiff_function [wrt 0] [order 1] %original : $(T) -> U 
    with {%jvp : $(T) -> (U, (T.A) -> U.A), %vjp : $(T) -> (U, (U.B) -> T.B)}

The autodiff_function_extract instruction takes a @differentiable function and extracts the original or one of the derivative functions.

autodiff_function_extract [original] [order 1] %f : $@differentiable (T) -> U

Since the types of derivative functions are opaque in the @differentiable (T) -> U type, we must be able to infer the expected type of derivative functions from the original function type. Thus derivative functions' types must align directly with the original function in an autodiff_function instruction.

However, what makes this difficult is that the derivative functions have a few extra types involved: associated types T.A, T.B, U.A and U.B. LoadableByAddress doesn't respect this type correspondence requirement for autodiff_function's original and derivative functions yet. So if T happens to be a large loadable type but T.A isn't, LoadableByAddress would generate the following incorrect SIL, where we expect the T.A argument to have the same storage type.

// Wrong! We expect `T.A` to also become indirect.
autodiff_function [wrt 0] [order 1] %original : $(@in_guaranteed T) -> U 
    with {%jvp : $(@in_guaranteed T) -> (U, (T.A) -> U.A), %vjp : $(@in_guaranteed T) -> (U, (U.B) -> T.B)}

As you see, autodiff_function is like function conversion except that it has more function operands involved and it has a requirement for correspondences in multiple operands' types. What would be a good way to fix this?

@shajrawi

Hmm, off the top of my head, I can't think of an easy way to do that, but here are a couple of options that should work (but would require modifying the pass):

  1. We decide if a type should become indirect / is large loadable in static bool isLargeLoadableType - specifically nativeSchemaOrigParam.requiresIndirect() - requiresIndirect() is based on the calling conventions / target triple, for example armv7 would be 3 or fields while arm64 would be 4 or more integer-sized fields. However, there's nothing stopping you from passing a smaller type indirectly, or a larger type directly, so taking into consideration the associated types requirements (everything needs to be direct/indirect) would work.

  2. The decision wherever we can change a function signature is taken in static bool modifiableFunction(CanSILFunctionType funcType) - if we are calling C, for example, we bail and leave the signature as-is. I could see a SIL level attribute being added for this use case which would make us bail on transforming the function.