Take the gradient of any Wasm function. (Warning: work in progress.)
Floretta uses reverse-mode automatic differentiation to transform a Wasm module, converting every function into two:
- The forward pass, which performs the same computation as the original function, but stores some extra data along the way, referred to as the tape.
- The backward pass, which uses the tape stored by the forward pass to retrace its steps in reverse.
Together, these comprise the vector-Jacobian product (VJP), which can then be used to compute the gradient of any function that returns a scalar.
For every memory in the original Wasm module, Floretta adds an additional memory in the transformed module, to store the derivative of each scalar in the original memory. Also, Floretta adds two more memories to store the tape: one for f32
values, and one for f64
values.
The easiest way to use Floretta is via the command line. If you have Rust installed, you can build the latest version of Floretta from source:
$ cargo install --locked floretta-cli
Use the --help
flag to see all available CLI arguments:
$ floretta --help
For example, if you create a file called square.wat
with these contents:
(module
(func (export "square") (param f64) (result f64)
(f64.mul (local.get 0) (local.get 0))))
Then you can use Floretta to take the backward pass of the "square"
function and name it "backprop"
:
$ floretta square.wat --export square backprop --output gradient.wasm
Finally, if you have a Wasm engine, you can use it to compute a gradient with the emitted Wasm binary by running the forward pass followed by the backward pass. For instance, if you have Node.js installed, you can create a file called gradient.mjs
with these contents:
import fs from "node:fs/promises";
const wasm = await fs.readFile("gradient.wasm");
const module = await WebAssembly.instantiate(wasm);
const { square, backprop } = module.instance.exports;
console.log(square(3));
console.log(backprop(1));
And run it like this:
$ node gradient.mjs
9
6
Floretta is licensed under the MIT License.