From 2c2708cb31db0ed82d2685f5abc729e6e8adc632 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 23 Feb 2024 09:26:04 -0500 Subject: [PATCH] Validate transformed modules --- .gitignore | 1 + src/index.ts | 13 +++++++++---- src/test/index.test.ts | 31 ++++++++++++++++++++++++++----- src/test/util.ts | 12 ++++++++++-- src/util.ts | 13 +++++++++++++ 5 files changed, 59 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index f091efc..30e2686 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.tgz /dist/ +/src/test/out/ node_modules/ diff --git a/src/index.ts b/src/index.ts index e10d7b4..f59cc69 100644 --- a/src/index.ts +++ b/src/index.ts @@ -234,7 +234,8 @@ class Autodiff { fwdTape, this.mod.tuple.extract(call, results.length + resultsGrad.length), ), - this.mod.tuple.make( + util.tupleMake( + this.mod, results.map((_, i) => this.mod.tuple.extract(this.fwdGet(tuple), i), ), @@ -416,7 +417,7 @@ export const autodiff = (mod: binaryen.Module): Gradient[] => { null, [ mod.local.set(out, fwdBody), - mod.tuple.make([ + util.tupleMake(mod, [ ...ad.results.map((_, i) => mod.tuple.extract(ad.fwdGet(out), i)), ...ad.resultsGrad.map(() => mod.f64.const(0)), util.structNew( @@ -447,12 +448,16 @@ export const autodiff = (mod: binaryen.Module): Gradient[] => { ), mod.local.set( gradResults, - mod.tuple.make( + util.tupleMake( + mod, ad.resultsGrad.map((_, i) => ad.get(ad.paramsGrad.length + i)), ), ), ...ad.bwd.reverse(), - mod.tuple.make(ad.paramsGrad.map((_, index) => ad.get(index))), + util.tupleMake( + mod, + ad.paramsGrad.map((_, index) => ad.get(index)), + ), ], bwdResult, ), diff --git a/src/test/index.test.ts b/src/test/index.test.ts index ac660a6..2738491 100644 --- a/src/test/index.test.ts +++ b/src/test/index.test.ts @@ -1,13 +1,26 @@ import binaryen from "binaryen"; -import { expect, test } from "vitest"; +import path from "path"; +import { beforeAll, expect, test } from "vitest"; import * as wasmad from "../index.js"; import * as util from "../util.js"; -import { slurp } from "./util.js"; +import { dir, mkdirp, rmrf, slurp, spit } from "./util.js"; + +const out = "out"; + +beforeAll(async () => { + await rmrf(out); + await mkdirp(out); +}); + +const features = + binaryen.Features.Multivalue | + binaryen.Features.ReferenceTypes | + binaryen.Features.GC; const wat = async (text: string): Promise => { const mod = binaryen.parseText(text); try { - mod.setFeatures(binaryen.Features.GC); + mod.setFeatures(features); return mod.emitBinary(); } finally { mod.dispose(); @@ -57,12 +70,13 @@ interface Names { const autodiff = async (filename: string): Promise => { const mod = binaryen.parseText(await slurp(filename)); try { - mod.setFeatures(binaryen.Features.Multivalue); - mod.setFeatures(binaryen.Features.GC); + mod.setFeatures(features); + const set = new util.Names(); const n = mod.getNumExports(); for (let i = 0; i < n; ++i) set.add(binaryen.getExportInfo(mod.getExportByIndex(i)).name); + const names = wasmad.autodiff(mod).map((f, i): Names => { const name = binaryen.getFunctionInfo(mod.getFunctionByIndex(i)).name; const fwdName = binaryen.getFunctionInfo(f.fwd).name; @@ -75,8 +89,15 @@ const autodiff = async (filename: string): Promise => { mod.addFunctionExport(bwdName, bwd); return { name, orig, fwd, bwd }; }); + + const filepath = path.join(out, filename); + await spit(filepath, mod.emitText()); + if (!mod.validate()) + throw Error(`invalid module: ${path.join(dir, filepath)}`); + const binary = mod.emitBinary(); const exports = await compile(binary); + const origs: WebAssembly.Exports = {}; const fwds: WebAssembly.Exports = {}; const bwds: WebAssembly.Exports = {}; diff --git a/src/test/util.ts b/src/test/util.ts index cae9dd1..3ec4a7f 100644 --- a/src/test/util.ts +++ b/src/test/util.ts @@ -3,12 +3,20 @@ import path from "path"; import url from "url"; import { expect, test } from "vitest"; -const dir = path.dirname(url.fileURLToPath(import.meta.url)); +export const dir = path.dirname(url.fileURLToPath(import.meta.url)); + +export const rmrf = async (dirname: string): Promise => { + await fs.rm(path.join(dir, dirname), { recursive: true, force: true }); +}; + +export const mkdirp = async (dirname: string): Promise => { + await fs.mkdir(path.join(dir, dirname), { recursive: true }); +}; export const slurp = async (filename: string): Promise => await fs.readFile(path.join(dir, filename), "utf8"); -const spit = async (filename: string, data: string): Promise => +export const spit = async (filename: string, data: string): Promise => await fs.writeFile(path.join(dir, filename), data, "utf8"); export const goldenfile = ( diff --git a/src/util.ts b/src/util.ts index e6dfab6..87d1f78 100644 --- a/src/util.ts +++ b/src/util.ts @@ -57,6 +57,19 @@ export const funcIndicesByName = ( return indices; }; +/** Like `tuple.make` on `binaryen.Module`, but accepts singletons. */ +export const tupleMake = ( + mod: binaryen.Module, + elements: binaryen.ExpressionRef[], +): binaryen.ExpressionRef => { + switch (elements.length) { + case 1: + return elements[0]; + default: + return mod.tuple.make(elements); + } +}; + type Bool = number; type Int = number;