Differentiable programming for gradient-based machine learning

Hello Swift community,

The development of differentiable programming in Swift (“Differentiable Swift”, “AutoDiff”) has come a long way since its beginning almost three years ago. Earlier this year, following the Core Team’s interest in evaluating incorporating this capability into Swift, @dan-zheng and @marcrasi drove and completed a big transition to upstream all of the implementation to the main branch of Swift.

Today, we would like to take differentiable programming in Swift to the Pitch phase with a new proposal. This proposal is derived from the Differentiable Programming Manifesto, but has been scoped down to a forward-compatible (ABI-compatible and optimizable) subset of features to support differentiable programming’s dominant use case — machine learning — as well as other gradient-based numerical computing. We look forward to your feedback!

Full Proposal

Differentiable programming for gradient-based machine learning


Derivatives are a fundamental tool in calculus and have applications in many
domains, notably gradient-based machine learning (ML). As an easy-to-use,
high-performance language, Swift is a great fit for both highly expressive
algorithms and numerical computations. Meanwhile, ML is one of the fastest
growing technologies in modern days, but the mainstream ML development tools are
mostly based on dynamic languages where it can be challenging for developers to
take advantange of software debugging tools and compile-time code diagnostics or
to maintain type safety in large-scale software.

As a compiled programming language with a modern type system, Swift has a unique
opportunity to develop its own numerical computing and ML ecosystem. Driven by
the growing needs of ML libraries and algorithms, we believe one key technology,
differentiable programming, will help push ML development experience and
developer productivity to a whole new level.

We propose adding differentiable programming as a first-class,
language-integrated feature in Swift, making Swift become the first
general-purpose, statically-typed programming language to have automatic


At a glance, this feature includes the following additions:

  • A @differentiable(reverse) declaration attribute for declaring
    differentiable functions.
  • @differentiable(reverse) function types.
  • A @derivative(of:) attribute for defining custom derivatives.
  • A Differentiation module to be distributed in Swift releases, containing:
    • A Differentiable protocol, generalizing data structures that are
    • Differential operators (e.g. gradient(of:)), for evaluating the
      derivatives of functions.

Differentiable programming is a new paradigm for programming in which programs
can be differentiated throughout. At a glance, differentiable programming lets
you take the derivative of functions whose parameters and results conform to the
Differentiable protocol.

import Differentiation

func f(_ x: SIMD32<Float>) -> Float {
    (x * x).sum()
let dfdx = gradient(of: f)
dfdx(SIMD32(repeating: 3)) // SIMD32([6, 6, 6, 6, ...])

The ability to get derivatives of programs enables a new world of numerical
computing applications, notably machine learning. With first-class support,
gradient-based learning algorithms can even be built using standard library
types such as Float and SIMD64<Float> and be differentiated using
protocol-oriented APIs such as valueWithGradient(at:in:).

import Differentiation

struct Perceptron: Differentiable {
    var weight: SIMD2<Float> = .random(in: -1..<1)
    var bias: Float = 0

    func callAsFunction(_ input: SIMD2<Float>) -> Float {
        (weight * input).sum() + bias

var model = Perceptron()
let andGateData: [(x: SIMD2<Float>, y: Float)] = [
    (x: [0, 0], y: 0),
    (x: [0, 1], y: 0),
    (x: [1, 0], y: 0),
    (x: [1, 1], y: 1),
for _ in 0..<100 {
    let (loss, modelGradient) = valueWithGradient(at: model) { model -> Float in
        var loss: Float = 0
        for (x, y) in andGateData {
            let prediction = model(x)
            let error = y - prediction
            loss = loss + error * error / 2
        return loss
    model.weight -= modelGradient.weight * 0.02
    model.bias -= modelGradient.bias * 0.02

Differentiable programming scales up from simple examples like this to
full-fledged machine learning models using neural networks. Neural networks are
similar to the Perceptron example above in that it contains trainable
parameters (commonly part of neural network layers) and each parameter can be
modified based on gradient of a loss with respect to each parameter. Neural
network layers can be generalized by a protocol that inherits from

// Example library:
public protocol Layer: Differentiable {
    associatedtype Input: Differentiable
    associatedtype Output: Differentiable

    func callAsFunction(_ input: Input) -> Output

public class Dense: Layer { ... }
public class Convolution: Layer { ... }
public struct NDArray: Differentiable { ... }

// Client code:
final class MyModel: Layer {
    let dense1: Dense
    let dense2: Dense

    func callAsFunction(_ input: NDArray<Float>) -> NDArray<Float> {

While the differentiation APIs are flexible and fully dynamic, differentiation
is based on a program transformation that happens at compile time. This enables
many static analyses that not only help produce more efficient code but also
detect common numerical programming mistakes such as non-differentiable
functions and zero derivatives.

let grad = gradient(at: 1.0) { x in
test.swift:2:4: warning: result does not depend on differentiation arguments and will always have a zero derivative
test.swift:2:4: note: add 'withoutDerivative(at:)' to silence the warning if zero derivatives are intentional
    withoutDerivative(at:  )

Unlike library-based automatic differentiation, differentiable programming makes
many common runtime errors in machine learning become directly debuggable using
LLDB without library boundaries. Also contrary to library-based approaches,
differential operators offered in the Differentiation library can be used to
take the derivative of functions on any type that conforms to the
Differentiable protocol, such as Float, SIMD4<Double>, Complex<Double>,
[Float] and custom types. This enables programmers to integrate gradient-based
learning algorithms, physical simulations, and scientific experiments directly
in their applications without having to incorporate any embedded domain-specific
language or an automatic differentiation algorithm.

Example: Intelligent apps

One example that uses gradient-based machine learning techniques to enhance user
experiences of an app is providing intellience based on learned user behavior.
Intelligent apps can make predictions, provide suggestions, and learn user
preferences: all of these can be powered by differentiable programming.

The core of such an intelligent app is a function with real-valued "trainable
parameters". Differentiation can be used to systematically optimize (i.e. find
"good" values for) these parameters via gradient descent. (Optimizing these
parameters via conventional algorithms is typically difficult or intractable.)

Consider a podcast player that tries to automatically adjust the playback speed
based on the podcast type and the podcast section. We can define its business
logic as the following, as well as a "model" which contains real-valued
parameters that control how inputs get mapped onto outputs.

enum PodcastCategory: Int {
    case comedy
    case news

enum PodcastSection: Int {
    case advertisement
    case introduction
    case body
    case conclusion

struct PodcastState {
    let category: PodcastCategory
    let section: PodcastSection

struct PodcastSpeedModel: Differentiable {
    var minSpeed, maxSpeed: Float
    /// The multiplier for each podcast category.
    var categoryMultipliers: [Float] 
    /// The multiplier for each podcast section.
    var sectionMultipliers: [Float]

    /// Returns a podcast speed multiplier prediction for the given podcast category
    /// and section.
    func prediction(for state: PodcastState) -> Float {
        let speed = categoryMultipliers[state.category] * sectionMultipliers[state.section]
        if speed < minSpeed { return minSpeed }
        if speed > maxSpeed { return maxSpeed }
        return speed

Parameters in this podcast speed model, represented as stored properties in the
struct, determine how quickly the podcast should play under different
circumstances: minSpeed, maxSpeed, categoryMultipliers, and
sectionMultipliers. A priori, it is not clear what good parameter values are,
and different users may prefer different parameter values.

An intelligent application could determine personalized parameter values as

  1. Let the user set the speed manually, and record observations whenever the
    user changes the speed.

  2. After collecting enough observations, search for parameter values such that
    the model predicts speeds close to the user's preferred speed. If such
    values are found, offer to start automatically setting the speed.

"Gradient descent" is an algorithm that performs this search, and a language
that supports differentiable programming makes it easy to implement gradient
descent. Here is some pseudocode illustrating gradient descent.

First, we need an objective function for gradient descent to minimize.
Mean absolute error is used

struct Observation {
    var podcastState: PodcastState
    var userSpeed: Float

func meanError(for model: PodcastSpeedModel, _ observations: [Observation]) -> Float {
    var error: Float = 0
    for observation in observations {
        error += abs(model.prediction(for: observation.podcastState) - observation.userSpeed)
    return error / Float(observations.count)

Next, we implement the gradient descent algorithm. In the loop, we take the
gradient of the mean error with respect to the model (i.e. with respect to its
properties such as minSpeed and categoryMultipliers). After some iterations,
the mean error will be minimized and the model will produce more "correct"
results based on its learning.

var model = PodcastSpeedModel()
let observations = storage.observations()
for _ in 0..<1000 {
    // The language differentiates `meanError` to get a "gradient", which is a value indicating
    // how to change `model` in order to decrease the value of `meanError`.
    let modelGradient = gradient(at: model) { meanError(for: $0, observations) }

    // Change `model` in the direction that decreased the value of `meanError`.
    let learningRate = 0.01
    model.minSpeed -= learningRate * modelGradient.minSpeed
    model.maxSpeed -= learningRate * modelGradient.maxSpeed
    for i in model.categoryMultipliers.indices {
        model.categoryMultipliers[i] -= learningRate * modelGradient.categoryMultipliers[i]
    for i in model.sectionMultipliers.indices {
        model.sectionMultipliers[i] -= learningRate * modelGradient.sectionMultipliers[i]

As we can see, differentiable programming enables developers to effortlessly
incorporate extremely lightweight gradient-based learning algorithms into
applications, while having derivative code synthesized automatically by Swift.

Language-integrated differentiable programming benefits not only ML
practitioners and app developers, but also developers of ML and scientific
computing frameworks. Relying on a single language-integrated differentiable
programming eliminates the burden of separately maintaining an automatic
differentiation algorithm and a domain-specific langauge, easing the development
and maintenance overhead.


This section is abridged! Please followed the link above to see the full text.

Math introduction

This section is abridged! Please followed the link above to see the full text.

History of differentiation algorithms

This section is abridged! Please followed the link above to see the full text.

Proposed solution

To push Swift's capabilities to the next level in numerics and machine learning,
we introduce differentiable programming as a new language feature, which
includes standard library APIs and small additive changes to the type system.

The Differentiable protocol

Differentiable is a protocol defined in the standard library that generalizes
all data structures that can be a parameter or result of a differentiable
function. The compiler derives protocol requirement implementations when a
conformance is declared and when any implementation is missing.

extension Float: Differentiable {
    typealias TangentVector = Self
struct Perceptron: Differentiable {
    var weight: SIMD64<Float>
    var bias: Float

The @differentiable(reverse) declaration attribute

The @differentiable(reverse) declaration attribute is an attribute that marks
function-like declarations (function declarations, initializers, properties, and
subscripts) as being differentiable.

func cubed(_ x: Float) -> Float {
    x * x * x
extension Perceptron {
    func callAsFunction(_ input: SIMD64<Float>) -> Float {
        (weight * input).sum() + bias

In Differentiable Programming Manifesto, it is described that the
differentiable programming feature uses @differentiable without (reverse).
However, we choose not to use @differentiable here because the initial set of
proposed feature do not include forward-mode differentiation. Adding (reverse)
makes room for future feature addition without ABI breakage.

@differentiable(reverse) function types

Differentiable functions are first-class values, identified by a
@differentiable(reverse) attribute in the function type. A @differentiable(reverse) function
type is a subtype of its corresponding normal function type (i.e. without a
@differentiable(reverse) attribute) with an extended ABI, which stores extra
information that allows their values to be differentiated anywhere the function
is passed. A normal function can be implicitly converted to a @differentiable(reverse)
function with appropriate compile-time checks.

func addOne(_ x: Float) -> Float { x + 1 }
let _: @differentiable(reverse) (Float) -> Float = addOne

@derivative attribute

The @derivative attribute is used for declaring custom derivative functions
for some other function declaration. This attribute can be used by libraries to
define differentiable functions that are "primitives", i.e. ones that the
compiler cannot differentiate automatically, or by the user to define special
behavior for debugging and performance tuning purposes.

The Differentiation library uses this attribute to define derivatives for math
functions, such as expf(_:) in the C standard library.

import Darwin // Or 'Glibc' on Linux

@derivative(of: expf)
func derivativeOfExpf(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    let y = expf(x)
    return (value: y, pullback: { v in v * y })

Differential operators

Standard library differentiation APIs that take @differentiable(reverse) functions and
return derivative functions or compute derivative values.

// In the standard library:
//     public func gradient<T, R: FloatingPoint>(
//       of body: @differentiable(reverse) (T) -> R
//     ) -> (T) -> T.TangentVector where R.TangentVector == R

func f(_ x: Float) -> Float {
    x * x
let dfdx = gradient(of: f)
dfdx(3) // 6

Detailed design

Differentiable data structures

This section is abridged! Please followed the link above to see the full text.

The Differentiable protocol

The Differentiable protocol defines operations and structures required for a
type to be differentiated.

public protocol Differentiable {
    /// A type that can be used to represent derivatives with respect to a
    /// value whose type is `Self`. Mathematically, this is equivalent to the
    /// tangent bundle of the differentiable manifold represented by the
    /// differentiable type.
    associatedtype TangentVector: Differentiable & AdditiveArithmetic
        where TangentVector == TangentVector.TangentVector

    /// Moves `self` along the given direction. In Riemannian geometry, this is
    /// equivalent to exponential map, which moves `self` on the geodesic
    /// surface along the given tangent vector.
    mutating func move(along direction: TangentVector)
    /// A closure that produces a zero tangent vector and does not capture `self`.
    /// In some cases, the zero tangent vector of `self` is equal to
    /// `TangentVector.zero`. In other cases, the zero tangent vector depends on
    /// information in `self`, such as shape for an n-dimensional array type.
    /// For differentiable programming, it is more memory-efficient to define a
    /// custom `zeroTangentVectorInitializer` property which returns a closure
    /// that captures and uses only the necessary information to create a zero
    /// tangent vector. For example:
    /// ```swift
    /// struct Vector {
    ///     var scalars: [Float]
    ///     var count: Int { scalars.count }
    ///     init(repeating repeatedElement: Float, count: Int) { ... }
    /// }
    /// extension Vector: Differentiable {
    ///     typealias TangentVector = Vector
    ///     @noDerivative
    ///     var zeroTangentVectorInitializer: () -> TangentVector {
    ///         let count = self.count
    ///         return { TangentVector(repeating: 0, count: count) }
    ///     }
    /// }
    /// ```
    var zeroTangentVectorInitializer: () -> TangentVector { get }

extension Differentiable {
    /// A tangent vector such that `move(along: zeroTangentVector)` will not modify
    /// `self`.
    var zeroTangentVector: TangentVector { zeroTangentVectorInitializer() }

This section is abridged! Please followed the link above to see the full text.

Differentiable function declarations

This section is abridged! Please followed the link above to see the full text.

Make a function differentiable using @derivative

This section is abridged! Please followed the link above to see the full text.

Differentiable function types

This section is abridged! Please followed the link above to see the full text.

Differential operators

The Differentiation module will provide APIs which developers can use to
obtain gradient functions, gradient vectors, and pullback closures, along with
efficiently-computed original results from a given @differentiable(reverse)
closure. These APIs are called "differential opeators".


gradient(of:) is a higher-order function which behaves exactly like the 𝛁
(Del) operator in mathematics. It takes a
differentiable closure that returns a scalar and its gradient function, i.e. a
closure which accepts the same arguments as the input closure but returns
gradient vectors with respect to the input closure's parameter.

/// Returns the gradient function of the given closure with respect to the argument.
/// - Parameter:
///   - body: A closure whose derivative function will be evaluated.
/// - Returns: A gradient vector with respect to `x`.
func gradient<T: Differentiable, R: FloatingPoint & Differentiable>(
    of body: @escaping @differentiable(reverse) (T) -> R
) -> (T) -> T.TangentVector where R.TangentVector: FloatingPoint


gradient(at:in:) is the "uncurried" form of gradient(of:). It takes a value
and a differentiable closure that returns a scalar, and evalutes the closure's
gradient function on the value.

/// Returns the gradient vector with respect to the argument by evaluating the
/// provided closure's derivative at the argument.
/// - Parameter:
///   - x: An argument to be passed to `body`.
///   - body: A closure whose derivative function will be evaluated.
/// - Returns: A gradient vector with respect to `x`.
func gradient<T: Differentiable, R: FloatingPoint & Differentiable>(
    at x: T, in body: @differentiable(reverse) (T) -> R
) -> T.TangentVector where R.TangentVector: FloatingPoint

The call sites of this API read as if the call is feeding an argument into the
trailing closure, getting back a gradient vector. This API is consistent with
developers' mental model on taking the gradient of algorithms, and therefore
will be the most commonly used API. For example, a deep learning model's
training loop may look like the following.

for _ in 0..<1000 {
    // Differentiate the loss with respect to the model `classifier` itself, producing a
    // tangent vector `modelGradient` that represents partial derivatives with respect to
    // all trainable model parameters in the model.
    let modelGradient = gradient(at: classifier) { classifier in
        let prediction = classifier(x)
        let loss = softmaxCrossEntropy(logits: prediction, labels: y)
        print("Loss: \(loss)")
        return loss
    optimizer.performStep(for: model, along: modelGradient)


Sometimes the developer needs to obtain both the original result and the
gradient vector. While it is possible for the developer to call the
differentiable closure and gradient(at:in:) separately, it would lead to
significant recomputation overhead, because computing the gradient vector of a
differentiable closure at a value will already compute the closure's original
result. valueWithGradient(at:in:) is an API for efficiently computing both the
original result and the gradient vector.

/// Returns the result and gradient vector with respect to the argument by evaluating the
/// provided closure's derivative at the argument.
/// - Parameter:
///   - x: An argument to be passed to `body`.
///   - body: A closure whose derivative function will be evaluated.
/// - Returns: The result of `body` evaluated on `x`, equivalent to `body(x)`, and
///   a gradient vector with respect to `x`.
func valueWithGradient<T: Differentiable, R: FloatingPoint & Differentiable>(
    at x: T, in body: @differentiable(reverse) (T) -> R
) -> (value: R, gradient: T.TangentVector) where R.TangentVector: FloatingPoint
// Example: Want both the result and the gradient of `foo(x)`.
func foo(_ x: Double) -> Double {
let x = 2.0

// Slow way:
let y = foo(x)
let dydx = gradient(at: x, in: foo)

// Efficient way:
let (y, dydx) = valueWithGradient(at: x, in: foo)


valueWithPullback(at:in:) is the most general form of differential operator
for reverse-mode automatic differentiation. Unlike valueWithGradient(at:in:)
which directly computes the gradient vector, valueWithPullback(at:in:) returns
a pullback closure that represents a linear approximation of the input closure
at the given value. This formulation corresponds exactly to derivative functions
that are defined with @derivative, and enables the most flexibility and
composability. In fact, all other differential operators discussed above are
implemented in terms of valueWithPullback(at:in:).

/// Returns the result and pullback closure by evaluating the provided closure's
/// derivative at the argument.
/// - Parameter:
///   - x: An argument to be passed to `body`.
///   - body: A closure whose derivative function will be evaluated.
/// - Returns: The result of `body` evaluated on `x`, equivalent to `body(x)`, and
///   a pullback closure, which represents a transposed linear combination that
///   approximates `body` at `x`. When evaluated on a tangent vector, `pullback` evaluates
///   the linear comibination on the tangent vector and returns a gradient vector with
///   respect to `x`.
func valueWithPullback<T: Differentiable, R: Differentiable>(
    at x: T, in body: @differentiable(reverse) (T) -> R
) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector)

Static analysis

Differentiable programming in Swift aims to provide the best static compiler
diagnostics to help users catch mistakes. Beyond error diagnostics, the compiler
and the standard library are equipped with static analyses and marker APIs that
help the user write differentiable code with explicit annotations about
non-obvious non-differentiable cases.

This section is abridged! Please followed the link above to see the full text.

Source compatibility

This feature does not change any existing APIs. While the addition of
@differentiable(reverse) function types changes the function implicit
conversion rules in the type checker, the relevent code paths are only triggered
when a @differentiable(reverse) function type is involved in a contextual

Effect on ABI stability

The ABI changes proposed is purely additive. Protocols with requirements marked
with @differentiable(reverse) will contain an extra entry storing its
corresponding derivative function, provided by conforming types. Similarly,
@differentiable(reverse) is a new function representation that represents a
bundle of two functions, the original function and the derivative function.

Effect on API resilience

This feature adds the Differentiable protocol and
differential operators to the standard library as
public APIs. They introduce additions to the standard library.

Differentiable protocol

The Differentiable protocol contains all necessary requirements for a type to
be differentiated. Without breaking API, it will be possible to add extensions
to the Differentiable protocol and add new requirements with default

Differential operators

Differential operators (e.g. derivative(of:) and gradient(of:)) are added to
the standard library as lightweight top-level higher-order functions. These APIs
can be renamed or moved under some namespace without breaking ABI.

Alternatives considered

Not support differentiable programming

We believe first-class differentiable programming is a big step towards making
Swift a real contender in the numerical computing and machine learning
landscape. Differentiable programming will enable intelligent applications,
machine learning models, scientific experiments, physical simulations, and more.

Use another language or framework for differentiable programming

Dynamic languages, like Python and Julia, have established library support for
differentiable programming. While it is possible to interoperate with these
libraries via Swift, we feel that first-class differentiable programming in
Swift is leaps ahead in expressivity, usability, and safety.

Other approaches to differentiable programming

"Approaches to automatic differentiation"
above for an overview and comparison of automatic differentiation approaches.
First-class language support for differentiation will enable convenient,
extensible, and performant differentiable programming in Swift - more so than
library-based approaches.


The development of this feature started in early 2018 as part of the Swift for
project and has been pioneered by
engineers from Google. The authors would like to thank everybody involved. See
section of the manifesto.


This link in particular seems to be linking to the manifesto (and are the main thing people clicked). Did you mean to link to https://github.com/rxwei/swift-evolution/blob/autodiff/proposals/0000-differentiable-programming.md instead?

Link fixed!

1 Like

As far as scoped down, what features were taken out?

Here's a list of features mentioned in the manifesto but out of scope for this proposal:

The manifesto also mentioned a number of general features that aren't directly related to differentiable programming. These are out of scope for this proposal as well:

  • Anonymous functions, aka. func _ in the manifesto. This would be a general-purpose feature which I think can be pitched separately to support use cases beyond differentiation (e.g. dynamic method replacement).
  • Compiler-synthesized conformances for AdditiveArithmetic.
  • @memberwise attribute for triggering derived conformances. Explicit derivation is a general feature that I think deserves its own proposal so that other derivable protocols can adopt it.

It's really great to see this coming together! I have a couple questions about some minor details of the proposal:

  1. What happens if two or more modules use @derivative to retroactively assign derivatives to the same external function? I guess as long as all the implementations are correct it doesn't matter too much which gets chosen, but it seems like this could be a problem similar to retroactive protocol conformances. It might be good to specify the behavior here, even if it's just an arbitrary derivative function that gets picked.

  2. Is the Differentiation module intended to be implicitly imported if the proposal is accepted like the Swift (and presumably Concurrency) modules? I know it currently requires an explicit import on main, so leaving this as-is would make it the first standard library module that's not implicitly imported. I don't have a strong opinion either way, but it might be worth discussing.

1 Like

The current implementation is that it will pick the first derivative found in the current module and then in imported modules. You are right that this is a very similar problem to retroactive conformances.

Differentiation is intended to be included in Swift distributions as a standalone module and will not be imported implicitly. Because this is a domain-specific feature and because there isn't a precedent in mainstream general-purpose languages, using a standalone module seems in line with progressive disclosure of complexity. Implicit import (or even merger into stdlib) can be proposed and assessed in the future. If we made it implicitly imported today then it would be difficult to revert that decision in future evolution.

1 Like

Could you illustrate how Jacobian computation would fit in that scheme?

1 Like

Since reverse-mode automatic differentiation is efficiently computing a vector-Jacobian product, to get a full real-valued Jacobian matrix you can call a pullback closure on individual basis tangent vectors to get all the rows in a Jacobian. The pullback closure itself is in fact a program representation of the linear transformation that the Jacobian matrix (transposed, to be accurate) represents.

// Let's say this is a differentiable function.
// Input size: 2
// Output size: 4
func someFunction(_ x: SIMD2<Double>) -> SIMD4<Double>

let x = SIMD2(1.0, 2.0)
let pb = pullback(at: x) { x in someFunction(x) }
pb([0, 0, 0, 1]) => [∂y4/∂x1, ∂y4/∂x2]
pb([0, 0, 1, 0]) => [∂y3/∂x1, ∂y3/∂x2]
pb([0, 1, 0, 0]) => [∂y2/∂x1, ∂y2/∂x2]
pb([1, 0, 0, 0]) => [∂y1/∂x1, ∂y1/∂x2]

Adding to the answer above, here's a generic Jacobian implementation from SwiftFusion using reverse-mode differentiation:

func jacobian<A: Differentiable, B: Differentiable>(
  of f: @differentiable (A) -> B,
  at x: A,
  basisVectors: [B.TangentVector]
) -> [A.TangentVector] {
  let pb = pullback(at: x, in: f)
  return basisVectors.map { pb($0) }

This implementation maps pullback closures over the output type's basis vectors. A forward-mode differentiation implementation would map differential closures over the input type's basis vectors.


For many computation graph optimizations (CSE, mixed-precision to name a few), would that be implemented in Swift compiler now or would that be in libraries (Differentiation?) outside of Swift core compiler?

Can we get rid of Differentiable.zeroTangentVectorInitializer and keep only (customizable) Differentiable.zeroTangentVector?

It is specifically said in the proposal that the .zeroTangentVectorInitializer is not to capture self, so I do not see a scenario where this would make any expressibility difference.

Conforming implementations are not required to be marked with @differentiable(reverse) attribute unless they are public .

I think we can omit the attribute even for public declaration. I think that'd be more inline with the recent inference rule in SE-0289 Result Builder.

There's this line in the example code:

public var zeroTangentVectorInitializer: () -> TangentVector { ... }

What does @noDerivative do in the computed property? Shouldn't it apply only to stored property?

1 Like

By CSE, are you referring to common subexpression elimination? If so, it is already part of the Swift compiler down the pipeline. Compiler-generated derivative functions are normal functions, and therefore are amenable to optimization just like normal code. (For the record, there's definitely a lot more work to be done to optimize differentiated control flow, but that can be done over time.) Moreover, the automatic differentiation algorithm, unlike symbolic differentiation, is by definition not to produce common subexpressions in the derivative code that it generates.

Not sure what you mean by "mixed-precision". Differentiation (both the compiler part and APIs defined in the Differentiation module) works generically on all Differentiable-conforming types, and is not hard-coded to recognize any math operations. So libraries have the freedom to optimize for mixed-precision computation in their derivatives.

Thank you! Yeah, these two just easy-to-pick TensorFlow specific optimizations done on the computation graph. There are of course others I would consider as computation graph / deep learning specific optimization passes: 1. optimal tensor placement (similar to register allocation: https://mxnet.apache.org/versions/1.7.0/api/architecture/note_memory.html#memory-allocation-algorithm); 2. automatic data parallelization: https://keras.io/guides/distributed_training/#singlehost-multidevice-synchronous-training; 3. automatic binomial checkpointing: https://openreview.net/forum?id=BkYYXJ9i- 4. automatic mixed-precision: https://pytorch.org/docs/stable/notes/amp_examples.html

As you said, common-subexpression-elimination is not a good example because it is well handled at the compiler layer already.

Where do you see these kind of optimization passes will reside in the future? Thanks!

Avoiding self capture is critical for ML in memory-constrained situations. Multi-dimensional arrays used for ML can be very large. When creating a zero value from a dynamically shaped mathematical object, we only need the shape and not the contents. Capturing self means that more memory will be consumed between the result computation (aka. "forward pass") and the gradient computation (aka. "backward pass").

A difference from result builders is that @differentiable(reverse) on a protocol requirement has ABI implications, i.e. it creates more entries in the protocol dispatch table, while @resultBuilder doesn't. I'm leaning towards requiring @differentiable(reverse) on public declarations for clarity and library evolution. Perhaps it can be proposed in a follow-up proposal instead to further assess the impact on library evolution.

@noDerivative helps silence a compiler error when zeroTangentVectorInitializer is being used on a path that's being differentiated. It is to tell the compiler that we intentionally make this property have no derivative and that the compiler shouldn't complain about it — any derivative on it will behave as if it's producing a zero derivative.

Here is an example where such an error would occur:

// Differentiating with respect to `x`, but `description` is not differentiable.
gradient(at: 1.0) { x in Double(x.description)! }
<location>: error: function is not differentiable
gradient(at: 1.0) { x in Double(x.description)! }
<location>: note: cannot differentiate through a non-differentiable result; do you want to add 'withoutDerivative(at:)'?
gradient(at: 1.0) { x in Double(x.description)! }

I don't think that it'd make any relevant difference between creating an empty tangent and pass it around vs passing the initializer around. Types containing large data should be managed by CoW anyway, as with other storage types. This scenario is not unique to zero-ed tangent.

If we want to avoid capturing self by accident (from an attempt to defer tangent init), we should instead remove zeroTangentVector, but I think the most Swift-y design would encourage creating zero tangent upfront and pass that tangent around.

From Differentiability Parameters section:

func foo<T: Differentiable>(_ x: T, _ y: T, _ z: T) -> T { ... }

// Derivative with respect to all parameters.
@derivative(of: foo)
func derivativeOfFoo<T: Differentiable>(_ x: T, _ y: T, _ z: T)
  -> (
    value: T,
    pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector, T.TangentVector)
  ) { ... }

The pullback result looks confusing. What if we use labeled tuple instead, and use wrt only for single-parameter case?

@derivative(of: foo)
func derivativeOfFoo<T: Differentiable>(_ x: T, _ y: T, _ z: T)
  -> (
    value: T,
    pullback: (T.TangentVector) -> (x: T.TangentVector, y: T.TangentVector, z: T.TangentVector)
  ) { ... }

// Derivative with respect to `x` and `z`.
@derivative(of: foo)
func derivativeOfFoo<T: Differentiable>(_ x: T, _ y: T, _ z: T)
  -> (
    value: T,
    pullback: (T.TangentVector) -> (x: T.TangentVector, z: T.TangentVector)
  ) { ... }

So then we can also do things like pullback(4).x which would be quite natural.

We view these as mostly orthogonal issues. The differentiable programming language feature is not responsible for targeting heterogenous compute, nor is it building a computation graph. It is responsible for producing derivative code, which can then be run (at the client libraries' will) to produce any staged program representation (such as a computation graph) for optimization purposes. Most of these can be customized entirely in client libraries and the differentiable programming language feature doesn't have to know about it.

AD checkpointing is definitely a relevant topic. Manual checkpointing can be done entirely at the library level - you can define a custom higher order function that makes a differentiable function's pullback recompute the original result.

func discardingCheckpoints<T: Differentiable, R: Differentiable>(
    _ f: @differentiable(reverse) (T) -> R
) -> @differentiable(reverse) (T) -> R {
    func body(_ x: T) -> R { f(x) }
    @derivative(of: body)
    func derivative(_ x: T) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) {
        (value: f(x), pullback: { v in pullback(at: x, in: f) })
    return body

Automatic checkpointing has been an active research topic for decades. To support this automatically on any type will require additional attributes and additional protocol customization points in the Differentiable protocol (or an inheriting protocol) that the compiler can generate code to call and check whether an operation's inputs, results, as well as the operation itself have characteristics that warrant a recompute. Then something like binomial checkpointing can be applied there in the compiler. This is a possible future direction but is not required in the initial system.

It does make a big difference in ML code. An empty/zero tangent can have a very large size. Either keeping self around for later zero creation or keeping a large zero tangent vector around can lead to increasing the memory footprint (it's not an issue of CoW; the value will stay alive because the pullback captures it and the pullback won't be executed/released until later). In contrast, a closure lets the differentiable type's implementer choose to capture all the necessary information needed for materializing a zero, and only that necessary information chosen by the type implementer will be kept between the forward pass and the backward pass.

To be clear, zeroTangentVectorInitializer is not to be called by the user or library developer directly. It is to be implemented by the type and be called by the compiler-generated derivative code of a function. Removing zeroTangentVector seems totally fine to me though.

Ok, I can see that scenario now. Then I think we should remove zeroTangentVector. Both zeroTangentVector and zeroTangentVectorInitializer are sending different signals about whether or not we should eagerly create zero tangents, and if the target user of this area of functionality is the compiler, then zeroTangent looks even less useful.

Also, maybe we can rename TangentVector to just Tangent?

Why is it type-significant to differentiate between forward- and reverse-mode AD? I'd think that it should get the same treatment from the type-level perspective, and that the AD mode is an implementation detail. So

// instead of this
@differentiable(reverse) (Float, Float, Int) -> Float

// We use this
@differentiable (Float, Float, Int) -> Float

It seems that the types of the functions are:

// normal function
(Float, Float, Int) -> Float

// differentiable function
@differentiable (Float, Float, Int) -> Float

// partial differentiable function
@differentiable (@noDerivative Float, Float, Int) -> Float

Should the @differentiable at the front instead be just a sugar to apply @differentiable to all of its parameter:

// This
@differentiable (Float, Float, Int) -> Float
// is sugar for
(@differentiable Float, @differentiable Float, Int) -> Float

// This
@differentiable (@noDerivative Float, Float, Int) -> Float
// is sugar for
(Float, @differentiable Float, Int) -> Float

// This, if someone so wish
@differentiable (@noDerivative Float, @noDerivative Float, Int) -> Float
// is sugar for normal function
(Float, Float, Int) -> Float

If that's already the case, I think it could be conveyed better in the proposal.

1 Like

Thanks. Comments on Automatic checkpointing makes a lot of sense! The pullback mechanism sounds interesting for manual checkpointing for sure.

I think here is the part where I missed some contexts (or just my own confusion). In the proposal and manifesto, how a 3rd-party library can operates the computation graph and do the said transformations if needed? It seems the hook-point would be to implement @Differentiable protocol. It then provided a mechanism to hook 3rd-party operators on heterogeneous computation devices (GPU). However, it doesn't help the 3rd-party library to get the full picture of the computation graph and do interesting things with it in their own architecture? I remember for S4TF, the computation graph indeed lowered into TensorFlow Graph at some point (maybe 1.x?). Just wondering how we plan to standardize that hook points or whether it exist at all.

Terms of Service

Privacy Policy

Cookie Policy