Hello SIL experts,
I'm working on the automatic differentiation SIL transform and I'm wondering how to best reabstract a function in a SIL transform (post-SILGen, where it's easy to generate reabstraction thunks) given a source/target function type.
Some context: currently, differentiation only works for SIL functions with direct parameters/results, but now we're trying to support functions with indirect parameters/results too. This is crucial for differentiation of generic functions (whose parameters/results may be indirect).
There are at least two cases where we want to reabstract functions in the differentiation SIL transform:
- For each function-to-differentiate, differentiation produces a primal function, which returns a "primal values struct" whose members are functions that are as-direct-as-possible in their lowered type. However, we need to construct an instance of the struct using operands with some indirect parameter/results. Thus, we need to reabstract these operands to their as-direct-as-possible version.
- For reference: without reabstraction, we get verification errors like this:
SIL verification failed: struct operand type does not match field type: (*opi)->getType() == loweredType
Verifying instruction:
%12 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 1 // user: %14
-> %14 = struct $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0 (%12 : $@callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)) // user: %15
%15 = tuple (%14 : $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, %11 : $Tensor<Float>) // user: %16
In function:
// AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0
sil hidden @AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0 : $@convention(thin) (@guaranteed Tensor<Float>) -> (@owned _AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, @owned Tensor<Float>) {
// %0 // users: %10, %1
bb0(%0 : $Tensor<Float>):
debug_value %0 : $Tensor<Float>, let, name "x", argno 1 // id: %1
%2 = metatype $@thin Tensor<Float>.Type // user: %10
%3 = metatype $@thin Float.Type // user: %6
%4 = integer_literal $Builtin.IntLiteral, 1 // user: %6
// function_ref Float.init(_builtinIntegerLiteral:)
%5 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %6
%6 = apply %5(%4, %3) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %8
%7 = alloc_stack $Float // users: %13, %10, %8
store %6 to %7 : $*Float // id: %8
// function_ref static Tensor<>._vjpMultiply(lhs:rhs:)
%9 = function_ref @$s10TensorFlow0A0VAAs14DifferentiableRzSFRz15CotangentVectorsADPQzRszrlE12_vjpMultiply3lhs3rhsACyxG_AK_xtAKctAK_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : FloatingPoint, τ_0_0 : TensorFlowScalar, τ_0_0 == τ_0_0.CotangentVector> (@guaranteed Tensor<τ_0_0>, @in_guaranteed τ_0_0, @thin Tensor<τ_0_0>.Type) -> (@owned Tensor<τ_0_0>, @owned @callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned Tensor<τ_0_0>, @out τ_0_0.CotangentVector)) // user: %10
%10 = apply %9<Float>(%0, %7, %2) : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : FloatingPoint, τ_0_0 : TensorFlowScalar, τ_0_0 == τ_0_0.CotangentVector> (@guaranteed Tensor<τ_0_0>, @in_guaranteed τ_0_0, @thin Tensor<τ_0_0>.Type) -> (@owned Tensor<τ_0_0>, @owned @callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned Tensor<τ_0_0>, @out τ_0_0.CotangentVector)) // users: %12, %11
%11 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 0 // user: %15
%12 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 1 // user: %14
dealloc_stack %7 : $*Float // id: %13
%14 = struct $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0 (%12 : $@callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)) // user: %15
%15 = tuple (%14 : $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, %11 : $Tensor<Float>) // user: %16
return %15 : $(_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, Tensor<Float>) // id: %16
} // end sil function 'AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0'
- Sometimes, differentiation may attempt to differentiate a reabstraction thunk (instead of the underlying function that should be differentiated). One solution is to detect these cases and differentiate the underlying function (creating a
vjp vector-Jacobian product function), and to build a new reabstraction thunk for the vjp function.
Either way, the ability to generate thunks seems important. What's the best way to create reabstraction thunks in a SIL transform (post-SILGen) simply given source/target function types?
Initial investigation:
- We found a
ReabstractionThunkGenerator class in lib/SILOptimizer/Utils/Generics.cpp, but it seems to be designed for generic specialization. I'm not sure whether it directly fits our use case. (It seems to require an ApplySite, which we wouldn't have for case 2 above).
- We started copying a bunch of code from SILGen (
SILGenPoly.cpp and SILGenThunk.cpp) to validate the approach. But this seems to involve removing important SILGen-specific constructs like ManagedValue and would result in a lot of code dupe.
Does anyone have advice? Any help/suggestions would be appreciated!
(Please reply if you think e.g. sharing code with ReabstractionThunkGenerator is a good idea, or if you feel there's a better approach altogether)
cc @rxwei
EDIT:
- Clarified case 1 above. Primal value struct members are not necessary direct, just "as-direct-as-possible".