diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 00188b9921c..b301722667f 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -131,6 +131,9 @@ type Generate struct { // will be stored as a list within the PlanValue. New // values will be generated based on how many were not // supplied (NULL). + + // Need to distiguish between sequence or snowflake type + Type string Values sqltypes.PlanValue } @@ -289,6 +292,27 @@ func shouldGenerate(v sqltypes.Value) bool { return false } +const ( + TimestampLength uint8 = 41 + MachineIDLength uint8 = 10 + SequenceLength uint8 = 12 + MaxSequence int64 = 1<> int64(SequenceLength+MachineIDLength)) + SnowflakeStartTime.UTC().UnixNano()/1e6 + sequence := cur & int64(MaxSequence) + machineID := (cur & (int64(MaxMachineID) << SequenceLength)) >> SequenceLength + for i, v := range resolved { + fmt.Println(fmt.Sprintf("Generating Snowflake, %s id %d", ins.GetTableName(), cur)) + if shouldGenerate(v) { + bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.Int64BindVariable(cur) + // calculate next id and advance ts and sequence + totalInc := sequence + 1 + ts := ts + totalInc/MaxSequence + sequence = totalInc % MaxSequence + // TODO: generate next id properly for snowflake + df := elapsedTime(ts, SnowflakeStartTime) + cur = int64((uint64(df) << uint64(timestampMoveLength)) | (uint64(machineID) << uint64(machineIDMoveLength)) | uint64(sequence)) + } else { + bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.ValueBindVariable(v) + } + } + return insertID, nil + } + // For Sequence cur := insertID for i, v := range resolved { if shouldGenerate(v) { @@ -609,8 +658,8 @@ func (ins *Insert) processUnowned(vcursor VCursor, vindexColumnsKeys [][]sqltype return nil } -//InsertVarName returns a name for the bind var for this column. This method is used by the planner and engine, -//to make sure they both produce the same names +// InsertVarName returns a name for the bind var for this column. This method is used by the planner and engine, +// to make sure they both produce the same names func InsertVarName(col sqlparser.ColIdent, rowNum int) string { return fmt.Sprintf("_%s_%d", col.CompliantName(), rowNum) } diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index d8c5d69a6a7..0a5c4a7fb01 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -29,6 +29,62 @@ import ( vschemapb "vitess.io/vitess/go/vt/proto/vschema" ) +func TestInsertUnshardedSnowflakeGenerate(t *testing.T) { + ins := NewQueryInsert( + InsertUnsharded, + &vindexes.Keyspace{ + Name: "ks", + Sharded: false, + }, + "dummy_insert", + ) + ins.Generate = &Generate{ + Keyspace: &vindexes.Keyspace{ + Name: "ks2", + Sharded: false, + }, + Query: "dummy_generate", + Type: "snowflake", + Values: sqltypes.PlanValue{ + Values: []sqltypes.PlanValue{ + {Value: sqltypes.NewInt64(1)}, + {Value: sqltypes.NULL}, + {Value: sqltypes.NewInt64(2)}, + {Value: sqltypes.NULL}, + {Value: sqltypes.NewInt64(3)}, + }, + }, + } + + vc := newDMLTestVCursor("0") + vc.results = []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "nextval", + "int64", + ), + "4", + ), + {InsertID: 1}, + } + + result, err := ins.Execute(vc, map[string]*querypb.BindVariable{}, false) + if err != nil { + t.Fatal(err) + } + vc.ExpectLog(t, []string{ + // Fetch two sequence value. + `ResolveDestinations ks2 [] Destinations:DestinationAnyShard()`, + `ExecuteStandalone dummy_generate n: type:INT64 value:"2" ks2 0`, + // Fill those values into the insert. + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: dummy_insert {__seq0: type:INT64 value:"1" __seq1: type:INT64 value:"4" __seq2: type:INT64 value:"2" __seq3: type:INT64 value:"5" __seq4: type:INT64 value:"3"} true true`, + }) + + // The insert id returned by ExecuteMultiShard should be overwritten by processGenerate. + expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 4}) +} + func TestInsertUnsharded(t *testing.T) { ins := NewQueryInsert( InsertUnsharded, diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index ab08f72951e..cda6c2113ed 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -249,7 +249,7 @@ func Exists(m Match, p Primitive) bool { return Find(m, p) != nil } -//MarshalJSON serializes the plan into a JSON representation. +// MarshalJSON serializes the plan into a JSON representation. func (p *Plan) MarshalJSON() ([]byte, error) { var instructions *PrimitiveDescription if p.Instructions != nil { diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 2364944d316..52c21a90c14 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -178,7 +178,7 @@ const ( // SelectScatter is for routing a scatter query // to all shards of a keyspace. SelectScatter - // SelectNext is for fetching from a sequence. + // SelectNext is for fetching from a sequence or snowflake. SelectNext // SelectDBA is for executing a DBA statement. SelectDBA diff --git a/go/vt/vtgate/planbuilder/from.go b/go/vt/vtgate/planbuilder/from.go index 54fb99a2852..9123267e85b 100644 --- a/go/vt/vtgate/planbuilder/from.go +++ b/go/vt/vtgate/planbuilder/from.go @@ -255,6 +255,9 @@ func (pb *primitiveBuilder) buildTablePrimitive(tableExpr *sqlparser.AliasedTabl switch { case vschemaTable.Type == vindexes.TypeSequence: eroute = engine.NewSimpleRoute(engine.SelectNext, vschemaTable.Keyspace) + // TODO: snowflake + case vschemaTable.Type == vindexes.TypeSnowflake: + eroute = engine.NewSimpleRoute(engine.SelectNext, vschemaTable.Keyspace) case vschemaTable.Type == vindexes.TypeReference: eroute = engine.NewSimpleRoute(engine.SelectReference, vschemaTable.Keyspace) case !vschemaTable.Keyspace.Sharded: diff --git a/go/vt/vtgate/planbuilder/insert.go b/go/vt/vtgate/planbuilder/insert.go index f0af6ae618e..3427b50d631 100644 --- a/go/vt/vtgate/planbuilder/insert.go +++ b/go/vt/vtgate/planbuilder/insert.go @@ -239,9 +239,11 @@ func modifyForAutoinc(ins *sqlparser.Insert, eins *engine.Insert) error { row[colNum] = sqlparser.NewArgument(engine.SeqVarName + strconv.Itoa(rowNum)) } + // TODO: Here query is generated for Snoflake eins.Generate = &engine.Generate{ Keyspace: eins.Table.AutoIncrement.Sequence.Keyspace, Query: fmt.Sprintf("select next :n values from %s", sqlparser.String(eins.Table.AutoIncrement.Sequence.Name)), + Type: eins.Table.AutoIncrement.Sequence.Type, Values: autoIncValues, } return nil diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index dc687c1b828..17300872f77 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -425,10 +425,10 @@ type ( ) /* - The greedy planner will plan a query by finding first finding the best route plan for every table. - Then, iteratively, it finds the cheapest join that can be produced between the remaining plans, - and removes the two inputs to this cheapest plan and instead adds the join. - As an optimization, it first only considers joining tables that have predicates defined between them + The greedy planner will plan a query by finding first finding the best route plan for every table. + Then, iteratively, it finds the cheapest join that can be produced between the remaining plans, + and removes the two inputs to this cheapest plan and instead adds the join. + As an optimization, it first only considers joining tables that have predicates defined between them */ func greedySolve(qg *abstract.QueryGraph, semTable *semantics.SemTable, vschema ContextVSchema) (joinTree, error) { joinTrees, err := seedPlanList(qg, semTable, vschema) @@ -604,6 +604,9 @@ func createRoutePlan(table *abstract.QueryTable, solves semantics.TableSet, vsch switch { case vschemaTable.Type == vindexes.TypeSequence: plan.routeOpCode = engine.SelectNext + // TODO: Snowflake + case vschemaTable.Type == vindexes.TypeSnowflake: + plan.routeOpCode = engine.SelectNext case vschemaTable.Type == vindexes.TypeReference: plan.routeOpCode = engine.SelectReference case !vschemaTable.Keyspace.Sharded: diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index ec07b042480..f95e3376b09 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -51,6 +51,7 @@ var TabletTypeSuffix = map[topodatapb.TabletType]string{ // The following constants represent table types. const ( TypeSequence = "sequence" + TypeSnowflake = "snowflake" TypeReference = "reference" ) @@ -156,6 +157,7 @@ func (ks *KeyspaceSchema) MarshalJSON() ([]byte, error) { } // AutoIncrement contains the auto-inc information for a table. +// TODO: We reuse same field for Snowflake and Sequence tables. type AutoIncrement struct { Column sqlparser.ColIdent `json:"column"` Sequence *Table `json:"sequence"` @@ -177,7 +179,7 @@ func BuildVSchema(source *vschemapb.SrvVSchema) (vschema *VSchema) { } // BuildKeyspaceSchema builds the vschema portion for one keyspace. -// The build ignores sequence references because those dependencies can +// The build ignores sequence/snowflake references because those dependencies can // go cross-keyspace. func BuildKeyspaceSchema(input *vschemapb.Keyspace, keyspace string) (*KeyspaceSchema, error) { if input == nil { @@ -199,7 +201,7 @@ func BuildKeyspaceSchema(input *vschemapb.Keyspace, keyspace string) (*KeyspaceS } // ValidateKeyspace ensures that the keyspace vschema is valid. -// External references (like sequence) are not validated. +// External references (like sequence/snowflake) are not validated. func ValidateKeyspace(input *vschemapb.Keyspace) error { _, err := BuildKeyspaceSchema(input, "") return err @@ -252,6 +254,11 @@ func buildTables(ks *vschemapb.Keyspace, vschema *VSchema, ksvschema *KeyspaceSc return fmt.Errorf("sequence table has to be in an unsharded keyspace or must be pinned: %s", tname) } t.Type = table.Type + case TypeSnowflake: + if keyspace.Sharded && table.Pinned == "" { + return fmt.Errorf("snowflake table has to be in an unsharded keyspace or must be pinned: %s", tname) + } + t.Type = table.Type default: return fmt.Errorf("unidentified table type %s", table.Type) } @@ -354,6 +361,7 @@ func resolveAutoIncrement(source *vschemapb.SrvVSchema, vschema *VSchema) { if t == nil || table.AutoIncrement == nil { continue } + // TODO: Should I check for type here seqks, seqtab, err := sqlparser.ParseTable(table.AutoIncrement.Sequence) var seq *Table if err == nil { diff --git a/go/vt/vttablet/tabletserver/planbuilder/builder.go b/go/vt/vttablet/tabletserver/planbuilder/builder.go index 528825ef3d2..f7f7f8c7d21 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/builder.go +++ b/go/vt/vttablet/tabletserver/planbuilder/builder.go @@ -17,6 +17,7 @@ limitations under the License. package planbuilder import ( + "fmt" "strings" "vitess.io/vitess/go/vt/sqlparser" @@ -44,17 +45,32 @@ func analyzeSelect(sel *sqlparser.Select, tables map[string]*schema.Table) (plan // Check if it's a NEXT VALUE statement. if nextVal, ok := sel.SelectExprs[0].(*sqlparser.Nextval); ok { - if plan.Table == nil || plan.Table.Type != schema.Sequence { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%s is not a sequence", sqlparser.ToString(sel.From)) + if plan.Table == nil || (plan.Table.Type != schema.Sequence && plan.Table.Type != schema.Snowflake) { + fmt.Println("plan.Table.Type", plan.Table.Type, schema.Snowflake) + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%s is not a sequence or snowflake", sqlparser.ToString(sel.From)) } - plan.PlanID = PlanNextval - v, err := sqlparser.NewPlanValue(nextVal.Expr) - if err != nil { - return nil, err + + switch plan.Table.Type { + case schema.Sequence: + plan.PlanID = PlanNextval + v, err := sqlparser.NewPlanValue(nextVal.Expr) + if err != nil { + return nil, err + } + plan.NextCount = v + plan.FieldQuery = nil + plan.FullQuery = nil + case schema.Snowflake: + // should be different from sequence? + plan.PlanID = PlanNextval + v, err := sqlparser.NewPlanValue(nextVal.Expr) + if err != nil { + return nil, err + } + plan.NextCount = v + plan.FieldQuery = nil + plan.FullQuery = nil } - plan.NextCount = v - plan.FieldQuery = nil - plan.FullQuery = nil } return plan, nil } diff --git a/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt b/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt index e7ec2cd817d..795a0532f4f 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt +++ b/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt @@ -139,6 +139,52 @@ "FullQuery": "select :bv from a where 1 != 1 limit :#maxLimit" } +# single value snowflake +"select next value from snow" +{ + "PlanID": "Nextval", + "TableName": "snow", + "Permissions": [ + { + "TableName": "snow", + "Role": 0 + } + ], + "NextCount": "1" +} + +# snowflake with number +"select next 10 values from snow" +{ + "PlanID": "Nextval", + "TableName": "snow", + "Permissions": [ + { + "TableName": "snow", + "Role": 0 + } + ], + "NextCount": "10" +} + +# snowflake with bindvar +"select next :a values from snow" +{ + "PlanID": "Nextval", + "TableName": "snow", + "Permissions": [ + { + "TableName": "snow", + "Role": 0 + } + ], + "NextCount": "\":a\"" +} + +# snowflake with bad value +"select next 12345667852342342342323423423 values from snow" +"strconv.ParseUint: parsing "12345667852342342342323423423": value out of range" + # single value sequence "select next value from seq" { @@ -188,11 +234,11 @@ # nextval on non-sequence table "select next value from a" -"a is not a sequence" +"a is not a sequence or snowflake" # nextval on non-existent table "select next value from id" -"id is not a sequence" +"id is not a sequence or snowflake" # for update "select eid from a for update" diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index c4c43d6c79e..58b7af51cb1 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -91,6 +91,13 @@ var sequenceFields = []*querypb.Field{ }, } +var snowflakeFields = []*querypb.Field{ + { + Name: "nextval", + Type: sqltypes.Int64, + }, +} + func (qre *QueryExecutor) shouldConsolidate() bool { cm := qre.tsv.qe.consolidatorMode.Get() return cm == tabletenv.Enable || (cm == tabletenv.NotOnMaster && qre.tabletType != topodatapb.TabletType_MASTER) @@ -519,70 +526,115 @@ func (*QueryExecutor) BeginAgain(ctx context.Context, dc *StatefulConnection) er return nil } +// currentMillis get current millisecond. +func currentMillis() int64 { + return time.Now().UTC().UnixNano() / 1e6 +} + func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { inc, err := resolveNumber(qre.plan.NextCount, qre.bindVars) if err != nil { return nil, err } tableName := qre.plan.TableName() - if inc < 1 { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid increment for sequence %s: %d", tableName, inc) - } + // check if snowflake - t := qre.plan.Table - t.SequenceInfo.Lock() - defer t.SequenceInfo.Unlock() - if t.SequenceInfo.NextVal == 0 || t.SequenceInfo.NextVal+inc > t.SequenceInfo.LastVal { - _, err := qre.execAsTransaction(func(conn *StatefulConnection) (*sqltypes.Result, error) { - query := fmt.Sprintf("select next_id, cache from %s where id = 0 for update", sqlparser.String(tableName)) - qr, err := qre.execStatefulConn(conn, query, false) + if inc < 1 { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid increment for sequence or snowflake %s: %d", tableName, inc) + } + + var ret int64 + if qre.plan.Table.SequenceInfo != nil { + t := qre.plan.Table + t.SequenceInfo.Lock() + defer t.SequenceInfo.Unlock() + if t.SequenceInfo.NextVal == 0 || t.SequenceInfo.NextVal+inc > t.SequenceInfo.LastVal { + _, err := qre.execAsTransaction(func(conn *StatefulConnection) (*sqltypes.Result, error) { + query := fmt.Sprintf("select next_id, cache from %s where id = 0 for update", sqlparser.String(tableName)) + qr, err := qre.execStatefulConn(conn, query, false) + if err != nil { + return nil, err + } + if len(qr.Rows) != 1 { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected rows from reading sequence %s (possible mis-route): %d", tableName, len(qr.Rows)) + } + nextID, err := evalengine.ToInt64(qr.Rows[0][0]) + if err != nil { + return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) + } + // If LastVal does not match next ID, then either: + // VTTablet just started, and we're initializing the cache, or + // Someone reset the id underneath us. + if t.SequenceInfo.LastVal != nextID { + if nextID < t.SequenceInfo.LastVal { + log.Warningf("Sequence next ID value %v is below the currently cached max %v, updating it to max", nextID, t.SequenceInfo.LastVal) + nextID = t.SequenceInfo.LastVal + } + t.SequenceInfo.NextVal = nextID + t.SequenceInfo.LastVal = nextID + } + cache, err := evalengine.ToInt64(qr.Rows[0][1]) + if err != nil { + return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) + } + if cache < 1 { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid cache value for sequence %s: %d", tableName, cache) + } + newLast := nextID + cache + for newLast < t.SequenceInfo.NextVal+inc { + newLast += cache + } + query = fmt.Sprintf("update %s set next_id = %d where id = 0", sqlparser.String(tableName), newLast) + conn.TxProperties().RecordQuery(query) + _, err = qre.execStatefulConn(conn, query, false) + if err != nil { + return nil, err + } + t.SequenceInfo.LastVal = newLast + return nil, nil + }) if err != nil { return nil, err } - if len(qr.Rows) != 1 { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected rows from reading sequence %s (possible mis-route): %d", tableName, len(qr.Rows)) - } - nextID, err := evalengine.ToInt64(qr.Rows[0][0]) - if err != nil { - return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) - } - // If LastVal does not match next ID, then either: - // VTTablet just started, and we're initializing the cache, or - // Someone reset the id underneath us. - if t.SequenceInfo.LastVal != nextID { - if nextID < t.SequenceInfo.LastVal { - log.Warningf("Sequence next ID value %v is below the currently cached max %v, updating it to max", nextID, t.SequenceInfo.LastVal) - nextID = t.SequenceInfo.LastVal + } + ret = t.SequenceInfo.NextVal + t.SequenceInfo.NextVal += inc + + } else if qre.plan.Table.SnowflakeInfo != nil { + t := qre.plan.Table + t.SnowflakeInfo.Lock() + defer t.SnowflakeInfo.Unlock() + + // if MachineID is 0, then we need to initialize snowflake + if t.SnowflakeInfo.MachineID == 0 { + _, err := qre.execAsTransaction(func(conn *StatefulConnection) (*sqltypes.Result, error) { + query := fmt.Sprintf("select machine_id from %s where id = 0", sqlparser.String(tableName)) + qr, err := qre.execStatefulConn(conn, query, false) + if err != nil { + return nil, err } - t.SequenceInfo.NextVal = nextID - t.SequenceInfo.LastVal = nextID - } - cache, err := evalengine.ToInt64(qr.Rows[0][1]) - if err != nil { - return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) - } - if cache < 1 { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid cache value for sequence %s: %d", tableName, cache) - } - newLast := nextID + cache - for newLast < t.SequenceInfo.NextVal+inc { - newLast += cache - } - query = fmt.Sprintf("update %s set next_id = %d where id = 0", sqlparser.String(tableName), newLast) - conn.TxProperties().RecordQuery(query) - _, err = qre.execStatefulConn(conn, query, false) + if len(qr.Rows) != 1 { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected rows from reading snowflake %s (possible mis-route): %d", tableName, len(qr.Rows)) + } + machineID, err := evalengine.ToInt64(qr.Rows[0][0]) + if err != nil { + return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) + } + t.SnowflakeInfo.SetMachineID(machineID) + return nil, nil + }) if err != nil { return nil, err } - t.SequenceInfo.LastVal = newLast - return nil, nil - }) + } + + // Generate new id here, return it and update last val with overflow + nextID, err := t.SnowflakeInfo.NextNID(inc, currentMillis()) if err != nil { - return nil, err + return nil, vterrors.Wrapf(err, "error generating snowflake with NextNID(%d) %s", inc, tableName) } + ret = int64(nextID) } - ret := t.SequenceInfo.NextVal - t.SequenceInfo.NextVal += inc return &sqltypes.Result{ Fields: sequenceFields, Rows: [][]sqltypes.Value{{ diff --git a/go/vt/vttablet/tabletserver/query_executor_test.go b/go/vt/vttablet/tabletserver/query_executor_test.go index 8683d63b279..615557dd04d 100644 --- a/go/vt/vttablet/tabletserver/query_executor_test.go +++ b/go/vt/vttablet/tabletserver/query_executor_test.go @@ -43,6 +43,7 @@ import ( "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" + "vitess.io/vitess/go/vt/vttablet/tabletserver/schema" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" querypb "vitess.io/vitess/go/vt/proto/query" @@ -654,6 +655,66 @@ func TestQueryExecutorPlanNextval(t *testing.T) { } } +func compareSnowflake(t *testing.T, got *sqltypes.Result, wantTimestamp, wantMachineID, wantSequence int64) { + wantFields := []*querypb.Field{{ + Name: "nextval", + Type: sqltypes.Int64, + }} + assert.Equal(t, wantFields, got.Fields) + id, _ := got.Rows[0][0].ToInt64() + gotTimestamp := (id >> int64(schema.SequenceLength+schema.MachineIDLength)) + schema.SnowflakeStartTime.UTC().UnixNano()/1e6 + gotSequence := id & int64(schema.MaxSequence) + gotMachineID := (id & (int64(schema.MaxMachineID) << schema.SequenceLength)) >> schema.SequenceLength + fmt.Println(gotTimestamp, gotSequence, gotMachineID) + assert.Equal(t, wantSequence, gotSequence) + assert.Equal(t, wantMachineID, gotMachineID) + // this is a flaky test + assert.Equal(t, wantTimestamp, gotTimestamp) + +} +func TestQueryExecutorSnowflakePlanNextval(t *testing.T) { + db := setUpQueryExecutorTest(t) + defer db.Close() + selQuery := "select machine_id from snow where id = 0" + db.AddQuery(selQuery, &sqltypes.Result{ + Fields: []*querypb.Field{ + {Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(1), + }}, + }) + ctx := context.Background() + tsv := newTestTabletServer(ctx, noFlags, db) + defer tsv.StopService() + + // test single value + qre := newTestQueryExecutor(ctx, tsv, "select next value from snow", 0) + assert.Equal(t, planbuilder.PlanNextval, qre.plan.PlanID) + currentMS := currentMillis() + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + compareSnowflake(t, got, currentMS, 1, 0) + + // test overflow with multiple values + currentMS = currentMillis() + qre = newTestQueryExecutor(ctx, tsv, "select next 5000 values from snow", 0) + assert.Equal(t, planbuilder.PlanNextval, qre.plan.PlanID) + got, err = qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + compareSnowflake(t, got, currentMS, 1, 1) + got, err = qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + compareSnowflake(t, got, currentMS+1, 1, 906) + +} + func TestQueryExecutorMessageStreamACL(t *testing.T) { aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63()) tableacl.Register(aclName, &simpleacl.Factory{}) @@ -1204,6 +1265,7 @@ func initQueryExecutorTestDB(db *fakesqldb.DB) { Rows: [][]sqltypes.Value{ mysql.BaseShowTablesRow("test_table", false, ""), mysql.BaseShowTablesRow("seq", false, "vitess_sequence"), + mysql.BaseShowTablesRow("snow", false, "vitess_snowflake"), mysql.BaseShowTablesRow("msg", false, "vitess_message,vt_ack_wait=30,vt_purge_after=120,vt_batch_size=1,vt_cache_size=10,vt_poller_interval=30"), }, }) @@ -1297,6 +1359,7 @@ func getQueryExecutorSupportedQueries() map[string]*sqltypes.Result { Rows: [][]sqltypes.Value{ mysql.ShowPrimaryRow("test_table", "pk"), mysql.ShowPrimaryRow("seq", "id"), + mysql.ShowPrimaryRow("snow", "id"), mysql.ShowPrimaryRow("msg", "id"), }, }, @@ -1327,6 +1390,15 @@ func getQueryExecutorSupportedQueries() map[string]*sqltypes.Result { Type: sqltypes.Int64, }}, }, + "select * from snow where 1 != 1": { + Fields: []*querypb.Field{{ + Name: "id", + Type: sqltypes.Int32, + }, { + Name: "machine_id", + Type: sqltypes.Int64, + }}, + }, "select * from msg where 1 != 1": { Fields: []*querypb.Field{{ Name: "id", diff --git a/go/vt/vttablet/tabletserver/schema/cached_size.go b/go/vt/vttablet/tabletserver/schema/cached_size.go index e01571ce58b..5670ce3ef2e 100644 --- a/go/vt/vttablet/tabletserver/schema/cached_size.go +++ b/go/vt/vttablet/tabletserver/schema/cached_size.go @@ -44,6 +44,17 @@ func (cached *SequenceInfo) CachedSize(alloc bool) int64 { } return size } + +func (cached *SnowflakeInfo) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) // TODO: not sure about this + } + return size +} func (cached *Table) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -67,6 +78,8 @@ func (cached *Table) CachedSize(alloc bool) int64 { } // field SequenceInfo *vitess.io/vitess/go/vt/vttablet/tabletserver/schema.SequenceInfo size += cached.SequenceInfo.CachedSize(true) + // field SequenceInfo *vitess.io/vitess/go/vt/vttablet/tabletserver/schema.SnowflakeInfo + size += cached.SnowflakeInfo.CachedSize(true) // field MessageInfo *vitess.io/vitess/go/vt/vttablet/tabletserver/schema.MessageInfo size += cached.MessageInfo.CachedSize(true) return size diff --git a/go/vt/vttablet/tabletserver/schema/engine.go b/go/vt/vttablet/tabletserver/schema/engine.go index 82c6e328502..fd407b54e5e 100644 --- a/go/vt/vttablet/tabletserver/schema/engine.go +++ b/go/vt/vttablet/tabletserver/schema/engine.go @@ -253,6 +253,13 @@ func (se *Engine) MakeNonMaster() { t.SequenceInfo.LastVal = 0 t.SequenceInfo.Unlock() } + if t.SnowflakeInfo != nil { + t.SnowflakeInfo.Lock() + // We don't care about this, since each tablet has its own machine ID. + // t.SnowflakeInfo.NextVal = 0 + // t.SnowflakeInfo.LastVal = 0 + t.SnowflakeInfo.Unlock() + } } } diff --git a/go/vt/vttablet/tabletserver/schema/engine_test.go b/go/vt/vttablet/tabletserver/schema/engine_test.go index 6d7bb9b0a1a..a7aebc16513 100644 --- a/go/vt/vttablet/tabletserver/schema/engine_test.go +++ b/go/vt/vttablet/tabletserver/schema/engine_test.go @@ -63,6 +63,7 @@ func TestOpenAndReload(t *testing.T) { mysql.BaseShowTablesRow("test_table_02", false, ""), mysql.BaseShowTablesRow("test_table_03", false, ""), mysql.BaseShowTablesRow("seq", false, "vitess_sequence"), + mysql.BaseShowTablesRow("snow", false, "vitess_snowflake"), mysql.BaseShowTablesRow("msg", false, "vitess_message,vt_ack_wait=30,vt_purge_after=120,vt_batch_size=1,vt_cache_size=10,vt_poller_interval=30"), }, SessionStateChanges: "", @@ -114,6 +115,7 @@ func TestOpenAndReload(t *testing.T) { // test_table_04 will in spite of older timestamp because it doesn't exist yet. mysql.BaseShowTablesRow("test_table_04", false, ""), mysql.BaseShowTablesRow("seq", false, "vitess_sequence"), + mysql.BaseShowTablesRow("snow", false, "vitess_snowflake"), }, }) db.AddQuery("select * from test_table_03 where 1 != 1", &sqltypes.Result{ @@ -143,6 +145,7 @@ func TestOpenAndReload(t *testing.T) { mysql.ShowPrimaryRow("test_table_03", "pk2"), mysql.ShowPrimaryRow("test_table_04", "pk"), mysql.ShowPrimaryRow("seq", "id"), + mysql.ShowPrimaryRow("snow", "id"), }, }) secondReadRowsValue := 123 @@ -153,7 +156,7 @@ func TestOpenAndReload(t *testing.T) { if firstTime { firstTime = false sort.Strings(created) - assert.Equal(t, []string{"dual", "msg", "seq", "test_table_01", "test_table_02", "test_table_03"}, created) + assert.Equal(t, []string{"dual", "msg", "seq", "snow", "test_table_01", "test_table_02", "test_table_03"}, created) assert.Equal(t, []string(nil), altered) assert.Equal(t, []string(nil), dropped) } else { @@ -225,6 +228,7 @@ func TestOpenAndReload(t *testing.T) { mysql.BaseShowTablesRow("test_table_02", false, ""), mysql.BaseShowTablesRow("test_table_04", false, ""), mysql.BaseShowTablesRow("seq", false, "vitess_sequence"), + mysql.BaseShowTablesRow("snow", false, "vitess_snowflake"), }, }) db.AddQuery(mysql.BaseShowPrimary, &sqltypes.Result{ @@ -234,6 +238,7 @@ func TestOpenAndReload(t *testing.T) { mysql.ShowPrimaryRow("test_table_02", "pk"), mysql.ShowPrimaryRow("test_table_04", "pk"), mysql.ShowPrimaryRow("seq", "id"), + mysql.ShowPrimaryRow("snow", "id"), }, }) err = se.ReloadAt(context.Background(), pos1) @@ -262,6 +267,7 @@ func TestReloadWithSwappedTables(t *testing.T) { mysql.BaseShowTablesRow("test_table_02", false, ""), mysql.BaseShowTablesRow("test_table_03", false, ""), mysql.BaseShowTablesRow("seq", false, "vitess_sequence"), + mysql.BaseShowTablesRow("snow", false, "vitess_snowflake"), mysql.BaseShowTablesRow("msg", false, "vitess_message,vt_ack_wait=30,vt_purge_after=120,vt_batch_size=1,vt_cache_size=10,vt_poller_interval=30"), }, SessionStateChanges: "", @@ -299,6 +305,7 @@ func TestReloadWithSwappedTables(t *testing.T) { sqltypes.MakeTrusted(sqltypes.Int64, []byte("256")), // allocated_size }, mysql.BaseShowTablesRow("seq", false, "vitess_sequence"), + mysql.BaseShowTablesRow("snow", false, "vitess_snowflake"), mysql.BaseShowTablesRow("msg", false, "vitess_message,vt_ack_wait=30,vt_purge_after=120,vt_batch_size=1,vt_cache_size=10,vt_poller_interval=30"), }, }) @@ -316,6 +323,7 @@ func TestReloadWithSwappedTables(t *testing.T) { mysql.ShowPrimaryRow("test_table_03", "pk"), mysql.ShowPrimaryRow("test_table_04", "mypk"), mysql.ShowPrimaryRow("seq", "id"), + mysql.ShowPrimaryRow("snow", "id"), mysql.ShowPrimaryRow("msg", "id"), }, }) @@ -358,6 +366,7 @@ func TestReloadWithSwappedTables(t *testing.T) { }, mysql.BaseShowTablesRow("test_table_04", false, ""), mysql.BaseShowTablesRow("seq", false, "vitess_sequence"), + mysql.BaseShowTablesRow("snow", false, "vitess_snowflake"), mysql.BaseShowTablesRow("msg", false, "vitess_message,vt_ack_wait=30,vt_purge_after=120,vt_batch_size=1,vt_cache_size=10,vt_poller_interval=30"), }, }) @@ -381,6 +390,7 @@ func TestReloadWithSwappedTables(t *testing.T) { mysql.ShowPrimaryRow("test_table_03", "mypk"), mysql.ShowPrimaryRow("test_table_04", "pk"), mysql.ShowPrimaryRow("seq", "id"), + mysql.ShowPrimaryRow("snow", "id"), mysql.ShowPrimaryRow("msg", "id"), }, }) @@ -569,6 +579,27 @@ func initialSchema() map[string]*Table { AllocatedSize: 0x96, SequenceInfo: &SequenceInfo{}, }, + "snow": { + Name: sqlparser.NewTableIdent("snow"), + Type: Snowflake, + Fields: []*querypb.Field{{ + Name: "id", + Type: sqltypes.Int32, + }, { + Name: "machine_id", + Type: sqltypes.Int64, + }}, + PKColumns: []int{0}, + CreateTime: 1427325875, + FileSize: 0x64, + AllocatedSize: 0x96, + SnowflakeInfo: &SnowflakeInfo{ + MachineID: 0, + Sequence: 0, + LastTimestamp: 0, + LastVal: 0, + }, + }, "msg": { Name: sqlparser.NewTableIdent("msg"), Type: Message, diff --git a/go/vt/vttablet/tabletserver/schema/load_table.go b/go/vt/vttablet/tabletserver/schema/load_table.go index 3d32dbd9075..507893d64c6 100644 --- a/go/vt/vttablet/tabletserver/schema/load_table.go +++ b/go/vt/vttablet/tabletserver/schema/load_table.go @@ -34,10 +34,15 @@ func LoadTable(conn *connpool.DBConn, tableName string, comment string) (*Table, if err := fetchColumns(ta, conn, sqlTableName); err != nil { return nil, err } + fmt.Println("fff comment", comment, "tableName", tableName) switch { case strings.Contains(comment, "vitess_sequence"): ta.Type = Sequence ta.SequenceInfo = &SequenceInfo{} + case strings.Contains(comment, "vitess_snowflake"): + ta.Type = Snowflake + ta.SnowflakeInfo = &SnowflakeInfo{} + fmt.Println("loaded snowflake table: ", tableName) case strings.Contains(comment, "vitess_message"): if err := loadMessageInfo(ta, comment); err != nil { return nil, err diff --git a/go/vt/vttablet/tabletserver/schema/load_table_test.go b/go/vt/vttablet/tabletserver/schema/load_table_test.go index 8d7e784ff0b..5d1fa72996a 100644 --- a/go/vt/vttablet/tabletserver/schema/load_table_test.go +++ b/go/vt/vttablet/tabletserver/schema/load_table_test.go @@ -84,6 +84,33 @@ func TestLoadTableSequence(t *testing.T) { } } +func TestLoadTableSnowflake(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + for query, result := range getSnowflakeTableQueries() { + db.AddQuery(query, result) + } + table, err := newTestLoadTable("USER_TABLE", "vitess_snowflake", db) + if err != nil { + t.Fatal(err) + } + want := &Table{ + Name: sqlparser.NewTableIdent("test_table"), + Type: Snowflake, + SnowflakeInfo: &SnowflakeInfo{ + MachineID: 0, + Sequence: 0, + LastTimestamp: 0, + LastVal: 0, + }, + } + table.Fields = nil + table.PKColumns = nil + if !reflect.DeepEqual(table, want) { + t.Errorf("Table:\n%#v, want\n%#v", table.SnowflakeInfo, want.SnowflakeInfo) + } +} + func TestLoadTableMessage(t *testing.T) { db := fakesqldb.New(t) defer db.Close() @@ -193,6 +220,20 @@ func getTestLoadTableQueries() map[string]*sqltypes.Result { } } +func getSnowflakeTableQueries() map[string]*sqltypes.Result { + return map[string]*sqltypes.Result{ + "select * from test_table where 1 != 1": { + Fields: []*querypb.Field{{ + Name: "id", + Type: sqltypes.Int64, + }, { + Name: "machine_id", + Type: sqltypes.Int64, + }}, + }, + } +} + func getMessageTableQueries() map[string]*sqltypes.Result { return map[string]*sqltypes.Result{ "select * from test_table where 1 != 1": { diff --git a/go/vt/vttablet/tabletserver/schema/schema.go b/go/vt/vttablet/tabletserver/schema/schema.go index a46377caa3e..a5d1a597e67 100644 --- a/go/vt/vttablet/tabletserver/schema/schema.go +++ b/go/vt/vttablet/tabletserver/schema/schema.go @@ -17,6 +17,7 @@ limitations under the License. package schema import ( + "fmt" "sync" "time" @@ -29,6 +30,7 @@ import ( const ( NoType = iota Sequence + Snowflake Message ) @@ -37,6 +39,7 @@ const ( var TypeNames = []string{ "none", "sequence", + "snowflake", "message", } @@ -50,6 +53,9 @@ type Table struct { // SequenceInfo contains info for sequence tables. SequenceInfo *SequenceInfo + // SnowflakeInfo contains info for snowflake tables. + SnowflakeInfo *SnowflakeInfo + // MessageInfo contains info for message tables. MessageInfo *MessageInfo @@ -69,6 +75,101 @@ type SequenceInfo struct { LastVal int64 } +// These constants are the bit lengths of snowflake ID parts. +const ( + TimestampLength uint8 = 41 + MachineIDLength uint8 = 10 + SequenceLength uint8 = 12 + MaxSequence int64 = 1< currentTimestamp { + fmt.Println("current timestamp is less than last timestamp, so we are overflowing again") + currentTimestamp = s.LastTimestamp + } else { + fmt.Println("Same timestamp", currentTimestamp) + } + // calculate first id values + firstInc := s.Sequence + 1 + firstTimestamp = currentTimestamp + firstInc/MaxSequence // add overflow to timestamp as ms + firstSequence = firstInc % MaxSequence // set first sequence + // calculate last id values + lastInc := s.Sequence + inc + s.LastTimestamp = currentTimestamp + lastInc/MaxSequence // add overflow to timestamp as ms + s.Sequence = lastInc % MaxSequence // set last sequence + } + fmt.Println("firstSequence", firstSequence, "firstTimestamp", firstTimestamp) + fmt.Println("lastSequence", s.Sequence, "lastTimestamp", s.LastTimestamp) + + firstDF := elapsedTime(firstTimestamp, SnowflakeStartTime) + firstId := (uint64(firstDF) << uint64(timestampMoveLength)) | (uint64(s.MachineID) << uint64(machineIDMoveLength)) | uint64(firstSequence) + return int64(firstId), nil +} + +// SetMachineID specify the machine ID. It will panic when machined > max limit for 2^10-1. +// This function is thread-unsafe, recommended you call him in the main function. +func (s *SnowflakeInfo) SetMachineID(m int64) error { + if m > MaxMachineID { + return fmt.Errorf("the machineID cannot be greater than 1023: %d", m) + } + s.MachineID = m + return nil +} + +// // ParseID parse snowflake it to SID struct. +// func ParseSnowflakeID(id uint64) SnowflakeID { +// t := id >> uint64(SequenceLength+MachineIDLength) +// sequence := id & uint64(MaxSequence) +// mID := (id & (uint64(MaxMachineID) << SequenceLength)) >> SequenceLength + +// return SnowflakeID{ +// ID: id, +// Sequence: sequence, +// MachineID: mID, +// Timestamp: t, +// } +// } + // MessageInfo contains info specific to message tables. type MessageInfo struct { // Fields stores the field info to be diff --git a/go/vt/vttablet/tabletserver/schema/schema_test.go b/go/vt/vttablet/tabletserver/schema/schema_test.go new file mode 100644 index 00000000000..3a7d48178e1 --- /dev/null +++ b/go/vt/vttablet/tabletserver/schema/schema_test.go @@ -0,0 +1,55 @@ +package schema + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func compareSnowflake(t *testing.T, id, wantTimestamp int64, wantSequence int64, wantMachineID int64) { + gotTimestamp := (id >> int64(SequenceLength+MachineIDLength)) + SnowflakeStartTime.UTC().UnixNano()/1e6 + gotSequence := id & int64(MaxSequence) + gotMachineID := (id & (int64(MaxMachineID) << SequenceLength)) >> SequenceLength + fmt.Println("got ", gotTimestamp, gotSequence, gotMachineID) + assert.Equal(t, wantSequence, gotSequence) + // assert.Equal(t, wantMachineID, gotMachineID) + // this is a flaky test + assert.Equal(t, wantTimestamp, gotTimestamp) +} + +func TestNextNID(t *testing.T) { + snow := &SnowflakeInfo{} + snow.SetMachineID(1) + ts := int64(1732711077200) + + gotId, err := snow.NextNID(1, ts) + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + // assert.Equal(t, gotId, snow.LastVal) + compareSnowflake(t, gotId, ts, 0, 1) + + // test multiple values within same ms (flaky) + for i := 1; i <= 4; i++ { + gotId, err := snow.NextNID(1, ts) + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + compareSnowflake(t, gotId, ts, int64(i), 1) + } + + // test ms overflow by 1 with high inc number + gotId, err = snow.NextNID(5000, ts) + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + compareSnowflake(t, gotId, ts, 5, 1) + gotId, err = snow.NextNID(1, ts) + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + compareSnowflake(t, gotId, ts+1, 910, 1) + + assert.Equal(t, 1, 2) +} diff --git a/go/vt/vttablet/tabletserver/schema/schematest/schematest.go b/go/vt/vttablet/tabletserver/schema/schematest/schematest.go index 236ed3d024f..6cc40e0eaa9 100644 --- a/go/vt/vttablet/tabletserver/schema/schematest/schematest.go +++ b/go/vt/vttablet/tabletserver/schema/schematest/schematest.go @@ -68,6 +68,7 @@ func Queries() map[string]*sqltypes.Result { mysql.ShowPrimaryRow("test_table_02", "pk"), mysql.ShowPrimaryRow("test_table_03", "pk"), mysql.ShowPrimaryRow("seq", "id"), + mysql.ShowPrimaryRow("snow", "id"), mysql.ShowPrimaryRow("msg", "id"), }, }, @@ -104,6 +105,15 @@ func Queries() map[string]*sqltypes.Result { Type: sqltypes.Int64, }}, }, + "select * from snow where 1 != 1": { + Fields: []*querypb.Field{{ + Name: "id", + Type: sqltypes.Int32, + }, { + Name: "machine_id", + Type: sqltypes.Int64, + }}, + }, "select * from msg where 1 != 1": { Fields: []*querypb.Field{{ Name: "id",