diff --git a/dig/dig.go b/dig/dig.go index 3a759c3..49a9c39 100644 --- a/dig/dig.go +++ b/dig/dig.go @@ -386,51 +386,91 @@ type Filter struct { } func (f Filter) Accept(ctx context.Context, pgmut *sync.Mutex, pg wpg.Conn, d any) (bool, error) { - var val []byte + if len(f.Arg) == 0 { + return true, nil + } switch v := d.(type) { - case []byte: - val = []byte(v) case eth.Bytes: - val = []byte(v) - default: - return true, nil + d = []byte(v) + case eth.Uint64: + d = uint64(v) } - - switch { - case strings.HasSuffix(f.Op, "contains"): - var res bool + switch v := d.(type) { + case []byte: switch { - case len(f.Ref.Table) > 0: - q := fmt.Sprintf( - `select true from %s where %s = $1`, - f.Ref.Table, - f.Ref.Column, - ) - pgmut.Lock() - defer pgmut.Unlock() - err := pg.QueryRow(ctx, q, val).Scan(&res) + case strings.HasSuffix(f.Op, "contains"): + var res bool switch { - case errors.Is(err, pgx.ErrNoRows): - res = false - case err != nil: - const tag = "filter using reference (%s %s): %w" - return false, fmt.Errorf(tag, f.Ref.Table, f.Ref.Column, err) - } - default: - for i := range f.Arg { - if bytes.Contains(val, eth.DecodeHex(f.Arg[i])) { - res = true - break + case len(f.Ref.Table) > 0: + q := fmt.Sprintf( + `select true from %s where %s = $1`, + f.Ref.Table, + f.Ref.Column, + ) + pgmut.Lock() + defer pgmut.Unlock() + err := pg.QueryRow(ctx, q, v).Scan(&res) + switch { + case errors.Is(err, pgx.ErrNoRows): + res = false + case err != nil: + const tag = "filter using reference (%s %s): %w" + return false, fmt.Errorf(tag, f.Ref.Table, f.Ref.Column, err) } + default: + for i := range f.Arg { + if bytes.Contains(v, eth.DecodeHex(f.Arg[i])) { + res = true + break + } + } + } + if strings.HasPrefix(f.Op, "!") { + return !res, nil } + return res, nil + default: + return true, nil } - if strings.HasPrefix(f.Op, "!") { - return !res, nil + case string: + switch f.Op { + case "eq": + return v == f.Arg[0], nil + case "ne": + return v != f.Arg[0], nil } - return res, nil - default: - return true, nil - } + case uint64: + i, err := strconv.ParseUint(f.Arg[0], 10, 64) + if err != nil { + return false, fmt.Errorf("unable to convert filter arg to int: %q", f.Arg[0]) + } + switch f.Op { + case "eq": + return v == i, nil + case "ne": + return v != i, nil + case "gt": + return v > i, nil + case "lt": + return v < i, nil + } + case *uint256.Int: + i := &uint256.Int{} + if err := i.SetFromDecimal(f.Arg[0]); err != nil { + return false, fmt.Errorf("unable to convert filter arg dec to uint256: %q", f.Arg[0]) + } + switch f.Op { + case "eq": + return v.Cmp(i) == 0, nil + case "ne": + return v.Cmp(i) != 0, nil + case "gt": + return v.Cmp(i) == 1, nil + case "lt": + return v.Cmp(i) == -1, nil + } + } + return true, nil } func parseArray(elm atype, s string) atype { diff --git a/dig/dig_test.go b/dig/dig_test.go index f21e619..1d48041 100644 --- a/dig/dig_test.go +++ b/dig/dig_test.go @@ -1,16 +1,37 @@ package dig import ( + "context" + "database/sql" "encoding/hex" "reflect" "strings" + "sync" "testing" + "blake.io/pqx/pqxtest" + "github.com/holiman/uint256" "github.com/indexsupply/x/bint" + "github.com/indexsupply/x/eth" + "github.com/indexsupply/x/tc" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" "kr.dev/diff" ) +func TestMain(m *testing.M) { + sql.Register("postgres", stdlib.GetDefaultDriver()) + pqxtest.TestMain(m) +} + +func testpg(t *testing.T) *pgxpool.Pool { + pqxtest.CreateDB(t, "") + pg, err := pgxpool.New(context.Background(), pqxtest.DSNForTest(t)) + tc.NoErr(t, err) + return pg +} + func TestHasStatic(t *testing.T) { cases := []struct { t atype @@ -368,3 +389,48 @@ func TestNumIndexed(t *testing.T) { } diff.Test(t, t.Errorf, 3, event.numIndexed()) } + +func TestFilter(t *testing.T) { + dec2uint256 := func(s string) *uint256.Int { + i, _ := uint256.FromDecimal(s) + return i + } + pg := testpg(t) + mt := new(sync.Mutex) + cases := []struct { + f Filter + d any + want bool + }{ + { + Filter{Op: "gt", Arg: []string{"1"}}, + eth.Uint64(0), + false, + }, + { + Filter{Op: "gt", Arg: []string{"1"}}, + eth.Uint64(2), + true, + }, + { + Filter{Op: "eq", Arg: []string{"340282366920938463463374607431768211456"}}, + dec2uint256("340282366920938463463374607431768211456"), + true, + }, + { + Filter{Op: "eq", Arg: []string{"foo"}}, + "foo", + true, + }, + { + Filter{Op: "ne", Arg: []string{"bar"}}, + "foo", + true, + }, + } + for _, c := range cases { + got, err := c.f.Accept(context.Background(), mt, pg, c.d) + tc.NoErr(t, err) + tc.WantGot(t, c.want, got) + } +}