Hi Swift community,
We have completed a comprehensive proposal for the differentiable programming feature we’ve been incubating over the last 1.5 years. We’ve gone over many iterations on the feature design, and have partially completed the implementation. Now we are ready to start a discussion on Swift Evolution, specifically on upstreaming and standardizing the feature.
Since this proposal is overly long (~60 pages), we hope to start by merging it into the docs/
directory in apple/swift via apple/swift#27034, and draft bite-sized proposals that contain references to the mega-proposal.
We look forward to your feedback!
- Richard, Dan, Marc and Bart
Full text: See external markdown.
Abridged text: See below. This is to fit 115603 words into the 32000-word limit on the forum.
Differentiable Programming Mega-Proposal
- Authors: Richard Wei, Dan Zheng, Marc Rasi, Bart Chrzaszcz
- Status: Partially implemented
Table of contents
Introduction
This proposal introduces first-class differentiable programming to Swift. First-class differentiable programming includes five core additions:
- The
Differentiable
protocol. @differentiable
function types.- The
@differentiable
declaration attribute for defining differentiable
functions. - The
@differentiating
and@transposing
attributes for defining custom
derivatives. - Differential operators (e.g.
derivative(of:)
) in the standard library.
Differentiable programming is a new paradigm for programming in which programs can be differentiated throughout. At a glance, differentiable programming lets you take the derivative of functions whose parameters and results conform to the Differentiable
protocol.
@differentiable
func f(_ x: Float) -> Float {
x * x
}
let dfdx = derivative(of: f)
dfdx(3) // 6
The ability to get derivatives of programs enables a new world of numerical computing applications, notably machine learning. With first-class support, gradient-based learning algorithms can even be built using standard library types such as Float
and SIMD64<Float>
and be differentiated using protocol-oriented APIs such as valueWithGradient(at:in:)
.
struct Perceptron: @memberwise Differentiable {
var weight: SIMD2<Float> = .random(in: -1..<1)
var bias: Float = 0
@differentiable
func callAsFunction(_ input: SIMD2<Float>) -> Float {
(weight * input).sum() + bias
}
}
var model = Perceptron()
let andGateData: [(x: SIMD2<Float>, y: Float)] = [
(x: [0, 0], y: 0),
(x: [0, 1], y: 0),
(x: [1, 0], y: 0),
(x: [1, 1], y: 1),
]
for _ in 0..<100 {
let (loss, 𝛁loss) = valueWithGradient(at: model) { model -> Float in
var loss: Float = 0
for (x, y) in andGateData {
let ŷ = model(x)
let error = y - ŷ
loss = loss + error * error / 2
}
return loss
}
print(loss)
model.weight -= 𝛁loss.weight * 0.02
model.bias -= 𝛁loss.bias * 0.02
}
Differentiable programming scales up to full machine learning models, built with third-party libraries like TensorFlow.
import TensorFlow
let model = Sequential {
var layer1 = Dense<Float>(inputSize: 784, outputSize: 100, activation: relu)
var layer2 = Dense<Float>(inputSize: 100, outputSize: 30, activation: relu)
var layer3 = Dense<Float>(inputSize: 30, outputSize: 3, activation: identity)
}
var classifier = Model()
let optimizer = SGD(for: classifier, learningRate: 0.02)
Context.local.learningPhase = .training
let x: Tensor<Float> = ...
let y: Tensor<Int32> = ...
for _ in 0..<1000 {
let 𝛁model = gradient(at: classifier) { classifier -> Tensor<Float> in
let ŷ = classifier(x)
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
print("Loss: \(loss)")
return loss
}
optimizer.update(&classifier, along: 𝛁model)
}
While the differentiation APIs are flexible and fully dynamic, differentiation is based on a program transformation that happens at compile-time. This enables many static analyses that not only help produce more efficient programs, but also detect common numerical programming mistakes such as non-differentiable functions and zero derivatives.
let grad = gradient(at: 1.0) { x in
3.squareRoot()
}
test.swift:2:18: warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to add 'withoutDerivative(at:)' to make it explicit?
3.squareRoot()
^
withoutDerivative(at:)
With a first-class differentiable programming language, some of the most common runtime errors in machine learning become directly debuggable without library boundaries. Simply step through backpropagation using LLDB to debug derivatives.
Backpropagation debugging demo using LLDB.
Motivation
Background
In mathematics, a derivative of a function of a real variable is another function that computes the sensitivity to changes in the output of the original function with respect to changes in the original function's arguments. Differentiation is the process of computing derivatives. See the "Math Introduction" section below for more details.
Derivatives are a fundamental tool in calculus and have applications in many domains, notably deep learning. Numerical computing in Swift Swift is an expressive, high-performance language that is a great fit for numerical applications. Recent proposals have paved the way for low-level numerical computing in Swift: [AdditiveArithmetic][SE-0233], SIMD [[1][SE-0229]] [[2][SE-0251]], [generic math functions][SE-0246]. However, high-level numerical computing applications, including machine learning and artificial intelligence, require more work.
We believe that first-class differentiable programming is a big step towards high-level numerical computing support and will make Swift a real contender in the numerical computing and machine learning landscape. Differentiable programming will enable intelligent applications, machine learning models, scientific experiments, physical simulations, and more.
Intelligent applications
Intelligent applications are smart: they use machine learning techniques to enhance user experiences. Intelligent applications can make predictions, provide suggestions, and learn user preferences: all of these can be powered by differentiable programming.
The core of an intelligent application is a function with real-valued parameters. Differentiation can be used to systematically optimize (i.e. find "good" values for) these parameters via gradient descent. (Optimizing these parameters via conventional algorithms is typically difficult or intractable.)
For example, consider a podcast player that tries to automatically adjust the playback speed based on the podcast type and the podcast section.
enum PodcastCategory {
case comedy
case news
...
}
enum PodcastSection {
case advertisement
case introduction
case body
case conclusion
}
struct PodcastState {
let category: PodcastCategory
let section: PodcastSection
}
struct PodcastSpeedModel {
var minSpeed, maxSpeed: Float
var categoryMultipliers: [PodcastCategory: Float]
var sectionMultipliers: [PodcastSection: Float]
/// Returns a podcast speed multiplier prediction for the given podcast category
/// and section.
func prediction(for state: PodcastState) -> Float {
let speed = categoryMultipliers[state.category] * sectionMultipliers[state.section]
if speed < minSpeed { return minSpeed }
if speed > maxSpeed { return maxSpeed }
return speed
}
}
This podcast speed model parameters that determine how quickly the podcast should play under different circumstances: minSpeed
, maxSpeed
, categoryMultipliers
, and sectionMultipliers
. A priori, it is not clear what good parameter values are, and different users may prefer different parameter values.
An intelligent application could determine personalized parameter values as follows:
-
Let the user set the speed manually, and record observations whenever the user changes the speed.
-
After collecting enough observations, search for parameter values such that the model predicts speeds close to the user's preferred speed. If such values are found, offer to start automatically setting the speed.
"Gradient descent" is an algorithm that performs this search, and a language that supports differentiable programming makes it easy to implement gradient descent. Here is some pseudocode illustrating gradient descent.
First, we need an objective function for gradient descent to minimize. Mean absolute error is used here:
struct Observation {
var podcastState: PodcastState
var userSpeed: Float
}
func meanError(for model: PodcastSpeedModel, _ observations: [Observation]) -> Float {
var error: Float = 0
for observation in observations {
error += abs(model.prediction(for: observation.state) - observation.userSpeed)
}
return error / Float(observations.count)
}
Next, we implement the gradient descent algorithm.
var model = PodcastModel()
let observations = storage.observations()
for _ in 0..<1000 {
// The language differentiates `meanError` to get a "gradient", which is a value indicating
// how to change `model` in order to decrease the value of `meanError`.
let gradient = gradient(at: model) { meanError(for: $0, observations) }
// Change `model` in the direction that decreased the value of `meanError`.
model -= 0.01 * gradient
}
Type-safe machine learning
Today, machine learning is predominantly done in dynamically-typed languages like Python: these languages are concise and easy to use. However, some people prefer safer programming: features like type checking and static diagnostics help catch errors early and improve productivity.
Differentiable programming in Swift enables safe, powerful machine learning. Custom differentiable data structures can be declared and checked at compile-time. Thanks to protocol-oriented programming, differentiable types are generalized by a protocol, enabling differential operators to be defined as higher-order functions constrained on such a protocol. Mathematical optimization algorithms such as neural network optimizers can also be defined generically over such a protocol and work with all differentiable types.
Calculus is fun
Calculus is fun, and differentiation in the Swift toolbox will let programmers explore that fun. Here are some interesting applications:
Animations
Easing functions specify the rate of change of parameters for animations. Differentiation enables easy manipulation of these functions.
Games
Physics equations can be modeled using differentiable functions in game engines. Intelligent agents in games can be trained using techniques like machine learning that are enabled by differentiation.
Simulations
Many simulation techniques for fluids and other physical processes are based on approximate solutions to equations defined in terms of derivatives, like the Euler equations and Navier-Stokes. Being able to differentiate functions is an important building block for implementing algorithms to solve these equations.
Robotics
Control algorithms used in robotics and mechanical engineering rely on (often higher-order) derivatives of functions that model the behavior of joints and other physical systems. A language like Swift that can efficiently compute these derivatives without incurring the unpredictable runtime overhead of garbage collection may be well-placed to run aboard robots.
Rendering and ray tracing
Traditional rendering systems are black boxes that consume data structures with scene geometry and produce images, but the physical processes they simulate are made up of differentiable functions. Building a ray tracer out of differentiable building blocks unlocks applications like inverse rendering (going from an image to scene geometry). [1] [2]
History of differentiation algorithms
This section is abridged! Please see the corresponding section in the external Markdown linked above.
Approaches to automatic differentiation
In practice, automatic differentiation is the most common differentiation algorithm because it is precise and efficient. This section summarizes approaches to automatic differentiation.
Embedded domain-specific languages
This section is abridged! Please see the corresponding section in the external Markdown linked above.
Source code transformation tools
Source code transformation tools are another approach to differentiable programming. Tool users write code, select various differentiation configuration options (the name of the function-to-differentiate, the independent and dependent variable, etc), and provide them to the tool. The tool analyzes the input code and generates output code that computes derivatives according to the options.
Historically, this is one of the oldest approaches for automatic differentiation. Tools like Tapenade and ADIC/ADIFOR compute derivatives of Fortran and C code.
An advantage of source code transformation tools is that they are essentially static compilers: they can perform static analyses on input code to generate optimized derivative-computing output code. For example, Tapenade performs "activity analysis" to determine variables that do not need a derivative and "TBR (to-be-recorded) analysis" to remove unnecessary intermediate variables during differentiation.
However, these tools are not ideal for usability: users must interact with an external GUI to specify inputs and they receive a textual program as output. This external workflow is an extra indirection that takes users out of their natural programming environment. Exposing the tool-provided differentiation features within a language would be more ergonomic.
Image of Tapenade web interface.
User specifies input program and configuration options.
Tapenade generates derivative-computing output program.
First-class language support
Another class of differentiable programming approaches is by integrating the differentiation semantics and code transformations into a programming language to some degree. While there are no mainstream programming languages that support differentiable programming, research systems like Stalin∇ add first-class differential operators (e.g. grad
) into the language and the reverse-mode automatic differentiation transformation into the compiler.
First-class language support for differentiation can reap the benefits of source code transformation techniques (e.g. language coverage, performant derivative code) without requiring programmers to use an external tool. Well-designed, powerful differentiation primitives enable users to define their own custom differentiation APIs that would otherwise not be possible in differentiation libraries.
Why bake differentiation into Swift?
First-class language support for differentiation will enable convenient, extensible, and performant differentiable programming in Swift.
Maximal coverage of Swift language features
First-class support for differentiation in Swift enables differentiation to work nicely with a maximal number of Swift language features, including mutation and control flow. Users of differentiable programming do not need to write in a restricted subset of Swift: just write normal code and use differentiation.
Extensibility
First-class language support enables an extensible differentiable programming system.
Custom types can be extended to be differentiable with minimal boilerplate. Custom derivative functions can be retroactively registered for existing functions. Users can define custom differentiation APIs using the powerful primitive operators defined in the standard library and supported by the type system.
Static warnings and errors
Some functions perform non-differentiable operations (on the path from parameters to result) and thus cannot be differentiated. Functions that do not use their parameters to compute the result are technically differentiable, but the derivative is trivially always zero.
With language support for differentiation, the compiler can identify these cases statically via data flow analysis and produce a non-differentiability error or warning. These diagnostics improve productivity and help users catch errors ahead of time. Library-based differentiation approaches cannot generally provide these diagnostics.
For details on static warnings and errors, see the "Static analysis" section in the detailed design below.
The pursuit for user-defined code transformations
The key code transformation enabling differentiable programming is "derivative code generation". Derivative code generation implements automatic differentiation: given an "original function" to differentiate, a derivative function is generated by replacing function applications in the original function with corresponding derivative function applications. The algorithm is described in detail in the Swift Differentiable Programming Implementation Overview document.
Some languages provide the ability to define custom code transformations:
-
Macros enable syntax-based code transformations at compile-time. Hygienic macros (macro systems that avoid accidental variable capture) are available in a variety of languages, including Lisp, Julia, Rust, and Scala, to name a few. As an example: generated type-safe schema wrappers can implemented using hygienic macros in Scala.
-
Compiler plugin systems enable programmers to write plugins that extend the behavior of a compiler. Compiler plugins are more popular in bootstrapped languages, like Haskell, Rust and Scala, where the plugin can be written in the language itself. As an example: a continuation-passing-style code transformation can be implemented as a compiler plugin in Scala.
One might make the case that derivative code generation for differentiation is better implemented as a custom code transformation. While that may be true in theory, Swift does not yet support custom code transformations in practice. This proposal presents differentiable programming as a system of high-level language features and semantics; derivative code generation is an implementation detail. If a system for custom code transformations is added to Swift one day, it may be possible to reimplement derivative code generation using that system without changing the high-level differentiable programming features proposed here.
Math introduction
What is a derivative?
The derivative of a function f
measures how quickly the function's output changes when you make small changes to the function's input. The value of this measurement depends on the input x
that you start with, and we call the value of the measurement starting at that input "the derivative of f
at x
.
For a single variable real function (a function with a single real input and a single real output), the derivative of f
at x
can be summarized as a single real number f'(x)
such that f(x + ε) ~= f(x) + f'(x) * ε
. In other words, changing the input by a tiny amount epsilon
changes the output by f'(x) * ε
.
f(x) = x
changes by exactly ε
whenever you change
its input by ε
, so its derivative is 1 everywhere.
Near
x = 0
, f(x) = x^2
changes very little when you
change its input, so its derivative at x = 0
is 0
(see orange line).
Near
x = 1
, f(x) = x^2
changes by approximately
2*ε
when you change its input by ε
, so its
derivative at x = 1
is 2
(see green line).
In general, the derivative of
f(x) = x^2
at x
is
2*x
.
Iterative optimization
Iterative optimization algorithms use derivatives to optimize functions (i.e. find the inputs that minimize or maximize the output of the function). For example, the simple "gradient descent" algorithm starts with an arbitrary input x
and uses the derivative of the function at x
to determine whether it needs to increase or decrease x
to decrease the output of the function. Then it mutates x
slightly along the appropriate direction and repeats until the output stops decreasing.
Derivatives of functions with arbitrary inputs
Real world programs deal with data more complicated than single real variables. Fortunately, there are mathematical theories that extend derivatives to functions with nearly arbitrary inputs and outputs.
Recall our original description of derivative: "The derivative of a function f
measures how quickly the function's output changes when you make small changes to the function's input." This makes sense for arbitrary input and output types, as long as we can describe small changes in them.
It is easy to describe small changes in nested structures of real numbers: they are just small changes in all the components' real numbers. For example, consider:
struct Point {
var x, y: Float
}
struct PointPair {
var p1, p2: Point
}
A small change in Point
might be "add 0.01
to x
and add 0.02
to y". A small change in PointPair
might be "add 0.01
to p1.x
and add 0.01
to p2.x
".
We can define new types that capture the values of these small changes. We call these types "tangent vectors", a term from math. For example:
extension Point {
struct TangentVector {
// `dx` and `dy` are small changes in `x` and `y`, respectively.
var dx, dy: Float
}
}
extension PointPair {
struct TangentVector {
// `dp1` and `dp2` are small changes in `p1` and `p2`, respectively.
var dp1, dp2: Point.TangentVector
}
}
In terms of these tangent vectors, the small changes that we described in words above would be:
Point.TangentVector(dx: 0.01, dy: 0.02)
PointPair.TangentVector(
p1: Point.TangentVector(dx: 0.01, dy: 0),
p2: Point.TangentVector(dx: 0.01, dy: 0))
In terms of tangent vectors, the derivative of a function f: (A) -> B
is a function df: (A, A.TangentVector) -> B.TangentVector
. In other words, df
takes a starting value of type A
and a small change A.TangentVector
and tells you what the resulting small change in B
is.
The gradient descent iterative optimization algorithm can run on any function f: (A) -> Float
as long as A
is a type for which we can define a tangent vector. It iteratively walks around different values of A
, searching for a value that minimizes the output of f
.
Proposed solution
To push Swift's capabilities to the next level in numerics and machine learning, we introduce differentiable programming as a new language feature, which includes standard library APIs and small additive changes to the type system.
The Differentiable
protocol
Differentiable
is a standard library protocol that generalizes all data structures that can be a parameter or result of a differentiable function. The compiler derives protocol requirement implementations when a @memberwise
conformance is declared.
extension Float: Differentiable {
typealias TangentVector = Self
}
struct Perceptron: @memberwise Differentiable {
var weight: SIMD64<Float>
var bias: Float
}
The @differentiable
declaration attribute
The @differentiable
declaration attribute is an attribute that marks function-like declarations (function declarations, initializers, properties, and subscripts) as being differentiable.
@differentiable
func cubed(_ x: Float) -> Float {
x * x * x
}
extension Perceptron {
@differentiable
func callAsFunction(_ input: SIMD64<Float>) -> Float {
(weight * input).sum() + bias
}
}
@differentiable
function types
A subtype of normal function types with a different runtime representation, which stores metadata that allows their values to be differentiated anywhere.
func addOne(_ x: Float) -> Float { x + 1 }
let _: @differentiable (Float) -> Float = addOne
let _: @differentiable(linear) (Float) -> Float = addOne
@differentiating
and @transposing
attributes
@differentiating
and @transposing
attributes are used for declaring custom derivative functions for some other function declaration.
import Glibc
@differentiating(expf)
func _(_ x: Float) -> (value: Float,
differential: @differentiable(linear) (Float) -> Float) {
let y = expf(x)
return (value: y, differential: { v in v * y })
}
Differential operators
Standard library differentiation APIs that take @differentiable
functions and return derivative functions or compute derivative values.
// In the standard library:
//
// func derivative<T: FloatingPoint, R>(
// of body: @escaping @differentiable (T) -> R
// ) -> (T) -> R where T.TangentVector: FloatingPoint
@differentiable
func f(_ x: Float) -> Float {
x * x
}
let dfdx = derivative(of: f)
dfdx(3) // 6
Detailed design
This section is abridged! Please see the corresponding section in the external Markdown linked above.
Examples of differentiable programming
This section is abridged! Please see the corresponding section in the external Markdown linked above.
Future directions
This section is abridged! Please see the corresponding section in the external Markdown linked above.
Source compatibility
This section is abridged! Please see the corresponding section in the external Markdown linked above.
Alternatives considered
This section is abridged! Please see the corresponding section in the external Markdown linked above.
Acknowledgements
This section is abridged! Please see the corresponding section in the external Markdown linked above.