Hi type checker and generics experts,
I'm working on automatic differentiation and I have some questions about type-checking where clauses in the @differentiable attribute.
I'll cut to the case and describe my problem/questions here. In a reply, I'll provide more information about @differentiable.
There exist a few instances of type-checking where clause requirements:
// 1. Generic function declarations.
func foo<T>(x: T) -> T where T : Differentiable, T == T.CotangentVector { ... }
// 2. Protocol extensions.
extension FloatingPoint where Self : Differentiable, Self == Self.CotangentVector { ... }
// 3. `@_specialize` attribute.
@_specialize(where T == Int)
public func foo<T>(_ t: T) -> T { ... }
My use-case involves the @differentiable function attribute.
Specifically, I want to build a generic signature that:
- Includes all requirements from the generic signature of the original function
foo.
- Also includes all requirements from the
@differentiable where clause.
// example.swift
@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
func foo<T : Numeric>(x: T, y: T) -> T {
return x + y
}
// I want to build the following generic signature:
// <T : Numeric, T : Differentiable, T == T.CotangentVector>
// VJP function must match this combined generic signature.
func vjpFoo<T>(x: T, y: T) -> (T, (T) -> (T, T))
where T : Numeric & Differentiable, T == T.CotangentVector
{
return (x + y, { v in (v, v) })
}
How can I do this?
Here's my original approach in TypeCheckAttr.cpp for building this combined generic signature. (I copied other code paths that handle where clause requirements.)
Below is a shortened version, edited for clarity:
// Important contextual variables:
// FuncDecl *original = ... // original function to differentiable
// DifferentiableAttr *attr = ... // `@differentiable` attribute declared on original function
// Create builder for target generic signature.
GenericSignatureBuilder builder(ctx);
// 1. Add generic signature of original function.
builder.addGenericSignature(original->getGenericSignature());
// 2. Add requirements from `@differentiable` where clause.
WhereClauseOwner owner(original, attr); // owner can access requirements
RequirementRequest::visitRequirements(
owner, TypeResolutionStage::Interface,
[&](const Requirement &req, RequirementRepr *reqRepr) {
builder.addRequirement(req, reqRepr,
FloatingRequirementSource::forExplicit(reqRepr),
/*subMap*/ nullptr, original->getModuleContext());
convertedRequirements.push_back(getCanonicalRequirement(req));
return false;
});
// 3. All requirements should've been added. Create goal generic signature.
auto goalGenericSig = std::move(builder).computeGenericSignature(
attr->getLocation(), /*allowConcreteGenericParams=*/true);
Unfortunately, compilation of example.swift fails with this error:
example.swift:2:62: error: 'CotangentVector' is not a member type of 'T'
@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
~ ^
I found that this diagnostic is generated in TypeResolution::resolveDependentMemberType in TypeCheckType.cpp. I added some debug print statements and found that the equivalence class of T does not contain a conformance to Differentiable:
// Output from `swiftc example.swift`:
BASE EQUIV CLASS
Equivalence class represented by τ_0_0:
Members: τ_0_0
Conformances:Numeric, AdditiveArithmetic, Equatable, ExpressibleByIntegerLiteral
However, when type checking a similar generic function declaration, the equivalence class does contain such a conformance:
// example2.swift
func foo<T : Numeric>(x: T) -> T where T : Differentiable, T == T.CotangentVector { return x }
// Output from `swiftc example2.swift`:
BASE EQUIV CLASS
Equivalence class represented by τ_0_0:
Members: τ_0_0, τ_0_0[.Differentiable].CotangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].TangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].TangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].TangentVector[.Differentiable].TangentVector
Conformances:Differentiable, Numeric, AdditiveArithmetic, Equatable, ExpressibleByIntegerLiteral
So something different is going on. My hypothesis is:
- In generic function declarations and protocol extensions with where clause constraints, the generic parameters are actually constrained by the where clause requirements.
// `T` does actually conform to `Differentiable`.
func foo<T : Numeric>(x: T) -> T where T : Differentiable, T == T.CotangentVector { return x }
- In generic function with
@differentiable that have a where clause, the generic parameters are not themselves constrained by the where clause requirements.// `T` itself doesn't actually conform to `Differentiable`, from the
// perspective of `foo`'s generic signature.
// The where clause exists outside of `foo`.
@differentiable(where T : Differentiable, T == T.CotangentVector)
func foo<T : Numeric>(x: T, y: T) -> T { ... }
I wonder how to work around this error to build combined generic signatures for @differentiable? I hope I explained the problem clearly. Any tips would be greatly appreciated!
cc @rxwei @Douglas_Gregor @Slava_Pestov
Some other ideas:
- I tried calling
RequirementRequest::visitRequirements with TypeResolutionStage::Structural, which makes the compilation error go away. Type checking seems to work just fine.
- However,
TypeResolutionStage::Structural isn't a robust fix because it simply creates an unresolved DependentMemberType where the associated type decl isn't set. This causes serialization of the type to fail:
// Debug statements added in
SERIALIZER ADDING TYPE
(dependent_member_type name=CotangentVector
(base=generic_type_param_type depth=0 index=0))
(dependent_member_type name=CotangentVector
(base=generic_type_param_type depth=0 index=0))
Assertion failed: (false && "we got a failure"), function addTypeRef, file /Users/dan/swift-build/swift/lib/Serialization/Serialization.cpp, line 682.
...
swift::serialization::Serializer::addTypeRef(swift::Type) + 520
swift::serialization::Serializer::writeGenericRequirements(llvm::ArrayRef<swift::Requirement>, std::__1::array<unsigned int, 256ul> const&) + 295
swift::serialization::Serializer::writeDeclAttribute(swift::DeclAttribute const*) + 14849
swift::serialization::Serializer::writeDecl(swift::Decl const*) + 1482
swift::serialization::Serializer::writeAllDeclsAndTypes() + 2339
swift::serialization::Serializer::writeAST(llvm::PointerUnion<swift::ModuleDecl*, swift::SourceFile*>, bool) + 4405
- I don't call
checkInheritanceClause (from TypeCheckDecl.cpp), which is called by checkGenericParams, which is called by visitFuncDecl and visitExtensionDecl.
-
checkInheritanceClause could be important because it's called in checkGenericParams (right before visitRequirements(TypeResolutionStage::Interface) ) but not in my code. checkInheritanceClause seems to evaluate InheritedTypeRequests, which would be important if info about T is added when evaluating T : Differentiable.