diff --git a/internal/mopertest/aggregation_test.go b/internal/mopertest/aggregation_test.go index 342b39a..736c464 100644 --- a/internal/mopertest/aggregation_test.go +++ b/internal/mopertest/aggregation_test.go @@ -2,55 +2,75 @@ package mopertest import ( "context" + "encoding/json" "fmt" "testing" "github.com/func25/mongofunc/mocom" "github.com/func25/mongofunc/moper" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) -func TestAggregationTest(t *testing.T) { - matchStage := bson.D{ - { - Key: "$match", - Value: moper.D{}.InEll("damage", 1, 2, 3, 4, 5, 6, 7, 8, 9), - }, +func TestAggregation(t *testing.T) { + intArr := []int{1, 2, 3, 4, 5, 6, 7, 8, 9} + + matchStage := moper.D{}.MatchD(moper.D{}.InArray("damage", intArr)) + groupStage := moper.D{}.Group( + moper.P{K: "_id", V: nil}, + moper.P{K: "total", V: moper.D{}.Sum("damage")}, + ) + + req := &mocom.AggregationRequest[Hero]{ + Pipeline: []moper.D{matchStage, groupStage}, + Options: []*options.AggregateOptions{}, + } + result, err := mocom.Aggregate(context.Background(), req) + if err != nil { + t.Error(err) + return } - groupStage := bson.D{ - { - Key: "$group", - Value: bson.D{ - { - Key: "_id", - Value: nil, - }, - { - Key: "total", - Value: bson.D{ - { - Key: "$sum", - Value: "$damage", - }, - }, - }, - }, - }, + expect := 0 + for _, v := range intArr { + expect += v * v } + if int(result[0]["total"].(int32)) != expect { + t.Error("wrong result", result[0]["total"], expect) + } +} + +func TestLookup(t *testing.T) { + intArr := []int{1} + matchStage := moper.D{}.MatchD(moper.D{}.InArray("damage", intArr)) + + lookupStage := moper.D{}.LookUp(). + From(Weapon{}.CollName()). + LocalField("damage"). + ForeignField("damage"). + As("weapon") + + unwindStage := moper.D{}.Equal("$unwind", moper.D{}.Equal("path", "$weapon").Equal("preserveNullAndEmptyArrays", false)) + req := &mocom.AggregationRequest[Hero]{ - Pipeline: mongo.Pipeline{matchStage, groupStage}, + Pipeline: []moper.D{matchStage, lookupStage.D(), unwindStage}, Options: []*options.AggregateOptions{}, - Result: []primitive.M{}, } - if err := mocom.Aggregate(context.Background(), req); err != nil { + result, err := mocom.Aggregate(context.Background(), req) + if err != nil { t.Error(err) return } - fmt.Println(req.Result) + x, err := json.Marshal(result) + fmt.Println(string(x)) + + // expect := 0 + // for _, v := range intArr { + // expect += v * v + // } + + // if result[0]["total"] != expect { + // t.Error("wrong result", result[0]["total"], expect) + // } } diff --git a/internal/mopertest/element_test.go b/internal/mopertest/element_test.go index a8a2a31..91ebf50 100644 --- a/internal/mopertest/element_test.go +++ b/internal/mopertest/element_test.go @@ -10,7 +10,6 @@ import ( func TestExists(t *testing.T) { ctx := context.Background() - filter := moper.D{}.Exists("omit", true) if count, err := mocom.Count[Hero](ctx, filter); err != nil { t.Error("[TestExists]", err) diff --git a/internal/mopertest/seed.go b/internal/mopertest/seed.go index 7c8aae5..19404d2 100644 --- a/internal/mopertest/seed.go +++ b/internal/mopertest/seed.go @@ -6,18 +6,28 @@ import ( "strconv" "github.com/func25/mongofunc/mocom" - "go.mongodb.org/mongo-driver/bson/primitive" ) -type Hero struct { - mocom.ID[primitive.ObjectID] `bson:",inline"` - Name string `bson:"name"` - Damage int `bson:"damage"` - SkillIds []int `bson:"skillIds"` - Omit bool `bson:"omit,omitempty"` +// weapon +type Weapon struct { + mocom.ID `bson:",inline"` + Type int `json:"type" bson:"type"` + Damage int `json:"damage" bson:"damage"` +} + +func (Weapon) CollName() string { + return "Weapons" } -const COLLECTION_NAME = "Heroes" +// hero +type Hero struct { + mocom.ID `bson:",inline"` + WeaponID interface{} `bson:"weaponId"` + Name string `bson:"name"` + Damage int `bson:"damage"` + SkillIds []int `bson:"skillIds"` + Omit bool `bson:"omit,omitempty"` +} var ( ROUND = 10 @@ -50,6 +60,18 @@ func init() { //Seed create 1 hero has 1 damage, 2 heroes have 2 damages,... until n (n == 10) func Seed(ctx context.Context, n int) error { count := 0 + weapons := []*Weapon{} + for i := 0; i < 3; i++ { + x := &Weapon{ + Type: i, + Damage: 1, + } + err := mocom.CreateWithID(ctx, x) + if err != nil { + return err + } + weapons = append(weapons, x) + } for i := 0; i < n; i++ { for j := 0; j <= i; j, count = j+1, count+1 { @@ -58,6 +80,7 @@ func Seed(ctx context.Context, n int) error { Damage: i + 1, SkillIds: []int{1, 2, 3, 4, 5}, Omit: j == i, + WeaponID: weapons[j%3].ID.ID, }) if err != nil { return err @@ -70,5 +93,6 @@ func Seed(ctx context.Context, n int) error { func Clear(ctx context.Context) error { _, err := mocom.Flush[Hero](ctx) + _, err = mocom.Flush[Weapon](ctx) return err } diff --git a/mocom/create.go b/mocom/create.go index a8ba48b..f9784fa 100644 --- a/mocom/create.go +++ b/mocom/create.go @@ -6,11 +6,21 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -func Create[T MongoModel](ctx context.Context, model *T, opts ...*options.InsertOneOptions) (interface{}, error) { - col := collWrite((*model).CollName()) +func Create[T Model](ctx context.Context, model T, opts ...*options.InsertOneOptions) (interface{}, error) { + col := collWrite(model.CollName()) if result, err := col.InsertOne(ctx, model, opts...); err != nil { return nil, err } else { return result.InsertedID, nil } } + +func CreateWithID[T IDModel](ctx context.Context, model T, opts ...*options.InsertOneOptions) error { + col := collWrite(model.CollName()) + if result, err := col.InsertOne(ctx, model, opts...); err != nil { + return err + } else { + model.SetID(result.InsertedID) + return nil + } +} diff --git a/mocom/find.go b/mocom/find.go index cdd314e..533ebc5 100644 --- a/mocom/find.go +++ b/mocom/find.go @@ -6,7 +6,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -func Find[T MongoModel](ctx context.Context, filter interface{}, opts ...*options.FindOptions) (res []T, err error) { +func Find[T Model](ctx context.Context, filter interface{}, opts ...*options.FindOptions) (res []T, err error) { var t T cur, err := collRead(t.CollName()).Find(ctx, filter, opts...) if err != nil { @@ -17,7 +17,7 @@ func Find[T MongoModel](ctx context.Context, filter interface{}, opts ...*option return res, err } -func FindOne[T MongoModel](ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) (res *T, err error) { +func FindOne[T Model](ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) (res *T, err error) { res = new(T) cur := collRead((*res).CollName()).FindOne(ctx, filter, opts...) err = cur.Decode(&res) diff --git a/mocom/mocom.go b/mocom/mocom.go index 5344ba1..6525d2e 100644 --- a/mocom/mocom.go +++ b/mocom/mocom.go @@ -6,39 +6,39 @@ import ( "time" "github.com/func25/mongofunc/moper" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" ) -func Count[T MongoModel](ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int64, error) { +func Count[T Model](ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int64, error) { var t T return collRead(t.CollName()).CountDocuments(ctx, filter, opts...) } -func EstimatedCount[T MongoModel](ctx context.Context, opts ...*options.EstimatedDocumentCountOptions) (int64, error) { +func EstimatedCount[T Model](ctx context.Context, opts ...*options.EstimatedDocumentCountOptions) (int64, error) { var t T return collRead(t.CollName()).EstimatedDocumentCount(ctx, opts...) } -func Aggregate[T MongoModel](ctx context.Context, req *AggregationRequest[T]) error { +func Aggregate[T Model](ctx context.Context, req *AggregationRequest[T]) (res []bson.M, err error) { var t T col := db.Collection(t.CollName()) cursor, err := col.Aggregate(ctx, req.Pipeline, req.Options...) if err != nil { - return err + return nil, err } - // fmt.Println(cursor) - - err = cursor.All(ctx, &req.Result) + err = cursor.All(ctx, &res) - return err + return res, err } //Flush clears all records of collection and return number of deleted records -func Flush[T MongoModel](ctx context.Context) (int64, error) { +func Flush[T Model](ctx context.Context) (int64, error) { var t T result, err := db.Collection(t.CollName()).DeleteMany(ctx, moper.D{}) if err != nil { @@ -63,14 +63,17 @@ func Tx(ctx context.Context, cfg TransactionConfig) (interface{}, error) { return session.WithTransaction(ctx, cfg.Func, cfg.Options) } -// TxOptimal will do the transaction with majority write-concern and local read-concern, client default read pref +// TxOptimal will do the transaction with majority write-concern +// snapshot read-concern, primary read preference +// +// This should be used when transaction does not contain any read func TxOptimal(ctx context.Context, f func(ctx mongo.SessionContext) (interface{}, error)) (interface{}, error) { if client == nil { return nil, errors.New("client is nil, please using mocom to create connection to mongo server or using your own client connection") } wc := writeconcern.New(writeconcern.WMajority(), writeconcern.WTimeout(5*time.Second)) - opts := options.Transaction().SetReadConcern(readconcern.Local()).SetWriteConcern(wc) + opts := options.Transaction().SetReadConcern(readconcern.Snapshot()).SetWriteConcern(wc).SetReadPreference(readpref.Primary()) session, err := client.StartSession() if err != nil { diff --git a/mocom/model.go b/mocom/model.go index 03930ae..064dc69 100644 --- a/mocom/model.go +++ b/mocom/model.go @@ -1,22 +1,30 @@ package mocom import ( - "go.mongodb.org/mongo-driver/bson" + "github.com/func25/mongofunc/moper" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) -type MongoModel interface { +type Model interface { CollName() string } -type ID[T any] struct { - ID T `json:"id" bson:"_id,omitempty"` +type IDModel interface { + Model + SetID(t interface{}) } -type AggregationRequest[T MongoModel] struct { - Result []bson.M - Pipeline mongo.Pipeline +type ID struct { + ID interface{} `json:"id" bson:"_id,omitempty"` +} + +func (id *ID) SetID(t interface{}) { + id.ID = t +} + +type AggregationRequest[T Model] struct { + Pipeline []moper.D Options []*options.AggregateOptions } diff --git a/mocom/update.go b/mocom/update.go index e321e7b..ccfea15 100644 --- a/mocom/update.go +++ b/mocom/update.go @@ -7,12 +7,12 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -func UpdateOne[T MongoModel](ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { +func UpdateOne[T Model](ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { var t T return collWrite(t.CollName()).UpdateOne(ctx, filter, update, opts...) } -func UpdateAndReturn[T MongoModel](ctx context.Context, filter interface{}, update interface{}, opts ...*options.FindOneAndUpdateOptions) (ptrT *T, err error) { +func UpdateAndReturn[T Model](ctx context.Context, filter interface{}, update interface{}, opts ...*options.FindOneAndUpdateOptions) (ptrT *T, err error) { ptrT = new(T) res := collWrite((*ptrT).CollName()).FindOneAndUpdate(ctx, filter, update, opts...) @@ -24,7 +24,7 @@ func UpdateAndReturn[T MongoModel](ctx context.Context, filter interface{}, upda return } -func UpdateMany[T MongoModel](ctx context.Context, filter interface{}, update interface{}) (*mongo.UpdateResult, error) { +func UpdateMany[T Model](ctx context.Context, filter interface{}, update interface{}) (*mongo.UpdateResult, error) { var t T return collWrite(t.CollName()).UpdateMany(ctx, filter, update) } diff --git a/moper/aggregate.go b/moper/aggregate.go new file mode 100644 index 0000000..1c766b0 --- /dev/null +++ b/moper/aggregate.go @@ -0,0 +1,21 @@ +package moper + +func (d D) Match(pairs ...P) D { + return d.Equal("$match", toPair(pairs)) +} + +func (d D) MatchD(pair D) D { + return d.Equal("$match", pair) +} + +func (d D) Group(pairs ...P) D { + return d.Equal("$group", toPair(pairs)) +} + +func (d D) GroupD(pair D) D { + return d.Equal("$group", pair) +} + +func (d D) Sum(fieldName string) D { + return d.Equal("$sum", "$"+fieldName) +} diff --git a/moper/aggregate_lookup.go b/moper/aggregate_lookup.go new file mode 100644 index 0000000..dab9128 --- /dev/null +++ b/moper/aggregate_lookup.go @@ -0,0 +1,50 @@ +package moper + +type LU struct { + d D +} + +// Specifies the foreign collection in the same database to join to the local collection. +// The foreign collection cannot be sharded +func (l *LU) From(collName string) *LU { + l.d = l.d.Equal("from", collName) + return l +} + +// Specifies the local documents' localField to perform an equality match with the foreign documents' foreignField. +func (l *LU) LocalField(field string) *LU { + l.d = l.d.Equal("localField", field) + return l +} + +// Specifies the foreign documents' foreignField to perform an equality match with the local documents' localField. +// If a foreign document does not contain a foreignField value, the $lookup uses a null value for the match. +func (l *LU) ForeignField(field string) *LU { + l.d = l.d.Equal("foreignField", field) + return l +} + +func (l *LU) Custom(p P) *LU { + l.d = l.d.Equal(p.K, p.V) + return l +} + +func (l *LU) D() D { + return D{}.Equal("$lookup", l.d) +} + +// Optional, Specifies the variables to use in the pipeline stages. +// Use the variable expressions to access the document fields that are input to the pipeline. +// func (l *LookUp) Let(collName string) *LookUp { +// l.d = l.d.Equal("let", collName) +// return l +// } + +func (l *LU) As(field string) *LU { + l.d = l.d.Equal("as", field) + return l +} + +func (d D) LookUp() *LU { + return &LU{d: d} +}