diff --git a/engine/transform.go b/engine/transform.go index b9ab2f6bf..6282dd83a 100644 --- a/engine/transform.go +++ b/engine/transform.go @@ -3,6 +3,8 @@ package engine import ( "encoding/json" "strings" + + "go.mongodb.org/mongo-driver/v2/bson" ) type Input struct { @@ -11,13 +13,7 @@ type Input struct { Rows [][]interface{} `json:"rows"` } -// TransformResponse for raw queries -func TransformResponse(data []byte) ([]byte, error) { - // TODO properly detect a json response - if !strings.HasPrefix(string(data), `{"columns":[`) { - return data, nil - } - +func TransformSQLResponse(data []byte) ([]byte, error) { var input Input err := json.Unmarshal(data, &input) if err != nil { @@ -41,3 +37,54 @@ func TransformResponse(data []byte) ([]byte, error) { return o, nil } + +func TransformMongoResponse(data []byte) ([]byte, error) { + var result []map[string]interface{} + + if err := bson.UnmarshalExtJSON(data, false, &result); err != nil { + return nil, err + } + + for _, doc := range result { + if doc["id"] == nil { + doc["id"] = doc["_id"] + } + } + + o, err := json.Marshal(result) + if err != nil { + return nil, err + } + + return o, nil +} + +// TransformResponse for raw queries +func TransformResponse(data []byte) ([]byte, error) { + // TODO properly detect a json response + switch { + case strings.HasPrefix(string(data), `{"columns":[`): + return TransformSQLResponse(data) + + // https://github.com/mongodb/mongo-go-driver/blob/91abd887f6b44ab56f47e58430f57b1be1996ceb/bson/extjson_wrappers.go#L18 + case strings.Contains(string(data), `{"$oid":`), + strings.Contains(string(data), `{"$date":`), + strings.Contains(string(data), `{"$numberInt":`), + strings.Contains(string(data), `{"$numberLong":`), + strings.Contains(string(data), `{"$symbol":`), + strings.Contains(string(data), `{"$numberDouble":`), + strings.Contains(string(data), `{"$numberDecimal":`), + strings.Contains(string(data), `{"$binary":`), + strings.Contains(string(data), `{"$code":`), + strings.Contains(string(data), `{"$scope":`), + strings.Contains(string(data), `{"$timestamp":`), + strings.Contains(string(data), `{"$regularExpression":`), + strings.Contains(string(data), `{"$dbPointer":`), + strings.Contains(string(data), `{"$minKey":`), + strings.Contains(string(data), `{"$maxKey":`), + strings.Contains(string(data), `{"$undefined":`): + return TransformMongoResponse(data) + } + + return data, nil +} diff --git a/engine/transform_test.go b/engine/transform_test.go index f4b0286e5..6fe29bf38 100644 --- a/engine/transform_test.go +++ b/engine/transform_test.go @@ -20,7 +20,14 @@ func Test_transformResponse(t *testing.T) { data: []byte(`{"columns":["id","email","username","str","strOpt","date","dateOpt","int","intOpt","float","floatOpt","bool","boolOpt"],"types":["string","string","string","string","string","datetime","datetime","int","int","double","double","int","int"],"rows":[["id1","email1","a","str","strOpt","2020-01-01T00:00:00+00:00","2020-01-01T00:00:00+00:00",5,5,5.5,5.5,1,0],["id2","email2","b","str","strOpt","2020-01-01T00:00:00+00:00","2020-01-01T00:00:00+00:00",5,5,5.5,5.5,1,0]]}`), }, want: []byte(`[{"bool":1,"boolOpt":0,"date":"2020-01-01T00:00:00+00:00","dateOpt":"2020-01-01T00:00:00+00:00","email":"email1","float":5.5,"floatOpt":5.5,"id":"id1","int":5,"intOpt":5,"str":"str","strOpt":"strOpt","username":"a"},{"bool":1,"boolOpt":0,"date":"2020-01-01T00:00:00+00:00","dateOpt":"2020-01-01T00:00:00+00:00","email":"email2","float":5.5,"floatOpt":5.5,"id":"id2","int":5,"intOpt":5,"str":"str","strOpt":"strOpt","username":"b"}]`), - }} + }, + { + name: "transform mongo raw response", + args: args{ + data: []byte(`[{"_id":{"$oid":"67347ee4a18fa09750c1085a"},"createdAt":{"$date":"2024-11-13T10:26:44.246Z"},"firstName":"Trua Nguyen","lastName":"Van","email":"truanv@gmail"},{"_id":{"$oid":"67348094597e341917026845"},"email":"truanv@gmail","firstName":"Trua Nguyen","lastName":"Van"},{"_id":{"$oid":"673480d6597e341917026dea"},"email":"truanv@gmail","firstName":"Trua Nguyen","lastName":"Van"},{"_id":{"$oid":"67348265597e34191702904f"},"firstName":"Trua Nguyen ","lastName":"Van","email":"truanv@gmail"},{"_id":{"$oid":"6734827b597e34191702923d"},"email":"truanv@gmail","firstName":"Trua Nguyen ","lastName":"Van"}]`), + }, + want: []byte(`[{"_id":"67347ee4a18fa09750c1085a","createdAt":"2024-11-13T10:26:44.246Z","email":"truanv@gmail","firstName":"Trua Nguyen","id":"67347ee4a18fa09750c1085a","lastName":"Van"},{"_id":"67348094597e341917026845","email":"truanv@gmail","firstName":"Trua Nguyen","id":"67348094597e341917026845","lastName":"Van"},{"_id":"673480d6597e341917026dea","email":"truanv@gmail","firstName":"Trua Nguyen","id":"673480d6597e341917026dea","lastName":"Van"},{"_id":"67348265597e34191702904f","email":"truanv@gmail","firstName":"Trua Nguyen ","id":"67348265597e34191702904f","lastName":"Van"},{"_id":"6734827b597e34191702923d","email":"truanv@gmail","firstName":"Trua Nguyen ","id":"6734827b597e34191702923d","lastName":"Van"}]`), + }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := TransformResponse(tt.args.data) diff --git a/generator/ast/dmmf/dmmf.go b/generator/ast/dmmf/dmmf.go index 26845738f..758ef6442 100644 --- a/generator/ast/dmmf/dmmf.go +++ b/generator/ast/dmmf/dmmf.go @@ -60,18 +60,20 @@ type Mappings struct { } type ModelOperation struct { - Model types.String `json:"model"` - Aggregate types.String `json:"aggregate"` - CreateOne types.String `json:"createOne"` - DeleteMany types.String `json:"deleteMany"` - DeleteOne types.String `json:"deleteOne"` - FindFirst types.String `json:"findFirst"` - FindMany types.String `json:"findMany"` - FindUnique types.String `json:"findUnique"` - GroupBy types.String `json:"groupBy"` - UpdateMany types.String `json:"updateMany"` - UpdateOne types.String `json:"updateOne"` - UpsertOne types.String `json:"upsertOne"` + Model types.String `json:"model"` + Aggregate types.String `json:"aggregate"` + CreateOne types.String `json:"createOne"` + DeleteMany types.String `json:"deleteMany"` + DeleteOne types.String `json:"deleteOne"` + FindFirst types.String `json:"findFirst"` + FindMany types.String `json:"findMany"` + FindUnique types.String `json:"findUnique"` + GroupBy types.String `json:"groupBy"` + UpdateMany types.String `json:"updateMany"` + UpdateOne types.String `json:"updateOne"` + UpsertOne types.String `json:"upsertOne"` + FindRaw types.String `json:"findRaw"` // MongoDB only + AggregateRaw types.String `json:"aggregateRaw"` // MongoDB only } func (m *ModelOperation) Namespace() string { diff --git a/generator/run.go b/generator/run.go index 861faf63c..f1ee161f5 100644 --- a/generator/run.go +++ b/generator/run.go @@ -87,6 +87,7 @@ func generateClient(input *Root) error { "actions/find", "actions/transaction", "actions/upsert", + "actions/raw", } var templates []*template.Template diff --git a/generator/templates/_header.gotpl b/generator/templates/_header.gotpl index 3c769e9e0..7b357afb9 100644 --- a/generator/templates/_header.gotpl +++ b/generator/templates/_header.gotpl @@ -10,6 +10,7 @@ import ( "os" "slices" "testing" + "fmt" // no-op import for go modules _ "github.com/joho/godotenv" diff --git a/generator/templates/actions/raw.gotpl b/generator/templates/actions/raw.gotpl new file mode 100644 index 000000000..07a447b5e --- /dev/null +++ b/generator/templates/actions/raw.gotpl @@ -0,0 +1,88 @@ +{{- /*gotype:github.com/steebchen/prisma-client-go/generator.Root*/ -}} + +{{ range $model := $.DMMF.Datamodel.Models }} + {{ $name := $model.Name.GoLowerCase }} + {{ $ns := (print $name "Actions") }} + {{ $result := (print $name "AggregateRaw") }} + + type {{ $result }} struct { + query builder.Query + } + + func (r {{ $result }}) getQuery() builder.Query { + return r.query + } + + func (r {{ $result }}) ExtractQuery() builder.Query { + return r.query + } + + func (r {{ $result }}) with() {} + func (r {{ $result }}) {{ $model.Name.GoLowerCase }}Model() {} + func (r {{ $result }}) {{ $model.Name.GoLowerCase }}Relation() {} + + func (r {{ $ns }}) FindRaw(filter interface{}, options ...interface{}) {{ $result }} { + var v {{ $result }} + v.query = builder.NewQuery() + v.query.Engine = r.client + v.query.Method = "findRaw" + v.query.Operation = "query" + v.query.Model = "{{ $model.Name.String }}" + + v.query.Inputs = append(v.query.Inputs, builder.Input{ + Name: "filter", + Value: fmt.Sprintf("%v", filter), + }) + + if len(options) > 0 { + v.query.Inputs = append(v.query.Inputs, builder.Input{ + Name: "options", + Value: fmt.Sprintf("%v", options[0]), + }) + } + return v + } + + func (r {{ $ns }}) AggregateRaw(pipeline []interface{}, options ...interface{}) {{ $result }} { + var v {{ $result }} + v.query = builder.NewQuery() + v.query.Engine = r.client + v.query.Method = "aggregateRaw" + v.query.Operation = "query" + v.query.Model = "{{ $model.Name.String }}" + + parsedPip := []interface{}{} + for _, p := range pipeline { + parsedPip = append(parsedPip, fmt.Sprintf("%v", p)) + } + + v.query.Inputs = append(v.query.Inputs, builder.Input{ + Name: "pipeline", + Value: parsedPip, + }) + + if len(options) > 0 { + v.query.Inputs = append(v.query.Inputs, builder.Input{ + Name: "options", + Value: fmt.Sprintf("%v", options[0]), + }) + } + return v + } + + func (r {{ $result }}) Exec(ctx context.Context) ([]{{ $model.Name.GoCase }}Model, error) { + var v []{{ $model.Name.GoCase }}Model + if err := r.query.Exec(ctx, &v); err != nil { + return nil, err + } + return v, nil + } + + func (r {{ $result }}) ExecInner(ctx context.Context) ([]Inner{{ $model.Name.GoCase }}, error) { + var v []Inner{{ $model.Name.GoCase }} + if err := r.query.Exec(ctx, &v); err != nil { + return nil, err + } + return v, nil + } +{{ end }} diff --git a/go.mod b/go.mod index 802a2635f..2d5575fb6 100644 --- a/go.mod +++ b/go.mod @@ -11,5 +11,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + go.mongodb.org/mongo-driver/v2 v2.0.0-beta2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4fe7b3c19..73101b9e4 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.mongodb.org/mongo-driver/v2 v2.0.0-beta2 h1:PRtbRKwblE8ZfI8qOhofcjn9y8CmKZI7trS5vDMeJX0= +go.mongodb.org/mongo-driver/v2 v2.0.0-beta2/go.mod h1:UGLb3ZgEzaY0cCbJpH9UFt9B6gEXiTPzsnJS38nBeoU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/runtime/builder/builder.go b/runtime/builder/builder.go index af5d82e9c..4687cbc0b 100644 --- a/runtime/builder/builder.go +++ b/runtime/builder/builder.go @@ -12,6 +12,20 @@ import ( "github.com/steebchen/prisma-client-go/logger" ) +type MethodFormat string + +const ( + FindRaw MethodFormat = "findRaw" + AggregateRaw MethodFormat = "aggregateRaw" +) + +var ( + MethodFormatMaping = map[MethodFormat]string{ + FindRaw: "find%sRaw", // find{Model}Raw + AggregateRaw: "aggregate%sRaw", // aggregate{Model}Raw + } +) + type Input struct { Name string Fields []Field @@ -100,8 +114,14 @@ func (q Query) Build() (string, error) { func (q Query) BuildInner() (string, error) { var builder strings.Builder - - builder.WriteString(q.Method + q.Model) + switch MethodFormat(q.Method) { + case FindRaw: + builder.WriteString(fmt.Sprintf(MethodFormatMaping[FindRaw], q.Model)) + case AggregateRaw: + builder.WriteString(fmt.Sprintf(MethodFormatMaping[AggregateRaw], q.Model)) + default: + builder.WriteString(q.Method + q.Model) + } if len(q.Inputs) > 0 { str, err := q.buildInputs(q.Inputs) @@ -129,7 +149,7 @@ func (q Query) buildInputs(inputs []Input) (string, error) { builder.WriteString("(") - for _, i := range inputs { + for index, i := range inputs { builder.WriteString(i.Name) builder.WriteString(":") @@ -150,7 +170,9 @@ func (q Query) buildInputs(inputs []Input) (string, error) { } } - builder.WriteString(",") + if index < len(inputs)-1 { + builder.WriteString(",") + } } builder.WriteString(")") diff --git a/runtime/raw/raw.go b/runtime/raw/raw.go index c15526d03..236e7c702 100644 --- a/runtime/raw/raw.go +++ b/runtime/raw/raw.go @@ -49,6 +49,20 @@ func doRaw(engine engine.Engine, action string, query string, params ...interfac return q } +func doCommandRaw(engine engine.Engine, action string, cmd string) builder.Query { + q := builder.NewQuery() + q.Engine = engine + q.Operation = "mutation" + q.Method = action + + q.Inputs = append(q.Inputs, builder.Input{ + Name: "command", + Value: cmd, + }) + + return q +} + func convertType(input interface{}) string { data, err := json.Marshal(input) if err != nil { diff --git a/runtime/raw/run_command_raw.go b/runtime/raw/run_command_raw.go new file mode 100644 index 000000000..01f8897ae --- /dev/null +++ b/runtime/raw/run_command_raw.go @@ -0,0 +1,37 @@ +package raw + +import ( + "context" + "fmt" + + "github.com/steebchen/prisma-client-go/runtime/builder" +) + +func (r Raw) RunCommandRaw(cmd interface{}) RunCommandExec { + return RunCommandExec{ + query: doCommandRaw(r.Engine, "runCommandRaw", fmt.Sprintf("%v", cmd)), + } +} + +type RunCommandExec struct { + query builder.Query +} + +func (r RunCommandExec) ExtractQuery() builder.Query { + return r.query +} + +func (r RunCommandExec) Tx() TxQueryResult { + v := NewTxQueryResult() + v.query = r.query + v.query.TxResult = make(chan []byte, 1) + return v +} + +func (r RunCommandExec) Exec(ctx context.Context, into interface{}) error { + if err := r.query.Exec(ctx, &into); err != nil { + return err + } + + return nil +}