diff --git a/activation/interface.go b/activation/interface.go index f788ca4888..8de82ac4e9 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -105,5 +105,6 @@ type postService interface { } type PostClient interface { + Info(ctx context.Context) (*types.PostInfo, error) Proof(ctx context.Context, challenge []byte) (*types.Post, *types.PostMetadata, error) } diff --git a/activation/mocks.go b/activation/mocks.go index efafbecd2c..c7075cc775 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -1888,6 +1888,45 @@ func (m *MockPostClient) EXPECT() *MockPostClientMockRecorder { return m.recorder } +// Info mocks base method. +func (m *MockPostClient) Info(ctx context.Context) (*types.PostInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Info", ctx) + ret0, _ := ret[0].(*types.PostInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Info indicates an expected call of Info. +func (mr *MockPostClientMockRecorder) Info(ctx any) *PostClientInfoCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockPostClient)(nil).Info), ctx) + return &PostClientInfoCall{Call: call} +} + +// PostClientInfoCall wrap *gomock.Call +type PostClientInfoCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *PostClientInfoCall) Return(arg0 *types.PostInfo, arg1 error) *PostClientInfoCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *PostClientInfoCall) Do(f func(context.Context) (*types.PostInfo, error)) *PostClientInfoCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *PostClientInfoCall) DoAndReturn(f func(context.Context) (*types.PostInfo, error)) *PostClientInfoCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Proof mocks base method. func (m *MockPostClient) Proof(ctx context.Context, challenge []byte) (*types.Post, *types.PostMetadata, error) { m.ctrl.T.Helper() diff --git a/api/grpcserver/post_client.go b/api/grpcserver/post_client.go index 525ac04b4b..1591566e31 100644 --- a/api/grpcserver/post_client.go +++ b/api/grpcserver/post_client.go @@ -29,82 +29,132 @@ func newPostClient(con chan<- postCommand) *postClient { } } +func (pc *postClient) Info(ctx context.Context) (*types.PostInfo, error) { + req := &pb.NodeRequest{ + Kind: &pb.NodeRequest_Metadata{ + Metadata: &pb.MetadataRequest{}, + }, + } + resp, err := pc.send(ctx, req) + if err != nil { + return nil, err + } + metadataResp := resp.GetMetadata() + if metadataResp == nil { + return nil, fmt.Errorf("unexpected response of type: %T", resp.GetKind()) + } + meta := metadataResp.GetMeta() + if meta == nil { + return nil, fmt.Errorf("post metadata is nil") + } + var nonce *types.VRFPostIndex + if meta.Nonce != nil { + nonce = new(types.VRFPostIndex) + *nonce = types.VRFPostIndex(meta.GetNonce()) + } + return &types.PostInfo{ + NodeID: types.BytesToNodeID(meta.GetNodeId()), + CommitmentATX: types.BytesToATXID(meta.GetCommitmentAtxId()), + Nonce: nonce, + + NumUnits: meta.GetNumUnits(), + LabelsPerUnit: meta.GetLabelsPerUnit(), + }, nil +} + func (pc *postClient) Proof(ctx context.Context, challenge []byte) (*types.Post, *types.PostMetadata, error) { - resp := make(chan *pb.ServiceResponse, 1) - cmd := postCommand{ - req: &pb.NodeRequest{ - Kind: &pb.NodeRequest_GenProof{ - GenProof: &pb.GenProofRequest{ - Challenge: challenge, - }, + req := &pb.NodeRequest{ + Kind: &pb.NodeRequest_GenProof{ + GenProof: &pb.GenProofRequest{ + Challenge: challenge, }, }, - resp: resp, } + var proofResp *pb.GenProofResponse for { - // send command - select { - case <-pc.closed: - return nil, nil, fmt.Errorf("post client closed") - case <-ctx.Done(): - return nil, nil, ctx.Err() - case pc.con <- cmd: + resp, err := pc.send(ctx, req) + if err != nil { + return nil, nil, err + } + + proofResp = resp.GetGenProof() + if proofResp == nil { + return nil, nil, fmt.Errorf("unexpected response of type: %T", resp.GetKind()) + } + + switch proofResp.GetStatus() { + case pb.GenProofStatus_GEN_PROOF_STATUS_ERROR: + return nil, nil, fmt.Errorf("error generating proof: %s", proofResp) + case pb.GenProofStatus_GEN_PROOF_STATUS_UNSPECIFIED: + return nil, nil, fmt.Errorf("unspecified error generating proof: %s", proofResp) + case pb.GenProofStatus_GEN_PROOF_STATUS_OK: + default: + return nil, nil, fmt.Errorf("unknown status: %s", proofResp) + } + + if proofResp.GetProof() != nil { + break } - // receive response select { - case <-pc.closed: - return nil, nil, fmt.Errorf("post client closed") case <-ctx.Done(): return nil, nil, ctx.Err() - case resp := <-resp: - proofResp := resp.GetGenProof() - if proofResp == nil { - return nil, nil, fmt.Errorf("unexpected response of type: %T", resp.GetKind()) - } - switch proofResp.GetStatus() { - case pb.GenProofStatus_GEN_PROOF_STATUS_OK: - if proofResp.GetProof() == nil { - select { - case <-ctx.Done(): - return nil, nil, ctx.Err() - case <-time.After(2 * time.Second): - // TODO(mafa): make polling interval configurable - continue - } - } - - proof := proofResp.GetProof() - proofMeta := proofResp.GetMetadata() - if proofMeta == nil { - return nil, nil, fmt.Errorf("proof metadata is nil") - } - - if !bytes.Equal(proofMeta.GetChallenge(), challenge) { - return nil, nil, fmt.Errorf("unexpected challenge: %x", proofMeta.GetChallenge()) - } - - postMeta := proofMeta.GetMeta() - if postMeta == nil { - return nil, nil, fmt.Errorf("post metadata is nil") - } - - return &types.Post{ - Nonce: proof.GetNonce(), - Indices: proof.GetIndices(), - Pow: proof.GetPow(), - }, &types.PostMetadata{ - Challenge: proofMeta.GetChallenge(), - LabelsPerUnit: postMeta.GetLabelsPerUnit(), - }, nil - case pb.GenProofStatus_GEN_PROOF_STATUS_ERROR: - return nil, nil, fmt.Errorf("error generating proof: %s", proofResp) - case pb.GenProofStatus_GEN_PROOF_STATUS_UNSPECIFIED: - return nil, nil, fmt.Errorf("unspecified error generating proof: %s", proofResp) - } + case <-time.After(2 * time.Second): + // TODO(mafa): make polling interval configurable + continue } } + + proof := proofResp.GetProof() + metadata := proofResp.GetMetadata() + if metadata == nil { + return nil, nil, fmt.Errorf("proof metadata is nil") + } + if !bytes.Equal(metadata.GetChallenge(), challenge) { + return nil, nil, fmt.Errorf("unexpected challenge: %x", metadata.GetChallenge()) + } + proofMeta := metadata.GetMeta() + if proofMeta == nil { + return nil, nil, fmt.Errorf("post metadata is nil") + } + post := &types.Post{ + Nonce: proof.GetNonce(), + Indices: proof.GetIndices(), + Pow: proof.GetPow(), + } + postMeta := &types.PostMetadata{ + Challenge: metadata.GetChallenge(), + LabelsPerUnit: proofMeta.GetLabelsPerUnit(), + } + return post, postMeta, nil +} + +func (pc *postClient) send(ctx context.Context, req *pb.NodeRequest) (*pb.ServiceResponse, error) { + resp := make(chan *pb.ServiceResponse, 1) + cmd := postCommand{ + req: req, + resp: resp, + } + + // send command + select { + case <-pc.closed: + return nil, fmt.Errorf("post client closed") + case <-ctx.Done(): + return nil, ctx.Err() + case pc.con <- cmd: + } + + // receive response + select { + case <-pc.closed: + return nil, fmt.Errorf("post client closed") + case <-ctx.Done(): + return nil, ctx.Err() + case resp := <-resp: + return resp, nil + } } func (pc *postClient) Close() error { diff --git a/api/grpcserver/post_service_test.go b/api/grpcserver/post_service_test.go index f687ba760c..e07c33a3a7 100644 --- a/api/grpcserver/post_service_test.go +++ b/api/grpcserver/post_service_test.go @@ -23,7 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" ) -func initPost(tb testing.TB, log *zap.Logger, opts activation.PostSetupOpts) types.NodeID { +func initPost(tb testing.TB, log *zap.Logger, opts activation.PostSetupOpts) (types.NodeID, types.ATXID) { tb.Helper() cfg := activation.DefaultPostConfig() @@ -32,7 +32,7 @@ func initPost(tb testing.TB, log *zap.Logger, opts activation.PostSetupOpts) typ require.NoError(tb, err) id := sig.NodeID() - goldenATXID := types.ATXID{2, 3, 4} + goldenATXID := types.RandomATXID() cdb := datastore.NewCachedDB(sql.InMemory(), logtest.New(tb)) mgr, err := activation.NewPostSetupManager(id, cfg, log.Named("manager"), cdb, goldenATXID) @@ -69,7 +69,7 @@ func initPost(tb testing.TB, log *zap.Logger, opts activation.PostSetupOpts) typ require.NoError(tb, mgr.StartSession(context.Background())) require.NoError(tb, eg.Wait()) require.Equal(tb, activation.PostSetupStateComplete, mgr.Status().State) - return id + return id, goldenATXID } func launchPostSupervisor(tb testing.TB, log *zap.Logger, cfg Config, postOpts activation.PostSetupOpts) func() { @@ -122,7 +122,7 @@ func Test_GenerateProof(t *testing.T) { opts.DataDir = t.TempDir() opts.ProviderID.SetInt64(int64(initialization.CPUProviderID())) opts.Scrypt.N = 2 // Speedup initialization in tests. - id := initPost(t, log.Named("post"), opts) + id, _ := initPost(t, log.Named("post"), opts) postCleanup := launchPostSupervisor(t, log.Named("supervisor"), cfg, opts) t.Cleanup(postCleanup) @@ -165,7 +165,7 @@ func Test_GenerateProof_TLS(t *testing.T) { opts.DataDir = t.TempDir() opts.ProviderID.SetInt64(int64(initialization.CPUProviderID())) opts.Scrypt.N = 2 // Speedup initialization in tests. - id := initPost(t, log.Named("post"), opts) + id, _ := initPost(t, log.Named("post"), opts) postCleanup := launchPostSupervisorTLS(t, log.Named("supervisor"), cfg, opts) t.Cleanup(postCleanup) @@ -198,7 +198,7 @@ func Test_GenerateProof_TLS(t *testing.T) { require.Nil(t, meta) } -func Test_Cancel_GenerateProof(t *testing.T) { +func Test_GenerateProof_Cancel(t *testing.T) { log := zaptest.NewLogger(t) svc := NewPostService(log) cfg, cleanup := launchServer(t, svc) @@ -208,7 +208,7 @@ func Test_Cancel_GenerateProof(t *testing.T) { opts.DataDir = t.TempDir() opts.ProviderID.SetInt64(int64(initialization.CPUProviderID())) opts.Scrypt.N = 2 // Speedup initialization in tests. - id := initPost(t, log.Named("post"), opts) + id, _ := initPost(t, log.Named("post"), opts) t.Cleanup(launchPostSupervisor(t, log.Named("supervisor"), cfg, opts)) var client activation.PostClient @@ -233,6 +233,46 @@ func Test_Cancel_GenerateProof(t *testing.T) { require.Nil(t, meta) } +func Test_Metadata(t *testing.T) { + log := zaptest.NewLogger(t) + svc := NewPostService(log) + cfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) + + opts := activation.DefaultPostSetupOpts() + opts.DataDir = t.TempDir() + opts.ProviderID.SetInt64(int64(initialization.CPUProviderID())) + opts.Scrypt.N = 2 // Speedup initialization in tests. + id, commitmentAtxId := initPost(t, log.Named("post"), opts) + postCleanup := launchPostSupervisor(t, log.Named("supervisor"), cfg, opts) + t.Cleanup(postCleanup) + + var client activation.PostClient + require.Eventually(t, func() bool { + var err error + client, err = svc.Client(id) + return err == nil + }, 10*time.Second, 100*time.Millisecond, "timed out waiting for connection") + + meta, err := client.Info(context.Background()) + require.NoError(t, err) + require.NotNil(t, meta) + require.Equal(t, id, meta.NodeID) + require.Equal(t, commitmentAtxId, meta.CommitmentATX) + require.NotNil(t, meta.Nonce) + require.Equal(t, opts.NumUnits, meta.NumUnits) + + // drop connection + postCleanup() + require.Eventually(t, func() bool { + meta, err = client.Info(context.Background()) + return err != nil + }, 5*time.Second, 100*time.Millisecond) + + require.ErrorContains(t, err, "post client closed") + require.Nil(t, meta) +} + func Test_GenerateProof_MultipleServices(t *testing.T) { log := zaptest.NewLogger(t) svc := NewPostService(log) @@ -245,7 +285,7 @@ func Test_GenerateProof_MultipleServices(t *testing.T) { opts.Scrypt.N = 2 // Speedup initialization in tests. // all but one should not be able to register to the node (i.e. open a stream to it). - id := initPost(t, log.Named("post1"), opts) + id, _ := initPost(t, log.Named("post1"), opts) t.Cleanup(launchPostSupervisor(t, log.Named("supervisor1"), cfg, opts)) opts.DataDir = t.TempDir() diff --git a/common/types/post.go b/common/types/post.go new file mode 100644 index 0000000000..b88de59bde --- /dev/null +++ b/common/types/post.go @@ -0,0 +1,11 @@ +package types + +// PostInfo contains information about the PoST as returned by the service. +type PostInfo struct { + NodeID NodeID + CommitmentATX ATXID + Nonce *VRFPostIndex + + NumUnits uint32 + LabelsPerUnit uint64 +}