diff --git a/integration_test.go b/integration_test.go index 22d9ec2..44008f8 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1239,6 +1239,56 @@ func TestAssociation(t *testing.T) { }, }, }, + // Test case for replacement + ModifyReplace{ + "cn=A1", "ou=Groups", + M{ + "member": A{ + "uid=user1,ou=Users," + testServer.GetSuffix(), + }, + }, + &AssertEntry{ + expectAttrs: M{ + "member": A{ + "uid=user1,ou=Users," + testServer.GetSuffix(), + }, + }, + }, + }, + ModifyReplace{ + "cn=A1", "ou=Groups", + M{ + "member": A{ + "uid=user1,ou=Users," + testServer.GetSuffix(), + "uid=user2,ou=Users," + testServer.GetSuffix(), + }, + }, + &AssertEntry{ + expectAttrs: M{ + "member": A{ + "uid=user1,ou=Users," + testServer.GetSuffix(), + "uid=user2,ou=Users," + testServer.GetSuffix(), + }, + }, + }, + }, + ModifyReplace{ + "cn=A1", "ou=Groups", + M{ + "member": A{ + "uid=user2,ou=Users," + testServer.GetSuffix(), + "uid=user3,ou=Users," + testServer.GetSuffix(), + }, + }, + &AssertEntry{ + expectAttrs: M{ + "member": A{ + "uid=user2,ou=Users," + testServer.GetSuffix(), + "uid=user3,ou=Users," + testServer.GetSuffix(), + }, + }, + }, + }, // Test case for duplicate members Add{ "cn=A2", "ou=Groups", diff --git a/modify_entry.go b/modify_entry.go index a2ef6e6..3710831 100644 --- a/modify_entry.go +++ b/modify_entry.go @@ -13,19 +13,28 @@ type ModifyEntry struct { hasSub bool path string AddChangeLog map[string]*SchemaValue - ReplaceChangeLog map[string]*SchemaValue - DelChangeLog map[string]*SchemaValue + ReplaceChangeLog struct { + old map[string]*SchemaValue + new map[string]*SchemaValue + } + DelChangeLog map[string]*SchemaValue } func NewModifyEntry(schemaMap *SchemaMap, dn *DN, attrsOrig map[string][]string) (*ModifyEntry, error) { // TODO modifyEntry := &ModifyEntry{ - schemaMap: schemaMap, - dn: dn, - attributes: map[string]*SchemaValue{}, - AddChangeLog: map[string]*SchemaValue{}, - ReplaceChangeLog: map[string]*SchemaValue{}, - DelChangeLog: map[string]*SchemaValue{}, + schemaMap: schemaMap, + dn: dn, + attributes: map[string]*SchemaValue{}, + AddChangeLog: map[string]*SchemaValue{}, + ReplaceChangeLog: struct { + old map[string]*SchemaValue + new map[string]*SchemaValue + }{ + old: map[string]*SchemaValue{}, + new: map[string]*SchemaValue{}, + }, + DelChangeLog: map[string]*SchemaValue{}, } for k, v := range attrsOrig { @@ -178,13 +187,16 @@ func (j *ModifyEntry) Replace(attrName string, attrValue []string) error { } } + // Record old attribute into changelog + j.ReplaceChangeLog.old[sv.Name()] = j.attributes[sv.Name()] + // Apply change if err := j.replacesv(sv); err != nil { return err } - // Record changelog - j.ReplaceChangeLog[sv.Name()] = sv + // Record new attribute into changelog + j.ReplaceChangeLog.new[sv.Name()] = sv return nil } @@ -315,13 +327,19 @@ func (j *ModifyEntry) Attrs() (map[string][]interface{}, map[string][]string) { func (e *ModifyEntry) Clone() *ModifyEntry { clone := &ModifyEntry{ - schemaMap: e.schemaMap, - dn: e.dn, - attributes: map[string]*SchemaValue{}, - AddChangeLog: map[string]*SchemaValue{}, - ReplaceChangeLog: map[string]*SchemaValue{}, - DelChangeLog: map[string]*SchemaValue{}, - dbEntryID: e.dbEntryID, + schemaMap: e.schemaMap, + dn: e.dn, + attributes: map[string]*SchemaValue{}, + AddChangeLog: map[string]*SchemaValue{}, + ReplaceChangeLog: struct { + old map[string]*SchemaValue + new map[string]*SchemaValue + }{ + old: map[string]*SchemaValue{}, + new: map[string]*SchemaValue{}, + }, + DelChangeLog: map[string]*SchemaValue{}, + dbEntryID: e.dbEntryID, } for k, v := range e.attributes { clone.attributes[k] = v.Clone() diff --git a/pass_through.go b/pass_through.go index a057039..51c5770 100644 --- a/pass_through.go +++ b/pass_through.go @@ -82,7 +82,11 @@ func (c *LDAPPassThroughClient) Authenticate(domain, user, password string) (boo ) sr, err := l.Search(search) if err != nil { - return false, xerrors.Errorf("Failed to search an user for pass-through. domain:%s, uid: %s, filter: %s, err: %w", domain, user, filter, err) + if !ldap.IsErrorWithCode(err, 32) { + return false, xerrors.Errorf("Failed to search an user for pass-through. domain:%s, uid: %s, filter: %s, err: %w", domain, user, filter, err) + } + // LDAP Result Code 32 "No Such Object + return false, xerrors.Errorf("No such object. domain: %s, uid: %s", domain, user) } if len(sr.Entries) == 0 { diff --git a/repo_hybrid.go b/repo_hybrid.go index e304876..88e71ef 100644 --- a/repo_hybrid.go +++ b/repo_hybrid.go @@ -32,7 +32,6 @@ var ( // repo_read for update findEntryIDByDNWithShareLock *sqlx.NamedStmt - findEntryIDByDNWithUpdateLock *sqlx.NamedStmt findEntryByDNWithShareLock *sqlx.NamedStmt findEntryByDNWithUpdateLock *sqlx.NamedStmt findEntryWithAssociationByDNWithUpdateLock *sqlx.NamedStmt @@ -233,13 +232,6 @@ func (r *HybridRepository) Init() error { return xerrors.Errorf("Failed to initialize prepared statement: %w", err) } - findEntryIDByDNWithUpdateLock, err = db.PrepareNamed(findEntryIDByDN + ` - FOR UPDATE - `) - if err != nil { - return xerrors.Errorf("Failed to initialize prepared statement: %w", err) - } - insertContainerStmtWithUpdateLock, err = db.PrepareNamed(`INSERT INTO ldap_container (id, dn_norm, dn_orig) VALUES (:id, :dn_norm, :dn_orig) -- Lock the record without change if already exists @@ -947,14 +939,14 @@ func (r HybridRepository) DeleteByDN(ctx context.Context, dn *DN) error { return err } - // Step 1: fetch the target entry and parent container with lock for update + // Step 1: fetch the target entry and parent container with lock for share fetchedEntry := struct { ID int64 `db:"id"` ParentID int64 `db:"parent_id"` HasSub bool `db:"has_sub"` }{} - err = r.get(tx, findEntryIDByDNWithUpdateLock, &fetchedEntry, map[string]interface{}{ + err = r.get(tx, findEntryIDByDNWithShareLock, &fetchedEntry, map[string]interface{}{ "rdn_norm": dn.RDNNormStr(), "parent_dn_norm": dn.ParentDN().DNNormStrWithoutSuffix(r.server.Suffix), }) @@ -2342,44 +2334,143 @@ func (r *HybridRepository) modifyEntryToDBEntry(ctx context.Context, tx *sqlx.Tx // Convert the value of member, uniqueMamber and memberOf attributes, DN => int64 addAssociation := map[string][]int64{} + delAssociation := map[string][]int64{} + + // Replace + if v, ok := entry.ReplaceChangeLog.new["member"]; ok { + add, del := diff(entry.ReplaceChangeLog.old["member"].NormStr(), v.NormStr()) + addsv, _ := NewSchemaValue(r.server.schemaMap, "member", add) + delsv, _ := NewSchemaValue(r.server.schemaMap, "member", del) + + addMember, err := r.dnArrayToIDArray(tx, map[string][]interface{}{"member": addsv.norm}, "member") + if err != nil { + return nil, nil, nil, err + } + delMember, err := r.dnArrayToIDArray(tx, map[string][]interface{}{"member": delsv.norm}, "member") + if err != nil { + return nil, nil, nil, err + } + if addIDs, ok := addAssociation["member"]; !ok { + addAssociation["member"] = addMember + } else { + addIDs = append(addIDs, addMember...) + } + if delIDs, ok := delAssociation["member"]; !ok { + delAssociation["member"] = delMember + } else { + delIDs = append(delIDs, delMember...) + } + } + if v, ok := entry.ReplaceChangeLog.new["uniqueMember"]; ok { + add, del := diff(entry.ReplaceChangeLog.old["uniqueMember"].NormStr(), v.NormStr()) + addsv, _ := NewSchemaValue(r.server.schemaMap, "uniqueMember", add) + delsv, _ := NewSchemaValue(r.server.schemaMap, "uniqueMember", del) + + addMember, err := r.dnArrayToIDArray(tx, map[string][]interface{}{"uniqueMember": addsv.norm}, "uniqueMember") + if err != nil { + return nil, nil, nil, err + } + delMember, err := r.dnArrayToIDArray(tx, map[string][]interface{}{"uniqueMember": delsv.norm}, "uniqueMember") + if err != nil { + return nil, nil, nil, err + } + if addIDs, ok := addAssociation["uniqueMember"]; !ok { + addAssociation["uniqueMember"] = addMember + } else { + addIDs = append(addIDs, addMember...) + } + if delIDs, ok := delAssociation["uniqueMember"]; !ok { + delAssociation["uniqueMember"] = delMember + } else { + delIDs = append(delIDs, delMember...) + } + } + if v, ok := entry.ReplaceChangeLog.new["memberOf"]; ok { + add, del := diff(entry.ReplaceChangeLog.old["memberOf"].NormStr(), v.NormStr()) + addsv, _ := NewSchemaValue(r.server.schemaMap, "memberOf", add) + delsv, _ := NewSchemaValue(r.server.schemaMap, "memberOf", del) + addMember, err := r.dnArrayToIDArray(tx, map[string][]interface{}{"memberOf": addsv.norm}, "memberOf") + if err != nil { + return nil, nil, nil, err + } + delMember, err := r.dnArrayToIDArray(tx, map[string][]interface{}{"memberOf": delsv.norm}, "memberOf") + if err != nil { + return nil, nil, nil, err + } + if addIDs, ok := addAssociation["memberOf"]; !ok { + addAssociation["memberOf"] = addMember + } else { + addIDs = append(addIDs, addMember...) + } + if delIDs, ok := delAssociation["memberOf"]; !ok { + delAssociation["memberOf"] = delMember + } else { + delIDs = append(delIDs, delMember...) + } + } + + // Add member, err := r.schemaValueToIDArray(tx, entry.AddChangeLog, "member") if err != nil { return nil, nil, nil, err } - addAssociation["member"] = member + if addIDs, ok := addAssociation["member"]; !ok { + addAssociation["member"] = member + } else { + addIDs = append(addIDs, member...) + } uniqueMember, err := r.schemaValueToIDArray(tx, entry.AddChangeLog, "uniqueMember") if err != nil { return nil, nil, nil, err } - addAssociation["uniqueMember"] = uniqueMember + if addIDs, ok := addAssociation["uniqueMember"]; !ok { + addAssociation["uniqueMember"] = uniqueMember + } else { + addIDs = append(addIDs, uniqueMember...) + } memberOf, err := r.schemaValueToIDArray(tx, entry.AddChangeLog, "memberOf") if err != nil { return nil, nil, nil, err } - addAssociation["memberOf"] = memberOf - - delAssociation := map[string][]int64{} + if addIDs, ok := addAssociation["memberOf"]; !ok { + addAssociation["memberOf"] = memberOf + } else { + addIDs = append(addIDs, memberOf...) + } + // Delete member, err = r.schemaValueToIDArray(tx, entry.DelChangeLog, "member") if err != nil { return nil, nil, nil, err } - delAssociation["member"] = member + if delIDs, ok := delAssociation["member"]; !ok { + delAssociation["member"] = member + } else { + delIDs = append(delIDs, member...) + } uniqueMember, err = r.schemaValueToIDArray(tx, entry.DelChangeLog, "uniqueMember") if err != nil { return nil, nil, nil, err } - delAssociation["uniqueMember"] = uniqueMember + if delIDs, ok := delAssociation["uniqueMember"]; !ok { + delAssociation["uniqueMember"] = uniqueMember + } else { + delIDs = append(delIDs, uniqueMember...) + } memberOf, err = r.schemaValueToIDArray(tx, entry.DelChangeLog, "memberOf") if err != nil { return nil, nil, nil, err } - delAssociation["memberOf"] = memberOf + if delIDs, ok := delAssociation["memberOf"]; !ok { + delAssociation["memberOf"] = memberOf + } else { + delIDs = append(delIDs, memberOf...) + } // Remove attributes to reduce attrs_orig column size r.dropAssociationAttrs(norm, orig) @@ -2717,6 +2808,9 @@ func (r *HybridRepository) exec(tx *sqlx.Tx, stmt *sqlx.NamedStmt, params map[st debugSQL(r.server.config.LogLevel, stmt.QueryString, params) result, err := tx.NamedStmt(stmt).Exec(params) errorSQL(err, stmt.QueryString, params) + if isForeignKeyError(err) { + return nil, NewRetryError(err) + } return result, err } diff --git a/util.go b/util.go index 1691451..40a7525 100644 --- a/util.go +++ b/util.go @@ -652,3 +652,30 @@ func mergeIndex(m1, m2 map[string]struct{}) map[string]struct{} { } return m } + +func diff(a, b []string) ([]string, []string) { + ma := make(map[string]struct{}, len(a)) + for _, x := range a { + ma[x] = struct{}{} + } + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + + add := []string{} + del := []string{} + + for k, _ := range mb { + if _, ok := ma[k]; !ok { + add = append(add, k) + } + } + for k, _ := range ma { + if _, ok := mb[k]; !ok { + del = append(del, k) + } + } + + return add, del +} diff --git a/util_test.go b/util_test.go index 4d67e80..5235128 100644 --- a/util_test.go +++ b/util_test.go @@ -3,6 +3,7 @@ package main import ( + "reflect" "testing" ) @@ -148,3 +149,55 @@ func TestSortObjectClassesAndVerifyChain(t *testing.T) { } } } + +func TestDiff(t *testing.T) { + testcases := []struct { + a []string + b []string + AddExpected []string + DelExpected []string + }{ + { + []string{"a"}, + []string{"b"}, + []string{"b"}, + []string{"a"}, + }, + { + []string{"a", "b"}, + []string{"c", "d"}, + []string{"c", "d"}, + []string{"a", "b"}, + }, + { + []string{"a", "b"}, + []string{"b", "c"}, + []string{"c"}, + []string{"a"}, + }, + { + []string{}, + []string{"a"}, + []string{"a"}, + []string{}, + }, + { + []string{"a"}, + []string{}, + []string{}, + []string{"a"}, + }, + } + + for i, tc := range testcases { + add, del := diff(tc.a, tc.b) + if !reflect.DeepEqual(add, tc.AddExpected) { + t.Errorf("Unexpected error on %d:\nadd '%v' expected, got '%v'\n", i, tc.AddExpected, add) + continue + } + if !reflect.DeepEqual(del, tc.DelExpected) { + t.Errorf("Unexpected error on %d:\ndel '%v' expected, got '%v'\n", i, tc.DelExpected, del) + continue + } + } +}