diff --git a/runtime/applet.go b/runtime/applet.go index c515515f69..67e667fdc1 100644 --- a/runtime/applet.go +++ b/runtime/applet.go @@ -274,6 +274,7 @@ func (a *Applet) Call(ctx context.Context, callable *starlark.Function, args ... }() t := a.newThread(ctx) + defer starlarkutil.RunOnExitFuncs(t) context.AfterFunc(ctx, func() { t.Cancel(context.Cause(ctx).Error()) @@ -357,6 +358,7 @@ func (a *Applet) ensureLoaded(fsys fs.FS, path string, currentlyLoading ...strin } thread := a.newThread(context.Background()) + defer starlarkutil.RunOnExitFuncs(thread) // override loader to allow loading starlark files thread.Load = func(thread *starlark.Thread, module string) (starlark.StringDict, error) { @@ -382,53 +384,64 @@ func (a *Applet) ensureLoaded(fsys fs.FS, path string, currentlyLoading ...strin return a.loadModule(thread, module) } - globals, err := starlark.ExecFileOptions( - &syntax.FileOptions{ - Set: true, - Recursion: true, - }, - thread, - a.ID, - src, - predeclared, - ) - if err != nil { - return fmt.Errorf("starlark.ExecFile: %v", err) - } - a.globals[path] = globals - - // if the file is in the root directory, check for the main function - // and schema function - mainFun, _ := globals["main"].(*starlark.Function) - if mainFun != nil { - if a.mainFile != "" { - return fmt.Errorf("multiple files with a main() function:\n- %s\n- %s", path, a.mainFile) + switch filepath.Ext(path) { + case ".star": + globals, err := starlark.ExecFileOptions( + &syntax.FileOptions{ + Set: true, + Recursion: true, + }, + thread, + a.ID, + src, + predeclared, + ) + if err != nil { + return fmt.Errorf("starlark.ExecFile: %v", err) } + a.globals[path] = globals + + // if the file is in the root directory, check for the main function + // and schema function + mainFun, _ := globals["main"].(*starlark.Function) + if mainFun != nil { + if a.mainFile != "" { + return fmt.Errorf("multiple files with a main() function:\n- %s\n- %s", path, a.mainFile) + } - a.mainFile = path - a.mainFun = mainFun - } - - schemaFun, _ := globals[schema.SchemaFunctionName].(*starlark.Function) - if schemaFun != nil { - if a.schemaFile != "" { - return fmt.Errorf("multiple files with a %s() function:\n- %s\n- %s", schema.SchemaFunctionName, path, a.schemaFile) + a.mainFile = path + a.mainFun = mainFun } - a.schemaFile = path - schemaVal, err := a.Call(context.Background(), schemaFun) - if err != nil { - return fmt.Errorf("calling schema function for %s: %w", a.ID, err) - } + schemaFun, _ := globals[schema.SchemaFunctionName].(*starlark.Function) + if schemaFun != nil { + if a.schemaFile != "" { + return fmt.Errorf("multiple files with a %s() function:\n- %s\n- %s", schema.SchemaFunctionName, path, a.schemaFile) + } + a.schemaFile = path - a.schema, err = schema.FromStarlark(schemaVal, globals) - if err != nil { - return fmt.Errorf("parsing schema for %s: %w", a.ID, err) + schemaVal, err := a.Call(context.Background(), schemaFun) + if err != nil { + return fmt.Errorf("calling schema function for %s: %w", a.ID, err) + } + + a.schema, err = schema.FromStarlark(schemaVal, globals) + if err != nil { + return fmt.Errorf("parsing schema for %s: %w", a.ID, err) + } + + a.schemaJSON, err = json.Marshal(a.schema) + if err != nil { + return fmt.Errorf("serializing schema to JSON for %s: %w", a.ID, err) + } } - a.schemaJSON, err = json.Marshal(a.schema) - if err != nil { - return fmt.Errorf("serializing schema to JSON for %s: %w", a.ID, err) + default: + a.globals[path] = starlark.StringDict{ + "file": File{ + fsys: fsys, + path: path, + }.Struct(), } } diff --git a/runtime/file.go b/runtime/file.go new file mode 100644 index 0000000000..b74dde4d88 --- /dev/null +++ b/runtime/file.go @@ -0,0 +1,106 @@ +package runtime + +import ( + "fmt" + "io" + "io/fs" + + "go.starlark.net/starlark" + "go.starlark.net/starlarkstruct" + "tidbyt.dev/pixlet/starlarkutil" +) + +type File struct { + fsys fs.FS + path string +} + +func (f File) Struct() *starlarkstruct.Struct { + return starlarkstruct.FromStringDict(starlark.String("File"), starlark.StringDict{ + "path": starlark.String(f.path), + "open": starlark.NewBuiltin("open", f.open), + }) +} + +func (f File) open(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var mode starlark.String + if err := starlark.UnpackArgs("open", args, kwargs, "mode?", &mode); err != nil { + return nil, err + } + + var binaryMode bool + switch mode.GoString() { + case "", "r", "rt": + binaryMode = false + + case "rb": + binaryMode = true + + default: + return nil, fmt.Errorf("unsupported mode: %s", mode) + } + + fl, err := f.fsys.Open(f.path) + if err != nil { + return nil, err + } else { + starlarkutil.AddOnExit(thread, func() { fl.Close() }) + } + + return Reader{fl, binaryMode}.Struct(), nil +} + +type Reader struct { + io.ReadCloser + binaryMode bool +} + +func (r Reader) Struct() *starlarkstruct.Struct { + return starlarkstruct.FromStringDict(starlark.String("Reader"), starlark.StringDict{ + "read": starlark.NewBuiltin("read", r.read), + "close": starlark.NewBuiltin("close", func(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + return nil, r.Close() + }), + }) +} + +// read reads the contents of the file. The Starlark signature is: +// +// read(size=-1) -> bytes +func (r Reader) read(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + starlarkSize := starlark.MakeInt(-1) + if err := starlark.UnpackArgs("read", args, kwargs, "size?", &starlarkSize); err != nil { + return nil, err + } + + var size int + if err := starlark.AsInt(starlarkSize, &size); err != nil { + return nil, fmt.Errorf("size is not an int") + } + + returnType := func(buf []byte) starlark.Value { + if r.binaryMode { + return starlark.Bytes(buf) + } else { + return starlark.String(buf) + } + } + + if size < 0 { + // read and return all bytes + buf, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + return returnType(buf), nil + } else { + // read and return size bytes + buf := make([]byte, size) + _, err := r.Read(buf) + if err != nil { + return nil, err + } + return returnType(buf), nil + } +} diff --git a/runtime/file_test.go b/runtime/file_test.go new file mode 100644 index 0000000000..d2b1ef344d --- /dev/null +++ b/runtime/file_test.go @@ -0,0 +1,41 @@ +package runtime + +import ( + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" +) + +func TestReadFile(t *testing.T) { + src := ` +load("hello.txt", hello = "file") + +def assert_eq(message, actual, expected): + if not expected == actual: + fail(message, "-", "expected", expected, "actual", actual) + +def test_read(): + f = hello.open() + assert_eq("read", f.read(), "hello world") + +def test_read_binary(): + f = hello.open(mode="rb") + assert_eq("read", f.read(), b"hello world") + +def main(): + pass + +` + + helloTxt := `hello world` + + vfs := &fstest.MapFS{ + "main.star": {Data: []byte(src)}, + "hello.txt": {Data: []byte(helloTxt)}, + } + + app, err := NewAppletFromFS("test_read_file", vfs) + require.NoError(t, err) + app.RunTests(t) +} diff --git a/starlarkutil/onexit.go b/starlarkutil/onexit.go new file mode 100644 index 0000000000..d3ee2a6cdc --- /dev/null +++ b/starlarkutil/onexit.go @@ -0,0 +1,27 @@ +package starlarkutil + +import "go.starlark.net/starlark" + +const ( + // ThreadOnExitKey is the key used to store functions that should be called + // when a thread exits. + ThreadOnExitKey = "tidbyt.dev/pixlet/runtime/on_exit" +) + +type threadOnExitFunc func() + +func AddOnExit(thread *starlark.Thread, fn threadOnExitFunc) { + if onExit, ok := thread.Local(ThreadOnExitKey).(*[]threadOnExitFunc); ok { + *onExit = append(*onExit, fn) + } else { + thread.SetLocal(ThreadOnExitKey, &[]threadOnExitFunc{fn}) + } +} + +func RunOnExitFuncs(thread *starlark.Thread) { + if onExit, ok := thread.Local(ThreadOnExitKey).(*[]threadOnExitFunc); ok { + for _, fn := range *onExit { + fn() + } + } +}