Skip to content

Commit

Permalink
Validate transformed modules
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Feb 23, 2024
1 parent 7e1a3b1 commit 2c2708c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.tgz
/dist/
/src/test/out/
node_modules/
13 changes: 9 additions & 4 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ class Autodiff {
fwdTape,
this.mod.tuple.extract(call, results.length + resultsGrad.length),
),
this.mod.tuple.make(
util.tupleMake(
this.mod,
results.map((_, i) =>
this.mod.tuple.extract(this.fwdGet(tuple), i),
),
Expand Down Expand Up @@ -416,7 +417,7 @@ export const autodiff = (mod: binaryen.Module): Gradient[] => {
null,
[
mod.local.set(out, fwdBody),
mod.tuple.make([
util.tupleMake(mod, [
...ad.results.map((_, i) => mod.tuple.extract(ad.fwdGet(out), i)),
...ad.resultsGrad.map(() => mod.f64.const(0)),
util.structNew(
Expand Down Expand Up @@ -447,12 +448,16 @@ export const autodiff = (mod: binaryen.Module): Gradient[] => {
),
mod.local.set(
gradResults,
mod.tuple.make(
util.tupleMake(
mod,
ad.resultsGrad.map((_, i) => ad.get(ad.paramsGrad.length + i)),
),
),
...ad.bwd.reverse(),
mod.tuple.make(ad.paramsGrad.map((_, index) => ad.get(index))),
util.tupleMake(
mod,
ad.paramsGrad.map((_, index) => ad.get(index)),
),
],
bwdResult,
),
Expand Down
31 changes: 26 additions & 5 deletions src/test/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
import binaryen from "binaryen";
import { expect, test } from "vitest";
import path from "path";
import { beforeAll, expect, test } from "vitest";
import * as wasmad from "../index.js";
import * as util from "../util.js";
import { slurp } from "./util.js";
import { dir, mkdirp, rmrf, slurp, spit } from "./util.js";

const out = "out";

beforeAll(async () => {
await rmrf(out);
await mkdirp(out);
});

const features =
binaryen.Features.Multivalue |
binaryen.Features.ReferenceTypes |
binaryen.Features.GC;

const wat = async (text: string): Promise<Uint8Array> => {
const mod = binaryen.parseText(text);
try {
mod.setFeatures(binaryen.Features.GC);
mod.setFeatures(features);
return mod.emitBinary();
} finally {
mod.dispose();
Expand Down Expand Up @@ -57,12 +70,13 @@ interface Names {
const autodiff = async <T extends AD>(filename: string): Promise<T> => {
const mod = binaryen.parseText(await slurp(filename));
try {
mod.setFeatures(binaryen.Features.Multivalue);
mod.setFeatures(binaryen.Features.GC);
mod.setFeatures(features);

const set = new util.Names();
const n = mod.getNumExports();
for (let i = 0; i < n; ++i)
set.add(binaryen.getExportInfo(mod.getExportByIndex(i)).name);

const names = wasmad.autodiff(mod).map((f, i): Names => {
const name = binaryen.getFunctionInfo(mod.getFunctionByIndex(i)).name;
const fwdName = binaryen.getFunctionInfo(f.fwd).name;
Expand All @@ -75,8 +89,15 @@ const autodiff = async <T extends AD>(filename: string): Promise<T> => {
mod.addFunctionExport(bwdName, bwd);
return { name, orig, fwd, bwd };
});

const filepath = path.join(out, filename);
await spit(filepath, mod.emitText());
if (!mod.validate())
throw Error(`invalid module: ${path.join(dir, filepath)}`);

const binary = mod.emitBinary();
const exports = await compile<any>(binary);

const origs: WebAssembly.Exports = {};
const fwds: WebAssembly.Exports = {};
const bwds: WebAssembly.Exports = {};
Expand Down
12 changes: 10 additions & 2 deletions src/test/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@ import path from "path";
import url from "url";
import { expect, test } from "vitest";

const dir = path.dirname(url.fileURLToPath(import.meta.url));
export const dir = path.dirname(url.fileURLToPath(import.meta.url));

export const rmrf = async (dirname: string): Promise<void> => {
await fs.rm(path.join(dir, dirname), { recursive: true, force: true });
};

export const mkdirp = async (dirname: string): Promise<void> => {
await fs.mkdir(path.join(dir, dirname), { recursive: true });
};

export const slurp = async (filename: string): Promise<string> =>
await fs.readFile(path.join(dir, filename), "utf8");

const spit = async (filename: string, data: string): Promise<void> =>
export const spit = async (filename: string, data: string): Promise<void> =>
await fs.writeFile(path.join(dir, filename), data, "utf8");

export const goldenfile = (
Expand Down
13 changes: 13 additions & 0 deletions src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ export const funcIndicesByName = (
return indices;
};

/** Like `tuple.make` on `binaryen.Module`, but accepts singletons. */
export const tupleMake = (
mod: binaryen.Module,
elements: binaryen.ExpressionRef[],
): binaryen.ExpressionRef => {
switch (elements.length) {
case 1:
return elements[0];
default:
return mod.tuple.make(elements);
}
};

type Bool = number;

type Int = number;
Expand Down

0 comments on commit 2c2708c

Please sign in to comment.