SILGen'ing a generic builtin that returns a tuple


(Richard Wei) #1

Hi,

I'm SILGen'ing a builtin used for automatic differentiation. I'm trying to understand SILGen more and see whether/how my solution can be improved.

Builtin.autodiffApplyJVP

This builtin applies the JVP function (representing a derivative) to the given argument, returning a tuple of the original function's result and a differential function. Here's the type signature:

func Builtin.autodiffApplyJVP<T, R>(f: @autodiff (T) -> R, x: T) -> (R, (T) -> R)

In the type signature, @autodiff represents a function bundle that has the original function and two functions that represent f's derivatives.

autodiff_function_extract

To get an autodiff-associated function or the original function from an @autodiff function, we just use a autodiff_function_extraction instruction on the @autodiff function.

SILGen

To bring things together, I want the following expression

Builtin.autodiffApplyJVP(f, x)

... to be SILGen'd to

%jvp = autodiff_function_extract [jvp] [order 1] %f 
    : <T, R> (@in_guaranteed T) -> (@out R, (@in_guaranteed T) -> @out R)
%first_result_buffer = alloc_stack $R
%second_result = apply %jvp(%first_result_buffer : $*R, %x : *T)

The second return of apply %jvp is always direct, but the first return can be indirect.

I sent out PR #21307 that uses SILGenFunction::getBufferForExprResult and SILGenFunction::manageBufferForExprResult. I noticed that the visiter has to return a single ManagedValue, so I allocated a buffer for the entire return tuple, and stored the second return value to the 2nd element address, even if the second tuple element does not need to be indirectly passed. The code seems to be inefficient, and even the optimizer didn't eliminate the store of the function value into the buffer.

  %7 = autodiff_function_extract [jvp] [order 1] %1 : $@autodiff @noescape @callee_guaranteed (@in_guaranteed T) -> @out U // user: %10
  %8 = alloc_stack $(U, @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector) // users: %17, %14, %13, %11, %9
  %9 = tuple_element_addr %8 : $*(U, @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector), 0 // user: %10
  %10 = apply %7(%9, %5) : $@noescape @callee_guaranteed (@in_guaranteed T) -> (@out U, @owned @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector) // user: %12
  %11 = tuple_element_addr %8 : $*(U, @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector), 1 // user: %12
  store %10 to [init] %11 : $*@callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector // id: %12
  %13 = tuple_element_addr %8 : $*(U, @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector), 0 // user: %16
  %14 = tuple_element_addr %8 : $*(U, @callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector), 1 // user: %15
  %15 = load [take] %14 : $*@callee_guaranteed (@in_guaranteed T) -> @out U.TangentVector // user: %20

Is there a way to improve this solution so that we will not have to allocate a buffer for the subset of expr results (when the expr returns a tuple) that do not need to be passed indirectly?