diff --git a/handler_search_generic.go b/handler_search_generic.go index e7cf1a4..787a1cb 100644 --- a/handler_search_generic.go +++ b/handler_search_generic.go @@ -81,19 +81,18 @@ func handleSearch(s *Server, w ldap.ResponseWriter, m *ldap.Message) { } // Phase 4: execute SQL and return entries - // TODO configurable default pageSize - var pageSize int32 = 500 + var pageSize int32 = s.config.DefaultPageSize if pageControl != nil { pageSize = pageControl.Size() } sessionMap := getPageSession(m) - var offset int32 + var cusor int64 if pageControl != nil { reqCookie := pageControl.Cookie() if reqCookie != "" { var ok bool - if offset, ok = sessionMap[reqCookie]; ok { + if cusor, ok = sessionMap[reqCookie]; ok { log.Printf("debug: paged results cookie is ok") // clear cookie @@ -114,13 +113,13 @@ func handleSearch(s *Server, w ldap.ResponseWriter, m *ldap.Message) { Scope: scope, Filter: r.Filter(), PageSize: pageSize, - Offset: offset, + Cursor: &cusor, RequestedAssocation: getRequestedMemberAttrs(r), IsMemberOfRequested: isMemberOfRequested(r), IsHasSubordinatesRequested: isHasSubOrdinatesRequested(r), } - maxCount, limittedCount, err := s.Repo().Search(ctx, baseDN, option, func(searchEntry *SearchEntry) error { + count, nextId, err := s.Repo().Search(ctx, baseDN, option, func(searchEntry *SearchEntry) error { responseEntry(s, w, m, r, searchEntry) return nil }) @@ -129,7 +128,7 @@ func handleSearch(s *Server, w ldap.ResponseWriter, m *ldap.Message) { return } - if maxCount == 0 { + if count == 0 { log.Printf("debug: Not found") // Must return success if no hit @@ -140,19 +139,19 @@ func handleSearch(s *Server, w ldap.ResponseWriter, m *ldap.Message) { var nextCookie string - if limittedCount+offset < maxCount { + if count == option.PageSize+1 { uuid, _ := uuid.NewRandom() nextCookie = uuid.String() sessionMap := getPageSession(m) - sessionMap[nextCookie] = offset + pageSize + sessionMap[nextCookie] = nextId } res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess) if pageControl != nil { // https://www.ietf.org/rfc/rfc2696.txt - control := message.NewSimplePagedResultsControl(maxCount, false, nextCookie) + control := message.NewSimplePagedResultsControl(0, false, nextCookie) var controls message.Controls = []message.Control{control} w.WriteControls(res, &controls) diff --git a/integration_test.go b/integration_test.go index d24928e..2142c37 100644 --- a/integration_test.go +++ b/integration_test.go @@ -5,6 +5,7 @@ package main import ( "os" "testing" + "time" "github.com/go-ldap/ldap/v3" ) @@ -86,6 +87,94 @@ func TestParallel(t *testing.T) { runTestCases(t, tcs) } +func TestParallelByNonRootUsers(t *testing.T) { + type A []string + type M map[string][]string + + tcs := []Command{ + Conn{}, + Bind{"cn=Manager", "secret", &AssertResponse{}}, + AddDC("com").SetAssert(&AssertResponse{53}), + AddDC("example", "dc=com"), + AddOU("Users"), + Add{ + "uid=op1", "ou=Users", + M{ + "objectClass": A{"inetOrgPerson"}, + "cn": A{"op1"}, + "sn": A{"op1"}, + "userPassword": A{SSHA("password1")}, + }, + &AssertEntry{}, + }, + Add{ + "uid=op2", "ou=Users", + M{ + "objectClass": A{"inetOrgPerson"}, + "cn": A{"op2"}, + "sn": A{"op2"}, + "userPassword": A{SSHA256("password2")}, + }, + &AssertEntry{}, + }, + Parallel{ + 100, + [][]Command{ + { + Conn{}, + Bind{"uid=op1,ou=Users", "password1", &AssertResponse{}}, + Add{ + "uid=user1", "ou=Users", + M{ + "objectClass": A{"inetOrgPerson"}, + "cn": A{"user1"}, + "sn": A{"user1"}, + }, + &AssertEntry{}, + }, + ModifyAdd{ + "uid=user1", "ou=Users", + M{ + "givenName": A{"user1"}, + }, + &AssertEntry{}, + }, + Delete{ + "uid=user1", "ou=Users", + &AssertNoEntry{}, + }, + }, + { + Conn{}, + Bind{"uid=op2,ou=Users", "password2", &AssertResponse{}}, + Add{ + "uid=user2", "ou=Users", + M{ + "objectClass": A{"inetOrgPerson"}, + "cn": A{"user2"}, + "sn": A{"user2"}, + }, + &AssertEntry{}, + }, + ModifyAdd{ + "uid=user2", "ou=Users", + M{ + "givenName": A{"user2"}, + }, + &AssertEntry{}, + }, + Delete{ + "uid=user2", "ou=Users", + &AssertNoEntry{}, + }, + }, + }, + }, + } + + runTestCases(t, tcs) +} + func TestDeadlock(t *testing.T) { type A []string type M map[string][]string @@ -323,6 +412,70 @@ func TestBind(t *testing.T) { runTestCases(t, tcs) } +func TestBindWithAccountLock(t *testing.T) { + type A []string + type M map[string][]string + + testServer.config.MigrationEnabled = true + testServer.LoadSchema() + + tcs := []Command{ + Conn{}, + Bind{"cn=Manager", "secret", &AssertResponse{}}, + AddDC("example", "dc=com"), + AddOU("Users"), + AddOU("Policies"), + Add{ + "uid=op1", "ou=Users", + M{ + "objectClass": A{"inetOrgPerson"}, + "cn": A{"op1"}, + "sn": A{"op1"}, + "userPassword": A{SSHA("password1")}, + }, + &AssertEntry{}, + }, + Add{ + "cn=standard-policy", "ou=Policies", + M{ + "objectClass": A{"top", "device", "pwdPolicy"}, + "pwdAttribute": A{"userPassword"}, + "pwdLockout": A{"TRUE"}, + "pwdMaxFailure": A{"2"}, + "pwdlockoutDuration": A{"10"}, + }, + nil, + }, + Bind{"uid=op1,ou=users", "password1", &AssertResponse{}}, + Bind{ + "uid=op1,ou=Users", + "invalid", + &AssertResponse{49}, + }, + Bind{"uid=op1,ou=users", "password1", &AssertResponse{}}, + Bind{ + "uid=op1,ou=Users", + "invalid", + &AssertResponse{49}, + }, + Bind{ + "uid=op1,ou=Users", + "invalid", + &AssertResponse{49}, + }, + // Account Locked + Bind{"uid=op1,ou=users", "password1", &AssertResponse{49}}, + // still locked + Wait{time.Second * 5}, + Bind{"uid=op1,ou=users", "password1", &AssertResponse{49}}, + // Unlocked + Wait{time.Second * 5}, + Bind{"uid=op1,ou=users", "password1", &AssertResponse{}}, + } + + runTestCases(t, tcs) +} + func TestSearchSpecialCharacters(t *testing.T) { type A []string type M map[string][]string @@ -489,6 +642,117 @@ func TestSearch(t *testing.T) { runTestCases(t, tcs) } +func TestSearchWithPaging(t *testing.T) { + type A []string + type M map[string][]string + + tcs := []Command{ + Conn{}, + Bind{"cn=Manager", "secret", &AssertResponse{}}, + AddDC("example", "dc=com"), + AddOU("Users"), + Add{ + "uid=user1", "ou=Users", + M{ + "objectClass": A{"inetOrgPerson"}, + "cn": A{"user1"}, + "sn": A{"user1"}, + "userPassword": A{SSHA("password1")}, + "employeeNumber": A{"emp1"}, + }, + &AssertEntry{}, + }, + Add{ + "uid=user2", "ou=Users", + M{ + "objectClass": A{"inetOrgPerson"}, + "cn": A{"user2"}, + "sn": A{"user2"}, + "userPassword": A{SSHA("password2")}, + "employeeNumber": A{"emp2"}, + }, + &AssertEntry{}, + }, + SearchWithPaging{ + Search: Search{ + "ou=Users," + testServer.GetSuffix(), + "uid=*", + ldap.ScopeWholeSubtree, + A{"*"}, + &AssertEntries{ + ExpectEntry{ + "uid=user1", + "ou=Users", + M{ + "sn": A{"user1"}, + }, + }, + ExpectEntry{ + "uid=user2", + "ou=Users", + M{ + "sn": A{"user2"}, + }, + }, + }, + }, + limit: 1, + }, + SearchWithPaging{ + Search: Search{ + "ou=Users," + testServer.GetSuffix(), + "uid=*", + ldap.ScopeWholeSubtree, + A{"*"}, + &AssertEntries{ + ExpectEntry{ + "uid=user1", + "ou=Users", + M{ + "sn": A{"user1"}, + }, + }, + ExpectEntry{ + "uid=user2", + "ou=Users", + M{ + "sn": A{"user2"}, + }, + }, + }, + }, + limit: 2, + }, + SearchWithPaging{ + Search: Search{ + "ou=Users," + testServer.GetSuffix(), + "uid=*", + ldap.ScopeWholeSubtree, + A{"*"}, + &AssertEntries{ + ExpectEntry{ + "uid=user1", + "ou=Users", + M{ + "sn": A{"user1"}, + }, + }, + ExpectEntry{ + "uid=user2", + "ou=Users", + M{ + "sn": A{"user2"}, + }, + }, + }, + }, + limit: 3, + }, + } + + runTestCases(t, tcs) +} + func TestScopeSearch(t *testing.T) { type A []string type M map[string][]string @@ -544,6 +808,21 @@ func TestScopeSearch(t *testing.T) { &AssertEntry{}, }, // base for container + Search{ + testServer.GetSuffix(), + "objectclass=*", + ldap.ScopeBaseObject, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "", + testServer.GetSuffix(), + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + }, + }, Search{ "ou=Users," + testServer.GetSuffix(), "objectclass=*", @@ -559,7 +838,79 @@ func TestScopeSearch(t *testing.T) { }, }, }, + Search{ + "ou=SubUsers,ou=Users," + testServer.GetSuffix(), + "objectclass=*", + ldap.ScopeBaseObject, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "ou=SubUsers,ou=Users", + "", + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + }, + }, // sub for container + Search{ + testServer.GetSuffix(), + "objectclass=*", + ldap.ScopeWholeSubtree, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "", + testServer.GetSuffix(), + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + ExpectEntry{ + "ou=Users", + "", + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + ExpectEntry{ + "ou=SubUsers", + "ou=Users", + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + ExpectEntry{ + "uid=user1", + "ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user2", + "ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user3", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user4", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + }, + }, Search{ "ou=Users," + testServer.GetSuffix(), "objectclass=*", @@ -610,7 +961,51 @@ func TestScopeSearch(t *testing.T) { }, }, }, + Search{ + "ou=SubUsers,ou=Users," + testServer.GetSuffix(), + "objectclass=*", + ldap.ScopeWholeSubtree, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "ou=SubUsers", + "ou=Users", + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + ExpectEntry{ + "uid=user3", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user4", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + }, + }, // one for container + Search{ + testServer.GetSuffix(), + "objectclass=*", + ldap.ScopeSingleLevel, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "ou=Users", + "", + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + }, + }, Search{ "ou=Users," + testServer.GetSuffix(), "objectclass=*", @@ -640,7 +1035,79 @@ func TestScopeSearch(t *testing.T) { }, }, }, + Search{ + "ou=SubUsers,ou=Users," + testServer.GetSuffix(), + "objectclass=*", + ldap.ScopeSingleLevel, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "uid=user3", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user4", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + }, + }, // children for container + Search{ + testServer.GetSuffix(), + "objectclass=*", + 3, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "ou=Users", + "", + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + ExpectEntry{ + "ou=SubUsers", + "ou=Users", + M{ + "hasSubordinates": A{"TRUE"}, + }, + }, + ExpectEntry{ + "uid=user1", + "ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user2", + "ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user3", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user4", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + }, + }, Search{ "ou=Users," + testServer.GetSuffix(), "objectclass=*", @@ -684,6 +1151,28 @@ func TestScopeSearch(t *testing.T) { }, }, }, + Search{ + "ou=SubUsers,ou=Users," + testServer.GetSuffix(), + "objectclass=*", + 3, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "uid=user3", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + ExpectEntry{ + "uid=user4", + "ou=SubUsers,ou=Users", + M{ + "hasSubordinates": A{"FALSE"}, + }, + }, + }, + }, // base for not container(admin virtual entry) Search{ "cn=Manager," + testServer.GetSuffix(), @@ -735,6 +1224,22 @@ func TestScopeSearch(t *testing.T) { }, }, }, + Search{ + "uid=user3,ou=SubUsers,ou=Users," + testServer.GetSuffix(), + "uid=user3", + ldap.ScopeWholeSubtree, + A{"*", "+"}, + &AssertEntries{ + ExpectEntry{ + "uid=user3", + "ou=SubUsers,ou=Users", + M{ + "sn": A{"user3"}, + "hasSubordinates": A{"FALSE"}, + }, + }, + }, + }, // one for not container Search{ "uid=user1,ou=Users," + testServer.GetSuffix(), @@ -743,6 +1248,13 @@ func TestScopeSearch(t *testing.T) { A{"*", "+"}, &AssertEntries{}, }, + Search{ + "uid=user3,ou=SubUsers,ou=Users," + testServer.GetSuffix(), + "uid=user3", + ldap.ScopeSingleLevel, + A{"*", "+"}, + &AssertEntries{}, + }, // children for not container Search{ "uid=user1,ou=Users," + testServer.GetSuffix(), @@ -751,6 +1263,42 @@ func TestScopeSearch(t *testing.T) { A{"*", "+"}, &AssertEntries{}, }, + Search{ + "uid=user3,ou=SubUsers,ou=Users," + testServer.GetSuffix(), + "uid=user3", + 3, + A{"*", "+"}, + &AssertEntries{}, + }, + // search for parent dc of the server suffix + Search{ + "dc=com", + "objectclass=*", + ldap.ScopeBaseObject, + A{"*", "+"}, + &AssertEntries{}, + }, + Search{ + "dc=com", + "objectclass=*", + ldap.ScopeSingleLevel, + A{"*", "+"}, + &AssertEntries{}, + }, + Search{ + "dc=com", + "objectclass=*", + ldap.ScopeWholeSubtree, + A{"*", "+"}, + &AssertEntries{}, + }, + Search{ + "dc=com", + "objectclass=*", + 3, + A{"*", "+"}, + &AssertEntries{}, + }, } runTestCases(t, tcs) @@ -1060,7 +1608,7 @@ func TestOperationalAttributesMigration(t *testing.T) { runTestCases(t, tcs) } -func TesPwdFailureTimeNano(t *testing.T) { +func TestPwdFailureTimeNano(t *testing.T) { type A []string type M map[string][]string diff --git a/main.go b/main.go index 4e4b13b..91bd056 100644 --- a/main.go +++ b/main.go @@ -143,6 +143,11 @@ var ( "", "DN of the default password policy entry (e.g. cn=standard-policy,ou=Policies,dc=example,dc=com)", ) + defaultPageSize = fs.Int( + "default-page-size", + 500, + "Default page size for search (default 500)", + ) ) type arrayFlags []string @@ -226,6 +231,7 @@ func main() { QueryTranslator: "default", SimpleACL: acl, DefaultPPolicyDN: *defaultPPolicyDN, + DefaultPageSize: int32(*defaultPageSize), }) go server.Start() diff --git a/repo.go b/repo.go index 0453e4f..833e133 100644 --- a/repo.go +++ b/repo.go @@ -80,7 +80,7 @@ type Repository interface { // Search handles search request by filter. // This is used for SEARCH operation. - Search(ctx context.Context, baseDN *DN, option *SearchOption, handler func(entry *SearchEntry) error) (int32, int32, error) + Search(ctx context.Context, baseDN *DN, option *SearchOption, handler func(entry *SearchEntry) error) (int32, int64, error) // Update modifies the entry by specified change data. // This is used for MOD operation. @@ -101,7 +101,7 @@ type SearchOption struct { Scope int Filter message.Filter PageSize int32 - Offset int32 + Cursor *int64 RequestedAssocation []string IsMemberOfRequested bool IsHasSubordinatesRequested bool diff --git a/repo_hybrid.go b/repo_hybrid.go index 1f2644d..4b49891 100644 --- a/repo_hybrid.go +++ b/repo_hybrid.go @@ -101,8 +101,8 @@ func (r *HybridRepository) Init() error { REFERENCES ldap_entry (id) ON DELETE RESTRICT ON UPDATE RESTRICT ); - CREATE INDEX IF NOT EXISTS idx_ldap_association_id ON ldap_association(name, id); - CREATE INDEX IF NOT EXISTS idx_ldap_association_member_id ON ldap_association(name, member_id); + CREATE INDEX IF NOT EXISTS idx_ldap_association_id ON ldap_association(id, name); + CREATE INDEX IF NOT EXISTS idx_ldap_association_member_id ON ldap_association(member_id, name); `) if err != nil { return xerrors.Errorf("Failed to initialize prepared statement: %w", err) @@ -1139,7 +1139,7 @@ func (e *HybridFetchedDBEntry) AttrsOrig() map[string][]string { return jsonMap } -func (r *HybridRepository) Search(ctx context.Context, baseDN *DN, option *SearchOption, handler func(entry *SearchEntry) error) (int32, int32, error) { +func (r *HybridRepository) Search(ctx context.Context, baseDN *DN, option *SearchOption, handler func(entry *SearchEntry) error) (int32, int64, error) { tx, err := r.beginReadonly(ctx) if err != nil { return 0, 0, nil @@ -1153,8 +1153,7 @@ func (r *HybridRepository) Search(ctx context.Context, baseDN *DN, option *Searc filterJoin := []string{} filterWhere := []string{} params := map[string]interface{}{ - "pageSize": option.PageSize, - "offset": option.Offset, + "pageSize": option.PageSize + 1, } r.collectScopeWhereSQL(baseDN, option, &scopeWhere, params) r.collectFilterWhereSQL(baseDN, option, &filterJoin, &filterWhere, params) @@ -1166,14 +1165,20 @@ func (r *HybridRepository) Search(ctx context.Context, baseDN *DN, option *Searc // r.collectAssociationSQLPlanB(option, &proj, &join, params) r.collectHasSubordinatesSQL(option, &proj, &join) + pagingFilter := "" + if option.Cursor != nil { + pagingFilter = `-- paging + AND e. id >= :cursor` + params["cursor"] = *option.Cursor + } + q := fmt.Sprintf(`WITH filtered_entry AS NOT MATERIALIZED ( SELECT e.id, e.parent_id, e.rdn_orig || ',' || dnc.dn_orig AS dn_orig, - e.attrs_orig, - count(e.id) over() AS count + e.attrs_orig FROM ldap_entry e -- DN join @@ -1185,25 +1190,27 @@ func (r *HybridRepository) Search(ctx context.Context, baseDN *DN, option *Searc AND -- ldap filter (%s) - ORDER BY e.id - LIMIT :pageSize OFFSET :offset + %s + ORDER BY e.id ASC + LIMIT :pageSize ) SELECT fe.id, fe.parent_id, fe.dn_orig, - fe.attrs_orig, - fe.count + fe.attrs_orig %s FROM filtered_entry fe %s - `, strings.Join(filterJoin, ""), scopeWhere.String(), strings.Join(filterWhere, " AND "), proj.String(), join.String()) + `, strings.Join(filterJoin, ""), scopeWhere.String(), strings.Join(filterWhere, " AND "), pagingFilter, proj.String(), join.String()) start := time.Now() rows, err := r.namedQuery(tx, q, params) end := time.Now() + log.Printf("info: Executed DB search: %d [ms], limit: %d, cursor: %d", end.Sub(start).Milliseconds(), option.PageSize, *option.Cursor) + if err != nil { if isNoResult(err) { // Need to return successful response @@ -1212,8 +1219,8 @@ FROM return 0, 0, xerrors.Errorf("Unexpected search query error. err: %w", err) } - var maxCount int32 = 0 var count int32 = 0 + var nextId int64 = 0 var dbEntry HybridFetchedDBEntry for rows.Next() { @@ -1221,9 +1228,12 @@ FROM if err != nil { return 0, 0, xerrors.Errorf("Unexpected struct scan error. err: %w", err) } - if maxCount == 0 { - maxCount = dbEntry.Count - log.Printf("info: Executed DB search: %d [ms], count: %d", end.Sub(start).Milliseconds(), maxCount) + + // Detected remaining next page + if option.PageSize == count { + nextId = dbEntry.ID + count++ + break } readEntry := r.toSearchEntry(&dbEntry) @@ -1238,7 +1248,7 @@ FROM dbEntry.Clear() } - return maxCount, count, nil + return count, nextId, nil } func (r *HybridRepository) toSearchEntry(dbEntry *HybridFetchedDBEntry) *SearchEntry { @@ -1391,31 +1401,40 @@ LEFT JOIN LATERAL ( } func (r *HybridRepository) collectScopeWhereSQL(baseDN *DN, option *SearchOption, where *strings.Builder, params map[string]interface{}) { + // Always return not found for parents of the server suffix + if baseDN.IsDC() && !baseDN.Equal(r.server.Suffix) { + where.WriteString(`FALSE`) + return + } + // Scope handling // 0: base (only base) // 1: one (only one level, not include base) // 2: sub (subtree, include base) // 3: children (subtree, not include base) - if option.Scope == 0 || option.Scope == 1 { - var col string - if option.Scope == 0 { - col = "id" - } else { - col = "parent_id" - } - where.WriteString(`e.`) - where.WriteString(col) - where.WriteString(` = (SELECT - e.id - FROM - ldap_entry e - LEFT JOIN ldap_container c ON e.parent_id = c.id - WHERE - e.rdn_norm = :rdn_norm - AND c.dn_norm = :parent_dn_norm)`) + if option.Scope == 0 { + where.WriteString(`e.rdn_norm = :rdn_norm AND dnc.dn_norm = :parent_dn_norm`) params["rdn_norm"] = baseDN.RDNNormStr() params["parent_dn_norm"] = baseDN.ParentDN().DNNormStrWithoutSuffix(r.server.Suffix) + } else if option.Scope == 1 { + if baseDN.Equal(r.server.Suffix) { + where.WriteString(`e.parent_id = (SELECT + c.id + FROM + ldap_container c + WHERE + c.id != 0 AND c.dn_norm = :parent_dn_norm)`) + params["parent_dn_norm"] = baseDN.DNNormStrWithoutSuffix(r.server.Suffix) + } else { + where.WriteString(`e.parent_id = (SELECT + c.id + FROM + ldap_container c + WHERE + c.dn_norm = :parent_dn_norm)`) + params["parent_dn_norm"] = baseDN.DNNormStrWithoutSuffix(r.server.Suffix) + } } else { var subWhere string if baseDN.Equal(r.server.Suffix) { @@ -1432,10 +1451,9 @@ func (r *HybridRepository) collectScopeWhereSQL(baseDN *DN, option *SearchOption if baseDN.ParentDN().DNNormStrWithoutSuffix(r.server.Suffix) == "" { subWhere = ` e.parent_id IN (SELECT - e.parent_id + c.id FROM - ldap_entry e - LEFT JOIN ldap_container c ON e.parent_id = c.id + ldap_container c WHERE c.dn_norm = :dn_norm OR @@ -1444,10 +1462,9 @@ func (r *HybridRepository) collectScopeWhereSQL(baseDN *DN, option *SearchOption } else { subWhere = ` e.parent_id IN (SELECT - e.parent_id + c.id FROM - ldap_entry e - LEFT JOIN ldap_container c ON e.parent_id = c.id + ldap_container c WHERE c.dn_norm = :dn_norm OR @@ -1457,14 +1474,7 @@ func (r *HybridRepository) collectScopeWhereSQL(baseDN *DN, option *SearchOption } if option.Scope == 2 { - where.WriteString(`e.id IN (SELECT - e.id - FROM - ldap_entry e - LEFT JOIN ldap_container c ON e.parent_id = c.id - WHERE - e.rdn_norm = :rdn_norm - AND c.dn_norm = :parent_dn_norm + where.WriteString(`(e.rdn_norm = :rdn_norm AND dnc.dn_norm = :parent_dn_norm OR`) where.WriteString(subWhere) where.WriteString(` @@ -1473,6 +1483,7 @@ func (r *HybridRepository) collectScopeWhereSQL(baseDN *DN, option *SearchOption params["parent_dn_norm"] = baseDN.ParentDN().DNNormStrWithoutSuffix(r.server.Suffix) params["dn_norm"] = baseDN.DNNormStrWithoutSuffix(r.server.Suffix) } else { + // scope == 3 where.WriteString(subWhere) params["parent_dn_norm"] = baseDN.ParentDN().DNNormStrWithoutSuffix(r.server.Suffix) params["dn_norm"] = baseDN.DNNormStrWithoutSuffix(r.server.Suffix) @@ -2602,16 +2613,28 @@ func (r *HybridRepository) Bind(ctx context.Context, dn *DN, callback func(curre } } else { // Record authTimestamp, also remove pwdAccountLockedTime and pwdFailureTime - n, o := nowTimeToJSONAttrs(TIMESTAMP_FORMAT) + go func() { + tx2, err := r.begin(ctx) + if err != nil { + log.Printf("error: Failed to begin tx after bind success. id: %d, err: %v", dest.ID, err) + return + } + n, o := nowTimeToJSONAttrs(TIMESTAMP_FORMAT) - if _, err := r.exec(tx, updateAfterBindSuccessByDN, map[string]interface{}{ - "id": dest.ID, - "auth_timestamp_norm": n, - "auth_timestamp_orig": o, - }); err != nil { - rollback(tx) - return xerrors.Errorf("Failed to update entry after bind success. id: %d, err: %w", dest.ID, err) - } + if _, err := r.exec(tx2, updateAfterBindSuccessByDN, map[string]interface{}{ + "id": dest.ID, + "auth_timestamp_norm": n, + "auth_timestamp_orig": o, + }); err != nil { + rollback(tx2) + log.Printf("error: Failed to update entry after bind success. id: %d, err: %v", dest.ID, err) + return + } + if err := commit(tx2); err != nil { + log.Printf("error: Failed to commit tx after bind success. id: %d, err: %v", dest.ID, err) + return + } + }() } if err := commit(tx); err != nil { diff --git a/server.go b/server.go index 02f4196..6934e59 100644 --- a/server.go +++ b/server.go @@ -46,6 +46,7 @@ type ServerConfig struct { QueryTranslator string SimpleACL []string DefaultPPolicyDN string + DefaultPageSize int32 } type Server struct { diff --git a/test_util.go b/test_util.go index 04eeccc..95d620e 100644 --- a/test_util.go +++ b/test_util.go @@ -56,6 +56,15 @@ type Command interface { Run(t *testing.T, conn *ldap.Conn) (*ldap.Conn, error) } +type Wait struct { + duration time.Duration +} + +func (w Wait) Run(t *testing.T, unused *ldap.Conn) (*ldap.Conn, error) { + time.Sleep(w.duration) + return unused, nil +} + type Parallel struct { count int ops [][]Command @@ -238,6 +247,38 @@ func (s Search) Run(t *testing.T, conn *ldap.Conn) (*ldap.Conn, error) { return conn, nil } +type SearchWithPaging struct { + Search + limit uint32 +} + +func (s SearchWithPaging) Run(t *testing.T, conn *ldap.Conn) (*ldap.Conn, error) { + search := ldap.NewSearchRequest( + s.baseDN, + s.scope, + ldap.NeverDerefAliases, + 0, // Size Limit + 0, // Time Limit + false, + "("+s.filter+")", // The filter to apply + s.attrs, // A list attributes to retrieve + nil, + ) + sr, err := conn.SearchWithPaging(search, s.limit) + if err != nil { + return conn, err + } + + if s.assert != nil { + err = s.assert.AssertEntries(conn, err, sr) + if err != nil { + return conn, err + } + } + + return conn, nil +} + func resolveDN(rdn, baseDN string) string { dn := rdn if baseDN != "" { @@ -672,22 +713,25 @@ func setupLDAPServer() *Server { // "objectClasses: ( 2.5.6.9 NAME 'groupOfNames' DESC 'RFC2256: a group of names (DNs)' SUP top STRUCTURAL MUST cn MAY ( businessCategory $ seeAlso $ owner $ ou $ o $ description $ member $ uniqueMember $ displayName ) )", // } testServer = NewServer(&ServerConfig{ - DBHostName: "localhost", - DBPort: testPGPort, - DBName: "ldap", - DBSchema: "public", - DBUser: "dev", - DBPassword: "dev", - DBMaxOpenConns: 2, - DBMaxIdleConns: 1, - Suffix: "dc=example,dc=com", - RootDN: "cn=Manager,dc=example,dc=com", - RootPW: "secret", - BindAddress: "127.0.0.1:8389", - LogLevel: "warn", - PProfServer: "127.0.0.1:10000", - GoMaxProcs: 0, - QueryTranslator: "default", + DBHostName: "localhost", + DBPort: testPGPort, + DBName: "ldap", + DBSchema: "public", + DBUser: "dev", + DBPassword: "dev", + DBMaxOpenConns: 2, + DBMaxIdleConns: 1, + Suffix: "dc=example,dc=com", + RootDN: "cn=Manager,dc=example,dc=com", + RootPW: "secret", + BindAddress: "127.0.0.1:8389", + LogLevel: "warn", + PProfServer: "127.0.0.1:10000", + GoMaxProcs: 0, + QueryTranslator: "default", + DefaultPPolicyDN: "cn=standard-policy,ou=Policies,dc=examle,dc=com", + DefaultPageSize: 500, + SimpleACL: []string{"uid=op1,ou=users,dc=example,dc=com:RW:", "uid=op2,ou=users,dc=example,dc=com:RW:"}, }) go testServer.Start() diff --git a/util.go b/util.go index ded2a9f..221f675 100644 --- a/util.go +++ b/util.go @@ -55,12 +55,12 @@ func getAuthSession(m *ldap.Message) *AuthSession { } } -func getPageSession(m *ldap.Message) map[string]int32 { +func getPageSession(m *ldap.Message) map[string]int64 { session := getSession(m) if pageSession, ok := session["page"]; ok { - return pageSession.(map[string]int32) + return pageSession.(map[string]int64) } else { - pageSession := map[string]int32{} + pageSession := map[string]int64{} session["page"] = pageSession return pageSession } @@ -107,13 +107,13 @@ func isHasSubOrdinatesRequested(r message.SearchRequest) bool { func getRequestedMemberAttrs(r message.SearchRequest) []string { if len(r.Attributes()) == 0 { - return []string{"member", "uniqueMember"} + return getAllMemberAttrs() } list := []string{} for _, attr := range r.Attributes() { if string(attr) == "*" { // TODO move to schema - return []string{"member", "uniqueMember"} + return getAllMemberAttrs() } a := string(attr) @@ -128,6 +128,10 @@ func getRequestedMemberAttrs(r message.SearchRequest) []string { return list } +func getAllMemberAttrs() []string { + return []string{"member", "uniqueMember"} +} + func responseUnsupportedSearch(w ldap.ResponseWriter, r message.SearchRequest) { log.Printf("warn: Unsupported search filter: %s", r.FilterString()) res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess)