Conditional entries in witness tables for AutoDiff

Yes, that’s probably what it is.

It looks like that when checking for overrides, if checking a protocol override (which we are in our case), the override is "allowed" only if the types match exactly? @rxwei does the addition of the @differentiable attribute change the type of the function? This check is done using getMemberTypeForComparison I believe in TypeCheckDeclOverride.cpp.

Additional attributes don't change the type, so we might need to add some new checks.

I have a work-in-progress solution for this. I'll open a PR soon to get feedback, but before that, @rxwei is there a quick way to compare two DifferentiableAttrs and check if one "implies" the other (for example, @differentiable implies @differentiable(wrt: 0))?

attr->getParameterIndices() gives you the parameter indices of AutoDiffParameterIndices * type, which contains a bit vector called parameters. However, this data structure is going to be replaced with AutoDiffIndexSubset. I would suggest converting two bit vectors into two AutoDiffIndexSubsets first, and call AutoDiffIndexSubset::isSubsetOf(...).

Thanks Richard! There is also additional conditions though. For example, if attribute x (just naming them for simplicity) has no provided vjp, but the overriding method defines a vjp, then an override should be allowed. I will add a new function performing such a comparison, but you should probably check it later in case I miss something.

@differentiable on protocol requirements cannot define VJPs, and in fact the vjp: field will be obsolete when both @differentiating and @transposing are in place. So we don't need to worry about it.

1 Like

It looks like the problem is not what we thought. After some debugging, it seems that SILWitnessVisitor::visitProtocolDecl is never called when compiling the following example:

public protocol Distribution {
  associatedtype Value
  func logProbability(of value: Value) -> Float
}

public protocol DifferentiableDistribution: Differentiable, Distribution {
  @differentiable(wrt: self)
  func logProbability(of value: Value) -> Float
}

@differentiable
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
  x.logProbability(of: .zero)
}

Why do you think this is happening?

This is because your code example doesn’t contain a conformance. Witness tables are only emitted for conformances.

Adding the following to your code example will trigger witness table emission.

But the other example should also not result in an error right? The error I get for the other one is still the same one saying that {{logProbability}} is not differentiable.

So, given this:

public protocol Distribution {
  associatedtype Value
  func logProbability(of value: Value) -> Float
}

public protocol DifferentiableDistribution: Differentiable, Distribution {
  @differentiable(wrt: self)
  func logProbability(of value: Value) -> Float
}

struct Foo: DifferentiableDistribution {
  @differentiable(wrt: self)
  func logProbability(of value: Float) -> Float {
    .zero
  }
}

The following compiles fine:

@differentiable
func blah(_ x: Foo) -> Float {
  x.logProbability(of: .zero)
}

However, this does not:

@differentiable
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
  x.logProbability(of: .zero)
}

This one does not invoke SILWitnessVisitor at all, so I am not sure where to look for fixing it.

Yeah, this is expected. We should first make sure that the refined protocol requirement (one with more @differentiable attributes) gets to become an overload during type checking, so that new witness entries will be emitted to the refining protocol's witness table when there is a conformance.

I think it gets marked as an override right now. At least it’s marked as such when it first goes through it. The types are marked as matched for overriding. Where would the next place to look be?

What does the witness table for the following conformance look like (if you do swiftc -emit-silgen)?

struct Foo: DifferentiableDistribution {
  @differentiable(wrt: self)
  func logProbability(of value: Float) -> Float {
    .zero
  }
}

I’m currently traveling so I can look into this tonight.

sil_witness_table hidden Foo: Distribution module empty {
  associated_type Value: Float
  method #Distribution.logProbability!1: <Self where Self : Distribution> (Self) -> (Self.Value) -> Float : @$s5empty3FooVAA12DistributionA2aDP14logProbability2ofSf5ValueQz_tFTW	// protocol witness for Distribution.logProbability(of:) in conformance Foo
}

So I guess this is not good.

The one you showed is the witness table for Distribution, which looks correct because there’s no overload there. What does the one for DifferentiableDistribution look like?

Sorry, yes that one looks like this:

sil_witness_table hidden Foo: DifferentiableDistribution module empty {
  base_protocol Differentiable: Foo: Differentiable module empty
  base_protocol Distribution: Foo: Distribution module empty
}

I'll try to look into it a bit more today but I'm traveling until Sunday so I'm only sparsely able to look into this.

It looks like the witness entry is still not emitted as expected. Could you verify, in your implementation, whether requiresNewWitnessTableEntry() returns true when this witness table is emitted?

This is now fixed in this PR. :)

The fix is based on the fact that the Swift compiler does not add entries to the witness tables of protocols for overridden functions, to avoid redundancy. However, the @differentiable attribute being added should not be interpreted as an override as it adds new functionality, and should result in entries being added to the witness table. The PR adds this check to the override checking code, thus enabling support for the aforementioned feature.

Thanks everyone for the help and the pointers that led to this fix!

1 Like