Skip to content

Commit

Permalink
Fixed return types for sql commands
Browse files Browse the repository at this point in the history
  • Loading branch information
stuioco committed Aug 27, 2024
1 parent 71d0bd9 commit 33f692b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 44 deletions.
55 changes: 22 additions & 33 deletions core/templating/datasource_sql_over_csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package templating
import (
"errors"
"regexp"
"strconv"
"strings"

log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -56,7 +55,7 @@ func parseSqlCommand(query string, datasource *TemplateDataSource) (SQLStatement
}
columnsPart = matches[1]
dataSourceName = matches[2]
if !dataSourceExists(datasource, dataSourceName) {
if !datasource.DataSourceExists(dataSourceName) {
return SQLStatement{}, errors.New("data source does not exist")
}

Expand Down Expand Up @@ -87,15 +86,18 @@ func parseSqlCommand(query string, datasource *TemplateDataSource) (SQLStatement
return SQLStatement{}, errors.New("invalid UPDATE query format")
}
dataSourceName = matches[1]
if !dataSourceExists(datasource, dataSourceName) {
if !datasource.DataSourceExists(dataSourceName) {
return SQLStatement{}, errors.New("data source does not exist")
}
setPart := matches[2]
if len(matches) == 4 {
wherePart = matches[3]
}

setClauses := parseSetClauses(setPart)
headers := datasource.DataSources[dataSourceName].Data[0]
setClauses, err := parseSetClauses(setPart, headers)
if err != nil {
return SQLStatement{}, err
}
conditions, err := parseConditions(wherePart)
if err != nil {
return SQLStatement{}, err
Expand All @@ -114,7 +116,7 @@ func parseSqlCommand(query string, datasource *TemplateDataSource) (SQLStatement
return SQLStatement{}, errors.New("invalid DELETE query format")
}
dataSourceName = matches[1]
if !dataSourceExists(datasource, dataSourceName) {
if !datasource.DataSourceExists(dataSourceName) {
return SQLStatement{}, errors.New("data source does not exist")
}
if len(matches) == 3 {
Expand Down Expand Up @@ -179,16 +181,18 @@ func trimQuotes(s string) string {
}

// parseSetClauses parses the SET part of an UPDATE query
func parseSetClauses(setPart string) map[string]string {
func parseSetClauses(setPart string, headers []string) (map[string]string, error) {
setClauses := make(map[string]string)
parts := strings.Split(setPart, ",")
for _, part := range parts {
keyValue := strings.Split(strings.TrimSpace(part), "=")
if len(keyValue) == 2 {
if !stringExists(headers, strings.TrimSpace(keyValue[0])) {
return nil, errors.New("invalid column provided: " + strings.TrimSpace(keyValue[0]))
} else if len(keyValue) == 2 {
setClauses[strings.TrimSpace(keyValue[0])] = trimQuotes(strings.TrimSpace(keyValue[1]))
}
}
return setClauses
return setClauses, nil
}

// parseConditions parses the WHERE part of the query into a slice of Conditions and returns an error if any issues are found.
Expand Down Expand Up @@ -230,14 +234,10 @@ func executeSqlSelectQuery(data *[][]string, query SQLStatement) []RowMap {
}

// ExecuteUpdateQuery executes an UPDATE query and modifies the data in-place
func executeSqlUpdateCommand(data *[][]string, query SQLStatement) []RowMap {
func executeSqlUpdateCommand(data *[][]string, query SQLStatement) int {
if len(*data) < 2 {
log.Error("no data available to update")
return []RowMap{
{
"rowsAffected": "0",
},
}
log.Debug("no data available to update")
return 0
}

headers := (*data)[0]
Expand All @@ -256,23 +256,16 @@ func executeSqlUpdateCommand(data *[][]string, query SQLStatement) []RowMap {
}
}
}
return []RowMap{
{
"rowsAffected": strconv.Itoa(rowsAffected),
},
}
return rowsAffected
}

// executeSqlDeleteCommand executes a DELETE query and modifies the data in-place
func executeSqlDeleteCommand(data *[][]string, query SQLStatement) []RowMap {
func executeSqlDeleteCommand(data *[][]string, query SQLStatement) int {
if len(*data) < 2 {
log.Println("no data available to delete")
return []RowMap{
{
"rowsAffected": "0",
},
}
log.Debug("no data available to delete")
return 0
}

headers := (*data)[0]
conditions := query.Conditions
rowsAffected := 0
Expand All @@ -285,11 +278,7 @@ func executeSqlDeleteCommand(data *[][]string, query SQLStatement) []RowMap {
rowsAffected++
}
}
return []RowMap{
{
"rowsAffected": strconv.Itoa(rowsAffected),
},
}
return rowsAffected
}

// removeRow removes the row at index rowIndex from the data.
Expand Down
27 changes: 19 additions & 8 deletions core/templating/datasource_sql_over_csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,30 @@ func TestTrimQuotes(t *testing.T) {
}
}

func TestParseSetClauses(t *testing.T) {
func TestParseSetClauses_ValidInput(t *testing.T) {
input := "age = '35', department = 'Engineering'"
inputHeaders := []string{"age", "department"}
expected := map[string]string{
"age": "35",
"department": "Engineering",
}

result := parseSetClauses(input)
result, _ := parseSetClauses(input, inputHeaders)
if !reflect.DeepEqual(result, expected) {
t.Errorf("expected %v, got %v", expected, result)
}
}

func TestParseSetClauses_InvalidInput(t *testing.T) {
input := "age = '35', department = 'Engineering'"
inputHeaders := []string{"fruit", "category"}

_, err := parseSetClauses(input, inputHeaders)
if err == nil {
t.Errorf("expected error but got none.")
}
}

func TestParseConditions_ValidInput(t *testing.T) {
wherePart := "id == '1' AND name != 'John' AND age >= '30'"
expected := []Condition{
Expand Down Expand Up @@ -247,12 +258,12 @@ func TestExecuteSqlUpdateCommand_RowCountResult(t *testing.T) {
DataSourceName: "employees",
}

expected := "1"
expected := 1

result := executeSqlUpdateCommand(&data, query)

if result[0]["rowsAffected"] != expected {
t.Errorf("expected %v, got %v", expected, result[0]["rowsAffected"])
if result != expected {
t.Errorf("expected %v, got %v", expected, result)
}
}

Expand Down Expand Up @@ -294,11 +305,11 @@ func TestExecuteSqlDeleteCommand_RowCountResult(t *testing.T) {
DataSourceName: "employees",
}

expected := "1"
expected := 1

result := executeSqlDeleteCommand(&data, query)

if result[0]["rowsAffected"] != expected {
t.Errorf("expected %v, got %v", expected, result[0]["rowsAffected"])
if result != expected {
t.Errorf("expected %v, got %v", expected, result)
}
}
2 changes: 1 addition & 1 deletion core/templating/template_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (templateDataSource *TemplateDataSource) GetAllDataSources() map[string]*Da
return templateDataSource.DataSources
}

func dataSourceExists(templateDataSource *TemplateDataSource, name string) bool {
func (templateDataSource *TemplateDataSource) DataSourceExists(name string) bool {
templateDataSource.RWMutex.Lock()
defer templateDataSource.RWMutex.Unlock()

Expand Down
8 changes: 6 additions & 2 deletions core/templating/template_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,13 @@ func (t templateHelpers) csvSqlCommand(commandString string) []RowMap {
case "SELECT":
results = executeSqlSelectQuery(&source.Data, command)
case "UPDATE":
results = executeSqlUpdateCommand(&source.Data, command)
rowsAffected := executeSqlUpdateCommand(&source.Data, command)
log.Debug(strconv.Itoa(rowsAffected) + " rows affected by " + commandString)
return nil
case "DELETE":
results = executeSqlDeleteCommand(&source.Data, command)
rowsAffected := executeSqlDeleteCommand(&source.Data, command)
log.Debug(strconv.Itoa(rowsAffected) + " rows affected by " + commandString)
return nil
default:
log.Error(fmt.Errorf("unsupported query type %s", command.Type))
return nil
Expand Down

0 comments on commit 33f692b

Please sign in to comment.