From 555ee1e52cd294de6a95ae901362d19cdf9ee542 Mon Sep 17 00:00:00 2001 From: Viacheslav Gonkivskyi Date: Thu, 28 Sep 2023 17:14:30 +0300 Subject: [PATCH] rework subscription --- nodebuilder/fraud/fraud.go | 14 +++++++------- nodebuilder/fraud/mocks/api.go | 4 ++-- nodebuilder/tests/fraud_test.go | 20 +++++++++++++++----- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/nodebuilder/fraud/fraud.go b/nodebuilder/fraud/fraud.go index 45c3863d6f..46af4af67d 100644 --- a/nodebuilder/fraud/fraud.go +++ b/nodebuilder/fraud/fraud.go @@ -18,7 +18,7 @@ var _ Module = (*API)(nil) //go:generate mockgen -destination=mocks/api.go -package=mocks . Module type Module interface { // Subscribe allows to subscribe on a Proof pub sub topic by its type. - Subscribe(context.Context, fraud.ProofType) (<-chan Proof, error) + Subscribe(context.Context, fraud.ProofType) (<-chan *Proof, error) // Get fetches fraud proofs from the disk by its type. Get(context.Context, fraud.ProofType) ([]Proof, error) } @@ -27,12 +27,12 @@ type Module interface { // TODO(@distractedm1nd): These structs need to be autogenerated. type API struct { Internal struct { - Subscribe func(context.Context, fraud.ProofType) (<-chan Proof, error) `perm:"public"` - Get func(context.Context, fraud.ProofType) ([]Proof, error) `perm:"public"` + Subscribe func(context.Context, fraud.ProofType) (<-chan *Proof, error) `perm:"public"` + Get func(context.Context, fraud.ProofType) ([]Proof, error) `perm:"public"` } } -func (api *API) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-chan Proof, error) { +func (api *API) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-chan *Proof, error) { return api.Internal.Subscribe(ctx, proofType) } @@ -49,12 +49,12 @@ type module struct { fraud.Service[*header.ExtendedHeader] } -func (s *module) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-chan Proof, error) { +func (s *module) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-chan *Proof, error) { subscription, err := s.Service.Subscribe(proofType) if err != nil { return nil, err } - proofs := make(chan Proof) + proofs := make(chan *Proof) go func() { defer close(proofs) defer subscription.Cancel() @@ -69,7 +69,7 @@ func (s *module) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-ch select { case <-ctx.Done(): return - case proofs <- Proof{Proof: proof}: + case proofs <- &Proof{Proof: proof}: } } }() diff --git a/nodebuilder/fraud/mocks/api.go b/nodebuilder/fraud/mocks/api.go index 399f8746e1..10111b81a8 100644 --- a/nodebuilder/fraud/mocks/api.go +++ b/nodebuilder/fraud/mocks/api.go @@ -53,10 +53,10 @@ func (mr *MockModuleMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { } // Subscribe mocks base method. -func (m *MockModule) Subscribe(arg0 context.Context, arg1 fraud0.ProofType) (<-chan fraud.Proof, error) { +func (m *MockModule) Subscribe(arg0 context.Context, arg1 fraud0.ProofType) (<-chan *fraud.Proof, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Subscribe", arg0, arg1) - ret0, _ := ret[0].(<-chan fraud.Proof) + ret0, _ := ret[0].(<-chan *fraud.Proof) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/nodebuilder/tests/fraud_test.go b/nodebuilder/tests/fraud_test.go index cf8b3bfce4..d6abf89eac 100644 --- a/nodebuilder/tests/fraud_test.go +++ b/nodebuilder/tests/fraud_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/celestiaorg/go-fraud" "github.com/ipfs/go-datastore" ds_sync "github.com/ipfs/go-datastore/sync" "github.com/libp2p/go-libp2p/core/host" @@ -13,6 +14,7 @@ import ( "github.com/tendermint/tendermint/types" "go.uber.org/fx" + "github.com/celestiaorg/celestia-node/header" headerfraud "github.com/celestiaorg/celestia-node/header/headertest/fraud" "github.com/celestiaorg/celestia-node/nodebuilder" "github.com/celestiaorg/celestia-node/nodebuilder/core" @@ -98,12 +100,21 @@ func TestFraudProofHandling(t *testing.T) { select { case p := <-subscr: require.Equal(t, 10, int(p.Height())) + t.Log("Caught the proof....") subCancel() case <-ctx.Done(): subCancel() t.Fatal("full node did not receive a fraud proof in time") } + getCtx, getCancel := context.WithTimeout(ctx, time.Second) + proofs, err := fullClient.Fraud.Get(getCtx, byzantine.BadEncoding) + getCancel() + + require.NoError(t, err) + require.NotNil(t, proofs) + require.Len(t, proofs, 1) + require.True(t, proofs[0].Type() == byzantine.BadEncoding) // This is an obscure way to check if the Syncer was stopped. // If we cannot get a height header within a timeframe it means the syncer was stopped // FIXME: Eventually, this should be a check on service registry managing and keeping @@ -149,11 +160,10 @@ func TestFraudProofHandling(t *testing.T) { // 9. fN := sw.NewNodeWithStore(node.Full, store) - require.Error(t, fN.Start(ctx)) - fNClient := getAdminClient(ctx, fN, t) - proofs, err := fNClient.Fraud.Get(ctx, byzantine.BadEncoding) - require.NoError(t, err) - require.NotNil(t, proofs) + err = fN.Start(ctx) + require.Error(t, err) + var fpExist *fraud.ErrFraudExists[*header.ExtendedHeader] + require.ErrorAs(t, err, &fpExist) sw.StopNode(ctx, bridge) sw.StopNode(ctx, full)