One use case that often comes up in differentiable programming is being able to define a function whose differentiability depends on parameters' protocol conformances. We've already made this possible for concrete functions, but not yet for protocol requirements, because it would involve conditional entries in a witness table from my understanding.
public protocol Distribution {
associatedtype Value
@differentiable(wrt: self where Self: Differentiable)
func logProbability(of value: Value) -> Tensor<Float>
}
Question: If a conforming type Foo does not also conform to Differentiable, the derivative function for this function will not exist. In that case, what should we put in the derivative entry in the Foo : Distribution witness table?
There's also a subtle issue with conditionally differentiable protocol requirements: What if the conforming type gets retroactively extended to conform to Differentiable in a third module? My current thinking is that we should emit an error if the user tries to differentiate this function in the third module because the conformance to Differentiable is defined in a separate module after the conformance to Distribution.
Conditional requirements are really tricky, and something we were really hoping to never have to support. I would instead design the interface so that the differential requirement is introduced in a child protocol:
public protocol Distribution {
associatedtype Value
func logProbability(of value: Value) -> Tensor<Float>
}
public protocol DifferentialDistribution: Distribution, Differentiable {
@differentiable(wrt: self)
func logProbability(of value: Value) -> Tensor<Float>
}
That also seems like it simplifies the design of differentiable, since it doesn't need its own set of conditional conformances.
Thanks Joe! That makes it easier. It's also great to know that conditional entries are to be avoided so that we won't be thinking towards that direction.
I tried this example:
public protocol Distribution {
associatedtype Value
func logProbability(of value: Value) -> Float
}
public protocol DifferentiableDistribution: Distribution, Differentiable {
@differentiable(wrt: self)
func logProbability(of value: Value) -> Float
}
@differentiable
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
x.logProbability(of: .zero)
}
It looks like when logProbability(_:) is being called, it's calling through Distribution, not DifferentiableDistribution:
That's interesting and it also shows why the error pops up. Shouldn't we always be choosing the most specific protocol that defines a function with a matching signature?
I'm not sure we really have a formal concept of requirement refinement. We have a hack to avoid introducing redundant wtable entries for identical requirements, and we have an overloading rule about picking the more specialized function signature, but I don't think we prefer one requirement over another in the abstract.
Since differentiability is tied to the function in the model, I'm not sure we can easily avoid conditional differentiability requirements.
Would it be easy to modify the overloading rule that picks the more specialized function signature such that it also considers the differentiable attribute? And would such a change result in DifferentiableDistribution.logProbability(_:) being picked over Distribution.logProbability(_:) in the example above?
Would it be useful if we looked at where that hack is and tried to modify the redundancy check so it also takes into account the "@differentiable" attribute when checking for identical requirements?
I think there's plenty of precedent for "informal" refinement in the standard library, where refining protocols often redeclare methods with refined requirements. Differentiation doesn't feel so different from those other cases. I suspect that, to Richard's point, our logic that tries to detect exactly equivalent
What happens though if a method signature is exactly the same in the refined protocol, with the only exception being an additional attribute (in this case the "@differentiable" attribute)?
I guess, ideally, you'd still coalesce the basic requirement, but introduce the differentiation witnesses as additional requirements in the refined protocol.
I'm not familiar with how the differentiation witnesses are introduces. @rxwei do you think something may be wrong with the handling of the @differentiable attribute?
We should expect the Foo: DifferentiableDistribution witness table to contain derivative witnesses (jvp and vjp). @marcrasi might know some details here.
I think we should check for extra @differentiable attributes near or within OverrideMatcher::checkPotentialOverrides. I haven't worked on this part of the code so I'm not entirely sure about the consequences.