Skip to content

Commit

Permalink
Merge pull request #13 from openstandia/fix
Browse files Browse the repository at this point in the history
Improvement some error handling
  • Loading branch information
wadahiro authored Jul 1, 2021
2 parents 6c14dcc + 7f397d5 commit fb221b8
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 23 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ PostgreSQL 12 or later.

## Install

### From bainary
### From binary

Please download it from [release page](../../releases).

### From source

`ldap-pg` is written by Golang. Install Golang then build `ldap-pg`:
`ldap-pg` is written by Go. Install Go then build `ldap-pg`:

```
make
Expand Down
14 changes: 13 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,20 @@ func (e *RetryError) Unwrap() error {
return e.err
}

func NewRetryError(err error) *RetryError {
func NewRetryError(err error) error {
return &RetryError{
err: err,
}
}

type InvalidDNError struct {
dnNorm string
}

func NewInvalidDNError(dnNorm string) error {
return &InvalidDNError{dnNorm}
}

func (e *InvalidDNError) Error() string {
return fmt.Sprintf("InvalidDNError. dn_norm: %s", e.dnNorm)
}
50 changes: 50 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,56 @@ func TestMemberOf(t *testing.T) {
},
},
},
Add{
"uid=user2", "ou=Users",
M{
"objectClass": A{"inetOrgPerson"},
"cn": A{"user2"},
"sn": A{"user2"},
"userPassword": A{SSHA("password1")},
},
&AssertEntry{},
},
// Test case for duplicate members
Add{
"cn=A1", "ou=Groups",
M{
"objectClass": A{"groupOfNames"},
"member": A{
"uid=user2,ou=Users," + testServer.GetSuffix(),
"uid=user2,ou=Users," + testServer.GetSuffix(),
},
},
&AssertLDAPError{
expectErrorCode: ldap.LDAPResultAttributeOrValueExists,
},
},
// Test case for adding a non-existent member
Add{
"cn=A1", "ou=Groups",
M{
"objectClass": A{"groupOfNames"},
"member": A{
"uid=notfound,ou=Users," + testServer.GetSuffix(),
},
},
&AssertLDAPError{
expectErrorCode: ldap.LDAPResultInvalidAttributeSyntax,
},
},
Add{
"cn=A1", "ou=Groups",
M{
"objectClass": A{"groupOfNames"},
"member": A{
"uid=user2,ou=Users," + testServer.GetSuffix(),
"uid=notfound,ou=Users," + testServer.GetSuffix(),
},
},
&AssertLDAPError{
expectErrorCode: ldap.LDAPResultInvalidAttributeSyntax,
},
},
}

runTestCases(t, tcs)
Expand Down
84 changes: 64 additions & 20 deletions repo_hybrid.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ func (r *HybridRepository) Init() error {
db := r.db

_, err = db.Exec(`
CREATE EXTENSION IF NOT EXISTS pgcrypto;
CREATE TABLE IF NOT EXISTS ldap_container (
id BIGINT PRIMARY KEY,
dn_norm VARCHAR(512) NOT NULL, -- cache
Expand Down Expand Up @@ -99,6 +97,9 @@ func (r *HybridRepository) Init() error {
CREATE INDEX IF NOT EXISTS idx_ldap_association_id ON ldap_association(name, id);
CREATE INDEX IF NOT EXISTS idx_ldap_association_member_id ON ldap_association(name, member_id);
`)
if err != nil {
return xerrors.Errorf("Failed to initialize prepared statement: %w", err)
}

findCredByDN, err = db.PrepareNamed(`SELECT
e.id, e.attrs_orig->'userPassword' AS credential
Expand Down Expand Up @@ -2036,7 +2037,9 @@ func (r *HybridRepository) schemaValueToIDArray(tx *sqlx.Tx, schemaValueMap map[
return rtn, nil
}

dnMap := map[string][]string{}
dnMap := map[string]StringSet{}
indexMap := map[string]int{} // key: dn_norm, value: index

for i, v := range schemaValue.Orig() {
dn, err := NormalizeDN(r.server.schemaMap, v)
if err != nil {
Expand All @@ -2045,10 +2048,23 @@ func (r *HybridRepository) schemaValueToIDArray(tx *sqlx.Tx, schemaValueMap map[
}

parentDNNorm := dn.ParentDN().DNNormStrWithoutSuffix(r.server.Suffix)
dnMap[parentDNNorm] = append(dnMap[parentDNNorm], dn.RDNNormStr())
if set, ok := dnMap[parentDNNorm]; ok {
set.Add(dn.RDNNormStr())
} else {
set = NewStringSet(dn.RDNNormStr())
dnMap[parentDNNorm] = set
}
}

return r.resolveDNMap(tx, dnMap)
ids, err := r.resolveDNMap(tx, dnMap)
if err != nil {
if dnErr, ok := err.(*InvalidDNError); ok {
index := indexMap[dnErr.dnNorm]
return nil, NewInvalidPerSyntax(attrName, index)
}
}

return ids, err
}

func (r *HybridRepository) dnArrayToIDArray(tx *sqlx.Tx, norm map[string]interface{}, attrName string) ([]int64, error) {
Expand All @@ -2059,59 +2075,77 @@ func (r *HybridRepository) dnArrayToIDArray(tx *sqlx.Tx, norm map[string]interfa
return rtn, nil
}

dnMap := map[string][]string{}
dnMap := map[string]StringSet{}
indexMap := map[string]int{} // key: dn_norm, value: index

for i, v := range dnArray {
dn, err := NormalizeDN(r.server.schemaMap, v)
if err != nil {
log.Printf("warn: Failed to normalize DN: %s", v)
return nil, NewInvalidPerSyntax(attrName, i)
}

indexMap[dn.DNNormStrWithoutSuffix(r.server.Suffix)] = i

parentDNNorm := dn.ParentDN().DNNormStrWithoutSuffix(r.server.Suffix)
dnMap[parentDNNorm] = append(dnMap[parentDNNorm], dn.RDNNormStr())
if set, ok := dnMap[parentDNNorm]; ok {
set.Add(dn.RDNNormStr())
} else {
set = NewStringSet(dn.RDNNormStr())
dnMap[parentDNNorm] = set
}
}

return r.resolveDNMap(tx, dnMap)
ids, err := r.resolveDNMap(tx, dnMap)
if err != nil {
if dnErr, ok := err.(*InvalidDNError); ok {
index := indexMap[dnErr.dnNorm]
return nil, NewInvalidPerSyntax(attrName, index)
}
}

return ids, err
}

// resolveDNMap resolves Map(key: rdn_norm, value: parent_dn_norm) to the entry's ids.
func (r *HybridRepository) resolveDNMap(tx *sqlx.Tx, dnMap map[string][]string) ([]int64, error) {
func (r *HybridRepository) resolveDNMap(tx *sqlx.Tx, dnMap map[string]StringSet) ([]int64, error) {
rtn := []int64{}

bq := `SELECT
e.id
e.id, e.rdn_norm
FROM
ldap_entry e
LEFT JOIN ldap_container c ON e.parent_id = c.id
WHERE
e.rdn_norm IN (:rdn_norm)
e.rdn_norm IN (:rdn_norms)
AND c.dn_norm = :parent_dn_norm
FOR SHARE
`

for k, v := range dnMap {
rdnNorms := v.Values()
q, params, err := sqlx.Named(bq, map[string]interface{}{
"rdn_norm": v,
"rdn_norms": rdnNorms,
"parent_dn_norm": k,
})
if err != nil {
log.Printf("error: Unexpected named query error. rdn_norm: %s, parent_dn_norm: %v, err: %v", k, v, err)
log.Printf("error: Unexpected named query error. rdn_norms: %v, parent_dn_norm: %s, err: %v", k, rdnNorms, err)
// System error
return nil, NewUnavailable()
}

q, params, err = sqlx.In(q, params...)
if err != nil {
log.Printf("error: Unexpected expand IN error. rdn_norm: %s, parent_dn_norm: %v, err: %v", k, v, err)
log.Printf("error: Unexpected expand IN error. rdn_norms: %v, parent_dn_norm: %s, err: %v", k, rdnNorms, err)
// System error
return nil, NewUnavailable()
}

q = tx.Rebind(q)

rows, err := tx.Query(q, params...)
rows, err := tx.Queryx(q, params...)
if err != nil {
log.Printf("error: Unexpected execute query error. rdn_norm: %s, parent_dn_norm: %v, err: %v", k, v, err)
log.Printf("error: Unexpected execute query error. rdn_norms: %v, parent_dn_norm: %s, err: %v", k, rdnNorms, err)
// System error
return nil, NewUnavailable()
}
Expand All @@ -2120,14 +2154,24 @@ func (r *HybridRepository) resolveDNMap(tx *sqlx.Tx, dnMap map[string][]string)

var ids []int64
for rows.Next() {
var id int64
err = rows.Scan(&id)
var entry struct {
Id int64 `db:"id"`
RDNNorm string `db:"rdn_norm"`
}
err = rows.StructScan(&entry)
if err != nil {
log.Printf("error: Unexpected query result scan error. rdn_norm: %s, parent_dn_norm: %v, err: %v", k, v, err)
log.Printf("error: Unexpected query result scan error. rdn_norms: %v, parent_dn_norm: %s, err: %v", k, rdnNorms, err)
// System error
return nil, NewUnavailable()
}
ids = append(ids, id)

delete(v, entry.RDNNorm)
ids = append(ids, entry.Id)
}

if v.Size() > 0 {
log.Printf("warn: Detected non-existent DN for association. rdn_norms: %v, parent_dn_norm: %s", v.Values(), rdnNorms)
return nil, NewInvalidDNError(v.First() + "," + k)
}

rtn = append(rtn, ids...)
Expand Down
41 changes: 41 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,44 @@ func verifyChainedObjectClasses(s *SchemaMap, objectClasses []*ObjectClass) *LDA

return nil
}

type StringSet map[string]struct{}

func NewStringSet(str ...string) StringSet {
set := StringSet{}
for _, v := range str {
set.Add(v)
}
return set
}

func (s StringSet) Add(str string) {
s[str] = struct{}{}
}

func (s StringSet) Size() int {
return len(s)
}

func (s StringSet) First() string {
// TODO Store the order of the map
for k, _ := range s {
return k
}
return ""
}

func (s StringSet) Contains(str string) bool {
_, ok := s[str]
return ok
}

func (s StringSet) Values() []string {
rtn := make([]string, s.Size())
i := 0
for k, _ := range s {
rtn[i] = k
i++
}
return rtn
}

0 comments on commit fb221b8

Please sign in to comment.