Skip to content

Commit

Permalink
Connect codec interface tests and refactor codec to interface like EV…
Browse files Browse the repository at this point in the history
…M one
  • Loading branch information
ilija42 committed Dec 12, 2024
1 parent a9ea8fd commit 28a8b69
Show file tree
Hide file tree
Showing 24 changed files with 1,901 additions and 337 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ require (
go.uber.org/zap v1.27.0
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
golang.org/x/sync v0.8.0
golang.org/x/text v0.18.0
)

require (
Expand Down Expand Up @@ -124,6 +123,7 @@ require (
golang.org/x/net v0.29.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/term v0.24.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
Expand Down
90 changes: 90 additions & 0 deletions pkg/solana/codec/codec_entry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package codec

import (
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/codec"
commonencodings "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
)

type Entry interface {
Encode(value any, into []byte) ([]byte, error)
Decode(encoded []byte) (any, []byte, error)
GetCodecType() commonencodings.TypeCodec
GetType() reflect.Type

Modifier() codec.Modifier
}

// TODO this can also be an event entry, but anchor-go defines events differently, maybe just have a separate struct and method that satisfy entry interface for events.
func NewEntry(idlAccount IdlTypeDef, idlTypes IdlTypeDefSlice, includeDiscriminator bool, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) {
refs := &codecRefs{
builder: builder,
codecs: make(map[string]commonencodings.TypeCodec),
typeDefs: idlTypes,
dependencies: make(map[string][]string),
}

if mod == nil {
mod = codec.MultiModifier{}
}

_, accCodec, err := createCodecType(idlAccount, refs, false)
if err != nil {
return nil, err
}

entry := &codecEntry{name: idlAccount.Name, includeDiscriminator: includeDiscriminator, codecType: accCodec, typ: accCodec.GetType(), mod: mod}
if entry.includeDiscriminator {
entry.Discriminator = commonencodings.NamedTypeCodec{Name: "Discriminator" + idlAccount.Name, Codec: NewDiscriminator(idlAccount.Name)}
}

return entry, nil
}

type codecEntry struct {
name string
includeDiscriminator bool
Discriminator commonencodings.NamedTypeCodec
typ reflect.Type
codecType commonencodings.TypeCodec
mod codec.Modifier
}

func (entry *codecEntry) GetType() reflect.Type {
return entry.typ
}

func (entry *codecEntry) GetCodecType() commonencodings.TypeCodec {
return entry.codecType
}

func (entry *codecEntry) Encode(value any, into []byte) ([]byte, error) {
encodedVal, err := entry.codecType.Encode(value, into)
if err != nil {
return nil, err
}

if entry.includeDiscriminator {
var byt []byte
disc := NewDiscriminator(entry.name)
encodedDisc, err := disc.Encode(&disc.hashPrefix, byt)
if err != nil {
return nil, err
}
return append(encodedDisc, encodedVal...), nil
}

return encodedVal, nil
}

func (entry *codecEntry) Decode(encoded []byte) (any, []byte, error) {
if entry.includeDiscriminator {
encoded = encoded[discriminatorLength:]
}
return entry.codecType.Decode(encoded)
}

func (entry *codecEntry) Modifier() codec.Modifier {
return entry.mod
}
135 changes: 135 additions & 0 deletions pkg/solana/codec/codec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package codec_test

import (
"bytes"
_ "embed"
"slices"
"testing"

bin "github.com/gagliardetto/binary"
"github.com/gagliardetto/solana-go"
ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types"
"github.com/stretchr/testify/require"

commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
clcommontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
. "github.com/smartcontractkit/chainlink-common/pkg/types/interfacetests" //nolint common practice to import test mods with .
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils/test_item_type"
)

const anyExtraValue = 3

func TestCodec(t *testing.T) {
tester := &codecInterfaceTester{}
RunCodecInterfaceTests(t, tester)
//RunCodecInterfaceTests(t, looptestutils.WrapCodecTesterForLoop(tester))
}

type codecInterfaceTester struct {
TestSelectionSupport
}

func (it *codecInterfaceTester) Setup(_ *testing.T) {}

func (it *codecInterfaceTester) GetAccountBytes(i int) []byte {
pk, _ := solana.NewRandomPrivateKey()
return pk.PublicKey().Bytes()
}

func (it *codecInterfaceTester) GetAccountString(i int) string {
return solana.PublicKeyFromBytes(it.GetAccountBytes(i)).String()
}

func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte {
if request.TestOn == testutils.TestItemType {
return encodeFieldsOnItem(t, request)
}

return encodeFieldsOnSliceOrArray(t, request)
}

func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report {
buf := new(bytes.Buffer)
if err := testutils.EncodeRequestToTestStruct(request).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
return buf.Bytes()
}

func encodeFieldsOnSliceOrArray(_ *testing.T, request *EncodeRequest) []byte {
args := make([]any, 1)
switch request.TestOn {
case testutils.TestItemArray1Type:
args[0] = [1]test_item_type.TestStruct{testutils.ToInternalType(request.TestStructs[0])}
case testutils.TestItemArray2Type:
args[0] = [2]test_item_type.TestStruct{testutils.ToInternalType(request.TestStructs[0]), testutils.ToInternalType(request.TestStructs[1])}
default:
tmp := make([]test_item_type.TestStruct, len(request.TestStructs))
for i, ts := range request.TestStructs {
tmp[i] = testutils.ToInternalType(ts)
}
args[0] = tmp
}

return []byte{}
}

func (it *codecInterfaceTester) GetCodec(t *testing.T) clcommontypes.Codec {
codecConfig := codec.Config{Configs: map[string]codec.ChainConfig{}}
testStruct := CreateTestStruct[*testing.T](0, it)
for k, v := range testutils.CodecDefs {
entry := codecConfig.Configs[k]
entry.IDL = v

if slices.Contains([]string{testutils.TestItemSliceType, testutils.TestItemArray1Type, testutils.TestItemArray2Type}, k) {
entry.ModifierConfigs = commoncodec.ModifiersConfig{
&commoncodec.RenameModifierConfig{Fields: map[string]string{"Items.NestedDynamicStruct.Inner.IntVal": "I"}},
&commoncodec.RenameModifierConfig{Fields: map[string]string{"Items.NestedStaticStruct.Inner.IntVal": "I"}},
&commoncodec.AddressBytesToStringModifierConfig{
Fields: []string{"Items.AccountStruct.AccountStr"},
Modifier: codec.SolanaAddressModifier{},
},
&commoncodec.WrapperModifierConfig{Fields: map[string]string{"Items.NestedStaticStruct.Inner.IntVal": "I"}},
}
} else if k != testutils.SizeItemType && k != testutils.NilType {
entry.ModifierConfigs = commoncodec.ModifiersConfig{
&commoncodec.RenameModifierConfig{Fields: map[string]string{"NestedDynamicStruct.Inner.IntVal": "I"}},
&commoncodec.RenameModifierConfig{Fields: map[string]string{"NestedStaticStruct.Inner.IntVal": "I"}},
}
}

if slices.Contains([]string{testutils.TestItemType, testutils.TestItemWithConfigExtra}, k) {
addressByteModifier := &commoncodec.AddressBytesToStringModifierConfig{
Fields: []string{"AccountStruct.AccountStr"},
Modifier: codec.SolanaAddressModifier{},
}
entry.ModifierConfigs = append(entry.ModifierConfigs, addressByteModifier)
}

if k == testutils.TestItemWithConfigExtra {
hardCode := &commoncodec.HardCodeModifierConfig{
OnChainValues: map[string]any{
"BigField": testStruct.BigField.String(),
"AccountStruct.Account": solana.PublicKeyFromBytes(testStruct.AccountStruct.Account),
},
OffChainValues: map[string]any{"ExtraField": anyExtraValue},
}
entry.ModifierConfigs = append(entry.ModifierConfigs, hardCode)
}
codecConfig.Configs[k] = entry
}

c, err := codec.NewCodec(codecConfig)
require.NoError(t, err)

return c
}

func (it *codecInterfaceTester) IncludeArrayEncodingSizeEnforcement() bool {
return true
}
func (it *codecInterfaceTester) Name() string {
return "Solana"
}
42 changes: 42 additions & 0 deletions pkg/solana/codec/decoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package codec

import (
"context"
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/codec"
commontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
)

type decoder struct {
Definitions map[string]Entry
}

var _ commontypes.Decoder = &decoder{}

func (m *decoder) Decode(_ context.Context, raw []byte, into any, itemType string) (err error) {
item, ok := m.Definitions[itemType]
if !ok {
return fmt.Errorf("%w: cannot find type %s", commontypes.ErrInvalidType, itemType)
}

val, remaining, err := item.Decode(raw)
if err != nil {
return err
}

if len(remaining) != 0 {
return fmt.Errorf("%w: remaining bytes after decoding %s", commontypes.ErrInvalidEncoding, itemType)
}

return codec.Convert(reflect.ValueOf(val), reflect.ValueOf(into), nil)
}

func (m *decoder) GetMaxDecodingSize(_ context.Context, n int, itemType string) (int, error) {
entry, ok := m.Definitions[itemType]
if !ok {
return 0, fmt.Errorf("%w: nil entry", commontypes.ErrInvalidType)
}
return entry.GetCodecType().Size(n)
}
16 changes: 8 additions & 8 deletions pkg/solana/codec/discriminator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (

const discriminatorLength = 8

func NewDiscriminator(name string) encodings.TypeCodec {
func NewDiscriminator(name string) *Discriminator {
sum := sha256.Sum256([]byte("account:" + name))
return &discriminator{hashPrefix: sum[:discriminatorLength]}
return &Discriminator{hashPrefix: sum[:discriminatorLength]}
}

type discriminator struct {
type Discriminator struct {
hashPrefix []byte
}

func (d discriminator) Encode(value any, into []byte) ([]byte, error) {
func (d Discriminator) Encode(value any, into []byte) ([]byte, error) {
if value == nil {
return append(into, d.hashPrefix...), nil
}
Expand All @@ -44,7 +44,7 @@ func (d discriminator) Encode(value any, into []byte) ([]byte, error) {
return append(into, *raw...), nil
}

func (d discriminator) Decode(encoded []byte) (any, []byte, error) {
func (d Discriminator) Decode(encoded []byte) (any, []byte, error) {
raw, remaining, err := encodings.SafeDecode(encoded, discriminatorLength, func(raw []byte) []byte { return raw })
if err != nil {
return nil, nil, err
Expand All @@ -57,15 +57,15 @@ func (d discriminator) Decode(encoded []byte) (any, []byte, error) {
return &raw, remaining, nil
}

func (d discriminator) GetType() reflect.Type {
func (d Discriminator) GetType() reflect.Type {
// Pointer type so that nil can inject values and so that the NamedCodec won't wrap with no-nil pointer.
return reflect.TypeOf(&[]byte{})
}

func (d discriminator) Size(_ int) (int, error) {
func (d Discriminator) Size(_ int) (int, error) {
return discriminatorLength, nil
}

func (d discriminator) FixedSize() (int, error) {
func (d Discriminator) FixedSize() (int, error) {
return discriminatorLength, nil
}
51 changes: 51 additions & 0 deletions pkg/solana/codec/encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package codec

import (
"context"
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/codec"
commontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
)

type encoder struct {
Definitions map[string]Entry
}

var _ commontypes.Encoder = &encoder{}

func (e *encoder) Encode(_ context.Context, item any, itemType string) (res []byte, err error) {
info, ok := e.Definitions[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find definition for %s", commontypes.ErrInvalidType, itemType)
}

if item != nil {
rItem := reflect.ValueOf(item)
myType := info.GetCodecType().GetType()
if rItem.Kind() == reflect.Pointer && myType.Kind() != reflect.Pointer {
rItem = reflect.Indirect(rItem)
}

if !rItem.IsZero() && rItem.Type() != myType {
tmp := reflect.New(myType)
if err := codec.Convert(rItem, tmp, nil); err != nil {
return nil, err
}
item = tmp.Elem().Interface()
} else {
item = rItem.Interface()
}
}

return info.Encode(item, nil)
}

func (e *encoder) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) {
entry, ok := e.Definitions[itemType]
if !ok {
return 0, fmt.Errorf("%w: nil entry", commontypes.ErrInvalidType)
}
return entry.GetCodecType().Size(n)
}
Loading

0 comments on commit 28a8b69

Please sign in to comment.