[Help needed] How to reabstract functions in SIL passes

Hello SIL experts,

I'm working on the automatic differentiation SIL transform and I'm wondering how to best reabstract a function in a SIL transform (post-SILGen, where it's easy to generate reabstraction thunks) given a source/target function type.

Some context: currently, differentiation only works for SIL functions with direct parameters/results, but now we're trying to support functions with indirect parameters/results too. This is crucial for differentiation of generic functions (whose parameters/results may be indirect).

There are at least two cases where we want to reabstract functions in the differentiation SIL transform:

  1. For each function-to-differentiate, differentiation produces a primal function, which returns a "primal values struct" whose members are functions that are as-direct-as-possible in their lowered type. However, we need to construct an instance of the struct using operands with some indirect parameter/results. Thus, we need to reabstract these operands to their as-direct-as-possible version.
  • For reference: without reabstraction, we get verification errors like this:
SIL verification failed: struct operand type does not match field type: (*opi)->getType() == loweredType
Verifying instruction:
     %12 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 1 // user: %14
->   %14 = struct $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0 (%12 : $@callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)) // user: %15
     %15 = tuple (%14 : $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, %11 : $Tensor<Float>) // user: %16
In function:
// AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0
sil hidden @AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0 : $@convention(thin) (@guaranteed Tensor<Float>) -> (@owned _AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, @owned Tensor<Float>) {
// %0                                             // users: %10, %1
bb0(%0 : $Tensor<Float>):
  debug_value %0 : $Tensor<Float>, let, name "x", argno 1 // id: %1
  %2 = metatype $@thin Tensor<Float>.Type         // user: %10
  %3 = metatype $@thin Float.Type                 // user: %6
  %4 = integer_literal $Builtin.IntLiteral, 1     // user: %6
  // function_ref Float.init(_builtinIntegerLiteral:)
  %5 = function_ref @$sSf22_builtinIntegerLiteralSfBI_tcfC : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %6
  %6 = apply %5(%4, %3) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // user: %8
  %7 = alloc_stack $Float                         // users: %13, %10, %8
  store %6 to %7 : $*Float                        // id: %8
  // function_ref static Tensor<>._vjpMultiply(lhs:rhs:)
  %9 = function_ref @$s10TensorFlow0A0VAAs14DifferentiableRzSFRz15CotangentVectorsADPQzRszrlE12_vjpMultiply3lhs3rhsACyxG_AK_xtAKctAK_xtFZ : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : FloatingPoint, τ_0_0 : TensorFlowScalar, τ_0_0 == τ_0_0.CotangentVector> (@guaranteed Tensor<τ_0_0>, @in_guaranteed τ_0_0, @thin Tensor<τ_0_0>.Type) -> (@owned Tensor<τ_0_0>, @owned @callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned Tensor<τ_0_0>, @out τ_0_0.CotangentVector)) // user: %10
  %10 = apply %9<Float>(%0, %7, %2) : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : FloatingPoint, τ_0_0 : TensorFlowScalar, τ_0_0 == τ_0_0.CotangentVector> (@guaranteed Tensor<τ_0_0>, @in_guaranteed τ_0_0, @thin Tensor<τ_0_0>.Type) -> (@owned Tensor<τ_0_0>, @owned @callee_guaranteed (@guaranteed Tensor<τ_0_0>) -> (@owned Tensor<τ_0_0>, @out τ_0_0.CotangentVector)) // users: %12, %11
  %11 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 0 // user: %15
  %12 = tuple_extract %10 : $(Tensor<Float>, @callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)), 1 // user: %14
  dealloc_stack %7 : $*Float                      // id: %13
  %14 = struct $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0 (%12 : $@callee_guaranteed (@guaranteed Tensor<Float>) -> (@owned Tensor<Float>, @out Float)) // user: %15
  %15 = tuple (%14 : $_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, %11 : $Tensor<Float>) // user: %16
  return %15 : $(_AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__Type__src_0_wrt_0, Tensor<Float>) // id: %16
} // end sil function 'AD__$s4ind47genericy10TensorFlow0C0VySfGAFF__primal_src_0_wrt_0'
  1. Sometimes, differentiation may attempt to differentiate a reabstraction thunk (instead of the underlying function that should be differentiated). One solution is to detect these cases and differentiate the underlying function (creating a vjp vector-Jacobian product function), and to build a new reabstraction thunk for the vjp function.

Either way, the ability to generate thunks seems important. What's the best way to create reabstraction thunks in a SIL transform (post-SILGen) simply given source/target function types?

Initial investigation:

  • We found a ReabstractionThunkGenerator class in lib/SILOptimizer/Utils/Generics.cpp, but it seems to be designed for generic specialization. I'm not sure whether it directly fits our use case. (It seems to require an ApplySite, which we wouldn't have for case 2 above).
  • We started copying a bunch of code from SILGen (SILGenPoly.cpp and SILGenThunk.cpp) to validate the approach. But this seems to involve removing important SILGen-specific constructs like ManagedValue and would result in a lot of code dupe.

Does anyone have advice? Any help/suggestions would be appreciated!

(Please reply if you think e.g. sharing code with ReabstractionThunkGenerator is a good idea, or if you feel there's a better approach altogether)

cc @rxwei


EDIT:

  • Clarified case 1 above. Primal value struct members are not necessary direct, just "as-direct-as-possible".

cc @Joe_Groff @Slava_Pestov

Reabstraction after SILGen isn't a thing, nor should it be a thing. It would be better to either emit your functions at the abstraction level necessary for your desired substitutions up front, which would allow generic substitution to just work without needing to introduce thunks in SIL passes, or wait for and/or contribute to the "opaque values" effort, which will make it so that "address only" types don't exist at the SILGen stage and get introduced later.

2 Likes

Right. We've discussed adding a reabstract_function SIL instruction to simplify some static and dynamic optimization, but I think we'd still expect that to have a reference to a function that could be used for the actual reabstraction.

One option would be to only differentiate functions where all parameters and results are indirect, and ensure functions that will be differentiated have an entry point that is reabstracted in this manner in SILGen.

As for your other question about detecting reabstraction thunks, you can check if F->isThunk() == IsReabstractionThunk. We already have a few optimizer peepholes that do this, because for the most part the behavior of a reabstraction thunk is determined entirely by its lowered function type.

Let me clarify a few things.

  • The only case of reabstraction we'd like to do is reabstracting not-so-direct functions into their as-direct-as-possible form given their AST type, because we want to insert them into a struct. I don't think there's a way to avoid this by doing it in SILGen, because AD is highly flow-sensitive, and when SILGen emits a function it would not know whether it will be differentiated or not.
  • As for #2 that Dan listed, the AD pass already knows how to differentiate a reabstraction thunk when the closure argument is @differentiable. It's likely that more things will fall out, so we can ignore this case for now.
  • Opaque values would be hugely helpful for the entire AD project, but we are currently blocked on this immediate issue, and are not able to help with opaque values in the near term due to time constraints.
1 Like

The reason AD is highly flow-sensitive is because we currently allow the following model of programming:

func foo(_ x: Float) -> Float {
  return x
}

// `Float` conforms to `Differentiable`.
let d = Float(3.0).derivative { x in
  foo(x)
}

When foo(_:) and the trailing closure get SILGen'd, SILGen won't know they will be differentiated. The Differentiable.derivative(in:) method takes a @differentiable function, which triggers a function conversion that signals the AD pass to differentiate function bodies recursively.

We could, however, require a @differentiable attribute on every single FuncDecl we want to differentiate, and a @differentiable contextual type for every ClosureExpr we want to differentiate. While that would allow more modular compilation, it's feels like a less flexible programming model.

We have the ability to reabstract totally opaque function values; I don't know why you think your programming model would need more than that. The place that's emitting a reference to the function will need to reabstract immediately if the function doesn't have the expected abstraction pattern.

I'm not familiar with this "ability to reabstract totally opaque function values", could you please share some more info?

1 Like

I mean that we can reabstract first-class function values without needing to know what functions they are statically. We just make a new closure that stores the original function value but has the right signature. All you need to do with differentiable functions is wrap all the functions in the bundle.

It sounds like this is exactly what @rxwei is trying to do, with the caveat that the new closure is emitted in an optimizer pass and not in SILGen.

I guess it wouldn't be impossible to refactor the code for emitting thunks so that it does not depend on the rest of SILGen, but it would be awkward because SILGenFunction and SILGenModule are basically implementing the "God object" anti-pattern.

Just a crazy idea (not sure if it makes sense)... but why couldn't we have reabstraction thunks emitted by a SIL pass? Then we could do it after simplifications and potentially eliminate thunks without generating them. Couldn't we just have an instruction like the one you guys are talking about?

1 Like

*To be clear, I am talking about just having a reabstraction thunk instruction that is emitted by SILGen/Passes and then later in the pipeline we emit the thunk. This would enable us to eliminate all of the re-abstraction thunk goop from SILGen and simplify SILGen.

I like that idea a lot.

%g = reabstract_function %f : $(@guaranteed T) -> @out U as $(@in_guaranteed T) -> @out U

Do you think this is going to be throw-away work when opaque values become the default?

Opaque values change nothing about the need for reabstraction.

2 Likes

We could have something like this, sure. In practice it would be a bit more compllicated because you need the abstraction pattern and formal type too. Reabstraction thunks can perform AST-level type conversions too (eg, a () -> Int is a subtype of () -> Any).

Could reabstraction thunks also be used for parameter convention conversion?

A really interesting pattern we've found is that functions with (@owned) -> @owned convention are mathematically transposable so are naturally differentiable without special memory management. Differentiating (@guaranteed) -> @owned functions requires complicated cleanup machinery like ManagedValue -- I recently wrote this machinery in AD but don't feel comfortable with its mathematical weirdness :)

Yes, the same implementation mechanisms let you can change conventions.