From 8bcda51fb59e0e7d9cc43328b5796a90e9cb24c0 Mon Sep 17 00:00:00 2001 From: Brennan Lamey Date: Thu, 28 Nov 2024 13:34:42 -0700 Subject: [PATCH] adds all-scalar flag --- cmd/kwil-cli/cmds/database/batch.go | 7 +++++- cmd/kwil-cli/cmds/database/call.go | 31 ++++++++++++++++++++------- cmd/kwil-cli/cmds/database/execute.go | 9 +++++++- cmd/kwil-cli/cmds/database/flags.go | 9 ++++++++ 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/cmd/kwil-cli/cmds/database/batch.go b/cmd/kwil-cli/cmds/database/batch.go index a20cc4852..42ce131dd 100644 --- a/cmd/kwil-cli/cmds/database/batch.go +++ b/cmd/kwil-cli/cmds/database/batch.go @@ -68,6 +68,11 @@ func batchCmd() *cobra.Command { return display.PrintErr(cmd, fmt.Errorf("error getting selected action or procedure: %w", err)) } + allScalar, err := getAllScalarsFlag(cmd) + if err != nil { + return display.PrintErr(cmd, fmt.Errorf("error getting all scalar flag: %w", err)) + } + fileType, err := getFileType(filePath) if err != nil { return display.PrintErr(cmd, fmt.Errorf("error getting file type: %w", err)) @@ -87,7 +92,7 @@ func batchCmd() *cobra.Command { return display.PrintErr(cmd, fmt.Errorf("error building inputs: %w", err)) } - tuples, err := buildExecutionInputs(ctx, cl, dbid, action, inputs) + tuples, err := buildExecutionInputs(ctx, cl, dbid, action, inputs, allScalar) if err != nil { return display.PrintErr(cmd, fmt.Errorf("error creating action inputs: %w", err)) } diff --git a/cmd/kwil-cli/cmds/database/call.go b/cmd/kwil-cli/cmds/database/call.go index a24feca72..6be702f8a 100644 --- a/cmd/kwil-cli/cmds/database/call.go +++ b/cmd/kwil-cli/cmds/database/call.go @@ -70,12 +70,17 @@ func callCmd() *cobra.Command { return display.PrintErr(cmd, fmt.Errorf("error getting selected action or procedure: %w", err)) } + allScalar, err := getAllScalarsFlag(cmd) + if err != nil { + return display.PrintErr(cmd, fmt.Errorf("error getting all scalar flag: %w", err)) + } + inputs, err := parseInputs(args) if err != nil { return display.PrintErr(cmd, fmt.Errorf("error getting inputs: %w", err)) } - tuples, err := buildExecutionInputs(ctx, clnt, dbid, action, inputs) + tuples, err := buildExecutionInputs(ctx, clnt, dbid, action, inputs, allScalar) if err != nil { return display.PrintErr(cmd, fmt.Errorf("error creating action/procedure inputs: %w", err)) } @@ -144,7 +149,9 @@ func (r *respCall) MarshalText() (text []byte, err error) { // buildProcedureInputs will build the inputs for either // an action or procedure executon/call. -func buildExecutionInputs(ctx context.Context, client clientType.Client, dbid string, proc string, inputs []map[string]string) ([][]any, error) { +// If skipArr is true, it will treat all values as a scalar value if it can't detect +// what the expected type is (which is the case for an action). +func buildExecutionInputs(ctx context.Context, client clientType.Client, dbid string, proc string, inputs []map[string]string, skipArr bool) ([][]any, error) { schema, err := client.GetSchema(ctx, dbid) if err != nil { return nil, fmt.Errorf("error getting schema: %w", err) @@ -152,7 +159,7 @@ func buildExecutionInputs(ctx context.Context, client clientType.Client, dbid st for _, a := range schema.Actions { if strings.EqualFold(a.Name, proc) { - return buildActionInputs(a, inputs) + return buildActionInputs(a, inputs, skipArr) } } @@ -189,7 +196,10 @@ func decodeMany(inputs []string) ([][]byte, bool) { return b64Arr, b64Ok } -func buildActionInputs(a *types.Action, inputs []map[string]string) ([][]any, error) { +// buildActionInputs will build the inputs for an action execution/call. +// if skipArr is true, it will treat all values as a scalar value. +// This is useful within CSV, where we do not expected arrays +func buildActionInputs(a *types.Action, inputs []map[string]string, skipArr bool) ([][]any, error) { tuples := [][]any{} for _, input := range inputs { newTuple := []any{} @@ -199,15 +209,20 @@ func buildActionInputs(a *types.Action, inputs []map[string]string) ([][]any, er val, ok := input[inputField] if !ok { - fmt.Println(len(newTuple)) // if not found, we should just add nil newTuple = append(newTuple, nil) continue } - split, err := splitIgnoringQuotedCommas(val) - if err != nil { - return nil, err + var split []string + if !skipArr { + var err error + split, err = splitIgnoringQuotedCommas(val) + if err != nil { + return nil, err + } + } else { + split = []string{val} } // attempt to decode base64 encoded values diff --git a/cmd/kwil-cli/cmds/database/execute.go b/cmd/kwil-cli/cmds/database/execute.go index 6bfff8391..a2c49fe78 100644 --- a/cmd/kwil-cli/cmds/database/execute.go +++ b/cmd/kwil-cli/cmds/database/execute.go @@ -52,12 +52,19 @@ func executeCmd() *cobra.Command { return display.PrintErr(cmd, fmt.Errorf("error getting selected action or procedure: %w", err)) } + allScalar, err := getAllScalarsFlag(cmd) + if err != nil { + return display.PrintErr(cmd, fmt.Errorf("error getting all scalar flag: %w", err)) + } + + action = strings.ToLower(action) + parsedArgs, err := parseInputs(args) if err != nil { return display.PrintErr(cmd, fmt.Errorf("error parsing inputs: %w", err)) } - inputs, err := buildExecutionInputs(ctx, cl, dbid, action, parsedArgs) + inputs, err := buildExecutionInputs(ctx, cl, dbid, action, parsedArgs, allScalar) if err != nil { return display.PrintErr(cmd, fmt.Errorf("error getting inputs: %w", err)) } diff --git a/cmd/kwil-cli/cmds/database/flags.go b/cmd/kwil-cli/cmds/database/flags.go index a16a66e11..1dfcdf1be 100644 --- a/cmd/kwil-cli/cmds/database/flags.go +++ b/cmd/kwil-cli/cmds/database/flags.go @@ -81,6 +81,7 @@ func getSelectedDbid(cmd *cobra.Command, conf *config.KwilCliConfig) (string, er // This includes the `execute`, `call`, and `batch` commands. func bindFlagsTargetingProcedureOrAction(cmd *cobra.Command) { bindFlagsTargetingDatabase(cmd) + bindAllScalarsFlag(cmd) cmd.Flags().StringP(actionNameFlag, "a", "", "the target action name") err := cmd.Flags().MarkDeprecated(actionNameFlag, "pass the action name as the first argument") if err != nil { @@ -120,3 +121,11 @@ func bindFlagsTargetingDatabase(cmd *cobra.Command) { cmd.Flags().StringP(ownerFlag, "o", "", "the target database owner") cmd.Flags().StringP(dbidFlag, "i", "", "the target database id") } + +func bindAllScalarsFlag(cmd *cobra.Command) { + cmd.Flags().Bool("all-scalars", false, "informs the client that all values should be scalar and never be treated as arrays") +} + +func getAllScalarsFlag(cmd *cobra.Command) (bool, error) { + return cmd.Flags().GetBool("all-scalars") +}