I’ve been thinking about adding support for reverse-mode automatic differentiation in Scala and have been running into some issues for which you may be able to provide advice. The core problem I want to discuss here is that of, given a function of type T => R we want to synthesize a function of type T => (R, TangentR => TangentT) that will return its result and a function that computes its vector-Jacobian product (VJP). We want to be able to use the synthesized function from the same project. I have considered a couple of ways of doing that in dotty:
Wrapper Types and Operator Overloading: We can create a wrapper type Variable[T] that contains an accumulator for the gradient of this variable, along with overloaded operators for operations using variables. This approach cannot really return a VJP function as it relies on side effects for computing gradients and as such it can also not be used to compute higher-order derivatives. I also implemented some early prototypes of this approach and its performance is significantly worse than the alternative presented next.
Source Code Transformation: This would require adding methods to an existing class/object which cannot be done with macros in dotty. Furthermore, the transformation needs to happen before type-checking so that the synthesized methods can be used by code in the same project. This makes this solution impossible without forking the dotty compiler which is not desirable. If you have any advice on this, that would also be greatly appreciated.
Staging: This a solution that may be plausible but I do not really know how to go about it. Say we define a function vjp(fun: T => R, at: T): (R, TangentR => TangentT) that traces the execution of the provided function at the provided point and returns its results along with a VJP function. The VJP function would be computed by taking the expression that results from tracing the execution of fun and transforming it. The resulting expression would need to be staged and compiled at runtime. However, ideally we only want to compile that VJP expression once for a given function fun and value at. I’m not sure if this is currently possible in dotty. I’m also not sure how different values of at should be handled, as ideally we do not want to compile the gradient function multiple times. Given however that the control flow of the traced function may depend on the value of at tracing and recompiling the VJP may be desirable. In this case, how expensive would that tracing and recompiling be? I don’t really have an idea on the cost and wonder if it is worth it. @biboudis, @smarter mentioned you may be able to advice on this.
Any other ideas/thoughts on this would be greatly appreciated. Previously source code transformations for automatic differentiation have been implemented successfully in Swift (see e.g., here for details).
Use strategy 2, but instead of insisting that the whole method must be synthesized, you require the developer to explicitly write down its signature, while its body is just a call to a macro that will perform the automatic derivation and synthesize the code.
I like the approach explained in 1). You might be interested in having a look at Noel Welsh’s implementation: https://github.com/noelwelsh/fdl
If performance is a concern, then I recommend using staging on top of it, as you mentioned in 3), to generate specialized code for a fixed function fun and value at. You can draw inspiration from the paper Demystifying Differentiable Programming (Fei Wang et. al.) and the project Lantern.
I like this proposal. The one issue I see is that while synthesizing the body of the method, I would need to know the names of all VJPs of the functions used in that method. This means that the naming scheme of these synthesized methods should be standardized. Is it ok to expect the developer to specify the correct name? If so, can I check for the name of the method during the macro expansion phase and throw an error if it’s not in the standard format?
I like this method because, in principle, I can later add a code preprocessing / generation phase that converts annotations on methods that I want to have differentiated to new methods that call the relevant macro.
Thanks for your response! I have provided some replies inline.
I just skimmed through the code of this project and it seems to me that the reverse-mode auto-diff code is not correct. Specifically, consider a case where a variable is used in multiple parts of an expression (e.g., x*x + x*y). Then, how are the gradients of that variable being accumulated over the whole expression? This example also shows why this approach typically requires some mutable state in the variable itself, which accumulates its gradients.
As you also may have noticed, I did look into Lantern before. However, the main issue there is that it relies on the delimited continuations plugin which is not maintained anymore. Also, there is no delimited continuations plugin in dotty at all. I could also implement some kind of continuations manually using function callbacks, but this approach still requires that common mutable state in the variable wrapper types, which is not ideal as it can result in bad side effects when used in a parallelized setting.
I also want to point out that approach 1) makes it hard to work out a solution for higher-order differentiation, which is something that I would like to support.
for my project, I’ve opted for 1 - the wrapper type. This makes it easy to add new functions and/or data types, as these can be handled in scala code rather than macro code. These then easily compose, whereas for source code transformation the whole graph needs to be accessible to the macro.
Maybe a mixture could yield the best of both though - have a macro that allows one to create such functions out of a handful of predefined operations. That should certainly be good for performance, as the number of intermediate objects can be greatly reduced.
I good AD library will need not only AD, but also operator fusion, especially when running on GPU.
I created DeepLearning.scala to perform compile-time reverse AD based on RAII Monads, with the help of Each, which performs general CPS translation. I also created Compute.scala for operator fusion at runtime (aka JIT).