Skip to content

samestep/floretta

Repository files navigation

Floretta crates.io docs.rs Build

Take the gradient of any Wasm function. (Warning: work in progress.)

Floretta uses reverse-mode automatic differentiation to transform a Wasm module, converting every function into two:

  1. The forward pass, which performs the same computation as the original function, but stores some extra data along the way, referred to as the tape.
  2. The backward pass, which uses the tape stored by the forward pass to retrace its steps in reverse.

Together, these comprise the vector-Jacobian product (VJP), which can then be used to compute the gradient of any function that returns a scalar.

For every memory in the original Wasm module, Floretta adds an additional memory in the transformed module, to store the derivative of each scalar in the original memory. Also, Floretta adds two more memories to store the tape: one for f32 values, and one for f64 values.

Usage

The easiest way to use Floretta is via the command line. If you have Rust installed, you can build the latest version of Floretta from source:

$ cargo install --locked floretta-cli

Use the --help flag to see all available CLI arguments:

$ floretta --help

For example, if you create a file called square.wat with these contents:

(module
  (func (export "square") (param f64) (result f64)
    (f64.mul (local.get 0) (local.get 0))))

Then you can use Floretta to take the backward pass of the "square" function and name it "backprop":

$ floretta square.wat --export square backprop --output gradient.wasm

Finally, if you have a Wasm engine, you can use it to compute a gradient with the emitted Wasm binary by running the forward pass followed by the backward pass. For instance, if you have Node.js installed, you can create a file called gradient.mjs with these contents:

import fs from "node:fs/promises";
const wasm = await fs.readFile("gradient.wasm");
const module = await WebAssembly.instantiate(wasm);
const { square, backprop } = module.instance.exports;
console.log(square(3));
console.log(backprop(1));

And run it like this:

$ node gradient.mjs
9
6

License

Floretta is licensed under the MIT License.

About

Automatic differentiation for WebAssembly.

Resources

License

Stars

Watchers

Forks