How to best maximally reabstract returned function type in SILGen

Hi SIL experts,

I wonder how to reabstract values from type A to type B during SILGen?

A: $(Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
B: $(Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)

The main difficulty: SILGen reabstraction utilities involve AST types and abstraction patterns, not SIL types. There is a corresponding AST type and abstraction pattern for B, but they are simply (Float) -> (Float, (Float) -> Float).

I want to force the returned function type to be maximally indirect (@in_guaranteed Float) -> @out Float. What is the best way to do this?


My ideas:

  • Opaque abstraction pattern is not appropriate because that causes the entire function type to be maximally indirect, not just the returned function type: $(@in_guaranteed Float) -> @out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float).
    • We could adapt the rest of our infrastructure to work with entirely maximally indirect function types, but I wonder if we can avoid this if possible for efficiency/simplicity.
    • Invariants are that arguments of B can have same abstraction as arguments of B, and returned tuple type of B can always be direct.
  • We could manually construct a generic function type <T>(Float) -> (Float, (T) -> T) and use that as an abstraction pattern for type B, but this feels quite hacky.

Context: I'm working on the differentiable programming project.

  • Derivative functions return a "linear approximation function".
  • We want to standardize all "returned linear approximation function types" to be maximally indirect in SIL to fix a sleuth of abstraction mismatch bugs - see issue description here for more details.

Any help would be appreciated!

cc @Joe_Groff @Slava_Pestov @John_McCall

("slew", not "sleuth")

I'm sorry, but that sounds like a cop-out to me. It sounds like you're having trouble applying abstraction patterns correctly, so you decided to try to define them away, and now, unsurprisingly, you're having trouble with even the first-order consequences of that decision.

Abstraction patterns can be structurally opaque; it's not all or nothing. You could construct an abstraction pattern by looking through function types in the original type and filling in the arguments and returns that aren't themselves function types with opaque patterns in those positions. Alternatively, maybe you want to introduce a new kind of abstraction pattern that has the abstraction behavior you want, similar to what we do with abstraction patterns for C/ObjC method signatures.

1 Like

Thank you for the responses!

Iā€™m sorry, please let me provide more context. The core issue is not reabstraction, but implementing derivative function type calculation in a way that is robust against function type transformations like LoadableByAddress.

Please see "The crux of the issue" for a summary of the core problem.

Context


Here are the typing rules for derivative functions:

// Given a function type `(T0, ...) -> U`,
// (where `T0`, ..., `U` all conform to the `Differentiable` protocol)
//
// The derivative function type is:
//                            
//  (T0, ...)      ->  (U,    (T0.TangentVector, ...) -> (U.TangentVector))
//   ^                  ^      ^~~~~~~~~~~~~~~~~~~~~      ^~~~~~~~~~~~~~~
//  original args   result      derivative wrt args    derivative wrt result
//
// The derivative function returns a tuple of:
// - The original result of type `U`.
// - A "linear approximation" function takes that derivatives with respect to
//   arguments and returns the derivative with respect to the result.

In SIL, there are also parameter/result conventions: parameters and results may be direct or indirect.
Currently, the derivative function type calculation logic uses the parameter/result conventions of the original function type to compute the parameter/result conventions of the returned linear approximation function:

Original type: (@guaranteed Foo) -> @owned Foo
Derivative type: (@guaranteed Foo) -> (
   @owned Foo,
   @owned (@guaranteed Foo.TangentVector) -> @owned Foo.TangentVector
           ^~~~~~~~~~~                       ^~~~~~
// `@guaranteed Foo.TangentVector` because original argument `Foo` is `@guaranteed`.
// `@owned Foo.TangentVector` because original result `Foo` is `@owned`.

// However, `Foo` and `Foo.TangentVector` are unrelated types.
// Copying the convention may not make sense - it also tightly couples the
// original and derivative types.
)

To support differentiable functions in SIL, we added the autodiff_function SIL instruction: it takes SIL values representing an original function and a derivative function and returns a @differentiable function-typed value. The actual type of the derivative function value must match the expected derivative function type (which is computed from the original function type as shown above).

// Slightly simplified SIL for brevity.
%orig_fn = function_ref @orig_fn : $(@guaranteed T) -> @owned T
%jvp_fn = function_ref @jvp_fn : $(@guaranteed T) -> (@owned T, (@guaranteed T.TangentVector) -> @owned T.TangentVector)

// %diff_fn: $@differentiable (T) -> T
%diff_fn = autodiff_function [wrt 0] %orig_fn : $(@guaranteed T) -> @owned T
  with {%jvp_fn : $(@guaranteed T) -> (@owned T, (@guaranteed T.TangentVector) -> @owned T.TangentVector)}

Transformations like LoadableByAddress rewrite function types; this breaks invariants currently required by autodiff_function instruction.

For example, Foo may be a large loadable type while Foo.TangentVector is not. LoadableByAddress rewrites Foo-typed arguments to be become @in_constant Foo while leaving Foo.TangentVector-typed arguments alone.

This leads to a verification error: autodiff_function expects the derivative function to return a linear approximation function taking @in_constant Foo.TangentVector, but the actual returned linear approximation function has type Foo.TangentVector.

SIL verification failed: JVP type does not match expected JVP type
  $@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float)
  $@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (@in_constant Foo.TangentVector) -> Float)

Verifying instruction:
     %21 = thin_to_thick_function %20 : $@convention(thin) (@in_constant Foo) -> Float to $@callee_guaranteed (@in_constant Foo) -> Float // user: %26
     %23 = thin_to_thick_function %22 : $@convention(thin) (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float) to $@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float) // user: %26
->   %26 = autodiff_function [wrt 0] [order 1] %21 : $@callee_guaranteed (@in_constant Foo) -> Float with {%23 : $@callee_guaranteed (@in_constant Foo) -> (Float, @owned @callee_guaranteed (Foo.TangentVector) -> Float)}

The crux of the issue

The crux of the issue is that derivative function types are tightly coupled with original function types, specifically the original parameter/result conventions.

autodiff_function instruction requires the type of the derivative function operand to exactly match the
expected derivative function type computed from the original function type. Currently, this requirement is broken easily by LoadableByAddress, which does not ensure that original/derivative function operands to autodiff_function remain compatible after transformation.


Solution ideas

I think there are a few different flavors of solutions:

  • Cop-out: disable LoadableByAddress for functions that become arguments to autodiff_function instructions.

    • There may be implementation challenges with selectively disabling LoadableByAddress. This approach doesn't scale to other transformations.
  • Heavyweight: create a dedicated @differentiable function type that stores derivative function type information.

    • Currently, @differentiable is simply represented as a bit in SILFunctionType::ExtInfo. It's possible to create a new class SILDifferentiableFunctionType that stores derivative function types along with the original function type. However, this is heavyweight and affects much of the codebase - @rxwei started prototyping SILDifferentiableFunctionType but chose not to pursue it further.
  • Standardization/simplification: change derivative function type calculation to not depend on the parameter/result conventions of the original function type.

    • The idea is to decouple the types of "returned linear approximation functions" from original function parameter/result conventions. Returned linear approximation functions are standardized to be maximally indirect: all parameters are @in_guaranteed and all results are @out.
    • We investigated this approach - it required some ad-hoc SILGen thunking and fixed LoadableByAddress + autodiff_function for most cases. However, one unhandled case (the only known unhandled case) is when the original function returns a large loadable type:
    // The returned linear approximation function types match (they are maximally indirect).
    // The only difference is the result convention of the original result.
    SIL verification failed: JVP type does not match expected JVP type
      $@convention(method) (...) -> (@owned RNNCellOutput<State>, @owned @callee_guaranteed (@in_guaranteed LSTMCell<Ļ„_0_0>.TangentVector) -> @out RNNCellOutput<State>.TangentVector)
      $@convention(method) (...) -> (@out RNNCellOutput<State>, @owned @callee_guaranteed (@in_guaranteed LSTMCell<Ļ„_0_0>.TangentVector) -> @out RNNCellOutput<State>.TangentVector)
    
    • It seems that LBA does not rewrite functions that return more than one result; i.e. it does not peer through tuple results to rewrite large loadable types. This means that LBA transforms the result of the original function to be indirect, but does not transform the "original function result" of the derivative function.
    • I'm not sure if "LBA not handling multiple results" is intentional, or if fundamental limitations prevent peering through tuple results - I don't believe there are fundamental limitations, but I think there's a design decision of how to handle tuples of large loadable types: make LBA transform (Large1, Large2) result to (@out Large1, @out Large2) or @out (Large1, Large2)? Changing LBA to peer through tuple results should fix LBA + autodiff_function for all known cases, though the implementation difficulty is not clear (LBA seems to assume "single result" in many places).

Best solution?

If someone has ideas for a best final solution, please reply. :slightly_smiling_face:

The standardization/simplification approach above seems ideal to me. To handle the last unhandled case (when original function returns a large loadable type), I think there are two options:

  • Further decouple derivative type from original type until original parameter/result conventions do not matter at all. This approach makes autodiff_function robust against any function type transformation. There are a few options for standardizing derivative function types, differing in indirectness:
    // Make derivative functions:
    // - Always return an indirect original result.
    //   This is the minimal necessary change.
    (T) -> (@out U, (@in_guaranteed T.TangentVector) -> @out U)
    
    // - Return a fully indirect tuple.
    //   I don't think there's any upside to the approach above.
    (T) -> @out (U, (@in_guaranteed T.TangentVector) -> @out U)
    
    // - Fully opaque, including original arguments.
    //
    //   This might simplify SILGen (enabling us to use SILGen reabstraction
    //   thunking infrastructure with an opaque output abstraction pattern),
    //   though SILGen may be simplify-able via other means, as Joe hinted at
    //   above.
    //
    //   However, this complicates the SIL differentiation transform
    //   considerably more than above approaches.
    (@in_guaranteed T) -> (@out U, (@in_guaranteed T.TangentVector) -> @out U)
    
  • Change LBA to peer through tuple results, transforming (T) -> (Large, ...) to (T) -> (@out Large, ...). This fixes autodiff_function instruction in an ad-hoc way: original functions and derivative functions will be type-transformed consistently. This approach seems less ideal because derivative function types are still coupled with original function result convention; as long as this coupling exists, all function type transformations are required to have the ability to transform original and derivative functions in a consistent way so that autodiff_function is not broken.

Any ideas and suggestions appreciated! We're far from familiar with all parts of the codebase so it's not clear which solution holistically fits best into Swift infrastructure today.

Let me explain the original motivation for reabstraction: once derivative function types are standardized to have particular parameter/result conventions (e.g. @out for original result, @in_guaranteed parameters and @out result for returned linear approximation function), user-defined derivative functions (which may be "more direct" in abstraction) must be thunked during SILGen to these expected derivative SIL function types.

Oh thanks, I didn't know this! Looking at AbstractionPattern.h, I thought that it's only possible (with the existing abstraction pattern kinds) to form a function type AbstractionPattern via a fully concrete function type, not something like:

(T, U, V) -> (R, (<opaque>) -> (<opaque>, <opaque>, ...))
(T, U, V) -> (R, (<opaque>, <opaque>, ...) -> <opaque>))
// Is it possible to construct these abstraction patterns?
// If so, could you please point to the APIs?

Interesting, I hadn't thought hard about this approach until now.

A DerivativeFunctionType abstraction pattern could store the original function type. getFunctionParamType would not need a custom implementation, as derivative function arguments match original function arguments.

getFunctionResultType would need a custom implementation returning a pattern representing (R, (<opaque>) -> (<opaque>, <opaque>, ...). Supporting that seems to require more abstraction patterns (DerivativeFunctionType, etc)? Please correct me if I'm wrong.

Terms of Service

Privacy Policy

Cookie Policy