From 762d902d1ddc3e89abf0d9abd83111e413649e72 Mon Sep 17 00:00:00 2001 From: Vetcher Date: Mon, 12 Feb 2018 12:02:14 +0300 Subject: [PATCH] Fix #3 (#4) * feat(all): add options mechanism, fix #3 * fix(all): fix errors for correct work, add fields to lazy-update params * fix(callbacks): fix getLastRecord method --- callbacks.go | 77 ++++++++++++++++++++++++++----------------- loggable.go | 2 +- options.go | 13 ++++++++ util.go | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 154 insertions(+), 31 deletions(-) create mode 100644 options.go create mode 100644 util.go diff --git a/callbacks.go b/callbacks.go index 52e7dd3..b1b0f7c 100644 --- a/callbacks.go +++ b/callbacks.go @@ -18,17 +18,24 @@ type LoggablePlugin interface { GetRecords(objectId string) ([]*ChangeLog, error) } -type loggablePlugin struct { - db *gorm.DB - mu sync.Mutex +type Option func(options *options) + +type plugin struct { + db *gorm.DB + mu sync.Mutex + opts options } -func Register(db *gorm.DB) (LoggablePlugin, error) { +func Register(db *gorm.DB, opts ...Option) (LoggablePlugin, error) { err := db.AutoMigrate(&ChangeLog{}).Error if err != nil { return nil, err } - r := &loggablePlugin{db: db} + o := options{} + for _, option := range opts { + option(&o) + } + r := &plugin{db: db, opts: o} callback := db.Callback() callback.Create().After("gorm:after_create").Register("loggable:create", r.addCreated) callback.Update().After("gorm:after_update").Register("loggable:update", r.addUpdated) @@ -36,17 +43,18 @@ func Register(db *gorm.DB) (LoggablePlugin, error) { return r, nil } -func (r *loggablePlugin) GetRecords(objectId string) ([]*ChangeLog, error) { +func (r *plugin) GetRecords(objectId string) ([]*ChangeLog, error) { var changes []*ChangeLog - err := r.db.Find(&changes).Where("object_id = ?", objectId).Error - if err != nil { - return nil, err - } - return changes, nil + return changes, r.db.Where("object_id = ?", objectId).Find(&changes).Error +} + +func (r *plugin) getLastRecord(objectId string) (*ChangeLog, error) { + var change ChangeLog + return &change, r.db.Where("object_id = ?", objectId).Order("created_at DESC").Limit(1).Find(&change).Error } // Deprecated: Use SetUserAndWhere instead. -func (r *loggablePlugin) SetUser(user string) *gorm.DB { +func (r *plugin) SetUser(user string) *gorm.DB { r.mu.Lock() db := r.db.Set("loggable:user", user) r.mu.Unlock() @@ -54,20 +62,20 @@ func (r *loggablePlugin) SetUser(user string) *gorm.DB { } // Deprecated: Use SetUserAndWhere instead. -func (r *loggablePlugin) SetWhere(where string) *gorm.DB { +func (r *plugin) SetWhere(where string) *gorm.DB { r.mu.Lock() db := r.db.Set("loggable:where", where) r.mu.Unlock() return db } -func (r *loggablePlugin) SetUserAndWhere(user, where string) *gorm.DB { +func (r *plugin) SetUserAndWhere(user, where string) *gorm.DB { r.mu.Lock() defer r.mu.Unlock() return r.db.Set("loggable:user", user).Set("loggable:where", where) } -func (r *loggablePlugin) addRecord(scope *gorm.Scope, action string) error { +func (r *plugin) addRecord(scope *gorm.Scope, action string) error { var jsonObject JSONB j, err := json.Marshal(scope.Value) if err != nil { @@ -77,14 +85,7 @@ func (r *loggablePlugin) addRecord(scope *gorm.Scope, action string) error { if err != nil { return err } - user, ok := scope.DB().Get("loggable:user") - if !ok { - user = "" - } - where, ok := scope.DB().Get("loggable:where") - if !ok { - where = "" - } + user, where := getUserAndWhere(scope) cl := ChangeLog{ ID: uuid.NewV4().String(), @@ -95,11 +96,19 @@ func (r *loggablePlugin) addRecord(scope *gorm.Scope, action string) error { ObjectType: scope.GetModelStruct().ModelType.Name(), Object: jsonObject, } - err = scope.DB().Create(&cl).Error - if err != nil { - return err + return scope.DB().Create(&cl).Error +} + +func getUserAndWhere(scope *gorm.Scope) (interface{}, interface{}) { + user, ok := scope.DB().Get("loggable:user") + if !ok { + user = "" + } + where, ok := scope.DB().Get("loggable:where") + if !ok { + where = "" } - return nil + return user, where } func isLoggable(scope *gorm.Scope) (isLoggable bool) { @@ -110,19 +119,27 @@ func isLoggable(scope *gorm.Scope) (isLoggable bool) { return } -func (r *loggablePlugin) addCreated(scope *gorm.Scope) { +func (r *plugin) addCreated(scope *gorm.Scope) { if isLoggable(scope) { r.addRecord(scope, "create") } } -func (r *loggablePlugin) addUpdated(scope *gorm.Scope) { +func (r *plugin) addUpdated(scope *gorm.Scope) { if isLoggable(scope) { + if r.opts.lazyUpdate { + record, err := r.getLastRecord(scope.PrimaryKeyValue().(string)) + if err == nil { + if isEqual(record.Object, scope.Value, r.opts.lazyUpdateFields...) { + return + } + } + } r.addRecord(scope, "update") } } -func (r *loggablePlugin) addDeleted(scope *gorm.Scope) { +func (r *plugin) addDeleted(scope *gorm.Scope) { if isLoggable(scope) { r.addRecord(scope, "delete") } diff --git a/loggable.go b/loggable.go index 33d7a3d..3fdf841 100644 --- a/loggable.go +++ b/loggable.go @@ -45,7 +45,7 @@ func (j *JSONB) Scan(value interface{}) error { } s, ok := value.([]byte) if !ok { - errors.New("Scan source was not string") + return errors.New("Scan source was not string") } *j = append((*j)[0:0], s...) return nil diff --git a/options.go b/options.go new file mode 100644 index 0000000..b074bdf --- /dev/null +++ b/options.go @@ -0,0 +1,13 @@ +package loggable + +type options struct { + lazyUpdate bool + lazyUpdateFields []string +} + +func LazyUpdateOption(fields ...string) func(options *options) { + return func(options *options) { + options.lazyUpdate = true + options.lazyUpdateFields = fields + } +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..960ac2b --- /dev/null +++ b/util.go @@ -0,0 +1,93 @@ +package loggable + +import ( + "encoding/json" + "reflect" + "strings" + "unicode" +) + +func isEqual(item1, item2 interface{}, except ...string) bool { + except = StringMap(except, ToSnakeCase) + m1, m2 := somethingToMapStringInterface(item1), somethingToMapStringInterface(item2) + if len(m1) != len(m2) { + return false + } + for k, v := range m1 { + if isInStringSlice(ToSnakeCase(k), except) { + continue + } + v2, ok := m2[k] + if !ok || !reflect.DeepEqual(v, v2) { + return false + } + } + return true +} + +func somethingToMapStringInterface(item interface{}) map[string]interface{} { + if item == nil { + return nil + } + switch raw := item.(type) { + case JSONB: + return somethingToMapStringInterface([]byte(raw)) + case string: + return somethingToMapStringInterface([]byte(raw)) + case []byte: + var m map[string]interface{} + err := json.Unmarshal(raw, &m) + if err != nil { + return nil + } + return m + default: + data, err := json.Marshal(item) + if err != nil { + return nil + } + return somethingToMapStringInterface(data) + } + return nil +} + +func isInStringSlice(what string, where []string) bool { + for i := range where { + if what == where[i] { + return true + } + } + return false +} + +var ToSnakeCase = toSomeCase("_") + +func toSomeCase(sep string) func(string) string { + return func(s string) string { + for i := range s { + if unicode.IsUpper(rune(s[i])) { + if i != 0 { + s = strings.Join([]string{s[:i], ToLowerFirst(s[i:])}, sep) + } else { + s = ToLowerFirst(s) + } + } + } + return s + } +} + +func ToLowerFirst(s string) string { + if len(s) == 0 { + return "" + } + return strings.ToLower(string(s[0])) + s[1:] +} + +func StringMap(strs []string, fn func(string) string) []string { + res := make([]string, len(strs)) + for i := range strs { + res[i] = fn(strs[i]) + } + return res +}