Hi SIL experts,
I have a question about lowering opaque abstraction patterns. Sorry for the long post - context is below and questions are at the bottom.
A @differentiable
function-typed value is a bundle of multiple function-typed values: an original function and multiple derivative functions. @differentiable
function-typed values are formed via the autodiff_function
instruction.
To thunk @differentiable
function-typed values during SILGen, we extract all elements from the bundle, thunk them individually, then create a new bundle of the thunked elements. (relevant logic in SILGenPoly.cpp
)
However, thunking with output opaque abstraction patterns is problematic. Basically, the lowered SIL type of thunked derivatives (1) does not match the expected derivative type computed during autodiff_function
SIL verification (2).
(1) is computed via SGF.getTypeLowering
given an opaque abstraction pattern and derivative AnyFunctionType
:
// assocFnOutputOrigType (AbstractionPattern):
AP::Opaque
// assocFnOutputSubstType (AnyFunctionType):
(function_type escaping
(input=function_params num_params=1
(param
(struct_type decl=Swift.(file).Float)))
(output=tuple_type num_elements=2
(tuple_type_elt
(struct_type decl=Swift.(file).Float))
(tuple_type_elt
(function_type escaping
(input=function_params num_params=1
(param
(struct_type decl=Swift.(file).Float)))
(output=struct_type decl=Swift.(file).Float)))))
// (1): SGF.getTypeLowering(assocFnOutputOrigType, assocFnOutputSubstType)
$@callee_guaranteed (@in_guaranteed Float) -> @out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float)
(2) is computed based on the type of the original function operand to autodiff_function
:
// Original function type:
@callee_guaranteed (@in_guaranteed Float) -> @out Float
// (2): Derivative (JVP) type:
@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
(1) and (2) don't match (the result types are different), leading to a SIL verification failure.
// Differing result types:
// (1):
@out (Float, @callee_guaranteed (@in_guaranteed Float) -> @out Float)
// (2):
(@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
(1) appears to be maximally indirect, ostensibly due to lowering an opaque abstraction pattern. But (2) is not maximally indirect: currently, derivative types are computed to return (@out T, ...)
rather than @out (T, ...)
.
Questions:
-
Is it expected/valid that
SGF.getTypeLowering
produces a SIL function type returning an@out (...)
tuple type, given an opaque abstraction pattern and tuple-returningAnyFunctionType
? -
How can we thunk
@differentiable
function typed values in a way that works with opaque abstraction patterns, and that also passes verification (i.e. (1) matches (2) in all cases)?- One idea is to make derivative types maximally indirect (so that (2) matches (1)). I wonder if there are alternatives.
cc @Joe_Groff