Skip to content

Commit

Permalink
fix Query.One + Filter behavior (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed Dec 16, 2024
1 parent cb20568 commit a1cf2dc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 26 deletions.
44 changes: 18 additions & 26 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,34 +239,19 @@ func (q *Query) One(ctx context.Context, out interface{}) error {
}

// If not, try a Query.
req := q.queryInput()

var res *dynamodb.QueryOutput
err := q.table.db.retry(ctx, func() error {
var err error
res, err = q.table.db.client.Query(ctx, req)
q.cc.incRequests()
if err != nil {
return err
}

switch {
case len(res.Items) == 0:
return ErrNotFound
case len(res.Items) > 1 && q.limit != 1:
return ErrTooMany
case res.LastEvaluatedKey != nil && q.searchLimit != 0:
return ErrTooMany
}

return nil
})
if err != nil {
iter := q.Iter().(*queryIter)
ok := iter.Next(ctx, out)
if err := iter.Err(); err != nil {
return err
}
q.cc.add(res.ConsumedCapacity)

return unmarshalItem(res.Items[0], out)
if !ok {
return ErrNotFound
}
// Best effort: do we have any pending unused items?
if iter.hasMore() {
return ErrTooMany
}
return nil
}

// Count executes this request, returning the number of results.
Expand Down Expand Up @@ -422,6 +407,13 @@ func (itr *queryIter) Next(ctx context.Context, out interface{}) bool {
return itr.err == nil
}

func (itr *queryIter) hasMore() bool {
if itr.query.limit > 0 && itr.n == itr.query.limit {
return false
}
return itr.output != nil && itr.idx < len(itr.output.Items)
}

// Err returns the error encountered, if any.
// You should check this after Next is finished.
func (itr *queryIter) Err() error {
Expand Down
25 changes: 25 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dynamo

import (
"context"
"errors"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -111,6 +112,30 @@ func TestGetAllCount(t *testing.T) {
t.Errorf("bad result for get one. %v ≠ %v", one, item)
}

// trigger ErrTooMany
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).One(ctx, &one)
if !errors.Is(err, ErrTooMany) {
t.Errorf("bad error from get one. %v ≠ %v", err, ErrTooMany)
}

// suppress ErrTooMany with Limit(1)
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).Limit(1).One(ctx, &one)
if err != nil {
t.Error("unexpected error:", err)
}
if one.UserID == 0 {
t.Errorf("bad result for get one: %v", one)
}

// trigger ErrNotFound via SearchLimit + Filter + One
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Filter("Msg = ?", item.Msg).Consistent(true).SearchLimit(1).One(ctx, &one)
if !errors.Is(err, ErrNotFound) {
t.Errorf("bad error from get one. %v ≠ %v", err, ErrNotFound)
}

// GetItem + Project
one = widget{}
projected := widget{
Expand Down

0 comments on commit a1cf2dc

Please sign in to comment.