From 512c9deb15b139246c2d2b4b28f7e8729f559452 Mon Sep 17 00:00:00 2001 From: Jun Kimura Date: Tue, 16 Jul 2024 11:05:17 +0900 Subject: [PATCH 1/3] add test for `splitIntoMultiBatch` Signed-off-by: Jun Kimura --- relay/prover_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 relay/prover_test.go diff --git a/relay/prover_test.go b/relay/prover_test.go new file mode 100644 index 0000000..31e159e --- /dev/null +++ b/relay/prover_test.go @@ -0,0 +1,96 @@ +package relay + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSplitIntoMultiBatch(t *testing.T) { + var M = func(n uint64) []byte { + return []byte(fmt.Sprintf("message-%d", n)) + } + var S = func(n uint64) []byte { + return []byte(fmt.Sprintf("signature-%d", n)) + } + var Signer = func(n uint64) []byte { + return []byte(fmt.Sprintf("signer-%d", n)) + } + var cases = []struct { + Messages [][]byte + Signatures [][]byte + BatchSizes []int + Signer []byte + MessageBatchSize uint64 + Error bool + }{ + // Messages.len = 1 is invalid + { + Messages: [][]byte{M(0)}, + Signatures: [][]byte{S(0)}, + BatchSizes: []int{1}, + Signer: Signer(0), + MessageBatchSize: 1, + Error: true, + }, + { + Messages: [][]byte{M(0), M(1)}, + Signatures: [][]byte{S(0), S(1)}, + BatchSizes: []int{2}, + Signer: Signer(0), + MessageBatchSize: 2, + Error: false, + }, + { + Messages: [][]byte{M(0), M(1), M(2)}, + Signatures: [][]byte{S(0), S(1), S(2)}, + BatchSizes: []int{3}, + Signer: Signer(0), + MessageBatchSize: 3, + Error: false, + }, + { + Messages: [][]byte{M(0), M(1), M(2)}, + Signatures: [][]byte{S(0), S(1), S(2)}, + BatchSizes: []int{2, 1}, + Signer: Signer(0), + MessageBatchSize: 2, + Error: false, + }, + { + Messages: [][]byte{M(0), M(1), M(2), M(3)}, + Signatures: [][]byte{S(0), S(1), S(2), S(3)}, + BatchSizes: []int{3, 1}, + Signer: Signer(0), + MessageBatchSize: 3, + Error: false, + }, + { + Messages: [][]byte{M(0), M(1), M(2), M(3), M(4)}, + Signatures: [][]byte{S(0), S(1), S(2), S(3), S(4)}, + BatchSizes: []int{3, 2}, + Signer: Signer(0), + MessageBatchSize: 3, + Error: false, + }, + } + for i, c := range cases { + t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { + require := require.New(t) + batches, err := splitIntoMultiBatch(c.Messages, c.Signatures, c.Signer, c.MessageBatchSize) + if c.Error { + require.Error(err) + return + } else { + require.NoError(err) + } + require.Len(batches, len(c.BatchSizes)) + for i, size := range c.BatchSizes { + require.Equal(batches[i].Signer, c.Signer) + require.Len(batches[i].Messages, size) + require.Len(batches[i].Signatures, size) + } + }) + } +} From 5c814d68c676721dbe88d8269adadd5550269e31 Mon Sep 17 00:00:00 2001 From: Jun Kimura Date: Tue, 16 Jul 2024 11:34:35 +0900 Subject: [PATCH 2/3] add tests for `aggregateMessages` Signed-off-by: Jun Kimura --- relay/prover.go | 21 +++++++--- relay/prover_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 6 deletions(-) diff --git a/relay/prover.go b/relay/prover.go index fb2ede0..7f1de0a 100644 --- a/relay/prover.go +++ b/relay/prover.go @@ -193,7 +193,7 @@ func (pr *Prover) SetupHeadersForUpdate(dstChain core.FinalityAwareChain, latest // NOTE: assume that the messages length and the signatures length are the same if pr.config.MessageAggregation { pr.getLogger().Info("aggregate messages", "num_messages", len(messages)) - update, err := pr.aggregateMessages(messages, signatures, pr.activeEnclaveKey.EnclaveKeyAddress) + update, err := aggregateMessages(pr.getLogger(), pr.config.GetMessageAggregationBatchSize(), pr.lcpServiceClient.AggregateMessages, messages, signatures, pr.activeEnclaveKey.EnclaveKeyAddress) if err != nil { return nil, err } @@ -210,14 +210,23 @@ func (pr *Prover) SetupHeadersForUpdate(dstChain core.FinalityAwareChain, latest return updates, nil } -func (pr *Prover) aggregateMessages(messages [][]byte, signatures [][]byte, signer []byte) (*lcptypes.UpdateClientMessage, error) { +type MessageAggregator func(ctx context.Context, in *elc.MsgAggregateMessages, opts ...grpc.CallOption) (*elc.MsgAggregateMessagesResponse, error) + +func aggregateMessages( + logger *log.RelayLogger, + batchSize uint64, + messageAggregator MessageAggregator, + messages [][]byte, + signatures [][]byte, + signer []byte, +) (*lcptypes.UpdateClientMessage, error) { if len(messages) == 0 { return nil, fmt.Errorf("aggregateMessages: messages must not be empty") } else if len(messages) != len(signatures) { return nil, fmt.Errorf("aggregateMessages: messages and signatures must have the same length: messages=%v signatures=%v", len(messages), len(signatures)) } for { - batches, err := splitIntoMultiBatch(messages, signatures, signer, pr.config.GetMessageAggregationBatchSize()) + batches, err := splitIntoMultiBatch(messages, signatures, signer, batchSize) if err != nil { return nil, err } @@ -235,7 +244,7 @@ func (pr *Prover) aggregateMessages(messages [][]byte, signatures [][]byte, sign Messages: batches[0].Messages, Signatures: batches[0].Signatures, } - resp, err := pr.lcpServiceClient.AggregateMessages(context.TODO(), &m) + resp, err := messageAggregator(context.TODO(), &m) if err != nil { return nil, fmt.Errorf("failed to aggregate messages: msg=%v %w", m, err) } @@ -247,7 +256,7 @@ func (pr *Prover) aggregateMessages(messages [][]byte, signatures [][]byte, sign } else if n == 0 { return nil, fmt.Errorf("unexpected error: batches must not be empty") } else { - pr.getLogger().Info("aggregateMessages", "num_batches", n) + logger.Info("aggregateMessages", "num_batches", n) } messages = nil signatures = nil @@ -257,7 +266,7 @@ func (pr *Prover) aggregateMessages(messages [][]byte, signatures [][]byte, sign Messages: b.Messages, Signatures: b.Signatures, } - resp, err := pr.lcpServiceClient.AggregateMessages(context.TODO(), &m) + resp, err := messageAggregator(context.TODO(), &m) if err != nil { return nil, fmt.Errorf("failed to aggregate messages: i=%v msg=%v %w", i, m, err) } diff --git a/relay/prover_test.go b/relay/prover_test.go index 31e159e..7d54099 100644 --- a/relay/prover_test.go +++ b/relay/prover_test.go @@ -1,10 +1,14 @@ package relay import ( + "context" "fmt" "testing" + "github.com/datachainlab/lcp-go/relay/elc" + "github.com/hyperledger-labs/yui-relayer/log" "github.com/stretchr/testify/require" + "google.golang.org/grpc" ) func TestSplitIntoMultiBatch(t *testing.T) { @@ -94,3 +98,95 @@ func TestSplitIntoMultiBatch(t *testing.T) { }) } } + +func TestAggregateMessages(t *testing.T) { + var M = func(n uint64) []byte { + return []byte(fmt.Sprintf("message-%d", n)) + } + var S = func(n uint64) []byte { + return []byte(fmt.Sprintf("signature-%d", n)) + } + var Signer = func(n uint64) []byte { + return []byte(fmt.Sprintf("signer-%d", n)) + } + + err := log.InitLogger("DEBUG", "text", "stdout") + require.NoError(t, err) + logger := log.GetLogger() + + var cases = []struct { + Messages [][]byte + Signatures [][]byte + Signer []byte + BatchSize uint64 + Error bool + }{ + // Messages.len = 0 is invalid + { + Messages: [][]byte{}, + Signatures: [][]byte{}, + Signer: Signer(0), + BatchSize: 2, + Error: true, + }, + { + Messages: [][]byte{M(0)}, + Signatures: [][]byte{S(0)}, + Signer: Signer(0), + BatchSize: 2, + Error: false, + }, + { + Messages: [][]byte{M(0), M(1)}, + Signatures: [][]byte{S(0), S(1)}, + Signer: Signer(0), + BatchSize: 2, + Error: false, + }, + // BatchSize = 1 is invalid + { + Messages: [][]byte{M(0), M(1)}, + Signatures: [][]byte{S(0), S(1)}, + Signer: Signer(0), + BatchSize: 1, + Error: true, + }, + { + Messages: [][]byte{M(0), M(1), M(2)}, + Signatures: [][]byte{S(0), S(1), S(2)}, + Signer: Signer(0), + BatchSize: 2, + Error: false, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { + require := require.New(t) + _, err := aggregateMessages(logger, c.BatchSize, mockMessageAggregator, c.Messages, c.Signatures, c.Signer) + if c.Error { + require.Error(err) + return + } else { + require.NoError(err) + } + }) + } +} + +func mockMessageAggregator(_ context.Context, in *elc.MsgAggregateMessages, _ ...grpc.CallOption) (*elc.MsgAggregateMessagesResponse, error) { + var res elc.MsgAggregateMessagesResponse + if len(in.Messages) != len(in.Signatures) { + return nil, fmt.Errorf("messages and signatures must have the same length") + } + if len(in.Messages) == 0 { + return nil, fmt.Errorf("messages.len = 0 is invalid") + } else if len(in.Messages) == 1 { + return nil, fmt.Errorf("messages.len = 1 is invalid") + } + for i := 0; i < len(in.Messages); i++ { + res.Message = append(res.Message, in.Messages[i]...) + res.Signature = append(res.Signature, in.Signatures[i]...) + } + return &res, nil +} From 1bf2f75fe4f4024a53f54a23eb943c2c8e0fafb5 Mon Sep 17 00:00:00 2001 From: Jun Kimura Date: Tue, 16 Jul 2024 11:59:08 +0900 Subject: [PATCH 3/3] fix a bug when only a single message contained in batch Signed-off-by: Jun Kimura --- relay/prover.go | 26 ++++++++++++++--------- relay/prover_test.go | 50 ++++++++++++++++++++++++++++++++------------ 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/relay/prover.go b/relay/prover.go index 7f1de0a..51ef11c 100644 --- a/relay/prover.go +++ b/relay/prover.go @@ -261,17 +261,23 @@ func aggregateMessages( messages = nil signatures = nil for i, b := range batches { - m := elc.MsgAggregateMessages{ - Signer: b.Signer, - Messages: b.Messages, - Signatures: b.Signatures, - } - resp, err := messageAggregator(context.TODO(), &m) - if err != nil { - return nil, fmt.Errorf("failed to aggregate messages: i=%v msg=%v %w", i, m, err) + logger.Info("aggregateMessages", "batch_index", i, "num_messages", len(b.Messages)) + if len(b.Messages) == 1 { + messages = append(messages, b.Messages[0]) + signatures = append(signatures, b.Signatures[0]) + } else { + m := elc.MsgAggregateMessages{ + Signer: b.Signer, + Messages: b.Messages, + Signatures: b.Signatures, + } + resp, err := messageAggregator(context.TODO(), &m) + if err != nil { + return nil, fmt.Errorf("failed to aggregate messages: batch_index=%v msg=%v %w", i, m, err) + } + messages = append(messages, resp.Message) + signatures = append(signatures, resp.Signature) } - messages = append(messages, resp.Message) - signatures = append(signatures, resp.Signature) } } } diff --git a/relay/prover_test.go b/relay/prover_test.go index 7d54099..eb4ffac 100644 --- a/relay/prover_test.go +++ b/relay/prover_test.go @@ -12,14 +12,14 @@ import ( ) func TestSplitIntoMultiBatch(t *testing.T) { - var M = func(n uint64) []byte { - return []byte(fmt.Sprintf("message-%d", n)) + var M = func(n uint8) []byte { + return []byte(fmt.Sprintf("message-%03d", n)) } - var S = func(n uint64) []byte { - return []byte(fmt.Sprintf("signature-%d", n)) + var S = func(n uint8) []byte { + return []byte(fmt.Sprintf("signature-%03d", n)) } - var Signer = func(n uint64) []byte { - return []byte(fmt.Sprintf("signer-%d", n)) + var Signer = func(n uint8) []byte { + return []byte(fmt.Sprintf("signer-%03d", n)) } var cases = []struct { Messages [][]byte @@ -100,14 +100,14 @@ func TestSplitIntoMultiBatch(t *testing.T) { } func TestAggregateMessages(t *testing.T) { - var M = func(n uint64) []byte { - return []byte(fmt.Sprintf("message-%d", n)) + var M = func(n uint8) []byte { + return []byte(fmt.Sprintf("message-%03d", n)) } - var S = func(n uint64) []byte { - return []byte(fmt.Sprintf("signature-%d", n)) + var S = func(n uint8) []byte { + return []byte(fmt.Sprintf("signature-%03d", n)) } - var Signer = func(n uint64) []byte { - return []byte(fmt.Sprintf("signer-%d", n)) + var Signer = func(n uint8) []byte { + return []byte(fmt.Sprintf("signer-%03d", n)) } err := log.InitLogger("DEBUG", "text", "stdout") @@ -158,22 +158,46 @@ func TestAggregateMessages(t *testing.T) { BatchSize: 2, Error: false, }, + { + Messages: [][]byte{M(0), M(1), M(2), M(3)}, + Signatures: [][]byte{S(0), S(1), S(2), S(3)}, + Signer: Signer(0), + BatchSize: 2, + Error: false, + }, + { + Messages: [][]byte{M(0), M(1), M(2), M(3), M(4)}, + Signatures: [][]byte{S(0), S(1), S(2), S(3), S(4)}, + Signer: Signer(0), + BatchSize: 2, + Error: false, + }, } for i, c := range cases { t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { require := require.New(t) - _, err := aggregateMessages(logger, c.BatchSize, mockMessageAggregator, c.Messages, c.Signatures, c.Signer) + res, err := aggregateMessages(logger, c.BatchSize, mockMessageAggregator, c.Messages, c.Signatures, c.Signer) if c.Error { require.Error(err) return } else { require.NoError(err) } + require.Equal(res.ProxyMessage, concatBytes(c.Messages)) + require.Equal(res.Signatures[0], concatBytes(c.Signatures)) }) } } +func concatBytes(bzs [][]byte) []byte { + var res []byte + for _, b := range bzs { + res = append(res, b...) + } + return res +} + func mockMessageAggregator(_ context.Context, in *elc.MsgAggregateMessages, _ ...grpc.CallOption) (*elc.MsgAggregateMessagesResponse, error) { var res elc.MsgAggregateMessagesResponse if len(in.Messages) != len(in.Signatures) {