Hello SIL experts,
I'm working on the automatic differentiation SIL transform and I'm wondering how to best reabstract a function in a SIL transform (post-SILGen, where it's easy to generate reabstraction thunks) given a source/target function type.
Some context: currently, differentiation only works for SIL functions with direct parameters/results, but now we're trying to support functions with indirect parameters/results too. This is crucial for differentiation of generic functions (whose parameters/results may be indirect).
There are at least two cases where we want to reabstract functions in the differentiation SIL transform:
- For each function-to-differentiate, differentiation produces a primal function, which returns a "primal values struct" whose members are functions that are as-direct-as-possible in their lowered type. However, we need to construct an instance of the struct using operands with some indirect parameter/results. Thus, we need to reabstract these operands to their as-direct-as-possible version.
- For reference: without reabstraction, we get verification errors like this:
SIL verification failed: struct operand type does not match field type: (*opi)->getType() == loweredType
Verifying instruction:
%12 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 1 // user: %14
-> %14 = struct $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0 (%12 : $@callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)) // user: %15
%15 = tuple (%14 : $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, %11 : $Tensor<Float>) // user: %16
In function:
// AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0
sil hidden @AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0 : $@convention(thin) (@guaranteed Tensor<Float>) -> (@owned _AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, @owned Tensor<Float>) {
// %0 // users: %10, %1
bb0(%0 : $Tensor<Float>):
debug_value %0 : $Tensor<Float>, let, name "x", argno 1 // id: %1
%2 = metatype $@thin Tensor<Float>.Type // user: %10
%3 = metatype $@thin Float.Type // user: %6
%4 = integer_literal $Builtin.IntLiteral, 1 // user: %6
// function_ref Float.init(_builtinIntegerLiteral:)
%5 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %6
%6 = apply %5(%4, %3) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %8
%7 = alloc_stack $Float // users: %13, %10, %8
store %6 to %7 : $*Float // id: %8
// function_ref static Tensor<>._vjpMultiply(lhs:rhs:)
%9 = function_ref @$s10TensorFlow0A0VAAs14DifferentiableRzSFRz15CotangentVectorsADPQzRszrlE12_vjpMultiply3lhs3rhsACyxG_AK_xtAKctAK_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : FloatingPoint, τ_0_0 : TensorFlowScalar, τ_0_0 == τ_0_0.CotangentVector> (@guaranteed Tensor<τ_0_0>, @in_guaranteed τ_0_0, @thin Tensor<τ_0_0>.Type) -> (@owned Tensor<τ_0_0>, @owned @callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned Tensor<τ_0_0>, @out τ_0_0.CotangentVector)) // user: %10
%10 = apply %9<Float>(%0, %7, %2) : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : FloatingPoint, τ_0_0 : TensorFlowScalar, τ_0_0 == τ_0_0.CotangentVector> (@guaranteed Tensor<τ_0_0>, @in_guaranteed τ_0_0, @thin Tensor<τ_0_0>.Type) -> (@owned Tensor<τ_0_0>, @owned @callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned Tensor<τ_0_0>, @out τ_0_0.CotangentVector)) // users: %12, %11
%11 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 0 // user: %15
%12 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 1 // user: %14
dealloc_stack %7 : $*Float // id: %13
%14 = struct $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0 (%12 : $@callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)) // user: %15
%15 = tuple (%14 : $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, %11 : $Tensor<Float>) // user: %16
return %15 : $(_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, Tensor<Float>) // id: %16
} // end sil function 'AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0'
- Sometimes, differentiation may attempt to differentiate a reabstraction thunk (instead of the underlying function that should be differentiated). One solution is to detect these cases and differentiate the underlying function (creating a
vjp
vector-Jacobian product function), and to build a new reabstraction thunk for thevjp
function.
Either way, the ability to generate thunks seems important. What's the best way to create reabstraction thunks in a SIL transform (post-SILGen) simply given source/target function types?
Initial investigation:
- We found a
ReabstractionThunkGenerator
class inlib/SILOptimizer/Utils/Generics.cpp
, but it seems to be designed for generic specialization. I'm not sure whether it directly fits our use case. (It seems to require anApplySite
, which we wouldn't have for case 2 above). - We started copying a bunch of code from SILGen (
SILGenPoly.cpp
andSILGenThunk.cpp
) to validate the approach. But this seems to involve removing important SILGen-specific constructs likeManagedValue
and would result in a lot of code dupe.
Does anyone have advice? Any help/suggestions would be appreciated!
(Please reply if you think e.g. sharing code with ReabstractionThunkGenerator
is a good idea, or if you feel there's a better approach altogether)
cc @rxwei
EDIT:
- Clarified case 1 above. Primal value struct members are not necessary direct, just "as-direct-as-possible".