Skip to content

Commit

Permalink
feat(mongodb): add findRaw, aggregateRaw and runCommandRaw (#1410)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Steeb <[email protected]>
  • Loading branch information
truanguyenvan and steebchen authored Nov 24, 2024
1 parent bbe8959 commit 6b25039
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 24 deletions.
61 changes: 54 additions & 7 deletions engine/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package engine
import (
"encoding/json"
"strings"

"go.mongodb.org/mongo-driver/v2/bson"
)

type Input struct {
Expand All @@ -11,13 +13,7 @@ type Input struct {
Rows [][]interface{} `json:"rows"`
}

// TransformResponse for raw queries
func TransformResponse(data []byte) ([]byte, error) {
// TODO properly detect a json response
if !strings.HasPrefix(string(data), `{"columns":[`) {
return data, nil
}

func TransformSQLResponse(data []byte) ([]byte, error) {
var input Input
err := json.Unmarshal(data, &input)
if err != nil {
Expand All @@ -41,3 +37,54 @@ func TransformResponse(data []byte) ([]byte, error) {

return o, nil
}

func TransformMongoResponse(data []byte) ([]byte, error) {
var result []map[string]interface{}

if err := bson.UnmarshalExtJSON(data, false, &result); err != nil {
return nil, err
}

for _, doc := range result {
if doc["id"] == nil {
doc["id"] = doc["_id"]
}
}

o, err := json.Marshal(result)
if err != nil {
return nil, err
}

return o, nil
}

// TransformResponse for raw queries
func TransformResponse(data []byte) ([]byte, error) {
// TODO properly detect a json response
switch {
case strings.HasPrefix(string(data), `{"columns":[`):
return TransformSQLResponse(data)

// https://github.com/mongodb/mongo-go-driver/blob/91abd887f6b44ab56f47e58430f57b1be1996ceb/bson/extjson_wrappers.go#L18
case strings.Contains(string(data), `{"$oid":`),
strings.Contains(string(data), `{"$date":`),
strings.Contains(string(data), `{"$numberInt":`),
strings.Contains(string(data), `{"$numberLong":`),
strings.Contains(string(data), `{"$symbol":`),
strings.Contains(string(data), `{"$numberDouble":`),
strings.Contains(string(data), `{"$numberDecimal":`),
strings.Contains(string(data), `{"$binary":`),
strings.Contains(string(data), `{"$code":`),
strings.Contains(string(data), `{"$scope":`),
strings.Contains(string(data), `{"$timestamp":`),
strings.Contains(string(data), `{"$regularExpression":`),
strings.Contains(string(data), `{"$dbPointer":`),
strings.Contains(string(data), `{"$minKey":`),
strings.Contains(string(data), `{"$maxKey":`),
strings.Contains(string(data), `{"$undefined":`):
return TransformMongoResponse(data)
}

return data, nil
}
9 changes: 8 additions & 1 deletion engine/transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ func Test_transformResponse(t *testing.T) {
data: []byte(`{"columns":["id","email","username","str","strOpt","date","dateOpt","int","intOpt","float","floatOpt","bool","boolOpt"],"types":["string","string","string","string","string","datetime","datetime","int","int","double","double","int","int"],"rows":[["id1","email1","a","str","strOpt","2020-01-01T00:00:00+00:00","2020-01-01T00:00:00+00:00",5,5,5.5,5.5,1,0],["id2","email2","b","str","strOpt","2020-01-01T00:00:00+00:00","2020-01-01T00:00:00+00:00",5,5,5.5,5.5,1,0]]}`),
},
want: []byte(`[{"bool":1,"boolOpt":0,"date":"2020-01-01T00:00:00+00:00","dateOpt":"2020-01-01T00:00:00+00:00","email":"email1","float":5.5,"floatOpt":5.5,"id":"id1","int":5,"intOpt":5,"str":"str","strOpt":"strOpt","username":"a"},{"bool":1,"boolOpt":0,"date":"2020-01-01T00:00:00+00:00","dateOpt":"2020-01-01T00:00:00+00:00","email":"email2","float":5.5,"floatOpt":5.5,"id":"id2","int":5,"intOpt":5,"str":"str","strOpt":"strOpt","username":"b"}]`),
}}
},
{
name: "transform mongo raw response",
args: args{
data: []byte(`[{"_id":{"$oid":"67347ee4a18fa09750c1085a"},"createdAt":{"$date":"2024-11-13T10:26:44.246Z"},"firstName":"Trua Nguyen","lastName":"Van","email":"truanv@gmail"},{"_id":{"$oid":"67348094597e341917026845"},"email":"truanv@gmail","firstName":"Trua Nguyen","lastName":"Van"},{"_id":{"$oid":"673480d6597e341917026dea"},"email":"truanv@gmail","firstName":"Trua Nguyen","lastName":"Van"},{"_id":{"$oid":"67348265597e34191702904f"},"firstName":"Trua Nguyen ","lastName":"Van","email":"truanv@gmail"},{"_id":{"$oid":"6734827b597e34191702923d"},"email":"truanv@gmail","firstName":"Trua Nguyen ","lastName":"Van"}]`),
},
want: []byte(`[{"_id":"67347ee4a18fa09750c1085a","createdAt":"2024-11-13T10:26:44.246Z","email":"truanv@gmail","firstName":"Trua Nguyen","id":"67347ee4a18fa09750c1085a","lastName":"Van"},{"_id":"67348094597e341917026845","email":"truanv@gmail","firstName":"Trua Nguyen","id":"67348094597e341917026845","lastName":"Van"},{"_id":"673480d6597e341917026dea","email":"truanv@gmail","firstName":"Trua Nguyen","id":"673480d6597e341917026dea","lastName":"Van"},{"_id":"67348265597e34191702904f","email":"truanv@gmail","firstName":"Trua Nguyen ","id":"67348265597e34191702904f","lastName":"Van"},{"_id":"6734827b597e34191702923d","email":"truanv@gmail","firstName":"Trua Nguyen ","id":"6734827b597e34191702923d","lastName":"Van"}]`),
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := TransformResponse(tt.args.data)
Expand Down
26 changes: 14 additions & 12 deletions generator/ast/dmmf/dmmf.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,20 @@ type Mappings struct {
}

type ModelOperation struct {
Model types.String `json:"model"`
Aggregate types.String `json:"aggregate"`
CreateOne types.String `json:"createOne"`
DeleteMany types.String `json:"deleteMany"`
DeleteOne types.String `json:"deleteOne"`
FindFirst types.String `json:"findFirst"`
FindMany types.String `json:"findMany"`
FindUnique types.String `json:"findUnique"`
GroupBy types.String `json:"groupBy"`
UpdateMany types.String `json:"updateMany"`
UpdateOne types.String `json:"updateOne"`
UpsertOne types.String `json:"upsertOne"`
Model types.String `json:"model"`
Aggregate types.String `json:"aggregate"`
CreateOne types.String `json:"createOne"`
DeleteMany types.String `json:"deleteMany"`
DeleteOne types.String `json:"deleteOne"`
FindFirst types.String `json:"findFirst"`
FindMany types.String `json:"findMany"`
FindUnique types.String `json:"findUnique"`
GroupBy types.String `json:"groupBy"`
UpdateMany types.String `json:"updateMany"`
UpdateOne types.String `json:"updateOne"`
UpsertOne types.String `json:"upsertOne"`
FindRaw types.String `json:"findRaw"` // MongoDB only
AggregateRaw types.String `json:"aggregateRaw"` // MongoDB only
}

func (m *ModelOperation) Namespace() string {
Expand Down
1 change: 1 addition & 0 deletions generator/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ func generateClient(input *Root) error {
"actions/find",
"actions/transaction",
"actions/upsert",
"actions/raw",
}

var templates []*template.Template
Expand Down
1 change: 1 addition & 0 deletions generator/templates/_header.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"slices"
"testing"
"fmt"

// no-op import for go modules
_ "github.com/joho/godotenv"
Expand Down
88 changes: 88 additions & 0 deletions generator/templates/actions/raw.gotpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{{- /*gotype:github.com/steebchen/prisma-client-go/generator.Root*/ -}}

{{ range $model := $.DMMF.Datamodel.Models }}
{{ $name := $model.Name.GoLowerCase }}
{{ $ns := (print $name "Actions") }}
{{ $result := (print $name "AggregateRaw") }}

type {{ $result }} struct {
query builder.Query
}

func (r {{ $result }}) getQuery() builder.Query {
return r.query
}

func (r {{ $result }}) ExtractQuery() builder.Query {
return r.query
}

func (r {{ $result }}) with() {}
func (r {{ $result }}) {{ $model.Name.GoLowerCase }}Model() {}
func (r {{ $result }}) {{ $model.Name.GoLowerCase }}Relation() {}

func (r {{ $ns }}) FindRaw(filter interface{}, options ...interface{}) {{ $result }} {
var v {{ $result }}
v.query = builder.NewQuery()
v.query.Engine = r.client
v.query.Method = "findRaw"
v.query.Operation = "query"
v.query.Model = "{{ $model.Name.String }}"

v.query.Inputs = append(v.query.Inputs, builder.Input{
Name: "filter",
Value: fmt.Sprintf("%v", filter),
})

if len(options) > 0 {
v.query.Inputs = append(v.query.Inputs, builder.Input{
Name: "options",
Value: fmt.Sprintf("%v", options[0]),
})
}
return v
}

func (r {{ $ns }}) AggregateRaw(pipeline []interface{}, options ...interface{}) {{ $result }} {
var v {{ $result }}
v.query = builder.NewQuery()
v.query.Engine = r.client
v.query.Method = "aggregateRaw"
v.query.Operation = "query"
v.query.Model = "{{ $model.Name.String }}"

parsedPip := []interface{}{}
for _, p := range pipeline {
parsedPip = append(parsedPip, fmt.Sprintf("%v", p))
}

v.query.Inputs = append(v.query.Inputs, builder.Input{
Name: "pipeline",
Value: parsedPip,
})

if len(options) > 0 {
v.query.Inputs = append(v.query.Inputs, builder.Input{
Name: "options",
Value: fmt.Sprintf("%v", options[0]),
})
}
return v
}

func (r {{ $result }}) Exec(ctx context.Context) ([]{{ $model.Name.GoCase }}Model, error) {
var v []{{ $model.Name.GoCase }}Model
if err := r.query.Exec(ctx, &v); err != nil {
return nil, err
}
return v, nil
}

func (r {{ $result }}) ExecInner(ctx context.Context) ([]Inner{{ $model.Name.GoCase }}, error) {
var v []Inner{{ $model.Name.GoCase }}
if err := r.query.Exec(ctx, &v); err != nil {
return nil, err
}
return v, nil
}
{{ end }}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
go.mongodb.org/mongo-driver/v2 v2.0.0-beta2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.mongodb.org/mongo-driver/v2 v2.0.0-beta2 h1:PRtbRKwblE8ZfI8qOhofcjn9y8CmKZI7trS5vDMeJX0=
go.mongodb.org/mongo-driver/v2 v2.0.0-beta2/go.mod h1:UGLb3ZgEzaY0cCbJpH9UFt9B6gEXiTPzsnJS38nBeoU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
30 changes: 26 additions & 4 deletions runtime/builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ import (
"github.com/steebchen/prisma-client-go/logger"
)

type MethodFormat string

const (
FindRaw MethodFormat = "findRaw"
AggregateRaw MethodFormat = "aggregateRaw"
)

var (
MethodFormatMaping = map[MethodFormat]string{
FindRaw: "find%sRaw", // find{Model}Raw
AggregateRaw: "aggregate%sRaw", // aggregate{Model}Raw
}
)

type Input struct {
Name string
Fields []Field
Expand Down Expand Up @@ -100,8 +114,14 @@ func (q Query) Build() (string, error) {

func (q Query) BuildInner() (string, error) {
var builder strings.Builder

builder.WriteString(q.Method + q.Model)
switch MethodFormat(q.Method) {
case FindRaw:
builder.WriteString(fmt.Sprintf(MethodFormatMaping[FindRaw], q.Model))
case AggregateRaw:
builder.WriteString(fmt.Sprintf(MethodFormatMaping[AggregateRaw], q.Model))
default:
builder.WriteString(q.Method + q.Model)
}

if len(q.Inputs) > 0 {
str, err := q.buildInputs(q.Inputs)
Expand Down Expand Up @@ -129,7 +149,7 @@ func (q Query) buildInputs(inputs []Input) (string, error) {

builder.WriteString("(")

for _, i := range inputs {
for index, i := range inputs {
builder.WriteString(i.Name)

builder.WriteString(":")
Expand All @@ -150,7 +170,9 @@ func (q Query) buildInputs(inputs []Input) (string, error) {
}
}

builder.WriteString(",")
if index < len(inputs)-1 {
builder.WriteString(",")
}
}

builder.WriteString(")")
Expand Down
14 changes: 14 additions & 0 deletions runtime/raw/raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ func doRaw(engine engine.Engine, action string, query string, params ...interfac
return q
}

func doCommandRaw(engine engine.Engine, action string, cmd string) builder.Query {
q := builder.NewQuery()
q.Engine = engine
q.Operation = "mutation"
q.Method = action

q.Inputs = append(q.Inputs, builder.Input{
Name: "command",
Value: cmd,
})

return q
}

func convertType(input interface{}) string {
data, err := json.Marshal(input)
if err != nil {
Expand Down
37 changes: 37 additions & 0 deletions runtime/raw/run_command_raw.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package raw

import (
"context"
"fmt"

"github.com/steebchen/prisma-client-go/runtime/builder"
)

func (r Raw) RunCommandRaw(cmd interface{}) RunCommandExec {
return RunCommandExec{
query: doCommandRaw(r.Engine, "runCommandRaw", fmt.Sprintf("%v", cmd)),
}
}

type RunCommandExec struct {
query builder.Query
}

func (r RunCommandExec) ExtractQuery() builder.Query {
return r.query
}

func (r RunCommandExec) Tx() TxQueryResult {
v := NewTxQueryResult()
v.query = r.query
v.query.TxResult = make(chan []byte, 1)
return v
}

func (r RunCommandExec) Exec(ctx context.Context, into interface{}) error {
if err := r.query.Exec(ctx, &into); err != nil {
return err
}

return nil
}

0 comments on commit 6b25039

Please sign in to comment.