I can get the gradient of a function with signature (Float) -> Float:
import _Differentiation
@differentiable
func timesTwo(a: Float) -> Float {
return a * 2
}
let grad = gradient(at: 3, in: timesTwo)
But I haven't figured out how to get the gradient of a function with signature (inout Float) -> Void, which does the same math:
@differentiable
func timesTwoInout(a: inout Float) {
a *= 2
}
let gradInout = gradient(at: 3, in: timesTwoInout) // compiler says "Cannot convert value of type '(inout Float) -> ()' to expected argument type '@differentiable (Float) -> ()'"
I realize I can wrap my inout version in a function with the (Float) -> Float signature, but that defeats the purpose of the speed of an inout function in my use case:
func timesTwoInoutWrapper(a: Float) -> Float {
var a = a
timesTwoInout(a: &a)
return a
}
let gradientInoutWrapper = gradient(at: 3, in: timesTwoInoutWrapper)
Since this wrapper works, it must simply be a question of available API for gradient(at:in:). Does anyone know a way to call gradient(at:in:) for a function of type (inout Float) -> Void?
Thanks!
lukasa
(Cory Benfield)
2
I don't want to distract from the core of your question, which I cannot answer. But I did want to address the question of "speed". For trivial types like Float there is no guarantee that inout will be any faster than simply returning. Indeed, a quick use of Compiler Explorer seems to suggest that these two functions produce effectively identical code.
Thanks! I'm actually using non-trivial types in my use case.
rxwei
(Richard Wei)
4
The initially proposed differential operators are all functional. They don't accept closures with inout parameters. While it is possible to add this in the future, I don't know concrete use cases where being able to do this is a must (as you said, you can wrap it in a functional closure).
Differentiating inout is useful when you need to perform an in-place operation as part of some bigger operation (e.g. a loop) that's being differentiated. Is it really useful to differentiate closures with inout parameters using a top-level differential operator?
Thanks @rxwei! After considering your reply, I realized I was going down the wrong design path.
My initial design was going to manually string together the pullbacks from a bunch of (inout SomeType) -> Void functions (this was because the surrounding code is not itself differentiable). But I realized that is unnecessary because I'm only calling them within another differentiable function, like you said. Too many details, I got tunnel vision until I read your reply 
1 Like