I’ve been thinking about adding support for reverse-mode automatic differentiation in Scala and have been running into some issues for which you may be able to provide advice. The core problem I want to discuss here is that of, given a function of type
T => R we want to synthesize a function of type
T => (R, TangentR => TangentT) that will return its result and a function that computes its vector-Jacobian product (VJP). We want to be able to use the synthesized function from the same project. I have considered a couple of ways of doing that in dotty:
Wrapper Types and Operator Overloading: We can create a wrapper type
Variable[T]that contains an accumulator for the gradient of this variable, along with overloaded operators for operations using variables. This approach cannot really return a VJP function as it relies on side effects for computing gradients and as such it can also not be used to compute higher-order derivatives. I also implemented some early prototypes of this approach and its performance is significantly worse than the alternative presented next.
- Source Code Transformation: This would require adding methods to an existing class/object which cannot be done with macros in dotty. Furthermore, the transformation needs to happen before type-checking so that the synthesized methods can be used by code in the same project. This makes this solution impossible without forking the dotty compiler which is not desirable. If you have any advice on this, that would also be greatly appreciated.
Staging: This a solution that may be plausible but I do not really know how to go about it. Say we define a function
vjp(fun: T => R, at: T): (R, TangentR => TangentT)that traces the execution of the provided function at the provided point and returns its results along with a VJP function. The VJP function would be computed by taking the expression that results from tracing the execution of
funand transforming it. The resulting expression would need to be staged and compiled at runtime. However, ideally we only want to compile that VJP expression once for a given function
at. I’m not sure if this is currently possible in dotty. I’m also not sure how different values of
atshould be handled, as ideally we do not want to compile the gradient function multiple times. Given however that the control flow of the traced function may depend on the value of
attracing and recompiling the VJP may be desirable. In this case, how expensive would that tracing and recompiling be? I don’t really have an idea on the cost and wonder if it is worth it. @biboudis, @smarter mentioned you may be able to advice on this.
Any other ideas/thoughts on this would be greatly appreciated. Previously source code transformations for automatic differentiation have been implemented successfully in Swift (see e.g., here for details).