Using gradient(at:in:) for a function with an inout parameter

I can get the gradient of a function with signature (Float) -> Float:

import _Differentiation

@differentiable
func timesTwo(a: Float) -> Float {
    return a * 2
}

let grad = gradient(at: 3, in: timesTwo)

But I haven't figured out how to get the gradient of a function with signature (inout Float) -> Void, which does the same math:

@differentiable
func timesTwoInout(a: inout Float) {
    a *= 2
}

let gradInout = gradient(at: 3, in: timesTwoInout) // compiler says "Cannot convert value of type '(inout Float) -> ()' to expected argument type '@differentiable (Float) -> ()'"

I realize I can wrap my inout version in a function with the (Float) -> Float signature, but that defeats the purpose of the speed of an inout function in my use case:

func timesTwoInoutWrapper(a: Float) -> Float {
    var a = a
    timesTwoInout(a: &a)
    return a
}

let gradientInoutWrapper = gradient(at: 3, in: timesTwoInoutWrapper)

Since this wrapper works, it must simply be a question of available API for gradient(at:in:). Does anyone know a way to call gradient(at:in:) for a function of type (inout Float) -> Void?

Thanks!

I don't want to distract from the core of your question, which I cannot answer. But I did want to address the question of "speed". For trivial types like Float there is no guarantee that inout will be any faster than simply returning. Indeed, a quick use of Compiler Explorer seems to suggest that these two functions produce effectively identical code.

Thanks! I'm actually using non-trivial types in my use case.

The initially proposed differential operators are all functional. They don't accept closures with inout parameters. While it is possible to add this in the future, I don't know concrete use cases where being able to do this is a must (as you said, you can wrap it in a functional closure).

Differentiating inout is useful when you need to perform an in-place operation as part of some bigger operation (e.g. a loop) that's being differentiated. Is it really useful to differentiate closures with inout parameters using a top-level differential operator?

Thanks @rxwei! After considering your reply, I realized I was going down the wrong design path.

My initial design was going to manually string together the pullbacks from a bunch of (inout SomeType) -> Void functions (this was because the surrounding code is not itself differentiable). But I realized that is unnecessary because I'm only calling them within another differentiable function, like you said. Too many details, I got tunnel vision until I read your reply :slight_smile:

1 Like
Terms of Service

Privacy Policy

Cookie Policy