Opaque abstraction patterns and `@differentiable` function thunking

Hi SIL experts,

I have a question about lowering opaque abstraction patterns. Sorry for the long post - context is below and questions are at the bottom.


A @differentiable function-typed value is a bundle of multiple function-typed values: an original function and multiple derivative functions. @differentiable function-typed values are formed via the autodiff_function instruction.

To thunk @differentiable function-typed values during SILGen, we extract all elements from the bundle, thunk them individually, then create a new bundle of the thunked elements. (relevant logic in SILGenPoly.cpp)

However, thunking with output opaque abstraction patterns is problematic. Basically, the lowered SIL type of thunked derivatives (1) does not match the expected derivative type computed during autodiff_function SIL verification (2).

(1) is computed via SGF.getTypeLowering given an opaque abstraction pattern and derivative AnyFunctionType:

// assocFnOutputOrigType (AbstractionPattern):
AP::Opaque

// assocFnOutputSubstType (AnyFunctionType):
(function_type escaping
  (input=function_params num_params=1
    (param
      (struct_type decl=Swift.(file).Float)))
  (output=tuple_type num_elements=2
    (tuple_type_elt
      (struct_type decl=Swift.(file).Float))
    (tuple_type_elt
      (function_type escaping
        (input=function_params num_params=1
          (param
            (struct_type decl=Swift.(file).Float)))
        (output=struct_type decl=Swift.(file).Float)))))

// (1): SGF.getTypeLowering(assocFnOutputOrigType, assocFnOutputSubstType)
$@callee_guaranteed (@in_guaranteed Float) -> @out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float)

(2) is computed based on the type of the original function operand to autodiff_function:

// Original function type:
@callee_guaranteed (@in_guaranteed Float) -> @out Float

// (2): Derivative (JVP) type:
@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)

(1) and (2) don't match (the result types are different), leading to a SIL verification failure.

// Differing result types:
// (1):
@out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float)
// (2):
(@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)

(1) appears to be maximally indirect, ostensibly due to lowering an opaque abstraction pattern. But (2) is not maximally indirect: currently, derivative types are computed to return (@out T, ...) rather than @out (T, ...).


Questions:

  1. Is it expected/valid that SGF.getTypeLowering produces a SIL function type returning an @out (...) tuple type, given an opaque abstraction pattern and tuple-returning AnyFunctionType?

  2. How can we thunk @differentiable function typed values in a way that works with opaque abstraction patterns, and that also passes verification (i.e. (1) matches (2) in all cases)?

    • One idea is to make derivative types maximally indirect (so that (2) matches (1)). I wonder if there are alternatives.

cc @Joe_Groff

Some more context: TF-123 tracks this issue.

Simple reproducer (calling print involves opaque abstraction patterns because it has Any arguments):

let fn: @differentiable (Float) -> (Float) = { $0 }
print(fn)

// SIL verification failed: Unexpected JVP function type: expectedJVPType == jvpType
// Actual (1) vs expected (2):
// (sil_function_type type=@callee_guaranteed (@in_guaranteed Float) -> @out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float))
// (sil_function_type type=@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float))

Testing SGF.getTypeLowering with opaque abstraction pattern + normal tuple-returning function-typed value: there doesn't appear to be a problem.

let fn2: (Float) -> (Float, Float) = { ($0, $0) }
print(fn2)
// (Function)

-emit-silgen reveals fn2 is being thunked to a maximally indirect (@in_guaranteed Float) -> @out (Float, Float) value:

// thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
sil shared [transparent] [serializable] [reabstraction_thunk] [ossa] @$sS3fIegydd_S2f_SftIegnr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> @out (Float, Float) {
// %0                                             // users: %5, %4
// %1                                             // user: %3
// %2                                             // user: %6
bb0(%0 : $*(Float, Float), %1 : $*Float, %2 : @guaranteed $@callee_guaranteed (Float) -> (Float, Float)):
  %3 = load [trivial] %1 : $*Float                // user: %6
  %4 = tuple_element_addr %0 : $*(Float, Float), 0 // user: %9
  %5 = tuple_element_addr %0 : $*(Float, Float), 1 // user: %10
  %6 = apply %2(%3) : $@callee_guaranteed (Float) -> (Float, Float) // user: %7
  (%7, %8) = destructure_tuple %6 : $(Float, Float) // users: %9, %10
  store %7 to [trivial] %4 : $*Float              // id: %9
  store %8 to [trivial] %5 : $*Float              // id: %10
  %11 = tuple ()                                  // user: %12
  return %11 : $()                                // id: %12
} // end sil function '$sS3fIegydd_S2f_SftIegnr_TR'

So the @differentiable function thunking problem appears to be just a type calculation inconsistency between (1) and (2). Changing (2) to be maximally indirect like (1) should fix the issue, but I wonder if there are alternatives.

Yes.

Sounds like you need to make the computation of the differentiable function type sensitive to the abstraction pattern used.

Terms of Service

Privacy Policy

Cookie Policy