Hi folks,
I'd like to share some ideas about differentiating stored properties attributed with property wrappers.
Ideas by: @rxwei, @bartchr808, @dan-zheng, @marcrasi
Implementation: [AutoDiff] Support differentiation of wrapped properties. by dan-zheng · Pull Request #31173 · apple/swift · GitHub
Background
Differentiable programming in Swift uses the Differentiable
protocol, which has compiler support for derived conformances.
The compiler synthesizes TangentVector
member structs for Differentiable
-conforming types based on which stored properties conform to Differentiable
. Details here.
Let "wrapped (stored) properties" refer to "stored properties attributed with property wrappers".
Let "tangent (stored) properties" refer to "the stored properties synthesized in TangentVector
member structs, corresponding to original stored properties that conform to Differentiable
".
struct Pair<T: Differentiable, U: Differentiable>: Differentiable {
var first: T
var second: U
// Compiler synthesizes:
// struct TangentVector: Differentiable & AdditiveArithmetic {
// var first: T.TangentVector // tangent property
// var second: U.TangentVector // tangent property
// }
}
Current behavior
Differentiable
conformance derivation currently computes tangent properties from wrapper backing stored properties instead of wrapped stored properties. This leads to some unexpected behavior.
Let's look at an example:
import _Differentiation
// Naive property wrapper.
@propertyWrapper
struct Wrapper<Value> {
var wrappedValue: Value
}
struct Struct {
@Wrapper var x: Float
// Compiler generates:
// var _x: Wrapper<Float>
// var x: Float {
// get { _x.wrappedValue }
// set { _x.wrappedValue = newValue }
// }
@Wrapper @Wrapper var y: Float
}
Currently, tangent properties are computed from wrapper backing stored properties, requiring wrapper types to conform to Differentiable
:
// Wrappers must conform to `Differentiable`.
extension Wrapper: Differentiable where Value: Differentiable {}
struct Struct: Differentiable {
@Wrapper var x: Float
@Wrapper @Wrapper var y: Float
// Compiler currently synthesizes:
// struct TangentVector: Differentiable & AdditiveArithmetic {
// var x: Wrapper<Float>.TangentVector
// var y: Wrapper<Wrapper<Float>>.TangentVector
// ...
// }
}
It seems weird that Wrapper<...>.TangentVector
appears in the synthesized TangentVector
struct, and that Wrapper
must conform to Differentiable
. Many property wrappers (e.g. @Lazy
) are unrelated to differentiation, and it may not make sense to conform them to Differentiable
.
Since the wrapped property Struct.x
has type Float
, one would expect the corresponding tangent property to have type Float
, not Wrapper<Float>.TangentVector
.
Idea
Instead, we can make Differentiable
conformance derivation treat wrapped stored properties like normal stored properties, using them to compute tangent properties in TangentVector
. This makes behavior consistent for normal stored properties and wrapped stored properties: one might say this is a fix rather than a new feature.
struct Struct: Differentiable {
@Wrapper var x: Float
@Wrapper @Wrapper var y: Float
// New behavior:
// struct TangentVector: Differentiable & AdditiveArithmetic {
// var x: Float
// var y: Float
// ...
// }
}
This behavior seems desirable for all wrapped stored properties and property wrapper types. Whether wrapper types conform to Differentiable
is now irrelevant - what matters is that wrapped properties conform to Differentiable
, just like normal stored properties.
Wrapper types are required to provide a setter for var wrappedValue
, which is needed to synthesize mutating func move(along:)
. This is consistent with existing Differentiable
conformance derivation requirements.
Accesses to wrapped stored properties can be differentiated, as expected:
@differentiable
func multiply(_ s: Struct) -> Float {
s.x * s.y
}
print(gradient(at: Struct(x: 3, y: 4), in: multiply))
// Struct.TangentVector(x: 4.0, y: 3.0)
Use cases
The new behavior makes differentiation work naturally for wrapped stored properties. Here's an example using non-trivial example property wrappers from SE-0258:
// `@Lazy` and `@Clamping` from:
// https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md
struct Struct: Differentiable {
@Lazy var x: Float = 10
@Clamping(min: -10, max: 10)
var y: Float = 5
}
@differentiable
func multiply(_ s: Struct) -> Float {
return s.x * s.y
}
print(gradient(at: Struct(x: 3, y: 4), in: multiply))
// Struct.TangentVector(x: 4.0, y: 3.0)
Any comments are welcome!