diff --git a/relay/prover.go b/relay/prover.go index 7f1de0a..51ef11c 100644 --- a/relay/prover.go +++ b/relay/prover.go @@ -261,17 +261,23 @@ func aggregateMessages( messages = nil signatures = nil for i, b := range batches { - m := elc.MsgAggregateMessages{ - Signer: b.Signer, - Messages: b.Messages, - Signatures: b.Signatures, - } - resp, err := messageAggregator(context.TODO(), &m) - if err != nil { - return nil, fmt.Errorf("failed to aggregate messages: i=%v msg=%v %w", i, m, err) + logger.Info("aggregateMessages", "batch_index", i, "num_messages", len(b.Messages)) + if len(b.Messages) == 1 { + messages = append(messages, b.Messages[0]) + signatures = append(signatures, b.Signatures[0]) + } else { + m := elc.MsgAggregateMessages{ + Signer: b.Signer, + Messages: b.Messages, + Signatures: b.Signatures, + } + resp, err := messageAggregator(context.TODO(), &m) + if err != nil { + return nil, fmt.Errorf("failed to aggregate messages: batch_index=%v msg=%v %w", i, m, err) + } + messages = append(messages, resp.Message) + signatures = append(signatures, resp.Signature) } - messages = append(messages, resp.Message) - signatures = append(signatures, resp.Signature) } } } diff --git a/relay/prover_test.go b/relay/prover_test.go index 7d54099..eb4ffac 100644 --- a/relay/prover_test.go +++ b/relay/prover_test.go @@ -12,14 +12,14 @@ import ( ) func TestSplitIntoMultiBatch(t *testing.T) { - var M = func(n uint64) []byte { - return []byte(fmt.Sprintf("message-%d", n)) + var M = func(n uint8) []byte { + return []byte(fmt.Sprintf("message-%03d", n)) } - var S = func(n uint64) []byte { - return []byte(fmt.Sprintf("signature-%d", n)) + var S = func(n uint8) []byte { + return []byte(fmt.Sprintf("signature-%03d", n)) } - var Signer = func(n uint64) []byte { - return []byte(fmt.Sprintf("signer-%d", n)) + var Signer = func(n uint8) []byte { + return []byte(fmt.Sprintf("signer-%03d", n)) } var cases = []struct { Messages [][]byte @@ -100,14 +100,14 @@ func TestSplitIntoMultiBatch(t *testing.T) { } func TestAggregateMessages(t *testing.T) { - var M = func(n uint64) []byte { - return []byte(fmt.Sprintf("message-%d", n)) + var M = func(n uint8) []byte { + return []byte(fmt.Sprintf("message-%03d", n)) } - var S = func(n uint64) []byte { - return []byte(fmt.Sprintf("signature-%d", n)) + var S = func(n uint8) []byte { + return []byte(fmt.Sprintf("signature-%03d", n)) } - var Signer = func(n uint64) []byte { - return []byte(fmt.Sprintf("signer-%d", n)) + var Signer = func(n uint8) []byte { + return []byte(fmt.Sprintf("signer-%03d", n)) } err := log.InitLogger("DEBUG", "text", "stdout") @@ -158,22 +158,46 @@ func TestAggregateMessages(t *testing.T) { BatchSize: 2, Error: false, }, + { + Messages: [][]byte{M(0), M(1), M(2), M(3)}, + Signatures: [][]byte{S(0), S(1), S(2), S(3)}, + Signer: Signer(0), + BatchSize: 2, + Error: false, + }, + { + Messages: [][]byte{M(0), M(1), M(2), M(3), M(4)}, + Signatures: [][]byte{S(0), S(1), S(2), S(3), S(4)}, + Signer: Signer(0), + BatchSize: 2, + Error: false, + }, } for i, c := range cases { t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { require := require.New(t) - _, err := aggregateMessages(logger, c.BatchSize, mockMessageAggregator, c.Messages, c.Signatures, c.Signer) + res, err := aggregateMessages(logger, c.BatchSize, mockMessageAggregator, c.Messages, c.Signatures, c.Signer) if c.Error { require.Error(err) return } else { require.NoError(err) } + require.Equal(res.ProxyMessage, concatBytes(c.Messages)) + require.Equal(res.Signatures[0], concatBytes(c.Signatures)) }) } } +func concatBytes(bzs [][]byte) []byte { + var res []byte + for _, b := range bzs { + res = append(res, b...) + } + return res +} + func mockMessageAggregator(_ context.Context, in *elc.MsgAggregateMessages, _ ...grpc.CallOption) (*elc.MsgAggregateMessagesResponse, error) { var res elc.MsgAggregateMessagesResponse if len(in.Messages) != len(in.Signatures) {