Skip to content

Commit

Permalink
jrpc2: refactor test helper
Browse files Browse the repository at this point in the history
The current testing mechanics relies on the order of the switch
statement. This new helper makes the tests less brittle and more
precise!
  • Loading branch information
ryandotsmith committed May 3, 2024
1 parent 3af910d commit 48f6187
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions jrpc2/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package jrpc2
import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"slices"
"sort"
"strings"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -108,13 +109,32 @@ var (
logs1000001JSON string
)

func methodsMatch(t *testing.T, body []byte, want ...string) bool {
var req []request

if err := json.Unmarshal(body, &req); err != nil {
var r request
if err := json.Unmarshal(body, &r); err != nil {
t.Fatal("unable to decode json into a request or []request")
}
req = append(req, r)
}

var methods []string
for i := range req {
methods = append(methods, req[i].Method)
}
t.Logf("methods=%#v", methods)
return slices.Equal(methods, want)
}

func TestLatest_Cached(t *testing.T) {
var counter int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getBlockByNumber"):
case methodsMatch(t, body, "eth_getBlockByNumber"):
switch counter {
case 0:
_, err := w.Write([]byte(`{"result": {
Expand Down Expand Up @@ -215,7 +235,7 @@ func TestValidate_Blocks(t *testing.T) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getBlockByNumber"):
case methodsMatch(t, body, "eth_getBlockByNumber", "eth_getBlockByNumber"):
_, err := w.Write([]byte(`[
{
"result": {
Expand Down Expand Up @@ -250,7 +270,7 @@ func TestValidate_Logs(t *testing.T) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getLogs"):
case methodsMatch(t, body, "eth_getLogs"):
_, err := w.Write([]byte(`{"result": [
{
"address": "0x0000000000000000000000000000000000000000",
Expand Down Expand Up @@ -285,7 +305,7 @@ func TestError(t *testing.T) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getBlockByNumber"):
case methodsMatch(t, body, "eth_getBlockByNumber"):
_, err := w.Write([]byte(`
[{
"jsonrpc": "2.0",
Expand Down Expand Up @@ -337,11 +357,11 @@ func TestGet_Cached(t *testing.T) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getBlockByNumber"):
case methodsMatch(t, body, "eth_getBlockByNumber"):
atomic.AddUint64(&reqCount, 1)
_, err := w.Write([]byte(block18000000JSON))
diff.Test(t, t.Fatalf, nil, err)
case strings.Contains(string(body), "eth_getLogs"):
case methodsMatch(t, body, "eth_getLogs"):
for ; reqCount == 0; time.Sleep(time.Second) {
}
_, err := w.Write([]byte(logs18000000JSON))
Expand Down Expand Up @@ -388,7 +408,7 @@ func TestGet_Cached_Pruned(t *testing.T) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getBlockByNumber"):
case methodsMatch(t, body, "eth_getBlockByNumber"):
atomic.AddInt32(&n, 1)
_, err := w.Write([]byte(block18000000JSON))
diff.Test(t, t.Fatalf, nil, err)
Expand Down Expand Up @@ -417,10 +437,10 @@ func TestNoLogs(t *testing.T) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getBlockByNumber"):
case methodsMatch(t, body, "eth_getBlockByNumber"):
_, err := w.Write([]byte(block1000001JSON))
diff.Test(t, t.Fatalf, nil, err)
case strings.Contains(string(body), "eth_getLogs"):
case methodsMatch(t, body, "eth_getLogs"):
_, err := w.Write([]byte(logs1000001JSON))
diff.Test(t, t.Fatalf, nil, err)
}
Expand All @@ -444,10 +464,10 @@ func TestLatest(t *testing.T) {
body, err := io.ReadAll(r.Body)
diff.Test(t, t.Fatalf, nil, err)
switch {
case strings.Contains(string(body), "eth_getBlockByNumber"):
case methodsMatch(t, body, "eth_getBlockByNumber"):
_, err := w.Write([]byte(block18000000JSON))
diff.Test(t, t.Fatalf, nil, err)
case strings.Contains(string(body), "eth_getLogs"):
case methodsMatch(t, body, "eth_getLogs"):
_, err := w.Write([]byte(logs18000000JSON))
diff.Test(t, t.Fatalf, nil, err)
}
Expand Down

0 comments on commit 48f6187

Please sign in to comment.