Conditional entries in witness tables for AutoDiff

One use case that often comes up in differentiable programming is being able to define a function whose differentiability depends on parameters' protocol conformances. We've already made this possible for concrete functions, but not yet for protocol requirements, because it would involve conditional entries in a witness table from my understanding.

In issue TF-637, @eaplatanios gave an example:

public protocol Distribution {
  associatedtype Value

  @differentiable(wrt: self where Self: Differentiable)
  func logProbability(of value: Value) -> Tensor<Float>
}

Question: If a conforming type Foo does not also conform to Differentiable, the derivative function for this function will not exist. In that case, what should we put in the derivative entry in the Foo : Distribution witness table?

There's also a subtle issue with conditionally differentiable protocol requirements: What if the conforming type gets retroactively extended to conform to Differentiable in a third module? My current thinking is that we should emit an error if the user tries to differentiate this function in the third module because the conformance to Differentiable is defined in a separate module after the conformance to Distribution.

cc @Slava_Pestov @Joe_Groff @John_McCall

2 Likes

Conditional requirements are really tricky, and something we were really hoping to never have to support. I would instead design the interface so that the differential requirement is introduced in a child protocol:

public protocol Distribution {
  associatedtype Value

  func logProbability(of value: Value) -> Tensor<Float>
}

public protocol DifferentialDistribution: Distribution, Differentiable {
  @differentiable(wrt: self)
  func logProbability(of value: Value) -> Tensor<Float>
}

That also seems like it simplifies the design of differentiable, since it doesn't need its own set of conditional conformances.

6 Likes

That is actually what I proposed as a workaround in TF-637, but it is not currently possible. Maybe @rxwei you know more about why?

Thanks Joe! That makes it easier. It's also great to know that conditional entries are to be avoided so that we won't be thinking towards that direction.

I tried this example:

public protocol Distribution {
  associatedtype Value
  func logProbability(of value: Value) -> Float
}

public protocol DifferentiableDistribution: Distribution, Differentiable {
  @differentiable(wrt: self)
  func logProbability(of value: Value) -> Float
}

@differentiable
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
  x.logProbability(of: .zero)
}

It looks like when logProbability(_:) is being called, it's calling through Distribution, not DifferentiableDistribution:

  %6 = witness_method $T, #Distribution.logProbability!1 : <Self where Self : Distribution> (Self) -> (Self.Value) -> Float : $@convention(witness_method: Distribution) <τ_0_0 where τ_0_0 : Distribution> (@in_guaranteed τ_0_0.Value, @in_guaranteed τ_0_0) -> Float // user: %7

I'm not sure how to best resolve this. Should we manually find any protocol override when performing the differentiation transformation?

1 Like

That's interesting and it also shows why the error pops up. Shouldn't we always be choosing the most specific protocol that defines a function with a matching signature?

@Douglas_Gregor or @xedin might be able to help you there. I would expect us to treat the requirement in the more refined protocol as more specific.

2 Likes

I'm not sure we really have a formal concept of requirement refinement. We have a hack to avoid introducing redundant wtable entries for identical requirements, and we have an overloading rule about picking the more specialized function signature, but I don't think we prefer one requirement over another in the abstract.

Since differentiability is tied to the function in the model, I'm not sure we can easily avoid conditional differentiability requirements.

3 Likes

Would it be easy to modify the overloading rule that picks the more specialized function signature such that it also considers the differentiable attribute? And would such a change result in DifferentiableDistribution.logProbability(_:) being picked over Distribution.logProbability(_:) in the example above?

1 Like

This is actually interesting and I'm not sure whether type-checker is culprit here, consider this example:

protocol P {
  associatedtype Value
  func foo()
}

protocol Q : P {
  func foo()
}

func bar<T: Q>(_ x: T) where T.Value: AdditiveArithmetic {
  x.foo()
}

Type-checked expression AST looks like this:

(call_expr type='()' location=test.swift:11:5 range=[test.swift:11:3 - line:11:9] arg_labels=
  (dot_syntax_call_expr type='() -> ()' location=test.swift:11:5 range=[test.swift:11:3 - line:11:5]
    (declref_expr type='(T) -> () -> ()' location=test.swift:11:5 range=[test.swift:11:5 - line:11:5] decl=test.(file).Q.foo()@test.swift:7:8 [with (substitution_map generic_signature=<Self where Self : Q> (substitution Self -> T))] function_ref=single)
    (declref_expr type='T' location=test.swift:11:3 range=[test.swift:11:3 - line:11:3] decl=test.(file).bar(_:).x@test.swift:10:18 function_ref=unapplied))
  (tuple_expr type='()' location=test.swift:11:8 range=[test.swift:11:8 - line:11:9]))

So substitutions are as follows:

(substitution_map generic_signature=<Self where Self : Q> (substitution Self -> T))

But generated SIL has this witness method:

  %2 = witness_method $T, #P.foo!1 : <Self where Self : P> (Self) -> () -> () : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0) -> () // user: %3

I'm not sure how did it switched from Q to P in this case...

1 Like

Could that be the hack that John was referring to?

Would it be useful if we looked at where that hack is and tried to modify the redundancy check so it also takes into account the "@differentiable" attribute when checking for identical requirements?

I think there's plenty of precedent for "informal" refinement in the standard library, where refining protocols often redeclare methods with refined requirements. Differentiation doesn't feel so different from those other cases. I suspect that, to Richard's point, our logic that tries to detect exactly equivalent

What happens though if a method signature is exactly the same in the refined protocol, with the only exception being an additional attribute (in this case the "@differentiable" attribute)?

I guess, ideally, you'd still coalesce the basic requirement, but introduce the differentiation witnesses as additional requirements in the refined protocol.

I'm not familiar with how the differentiation witnesses are introduces. @rxwei do you think something may be wrong with the handling of the @differentiable attribute?

Yeah, maybe something's wrong here.

public protocol Distribution {
  associatedtype Value
  func logProbability(of value: Value) -> Float
}

public protocol DifferentiableDistribution: Distribution, Differentiable {
  @differentiable(wrt: self)
  func logProbability(of value: Value) -> Float
}

struct Foo: DifferentiableDistribution {
  @differentiable(wrt: self)
  func logProbability(of value: Float) -> Float { .zero }
}
sil_witness_table hidden Foo: DifferentiableDistribution module witness {
  base_protocol Differentiable: Foo: Differentiable module witness
  base_protocol Distribution: Foo: Distribution module witness
}

sil_witness_table hidden Foo: Distribution module witness {
  associated_type Value: Float
  method #Distribution.logProbability!1: <Self where Self : Distribution> (Self) -> (Self.Value) -> Float : @$s7witness3FooVAA12DistributionA2aDP14logProbability2ofSf5ValueQz_tFTW     // protocol witness for Distribution.logProbability(of:) in conformance Foo
}

We should expect the Foo: DifferentiableDistribution witness table to contain derivative witnesses (jvp and vjp). @marcrasi might know some details here.

1 Like

@rxwei could it be this line?

Looks like it is. Maybe we should make any protocol requirement override that defines more @differentiable attributes be in getOverriddenDecls().

Yeah I was looking into this now. If you give me some pointers as to what's involved in doing that I can tackle it now.

I think we should check for extra @differentiable attributes near or within OverrideMatcher::checkPotentialOverrides. I haven't worked on this part of the code so I'm not entirely sure about the consequences.