diff --git a/client/proposalmsgs_test.go b/client/proposalmsgs_test.go index 8c1f42d6..6acb2ef5 100644 --- a/client/proposalmsgs_test.go +++ b/client/proposalmsgs_test.go @@ -24,7 +24,6 @@ import ( "perun.network/go-perun/channel/test" "perun.network/go-perun/client" clienttest "perun.network/go-perun/client/test" - _ "perun.network/go-perun/wire/perunio/serializer" // wire serialzer init peruniotest "perun.network/go-perun/wire/perunio/test" pkgtest "polycry.pt/poly-go/test" ) diff --git a/wire/controlmsgs_test.go b/wire/controlmsgs_test.go index 2d8b1ff2..415f1072 100644 --- a/wire/controlmsgs_test.go +++ b/wire/controlmsgs_test.go @@ -17,7 +17,6 @@ package wire_test import ( "testing" - _ "perun.network/go-perun/wire/perunio/serializer" // wire serialzer init peruniotest "perun.network/go-perun/wire/perunio/test" wiretest "perun.network/go-perun/wire/test" ) diff --git a/wire/encode.go b/wire/encode.go index ec724041..eec724d2 100644 --- a/wire/encode.go +++ b/wire/encode.go @@ -24,17 +24,6 @@ import ( "perun.network/go-perun/wire/perunio" ) -var envelopeSerializer EnvelopeSerializer - -// SetEnvelopeSerializer sets the global envelope serializer instance. Must not -// be called directly but through importing the needed backend. -func SetEnvelopeSerializer(e EnvelopeSerializer) { - if envelopeSerializer != nil { - panic("envelope serializer already set") - } - envelopeSerializer = e -} - type ( // Msg is the top-level abstraction for all messages sent between Perun // nodes. @@ -59,18 +48,6 @@ type ( } ) -// EncodeEnvelope serializes the envelope into the writer, using the global -// envelope serialzer instance. -func EncodeEnvelope(w io.Writer, env *Envelope) error { - return envelopeSerializer.Encode(w, env) -} - -// DecodeEnvelope deserializes an envelope from the reader, using the global -// envelope serialzer instance. -func DecodeEnvelope(r io.Reader) (*Envelope, error) { - return envelopeSerializer.Decode(r) -} - // EncodeMsg encodes a message into an io.Writer. It also encodes the message // type whereas the Msg.Encode implementation is assumed not to write the type. func EncodeMsg(msg Msg, w io.Writer) error { diff --git a/wire/net/bus.go b/wire/net/bus.go index c07c11c2..6d89b1ac 100644 --- a/wire/net/bus.go +++ b/wire/net/bus.go @@ -45,14 +45,14 @@ const ( // NewBus creates a new network bus. The dialer and listener are used to // establish new connections internally, while id is this node's identity. -func NewBus(id wire.Account, d Dialer) *Bus { +func NewBus(id wire.Account, d Dialer, s wire.EnvelopeSerializer) *Bus { b := &Bus{ mainRecv: wire.NewReceiver(), recvs: make(map[wallet.AddrKey]wire.Consumer), } onNewEndpoint := func(wire.Address) wire.Consumer { return b.mainRecv } - b.reg = NewEndpointRegistry(id, onNewEndpoint, d) + b.reg = NewEndpointRegistry(id, onNewEndpoint, d, s) go b.dispatchMsgs() return b diff --git a/wire/net/bus_test.go b/wire/net/bus_test.go index 9519d8be..7a5d46c2 100644 --- a/wire/net/bus_test.go +++ b/wire/net/bus_test.go @@ -22,6 +22,7 @@ import ( "perun.network/go-perun/wire" "perun.network/go-perun/wire/net" nettest "perun.network/go-perun/wire/net/test" + perunio "perun.network/go-perun/wire/perunio/serializer" wiretest "perun.network/go-perun/wire/test" ) @@ -32,7 +33,7 @@ func TestBus(t *testing.T) { var hub nettest.ConnHub wiretest.GenericBusTest(t, func(acc wire.Account) wire.Bus { - bus := net.NewBus(acc, hub.NewNetDialer()) + bus := net.NewBus(acc, hub.NewNetDialer(), perunio.Serializer()) hub.OnClose(func() { bus.Close() }) go bus.Listen(hub.NewNetListener(acc.Address())) return bus diff --git a/wire/net/dialer.go b/wire/net/dialer.go index 4373b989..967fc48c 100644 --- a/wire/net/dialer.go +++ b/wire/net/dialer.go @@ -23,13 +23,15 @@ import ( // Dialer is an interface that allows creating a connection to a peer via its // Perun address. The established connections are not authenticated yet. type Dialer interface { - // Dial creates a connection to a peer. - // The passed context is used to abort the dialing process. The returned - // connection might not belong to the requested address. + // Dial creates a connection to a peer. The passed context is used to abort + // the dialing process. The returned connection might not belong to the + // requested address. + // + // `ser` is used for message serialization. // // Dial needs to be reentrant, and concurrent calls to Close() must abort // any ongoing Dial() calls. - Dial(ctx context.Context, addr wire.Address) (Conn, error) + Dial(ctx context.Context, addr wire.Address, ser wire.EnvelopeSerializer) (Conn, error) // Close aborts any ongoing calls to Dial(). // // Close() needs to be reentrant, and repeated calls to Close() need to diff --git a/wire/net/endpoint_internal_test.go b/wire/net/endpoint_internal_test.go index 57237765..f5d3b81a 100644 --- a/wire/net/endpoint_internal_test.go +++ b/wire/net/endpoint_internal_test.go @@ -27,6 +27,7 @@ import ( _ "perun.network/go-perun/backend/sim" // backend init wallettest "perun.network/go-perun/wallet/test" "perun.network/go-perun/wire" + perunio "perun.network/go-perun/wire/perunio/serializer" wiretest "perun.network/go-perun/wire/test" "polycry.pt/poly-go/test" ) @@ -53,7 +54,7 @@ func makeSetup(rng *rand.Rand) *setup { } // Dial simulates creating a connection to a. -func (s *setup) Dial(ctx context.Context, addr wire.Address) (Conn, error) { +func (s *setup) Dial(ctx context.Context, addr wire.Address, _ wire.EnvelopeSerializer) (Conn, error) { s.mutex.RLock() defer s.mutex.RUnlock() @@ -99,7 +100,7 @@ func makeClient(conn Conn, rng *rand.Rand, dialer Dialer) *client { receiver := wire.NewReceiver() registry := NewEndpointRegistry(wallettest.NewRandomAccount(rng), func(wire.Address) wire.Consumer { return receiver - }, dialer) + }, dialer, perunio.Serializer()) return &client{ endpoint: registry.addEndpoint(wallettest.NewRandomAddress(rng), conn, true), diff --git a/wire/net/endpoint_registry.go b/wire/net/endpoint_registry.go index 39937d8d..41facf71 100644 --- a/wire/net/endpoint_registry.go +++ b/wire/net/endpoint_registry.go @@ -67,6 +67,7 @@ type EndpointRegistry struct { id wire.Account // The identity of the node. dialer Dialer // Used for dialing peers. onNewEndpoint func(wire.Address) wire.Consumer // Selects Consumer for new Endpoints' receive loop. + ser wire.EnvelopeSerializer endpoints map[wallet.AddrKey]*fullEndpoint // The list of all of all established Endpoints. dialing map[wallet.AddrKey]*dialingEndpoint @@ -81,11 +82,17 @@ const exchangeAddrsTimeout = 10 * time.Second // NewEndpointRegistry creates a new registry. // The provided callback is used to set up new peer's subscriptions and it is // called before the peer starts receiving messages. -func NewEndpointRegistry(id wire.Account, onNewEndpoint func(wire.Address) wire.Consumer, dialer Dialer) *EndpointRegistry { +func NewEndpointRegistry( + id wire.Account, + onNewEndpoint func(wire.Address) wire.Consumer, + dialer Dialer, + ser wire.EnvelopeSerializer, +) *EndpointRegistry { return &EndpointRegistry{ id: id, onNewEndpoint: onNewEndpoint, dialer: dialer, + ser: ser, endpoints: make(map[wallet.AddrKey]*fullEndpoint), dialing: make(map[wallet.AddrKey]*dialingEndpoint), @@ -138,7 +145,7 @@ func (r *EndpointRegistry) Listen(listener Listener) { // Start listener and accept all incoming peer connections, writing them to // the registry. for { - conn, err := listener.Accept() + conn, err := listener.Accept(r.ser) if err != nil { r.Log().Debugf("EndpointRegistry.Listen: Accept() loop: %v", err) return @@ -240,7 +247,7 @@ func (r *EndpointRegistry) authenticatedDial( close(de.created) }() - conn, err := r.dialer.Dial(ctx, addr) + conn, err := r.dialer.Dial(ctx, addr, r.ser) if err != nil { return nil, errors.WithMessage(err, "failed to dial") } diff --git a/wire/net/endpoint_registry_external_test.go b/wire/net/endpoint_registry_external_test.go index 42350c0a..057bc8c9 100644 --- a/wire/net/endpoint_registry_external_test.go +++ b/wire/net/endpoint_registry_external_test.go @@ -26,6 +26,7 @@ import ( "perun.network/go-perun/wire" "perun.network/go-perun/wire/net" nettest "perun.network/go-perun/wire/net/test" + perunio "perun.network/go-perun/wire/perunio/serializer" ctxtest "polycry.pt/poly-go/context/test" "polycry.pt/poly-go/sync" "polycry.pt/poly-go/test" @@ -43,8 +44,8 @@ func TestEndpointRegistry_Get_Pair(t *testing.T) { var hub nettest.ConnHub dialerID := wallettest.NewRandomAccount(rng) listenerID := wallettest.NewRandomAccount(rng) - dialerReg := net.NewEndpointRegistry(dialerID, nilConsumer, hub.NewNetDialer()) - listenerReg := net.NewEndpointRegistry(listenerID, nilConsumer, nil) + dialerReg := net.NewEndpointRegistry(dialerID, nilConsumer, hub.NewNetDialer(), perunio.Serializer()) + listenerReg := net.NewEndpointRegistry(listenerID, nilConsumer, nil, perunio.Serializer()) listener := hub.NewNetListener(listenerID.Address()) done := make(chan struct{}) @@ -88,8 +89,8 @@ func TestEndpointRegistry_Get_Multiple(t *testing.T) { t.Logf("subscribing %s\n", addr) return nil } - dialerReg := net.NewEndpointRegistry(dialerID, logPeer, dialer) - listenerReg := net.NewEndpointRegistry(listenerID, logPeer, nil) + dialerReg := net.NewEndpointRegistry(dialerID, logPeer, dialer, perunio.Serializer()) + listenerReg := net.NewEndpointRegistry(listenerID, logPeer, nil, perunio.Serializer()) listener := hub.NewNetListener(listenerID.Address()) done := make(chan struct{}) diff --git a/wire/net/endpoint_registry_internal_test.go b/wire/net/endpoint_registry_internal_test.go index d412209c..4a66a032 100644 --- a/wire/net/endpoint_registry_internal_test.go +++ b/wire/net/endpoint_registry_internal_test.go @@ -28,6 +28,7 @@ import ( "perun.network/go-perun/wallet" wallettest "perun.network/go-perun/wallet/test" "perun.network/go-perun/wire" + perunio "perun.network/go-perun/wire/perunio/serializer" wiretest "perun.network/go-perun/wire/test" ctxtest "polycry.pt/poly-go/context/test" "polycry.pt/poly-go/sync/atomic" @@ -52,7 +53,7 @@ func (d *mockDialer) Close() error { return nil } -func (d *mockDialer) Dial(ctx context.Context, addr wire.Address) (Conn, error) { +func (d *mockDialer) Dial(ctx context.Context, addr wire.Address, _ wire.EnvelopeSerializer) (Conn, error) { d.mutex.Lock() defer d.mutex.Unlock() @@ -85,8 +86,8 @@ type mockListener struct { dialer mockDialer } -func (l *mockListener) Accept() (Conn, error) { - return l.dialer.Dial(context.Background(), nil) +func (l *mockListener) Accept(ser wire.EnvelopeSerializer) (Conn, error) { + return l.dialer.Dial(context.Background(), nil, ser) } func (l *mockListener) Close() error { @@ -123,7 +124,7 @@ func TestRegistry_Get(t *testing.T) { t.Parallel() dialer := newMockDialer() - r := NewEndpointRegistry(id, nilConsumer, dialer) + r := NewEndpointRegistry(id, nilConsumer, dialer, perunio.Serializer()) existing := newEndpoint(peerAddr, newMockConn()) r.endpoints[wallet.Key(peerAddr)] = newFullEndpoint(existing) @@ -138,7 +139,7 @@ func TestRegistry_Get(t *testing.T) { t.Parallel() dialer := newMockDialer() - r := NewEndpointRegistry(id, nilConsumer, dialer) + r := NewEndpointRegistry(id, nilConsumer, dialer, perunio.Serializer()) dialer.Close() ctxtest.AssertTerminates(t, timeout, func() { @@ -154,7 +155,7 @@ func TestRegistry_Get(t *testing.T) { t.Parallel() dialer := newMockDialer() - r := NewEndpointRegistry(id, nilConsumer, dialer) + r := NewEndpointRegistry(id, nilConsumer, dialer, perunio.Serializer()) ct := test.NewConcurrent(t) a, b := newPipeConnPair() @@ -182,7 +183,7 @@ func TestRegistry_authenticatedDial(t *testing.T) { rng := test.Prng(t) id := wallettest.NewRandomAccount(rng) d := &mockDialer{dial: make(chan Conn)} - r := NewEndpointRegistry(id, nilConsumer, d) + r := NewEndpointRegistry(id, nilConsumer, d, perunio.Serializer()) remoteID := wallettest.NewRandomAccount(rng) remoteAddr := remoteID.Address() @@ -267,7 +268,7 @@ func TestRegistry_setupConn(t *testing.T) { t.Run("ExchangeAddrs fail", func(t *testing.T) { d := &mockDialer{dial: make(chan Conn)} - r := NewEndpointRegistry(id, nilConsumer, d) + r := NewEndpointRegistry(id, nilConsumer, d, perunio.Serializer()) a, b := newPipeConnPair() go func() { err := b.Send(&wire.Envelope{ @@ -286,7 +287,7 @@ func TestRegistry_setupConn(t *testing.T) { t.Run("ExchangeAddrs success (peer already exists)", func(t *testing.T) { d := &mockDialer{dial: make(chan Conn)} - r := NewEndpointRegistry(id, nilConsumer, d) + r := NewEndpointRegistry(id, nilConsumer, d, perunio.Serializer()) a, b := newPipeConnPair() go func() { err := ExchangeAddrsActive(context.Background(), remoteID, id.Address(), b) @@ -303,7 +304,7 @@ func TestRegistry_setupConn(t *testing.T) { t.Run("ExchangeAddrs success (peer did not exist)", func(t *testing.T) { d := &mockDialer{dial: make(chan Conn)} - r := NewEndpointRegistry(id, nilConsumer, d) + r := NewEndpointRegistry(id, nilConsumer, d, perunio.Serializer()) a, b := newPipeConnPair() go func() { err := ExchangeAddrsActive(context.Background(), remoteID, id.Address(), b) @@ -331,7 +332,7 @@ func TestRegistry_Listen(t *testing.T) { d := newMockDialer() l := newMockListener() - r := NewEndpointRegistry(id, nilConsumer, d) + r := NewEndpointRegistry(id, nilConsumer, d, perunio.Serializer()) go func() { // Listen() will only terminate if the listener is closed. @@ -365,7 +366,12 @@ func TestRegistry_addEndpoint_Subscribe(t *testing.T) { t.Parallel() rng := test.Prng(t) called := false - r := NewEndpointRegistry(wallettest.NewRandomAccount(rng), func(wire.Address) wire.Consumer { called = true; return nil }, nil) + r := NewEndpointRegistry( + wallettest.NewRandomAccount(rng), + func(wire.Address) wire.Consumer { called = true; return nil }, + nil, + perunio.Serializer(), + ) assert.False(t, called, "onNewEndpoint must not have been called yet") r.addEndpoint(wallettest.NewRandomAddress(rng), newMockConn(), false) @@ -378,7 +384,12 @@ func TestRegistry_Close(t *testing.T) { rng := test.Prng(t) t.Run("double close error", func(t *testing.T) { - r := NewEndpointRegistry(wallettest.NewRandomAccount(rng), nilConsumer, nil) + r := NewEndpointRegistry( + wallettest.NewRandomAccount(rng), + nilConsumer, + nil, + perunio.Serializer(), + ) r.Close() assert.Error(t, r.Close()) }) @@ -386,7 +397,12 @@ func TestRegistry_Close(t *testing.T) { t.Run("dialer close error", func(t *testing.T) { d := &mockDialer{dial: make(chan Conn)} d.Close() - r := NewEndpointRegistry(wallettest.NewRandomAccount(rng), nilConsumer, d) + r := NewEndpointRegistry( + wallettest.NewRandomAccount(rng), + nilConsumer, + d, + perunio.Serializer(), + ) assert.Error(t, r.Close()) }) @@ -395,5 +411,6 @@ func TestRegistry_Close(t *testing.T) { // newPipeConnPair creates endpoints that are connected via pipes. func newPipeConnPair() (a Conn, b Conn) { c0, c1 := net.Pipe() - return NewIoConn(c0), NewIoConn(c1) + ser := perunio.Serializer() + return NewIoConn(c0, ser), NewIoConn(c1, ser) } diff --git a/wire/net/exchange_addr_internal_test.go b/wire/net/exchange_addr_internal_test.go index a8a8fd8c..78995c2e 100644 --- a/wire/net/exchange_addr_internal_test.go +++ b/wire/net/exchange_addr_internal_test.go @@ -23,7 +23,6 @@ import ( wallettest "perun.network/go-perun/wallet/test" "perun.network/go-perun/wire" - _ "perun.network/go-perun/wire/protobuf" // wire serialzer init wiretest "perun.network/go-perun/wire/test" ctxtest "polycry.pt/poly-go/context/test" "polycry.pt/poly-go/test" diff --git a/wire/net/ioconn.go b/wire/net/ioconn.go index 83fc87ea..f55a60f0 100644 --- a/wire/net/ioconn.go +++ b/wire/net/ioconn.go @@ -27,19 +27,21 @@ var _ Conn = (*ioConn)(nil) // ioConn is a connection that communicates its messages over an io stream. type ioConn struct { - closed atomic.Bool - conn io.ReadWriteCloser + closed atomic.Bool + conn io.ReadWriteCloser + serializer wire.EnvelopeSerializer } // NewIoConn creates a peer message connection from an io stream. -func NewIoConn(conn io.ReadWriteCloser) Conn { +func NewIoConn(conn io.ReadWriteCloser, serializer wire.EnvelopeSerializer) Conn { return &ioConn{ - conn: conn, + conn: conn, + serializer: serializer, } } func (c *ioConn) Send(e *wire.Envelope) error { - if err := wire.EncodeEnvelope(c.conn, e); err != nil { + if err := c.serializer.Encode(c.conn, e); err != nil { c.conn.Close() return err } @@ -47,7 +49,7 @@ func (c *ioConn) Send(e *wire.Envelope) error { } func (c *ioConn) Recv() (*wire.Envelope, error) { - e, err := wire.DecodeEnvelope(c.conn) + e, err := c.serializer.Decode(c.conn) if err != nil { c.conn.Close() return nil, err diff --git a/wire/net/listener.go b/wire/net/listener.go index 4df47d5d..5050e222 100644 --- a/wire/net/listener.go +++ b/wire/net/listener.go @@ -14,17 +14,21 @@ package net +import "perun.network/go-perun/wire" + // Listener is an interface that allows listening for peer incoming connections. // The accepted connections still need to be authenticated. type Listener interface { // Accept accepts an incoming connection, which still has to perform // authentication to exchange addresses. // + // `ser` specifies the message serialization format. + // // This function does not have to be reentrant, but concurrent calls to // Close() must abort ongoing Accept() calls. Accept() must only return // errors after Close() was called or an unrecoverable fatal error occurred // in the Listener and it is closed. - Accept() (Conn, error) + Accept(ser wire.EnvelopeSerializer) (Conn, error) // Close closes the listener and aborts any ongoing Accept() call. Close() error } diff --git a/wire/net/simple/dialer.go b/wire/net/simple/dialer.go index c7f18b39..3f5e617b 100644 --- a/wire/net/simple/dialer.go +++ b/wire/net/simple/dialer.go @@ -44,6 +44,7 @@ var _ wirenet.Dialer = (*Dialer)(nil) // attempts. Leaving the timeout as 0 will result in no timeouts. Standard OS // timeouts may still apply even when no timeout is selected. The network string // controls the type of connection that the dialer can dial. +// `serializer` defines the message encoding. func NewNetDialer(network string, defaultTimeout time.Duration) *Dialer { return &Dialer{ peers: make(map[wallet.AddrKey]string), @@ -71,7 +72,7 @@ func (d *Dialer) host(key wallet.AddrKey) (string, bool) { } // Dial implements Dialer.Dial(). -func (d *Dialer) Dial(ctx context.Context, addr wire.Address) (wirenet.Conn, error) { +func (d *Dialer) Dial(ctx context.Context, addr wire.Address, ser wire.EnvelopeSerializer) (wirenet.Conn, error) { done := make(chan struct{}) defer close(done) @@ -97,7 +98,7 @@ func (d *Dialer) Dial(ctx context.Context, addr wire.Address) (wirenet.Conn, err return nil, errors.Wrap(err, "failed to dial peer") } - return wirenet.NewIoConn(conn), nil + return wirenet.NewIoConn(conn, ser), nil } // Register registers a network address for a peer address. diff --git a/wire/net/simple/dialer_internal_test.go b/wire/net/simple/dialer_internal_test.go index 3cf409c6..bc34bf44 100644 --- a/wire/net/simple/dialer_internal_test.go +++ b/wire/net/simple/dialer_internal_test.go @@ -25,7 +25,7 @@ import ( simwallet "perun.network/go-perun/backend/sim/wallet" "perun.network/go-perun/wallet" "perun.network/go-perun/wire" - _ "perun.network/go-perun/wire/protobuf" // wire serialzer init + perunio "perun.network/go-perun/wire/perunio/serializer" ctxtest "polycry.pt/poly-go/context/test" "polycry.pt/poly-go/test" ) @@ -66,6 +66,7 @@ func TestDialer_Dial(t *testing.T) { require.NoError(t, err) defer l.Close() + ser := perunio.Serializer() d := NewTCPDialer(timeout) d.Register(laddr, lhost) daddr := simwallet.NewRandomAddress(rng) @@ -79,7 +80,7 @@ func TestDialer_Dial(t *testing.T) { } ct := test.NewConcurrent(t) go ct.Stage("accept", func(rt test.ConcT) { - conn, err := l.Accept() + conn, err := l.Accept(ser) assert.NoError(t, err) require.NotNil(rt, conn) @@ -90,7 +91,7 @@ func TestDialer_Dial(t *testing.T) { ct.Stage("dial", func(rt test.ConcT) { ctxtest.AssertTerminates(t, timeout, func() { - conn, err := d.Dial(context.Background(), laddr) + conn, err := d.Dial(context.Background(), laddr, ser) assert.NoError(t, err) require.NotNil(rt, conn) @@ -105,7 +106,7 @@ func TestDialer_Dial(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() ctxtest.AssertTerminates(t, timeout, func() { - conn, err := d.Dial(ctx, laddr) + conn, err := d.Dial(ctx, laddr, ser) assert.Nil(t, conn) assert.Error(t, err) }) @@ -116,7 +117,7 @@ func TestDialer_Dial(t *testing.T) { d.Register(noHostAddr, "no such host") ctxtest.AssertTerminates(t, timeout, func() { - conn, err := d.Dial(context.Background(), noHostAddr) + conn, err := d.Dial(context.Background(), noHostAddr, ser) assert.Nil(t, conn) assert.Error(t, err) }) @@ -125,7 +126,7 @@ func TestDialer_Dial(t *testing.T) { t.Run("unknown address", func(t *testing.T) { ctxtest.AssertTerminates(t, timeout, func() { unkownAddr := simwallet.NewRandomAddress(rng) - conn, err := d.Dial(context.Background(), unkownAddr) + conn, err := d.Dial(context.Background(), unkownAddr, ser) assert.Error(t, err) assert.Nil(t, conn) }) diff --git a/wire/net/simple/listener.go b/wire/net/simple/listener.go index 09d2c6f6..34aed557 100644 --- a/wire/net/simple/listener.go +++ b/wire/net/simple/listener.go @@ -18,6 +18,7 @@ import ( "net" "github.com/pkg/errors" + "perun.network/go-perun/wire" wirenet "perun.network/go-perun/wire/net" ) @@ -50,11 +51,11 @@ func NewUnixListener(address string) (*Listener, error) { } // Accept implements peer.Dialer.Accept(). -func (l *Listener) Accept() (wirenet.Conn, error) { +func (l *Listener) Accept(ser wire.EnvelopeSerializer) (wirenet.Conn, error) { conn, err := l.Listener.Accept() if err != nil { return nil, errors.Wrap(err, "accept failed") } - return wirenet.NewIoConn(conn), nil + return wirenet.NewIoConn(conn, ser), nil } diff --git a/wire/net/simple/listener_internal_test.go b/wire/net/simple/listener_internal_test.go index 90300af5..c37a6470 100644 --- a/wire/net/simple/listener_internal_test.go +++ b/wire/net/simple/listener_internal_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + perunio "perun.network/go-perun/wire/perunio/serializer" "polycry.pt/poly-go/context/test" ) @@ -73,6 +74,7 @@ func TestNewListener(t *testing.T) { func TestListener_Accept(t *testing.T) { // Happy case already tested in TestDialer_Dial. + ser := perunio.Serializer() timeout := 100 * time.Millisecond t.Run("timeout", func(t *testing.T) { l, err := NewTCPListener(addr) @@ -80,7 +82,7 @@ func TestListener_Accept(t *testing.T) { defer l.Close() test.AssertNotTerminates(t, timeout, func() { - l.Accept() //nolint:errcheck + l.Accept(ser) //nolint:errcheck }) }) @@ -90,7 +92,7 @@ func TestListener_Accept(t *testing.T) { l.Close() test.AssertTerminates(t, timeout, func() { - conn, err := l.Accept() + conn, err := l.Accept(ser) assert.Nil(t, conn) assert.Error(t, err) }) diff --git a/wire/net/test/connhub.go b/wire/net/test/connhub.go index 2df09c87..1c958b78 100644 --- a/wire/net/test/connhub.go +++ b/wire/net/test/connhub.go @@ -65,7 +65,7 @@ func (h *ConnHub) NewNetDialer() *Dialer { panic("ConnHub already closed") } - dialer := &Dialer{hub: h} + dialer := NewDialer(h) h.dialers.insert(dialer) dialer.OnClose(func() { h.dialers.erase(dialer) //nolint:errcheck diff --git a/wire/net/test/connhub_internal_test.go b/wire/net/test/connhub_internal_test.go index e83e0736..9e3e13bc 100644 --- a/wire/net/test/connhub_internal_test.go +++ b/wire/net/test/connhub_internal_test.go @@ -24,7 +24,7 @@ import ( _ "perun.network/go-perun/backend/sim" // backend init wallettest "perun.network/go-perun/wallet/test" "perun.network/go-perun/wire" - _ "perun.network/go-perun/wire/protobuf" // wire serialzer init + perunio "perun.network/go-perun/wire/perunio/serializer" wiretest "perun.network/go-perun/wire/test" ctxtest "polycry.pt/poly-go/context/test" "polycry.pt/poly-go/sync" @@ -33,6 +33,7 @@ import ( func TestConnHub_Create(t *testing.T) { rng := pkgtest.Prng(t) + ser := perunio.Serializer() t.Run("create and dial existing", func(t *testing.T) { assert := assert.New(t) @@ -45,7 +46,7 @@ func TestConnHub_Create(t *testing.T) { ct := pkgtest.NewConcurrent(t) go ctxtest.AssertTerminates(t, timeout, func() { ct.Stage("accept", func(rt pkgtest.ConcT) { - conn, err := l.Accept() + conn, err := l.Accept(ser) assert.NoError(err) require.NotNil(rt, conn) assert.NoError(conn.Send(wiretest.NewRandomEnvelope(rng, wire.NewPingMsg()))) @@ -54,7 +55,7 @@ func TestConnHub_Create(t *testing.T) { ctxtest.AssertTerminates(t, timeout, func() { ct.Stage("dial", func(rt pkgtest.ConcT) { - conn, err := d.Dial(context.Background(), addr) + conn, err := d.Dial(context.Background(), addr, ser) assert.NoError(err) require.NotNil(rt, conn) m, err := conn.Recv() @@ -85,7 +86,7 @@ func TestConnHub_Create(t *testing.T) { d := c.NewNetDialer() ctxtest.AssertTerminates(t, timeout, func() { - conn, err := d.Dial(context.Background(), wallettest.NewRandomAddress(rng)) + conn, err := d.Dial(context.Background(), wallettest.NewRandomAddress(rng), ser) assert.Nil(conn) assert.Error(err) }) diff --git a/wire/net/test/dialer.go b/wire/net/test/dialer.go index 08e27b20..b6bec073 100644 --- a/wire/net/test/dialer.go +++ b/wire/net/test/dialer.go @@ -36,8 +36,15 @@ type Dialer struct { var _ wirenet.Dialer = (*Dialer)(nil) +// NewDialer creates a new test dialer. +func NewDialer(hub *ConnHub) *Dialer { + return &Dialer{ + hub: hub, + } +} + // Dial tries to connect to a wire. -func (d *Dialer) Dial(ctx context.Context, address wire.Address) (wirenet.Conn, error) { +func (d *Dialer) Dial(ctx context.Context, address wire.Address, ser wire.EnvelopeSerializer) (wirenet.Conn, error) { if d.IsClosed() { return nil, errors.New("dialer closed") } @@ -54,13 +61,13 @@ func (d *Dialer) Dial(ctx context.Context, address wire.Address) (wirenet.Conn, } local, remote := net.Pipe() - if !l.Put(ctx, wirenet.NewIoConn(remote)) { + if !l.Put(ctx, wirenet.NewIoConn(remote, ser)) { local.Close() remote.Close() return nil, errors.New("Put() failed") } atomic.AddInt32(&d.dialed, 1) - return wirenet.NewIoConn(local), nil + return wirenet.NewIoConn(local, ser), nil } // Close closes a connection. diff --git a/wire/net/test/dialer_internal_test.go b/wire/net/test/dialer_internal_test.go index dcce7179..db8066a6 100644 --- a/wire/net/test/dialer_internal_test.go +++ b/wire/net/test/dialer_internal_test.go @@ -21,17 +21,19 @@ import ( "github.com/stretchr/testify/assert" "perun.network/go-perun/wallet/test" + perunio "perun.network/go-perun/wire/perunio/serializer" pkgtest "polycry.pt/poly-go/test" ) func TestDialer_Dial(t *testing.T) { rng := pkgtest.Prng(t) + ser := perunio.Serializer() // Closed dialer must always fail. t.Run("closed", func(t *testing.T) { var d Dialer d.Close() - conn, err := d.Dial(context.Background(), test.NewRandomAddress(rng)) + conn, err := d.Dial(context.Background(), test.NewRandomAddress(rng), ser) assert.Nil(t, conn) assert.Error(t, err) }) @@ -41,7 +43,7 @@ func TestDialer_Dial(t *testing.T) { var d Dialer ctx, cancel := context.WithCancel(context.Background()) cancel() - conn, err := d.Dial(ctx, test.NewRandomAddress(rng)) + conn, err := d.Dial(ctx, test.NewRandomAddress(rng), ser) assert.Nil(t, conn) assert.Error(t, err) }) diff --git a/wire/net/test/listener.go b/wire/net/test/listener.go index 87387e3d..0ebc201b 100644 --- a/wire/net/test/listener.go +++ b/wire/net/test/listener.go @@ -20,6 +20,7 @@ import ( "github.com/pkg/errors" + "perun.network/go-perun/wire" wirenet "perun.network/go-perun/wire/net" "polycry.pt/poly-go/sync" ) @@ -47,7 +48,7 @@ func NewNetListener() *Listener { // Accept returns the next connection that is enqueued via Put(). This function // blocks until either Put() is called or until the listener is closed. -func (l *Listener) Accept() (wirenet.Conn, error) { +func (l *Listener) Accept(wire.EnvelopeSerializer) (wirenet.Conn, error) { if l.IsClosed() { return nil, errors.New("listener closed") } diff --git a/wire/net/test/listener_internal_test.go b/wire/net/test/listener_internal_test.go index d2d990f4..87b253cd 100644 --- a/wire/net/test/listener_internal_test.go +++ b/wire/net/test/listener_internal_test.go @@ -23,6 +23,7 @@ import ( "perun.network/go-perun/wire" wirenet "perun.network/go-perun/wire/net" + perunio "perun.network/go-perun/wire/perunio/serializer" ctxtest "polycry.pt/poly-go/context/test" ) @@ -47,7 +48,7 @@ func TestListener_Accept_Put(t *testing.T) { defer close(done) ctxtest.AssertTerminates(t, timeout, func() { - conn, err := l.Accept() + conn, err := l.Accept(perunio.Serializer()) assert.NoError(t, err, "Accept must not fail") assert.Same(t, connection, conn, "Accept must receive connection from Put") @@ -67,12 +68,13 @@ func TestListener_Accept_Put(t *testing.T) { func TestListener_Accept_Close(t *testing.T) { t.Parallel() + ser := perunio.Serializer() t.Run("close before accept", func(t *testing.T) { l := NewNetListener() l.Close() ctxtest.AssertTerminates(t, timeout, func() { - conn, err := l.Accept() + conn, err := l.Accept(ser) assert.Error(t, err, "Accept must fail") assert.Nil(t, conn) assert.Zero(t, l.NumAccepted()) @@ -87,7 +89,7 @@ func TestListener_Accept_Close(t *testing.T) { }() ctxtest.AssertTerminates(t, 2*timeout, func() { - conn, err := l.Accept() + conn, err := l.Accept(ser) assert.Error(t, err, "Accept must fail") assert.Nil(t, conn) assert.Zero(t, l.NumAccepted()) @@ -117,7 +119,7 @@ func TestListener_Put(t *testing.T) { // Closed listener must abort Put() calls. assert.False(t, l.Put(context.Background(), connection)) // Accept() must always fail when closed. - conn, err := l.Accept() + conn, err := l.Accept(perunio.Serializer()) assert.Nil(t, conn) assert.Error(t, err) assert.Zero(t, l.NumAccepted()) diff --git a/wire/net/test/pipeconn.go b/wire/net/test/pipeconn.go index 2eaf74c1..bd95f36d 100644 --- a/wire/net/test/pipeconn.go +++ b/wire/net/test/pipeconn.go @@ -21,6 +21,7 @@ import ( "perun.network/go-perun/wire" wirenet "perun.network/go-perun/wire/net" + perunio "perun.network/go-perun/wire/perunio/serializer" "polycry.pt/poly-go/sync/atomic" ) @@ -63,5 +64,6 @@ func (c *Conn) IsClosed() bool { func NewTestConnPair() (a wirenet.Conn, b wirenet.Conn) { closed := new(atomic.Bool) c0, c1 := net.Pipe() - return &Conn{closed, wirenet.NewIoConn(c0)}, &Conn{closed, wirenet.NewIoConn(c1)} + ser := perunio.Serializer() + return &Conn{closed, wirenet.NewIoConn(c0, ser)}, &Conn{closed, wirenet.NewIoConn(c1, ser)} } diff --git a/wire/perunio/serializer/serializer.go b/wire/perunio/serializer/serializer.go index b9c2a32a..ff86ae6d 100644 --- a/wire/perunio/serializer/serializer.go +++ b/wire/perunio/serializer/serializer.go @@ -22,12 +22,13 @@ import ( "perun.network/go-perun/wire/perunio" ) -type serializer struct{} - -func init() { - wire.SetEnvelopeSerializer(serializer{}) +// Serializer returns a perunio serializer. +func Serializer() wire.EnvelopeSerializer { + return serializer{} } +type serializer struct{} + // Encode encodes the envelope into the wire using perunio encoding format. func (serializer) Encode(w io.Writer, env *wire.Envelope) error { if err := perunio.Encode(w, env.Sender, env.Recipient); err != nil { diff --git a/wire/perunio/test/msgtest.go b/wire/perunio/test/msgtest.go index 7df1664f..86a55cb5 100644 --- a/wire/perunio/test/msgtest.go +++ b/wire/perunio/test/msgtest.go @@ -20,6 +20,7 @@ import ( "testing" "perun.network/go-perun/wire" + perunio "perun.network/go-perun/wire/perunio/serializer" wiretest "perun.network/go-perun/wire/test" pkgtest "polycry.pt/poly-go/test" ) @@ -30,17 +31,21 @@ type serializableEnvelope struct { env *wire.Envelope } +var serializer = perunio.Serializer() + func (e *serializableEnvelope) Encode(writer io.Writer) error { - return wire.EncodeEnvelope(writer, e.env) + return serializer.Encode(writer, e.env) } func (e *serializableEnvelope) Decode(reader io.Reader) (err error) { - e.env, err = wire.DecodeEnvelope(reader) + e.env, err = serializer.Decode(reader) return err } func newSerializableEnvelope(rng *rand.Rand, msg wire.Msg) *serializableEnvelope { - return &serializableEnvelope{env: wiretest.NewRandomEnvelope(rng, msg)} + return &serializableEnvelope{ + env: wiretest.NewRandomEnvelope(rng, msg), + } } // MsgSerializerTest performs generic serializer tests on a wire.Msg object. @@ -48,6 +53,6 @@ func newSerializableEnvelope(rng *rand.Rand, msg wire.Msg) *serializableEnvelope // and the registration of the corresponding decoders. func MsgSerializerTest(t *testing.T, msg wire.Msg) { t.Helper() - - GenericSerializerTest(t, newSerializableEnvelope(pkgtest.Prng(t), msg)) + e := newSerializableEnvelope(pkgtest.Prng(t), msg) + GenericSerializerTest(t, e) } diff --git a/wire/protobuf/serializer.go b/wire/protobuf/serializer.go index 370114e7..d5b7222c 100644 --- a/wire/protobuf/serializer.go +++ b/wire/protobuf/serializer.go @@ -27,8 +27,9 @@ import ( type serializer struct{} -func init() { - wire.SetEnvelopeSerializer(serializer{}) +// Serializer returns a protobuf serializer. +func Serializer() wire.EnvelopeSerializer { + return serializer{} } // Encode encodes an envelope from the reader using protocol buffers diff --git a/wire/protobuf/serializer_test.go b/wire/protobuf/serializer_test.go index 4da1fe5f..b03b1b0e 100644 --- a/wire/protobuf/serializer_test.go +++ b/wire/protobuf/serializer_test.go @@ -20,7 +20,6 @@ import ( _ "perun.network/go-perun/backend/sim/channel" _ "perun.network/go-perun/backend/sim/wallet" clienttest "perun.network/go-perun/client/test" - _ "perun.network/go-perun/wire/protobuf" protobuftest "perun.network/go-perun/wire/protobuf/test" wiretest "perun.network/go-perun/wire/test" ) diff --git a/wire/protobuf/test/serializertest.go b/wire/protobuf/test/serializertest.go index 6872951f..9e4b4cf4 100644 --- a/wire/protobuf/test/serializertest.go +++ b/wire/protobuf/test/serializertest.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/require" wallettest "perun.network/go-perun/wallet/test" "perun.network/go-perun/wire" + "perun.network/go-perun/wire/protobuf" pkgtest "polycry.pt/poly-go/test" ) @@ -36,9 +37,10 @@ func MsgSerializerTest(t *testing.T, msg wire.Msg) { envelope.Msg = msg var buff bytes.Buffer - require.NoError(t, wire.EncodeEnvelope(&buff, envelope)) + ser := protobuf.Serializer() + require.NoError(t, ser.Encode(&buff, envelope)) - gotEnvelope, err := wire.DecodeEnvelope(&buff) + gotEnvelope, err := ser.Decode(&buff) require.NoError(t, err) assert.EqualValues(t, envelope, gotEnvelope) } diff --git a/wire/test/serializinglocalbus.go b/wire/test/serializinglocalbus.go index 5da06e93..0a224617 100644 --- a/wire/test/serializinglocalbus.go +++ b/wire/test/serializinglocalbus.go @@ -19,17 +19,20 @@ import ( "context" "perun.network/go-perun/wire" + perunio "perun.network/go-perun/wire/perunio/serializer" ) // SerializingLocalBus is a local bus that also serializes messages for testing. type SerializingLocalBus struct { *wire.LocalBus + ser wire.EnvelopeSerializer } // NewSerializingLocalBus creates a new serializing local bus. func NewSerializingLocalBus() *SerializingLocalBus { return &SerializingLocalBus{ LocalBus: wire.NewLocalBus(), + ser: perunio.Serializer(), } } @@ -38,12 +41,12 @@ func (b *SerializingLocalBus) Publish(ctx context.Context, e *wire.Envelope) (er // Serialize and deserialize the envelope before publishing it on the local // bus, to simulate envelope serialization. var buf bytes.Buffer - err = wire.EncodeEnvelope(&buf, e) + err = b.ser.Encode(&buf, e) if err != nil { return } - deserializedEnvelope, err := wire.DecodeEnvelope(&buf) + deserializedEnvelope, err := b.ser.Decode(&buf) if err != nil { return }