Need help type-checking where clause requirements


(Dan Zheng) #1

Hi type checker and generics experts,

I'm working on automatic differentiation and I have some questions about type-checking where clauses in the @differentiable attribute.

I'll cut to the case and describe my problem/questions here. In a reply, I'll provide more information about @differentiable.

There exist a few instances of type-checking where clause requirements:

// 1. Generic function declarations.
func foo<T>(x: T) -> T where T : Differentiable, T == T.CotangentVector { ... }

// 2. Protocol extensions.
extension FloatingPoint where Self : Differentiable, Self == Self.CotangentVector { ... }

// 3. `@_specialize` attribute.
@_specialize(where T == Int)
public func foo<T>(_ t: T) -> T { ... }

My use-case involves the @differentiable function attribute.
Specifically, I want to build a generic signature that:

  • Includes all requirements from the generic signature of the original function foo.
  • Also includes all requirements from the @differentiable where clause.
// example.swift
@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
func foo<T : Numeric>(x: T, y: T) -> T {
  return x + y
// I want to build the following generic signature:
// <T : Numeric, T : Differentiable, T == T.CotangentVector>

// VJP function must match this combined generic signature.
func vjpFoo<T>(x: T, y: T) -> (T, (T) -> (T, T))
  where T : Numeric & Differentiable, T == T.CotangentVector
  return (x + y, { v in (v, v) })

How can I do this?

Here's my original approach in TypeCheckAttr.cpp for building this combined generic signature. (I copied other code paths that handle where clause requirements.)
Below is a shortened version, edited for clarity:

// Important contextual variables:
// FuncDecl *original = ... // original function to differentiable
// DifferentiableAttr *attr = ... // `@differentiable` attribute declared on original function

// Create builder for target generic signature.
GenericSignatureBuilder builder(ctx);

// 1. Add generic signature of original function.

// 2. Add requirements from `@differentiable` where clause.
WhereClauseOwner owner(original, attr); // owner can access requirements
    owner, TypeResolutionStage::Interface,
    [&](const Requirement &req, RequirementRepr *reqRepr) {
        builder.addRequirement(req, reqRepr,
                               /*subMap*/ nullptr, original->getModuleContext());
        return false;

// 3. All requirements should've been added. Create goal generic signature.
auto goalGenericSig = std::move(builder).computeGenericSignature(
    attr->getLocation(), /*allowConcreteGenericParams=*/true);

Unfortunately, compilation of example.swift fails with this error:

example.swift:2:62: error: 'CotangentVector' is not a member type of 'T'
@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
                                                           ~ ^

I found that this diagnostic is generated in TypeResolution::resolveDependentMemberType in TypeCheckType.cpp. I added some debug print statements and found that the equivalence class of T does not contain a conformance to Differentiable:

// Output from `swiftc example.swift`:
Equivalence class represented by τ_0_0:
Members: τ_0_0
Conformances:Numeric, AdditiveArithmetic, Equatable, ExpressibleByIntegerLiteral

However, when type checking a similar generic function declaration, the equivalence class does contain such a conformance:

// example2.swift
func foo<T : Numeric>(x: T) -> T where T : Differentiable, T == T.CotangentVector { return x }
// Output from `swiftc example2.swift`:
Equivalence class represented by τ_0_0:
Members: τ_0_0, τ_0_0[.Differentiable].CotangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].TangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].CotangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].TangentVector[.Differentiable].TangentVector, τ_0_0[.Differentiable].CotangentVector[.Differentiable].TangentVector[.Differentiable].TangentVector
Conformances:Differentiable, Numeric, AdditiveArithmetic, Equatable, ExpressibleByIntegerLiteral

So something different is going on. My hypothesis is:

  • In generic function declarations and protocol extensions with where clause constraints, the generic parameters are actually constrained by the where clause requirements.
    // `T` does actually conform to `Differentiable`.
    func foo<T : Numeric>(x: T) -> T where T : Differentiable, T == T.CotangentVector { return x }
  • In generic function with @differentiable that have a where clause, the generic parameters are not themselves constrained by the where clause requirements.
    // `T` itself doesn't actually conform to `Differentiable`, from the
    // perspective of `foo`'s generic signature.
    // The where clause exists outside of `foo`.
    @differentiable(where T : Differentiable, T == T.CotangentVector)
    func foo<T : Numeric>(x: T, y: T) -> T { ... }

I wonder how to work around this error to build combined generic signatures for @differentiable? I hope I explained the problem clearly. Any tips would be greatly appreciated!

cc @rxwei @Douglas_Gregor @Slava_Pestov

Some other ideas:

  • I tried calling RequirementRequest::visitRequirements with TypeResolutionStage::Structural, which makes the compilation error go away. Type checking seems to work just fine.
    • However, TypeResolutionStage::Structural isn't a robust fix because it simply creates an unresolved DependentMemberType where the associated type decl isn't set. This causes serialization of the type to fail:
// Debug statements added in 
(dependent_member_type name=CotangentVector
  (base=generic_type_param_type depth=0 index=0))
(dependent_member_type name=CotangentVector
  (base=generic_type_param_type depth=0 index=0))
Assertion failed: (false && "we got a failure"), function addTypeRef, file /Users/dan/swift-build/swift/lib/Serialization/Serialization.cpp, line 682.
swift::serialization::Serializer::addTypeRef(swift::Type) + 520
swift::serialization::Serializer::writeGenericRequirements(llvm::ArrayRef<swift::Requirement>, std::__1::array<unsigned int, 256ul> const&) + 295
swift::serialization::Serializer::writeDeclAttribute(swift::DeclAttribute const*) + 14849
swift::serialization::Serializer::writeDecl(swift::Decl const*) + 1482
swift::serialization::Serializer::writeAllDeclsAndTypes() + 2339
swift::serialization::Serializer::writeAST(llvm::PointerUnion<swift::ModuleDecl*, swift::SourceFile*>, bool) + 4405
  • I don't call checkInheritanceClause (from TypeCheckDecl.cpp), which is called by checkGenericParams, which is called by visitFuncDecl and visitExtensionDecl.
    • checkInheritanceClause could be important because it's called in checkGenericParams (right before visitRequirements(TypeResolutionStage::Interface) ) but not in my code. checkInheritanceClause seems to evaluate InheritedTypeRequests, which would be important if info about T is added when evaluating T : Differentiable.

(Dan Zheng) #2

Here's some more background information about @differentiable.

@differentiable is a function attribute that marks functions as being differentiable. One use of @differentiable is registering primitive derivative functions, like vector-Jacobian product functions (vjps), that specify how to differentiate the function.

Type-checking is necessary because primitive derivative functions must have a certain type based on the original function's type. Here's an example:

// `@differentiable` attribute declares `foo` as being differentiable with
// the primitive vjp function `vjpFoo`.
// If the original function has type:
// <T0 : Differentiable, T1 : Differentiable, ...>
// (T0, T1, ...) -> U
// Then the type of the `vjp` must be:
// <T0 : Differentiable, T1 : Differentiable,
//  *** extra `@differentiable` where clause requirements ***>;
// (T0, T1, ...)  ->  (U,     (U) -> (T0.CotangentVector, T1.CotangentVector, ...))
// ^~~~~~~~~~~         ^       ^      ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// original args    result   vector            vector-Jacobian products

@differentiable(vjp: vjpFoo where T : Differentiable, T == T.CotangentVector)
func foo<T : Numeric>(x: T, y: T) -> T {
  return x + y

// VJPs should have generic signatures that:
// - Include all requirements from the generic signature of the original function
// - Include all requirements from the `@differentiable` where clause
func vjpFoo<T>(x: T, y: T) -> (T, (T) -> (T, T))
  where T : Numeric & Differentiable, T == T.CotangentVector
  return (x + y, { v in (v, v) })

(Douglas Gregor) #3

I think you have the right idea with using TypeResolutionStage::Structural, because you can't create a fully-resolved DependentMemberType until you've processed the other requirements in the where clause. However, you don't want to store these structural types in convertedRequirements: rather, you'll want to either (1) take another pass through to compute the requirement types using the goalGenericSig you built or (2) record goalRecordSig for later use rather than individual requirements.


(Dan Zheng) #4

Thank you for the response Doug!

Your response enlightened me: manually keep tracking of convertedRequirements was the problem. I copied that code unnecessarily from @_specialize attribute type-checking.

The solution is simply to get requirements from the builder-computed generic signature:

+    // Compute generic signature and environment for autodiff associated
+    // functions.
     whereClauseGenSig = std::move(builder).computeGenericSignature(
         attr->getLocation(), /*allowConcreteGenericParams=*/true);
+    // Store the resolved requirements in the attribute.
+    attr->setRequirements(ctx, whereClauseGenSig->getRequirements());
-    attr->setRequirements(ctx, convertedRequirements);

For those interested, here's the commit with these changes. @differentiable where clauses now support requirements with dependent member types. Thanks again Doug!