Hi type checker and generics experts,

I'm working on automatic differentiation and I have some questions about type-checking where clauses in the @differentiable attribute.

I'll cut to the case and describe my problem/questions here. In a reply, I'll provide more information about @differentiable.


There exist a few instances of type-checking where clause requirements:

// 1. Generic function declarations.
func foo<T>(x: T) -> T where T : Differentiable, T == T.CotangentVector { ... }

// 2. Protocol extensions.
extension FloatingPoint where Self : Differentiable, Self == Self.CotangentVector { ... }

// 3. `@_specialize` attribute.
@_specialize(where T == Int)
public func foo<T>(_ t: T) -> T { ... }

My use-case involves the @differentiable function attribute.
Specifically, I want to build a generic signature that:

  • Includes all requirements from the generic signature of the original function foo.
  • Also includes all requirements from the @differentiable where clause.
// example.swift
@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
func foo<T : Numeric>(x: T, y: T) -> T {
  return x + y
}
// I want to build the following generic signature:
// <T : Numeric, T : Differentiable, T == T.CotangentVector>

// VJP function must match this combined generic signature.
func vjpFoo<T>(x: T, y: T) -> (T, (T) -> (T, T))
  where T : Numeric & Differentiable, T == T.CotangentVector
{
  return (x + y, { v in (v, v) })
}

How can I do this?

Here's my original approach in TypeCheckAttr.cpp for building this combined generic signature. (I copied other code paths that handle where clause requirements.)
Below is a shortened version, edited for clarity:

// Important contextual variables:
// FuncDecl *original = ... // original function to differentiable
// DifferentiableAttr *attr = ... // `@differentiable` attribute declared on original function

// Create builder for target generic signature.
GenericSignatureBuilder builder(ctx);

// 1. Add generic signature of original function.
builder.addGenericSignature(original->getGenericSignature());

// 2. Add requirements from `@differentiable` where clause.
WhereClauseOwner owner(original, attr); // owner can access requirements
RequirementRequest::visitRequirements(
    owner, TypeResolutionStage::Interface,
    [&](const Requirement &req, RequirementRepr *reqRepr) {
        builder.addRequirement(req, reqRepr,
                               FloatingRequirementSource::forExplicit(reqRepr),
                               /*subMap*/ nullptr, original->getModuleContext());
        convertedRequirements.push_back(getCanonicalRequirement(req));
        return false;
    });

// 3. All requirements should've been added. Create goal generic signature.
auto goalGenericSig = std::move(builder).computeGenericSignature(
    attr->getLocation(), /*allowConcreteGenericParams=*/true);

Unfortunately, compilation of example.swift fails with this error:

example.swift:2:62: error: 'CotangentVector' is not a member type of 'T'
@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
                                                           ~ ^

I found that this diagnostic is generated in TypeResolution::resolveDependentMemberType in TypeCheckType.cpp. I added some debug print statements and found that the equivalence class of T does not contain a conformance to Differentiable:

// Output from `swiftc example.swift`:
BASE EQUIV CLASS
Equivalence class represented by τ_0_0:
Members: τ_0_0
Conformances:Numeric, AdditiveArithmetic, Equatable, ExpressibleByIntegerLiteral

However, when type checking a similar generic function declaration, the equivalence class does contain such a conformance:

// example2.swift
func foo<T : Numeric>(x: T) -> T where T : Differentiable, T == T.CotangentVector { return x }
// Output from `swiftc example2.swift`:
BASE EQUIV CLASS
Equivalence class represented by τ_0_0:
Members: τ_0_0, τ_0_0[.Differentiable].CotangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].TangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].TangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].TangentVector[.Differentiable].TangentVector
Conformances:Differentiable, Numeric, AdditiveArithmetic, Equatable, ExpressibleByIntegerLiteral

So something different is going on. My hypothesis is:

  • In generic function declarations and protocol extensions with where clause constraints, the generic parameters are actually constrained by the where clause requirements.
    // `T` does actually conform to `Differentiable`.
    func foo<T : Numeric>(x: T) -> T where T : Differentiable, T == T.CotangentVector { return x }
    
  • In generic function with @differentiable that have a where clause, the generic parameters are not themselves constrained by the where clause requirements.
    // `T` itself doesn't actually conform to `Differentiable`, from the
    // perspective of `foo`'s generic signature.
    // The where clause exists outside of `foo`.
    @differentiable(where T : Differentiable, T == T.CotangentVector)
    func foo<T : Numeric>(x: T, y: T) -> T { ... }
    

I wonder how to work around this error to build combined generic signatures for @differentiable? I hope I explained the problem clearly. Any tips would be greatly appreciated!

cc @rxwei @Douglas_Gregor @Slava_Pestov


Some other ideas:

  • I tried calling RequirementRequest::visitRequirements with TypeResolutionStage::Structural, which makes the compilation error go away. Type checking seems to work just fine.
    • However, TypeResolutionStage::Structural isn't a robust fix because it simply creates an unresolved DependentMemberType where the associated type decl isn't set. This causes serialization of the type to fail:
// Debug statements added in 
SERIALIZER ADDING TYPE
(dependent_member_type name=CotangentVector
  (base=generic_type_param_type depth=0 index=0))
(dependent_member_type name=CotangentVector
  (base=generic_type_param_type depth=0 index=0))
Assertion failed: (false && "we got a failure"), function addTypeRef, file /Users/dan/swift-build/swift/lib/Serialization/Serialization.cpp, line 682.
...
swift::serialization::Serializer::addTypeRef(swift::Type) + 520
swift::serialization::Serializer::writeGenericRequirements(llvm::ArrayRef<swift::Requirement>, std::__1::array<unsigned int, 256ul> const&) + 295
swift::serialization::Serializer::writeDeclAttribute(swift::DeclAttribute const*) + 14849
swift::serialization::Serializer::writeDecl(swift::Decl const*) + 1482
swift::serialization::Serializer::writeAllDeclsAndTypes() + 2339
swift::serialization::Serializer::writeAST(llvm::PointerUnion<swift::ModuleDecl*, swift::SourceFile*>, bool) + 4405
  • I don't call checkInheritanceClause (from TypeCheckDecl.cpp), which is called by checkGenericParams, which is called by visitFuncDecl and visitExtensionDecl.
    • checkInheritanceClause could be important because it's called in checkGenericParams (right before visitRequirements(TypeResolutionStage::Interface) ) but not in my code. checkInheritanceClause seems to evaluate InheritedTypeRequests, which would be important if info about T is added when evaluating T : Differentiable.

Here's some more background information about @differentiable.

@differentiable is a function attribute that marks functions as being differentiable. One use of @differentiable is registering primitive derivative functions, like vector-Jacobian product functions (vjps), that specify how to differentiate the function.

Type-checking is necessary because primitive derivative functions must have a certain type based on the original function's type. Here's an example:

// `@differentiable` attribute declares `foo` as being differentiable with
// the primitive vjp function `vjpFoo`.
//
// If the original function has type:
// <T0 : Differentiable, T1 : Differentiable, ...>
// (T0, T1, ...) -> U
//
// Then the type of the `vjp` must be:
// <T0 : Differentiable, T1 : Differentiable,
//  *** extra `@differentiable` where clause requirements ***>;
// (T0, T1, ...)  ->  (U,     (U) -> (T0.CotangentVector, T1.CotangentVector, ...))
// ^~~~~~~~~~~         ^       ^      ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// original args    result   vector            vector-Jacobian products

@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
func foo<T : Numeric>(x: T, y: T) -> T {
  return x + y
}

// VJPs should have generic signatures that:
// - Include all requirements from the generic signature of the original function
// - Include all requirements from the `@differentiable` where clause
func vjpFoo<T>(x: T, y: T) -> (T, (T) -> (T, T))
  where T : Numeric & Differentiable, T == T.CotangentVector
{
  return (x + y, { v in (v, v) })
}

I think you have the right idea with using TypeResolutionStage::Structural, because you can't create a fully-resolved DependentMemberType until you've processed the other requirements in the where clause. However, you don't want to store these structural types in convertedRequirements: rather, you'll want to either (1) take another pass through to compute the requirement types using the goalGenericSig you built or (2) record goalRecordSig for later use rather than individual requirements.

Doug

1 Like

Thank you for the response Doug!

Your response enlightened me: manually keep tracking of convertedRequirements was the problem. I copied that code unnecessarily from @_specialize attribute type-checking.

The solution is simply to get requirements from the builder-computed generic signature:

+    // Compute generic signature and environment for autodiff associated
+    // functions.
     whereClauseGenSig = std::move(builder).computeGenericSignature(
         attr->getLocation(), /*allowConcreteGenericParams=*/true);
+    // Store the resolved requirements in the attribute.
+    attr->setRequirements(ctx, whereClauseGenSig->getRequirements());
-    attr->setRequirements(ctx, convertedRequirements);

For those interested, here's the commit with these changes. @differentiable where clauses now support requirements with dependent member types. Thanks again Doug!

1 Like