Skip to content

Commit

Permalink
Fix: "tree signing session not found" error (ark-network#323)
Browse files Browse the repository at this point in the history
* failing test

* fix duplicate input register

* fix btc-embedded coin selection

* rename test

* add checks in failing test case

* fixes GetEventStream

* add TODO comment in createPoolTx

* update with master changes

* fix server unit test

* increase liquidity of testing ASP

* simplify AliceSeveralPaymentsBob test
  • Loading branch information
louisinger authored Sep 20, 2024
1 parent 5c2065a commit 9e3d667
Show file tree
Hide file tree
Showing 13 changed files with 310 additions and 104 deletions.
11 changes: 7 additions & 4 deletions pkg/client-sdk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
filestore "github.com/ark-network/ark/pkg/client-sdk/wallet/singlekey/store/file"
inmemorystore "github.com/ark-network/ark/pkg/client-sdk/wallet/singlekey/store/inmemory"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -236,11 +237,13 @@ func (a *arkClient) ping(
ticker := time.NewTicker(5 * time.Second)

go func(t *time.Ticker) {
// nolint
a.client.Ping(ctx, paymentID)
if _, err := a.client.Ping(ctx, paymentID); err != nil {
logrus.Warnf("failed to ping asp: %s", err)
}
for range t.C {
// nolint
a.client.Ping(ctx, paymentID)
if _, err := a.client.Ping(ctx, paymentID); err != nil {
logrus.Warnf("failed to ping asp: %s", err)
}
}
}(ticker)

Expand Down
2 changes: 1 addition & 1 deletion pkg/client-sdk/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type ASPClient interface {
) error
GetEventStream(
ctx context.Context, paymentID string,
) (<-chan RoundEventChannel, error)
) (<-chan RoundEventChannel, func(), error)
Ping(ctx context.Context, paymentID string) (RoundEvent, error)
FinalizePayment(
ctx context.Context, signedForfeitTxs []string, signedRoundTx string,
Expand Down
50 changes: 33 additions & 17 deletions pkg/client-sdk/client/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/ark-network/ark/pkg/client-sdk/internal/utils"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
Expand All @@ -23,7 +24,6 @@ import (
type grpcClient struct {
conn *grpc.ClientConn
svc arkv1.ArkServiceClient
eventsCh chan client.RoundEventChannel
treeCache *utils.Cache[tree.CongestionTree]
}

Expand All @@ -48,10 +48,9 @@ func NewClient(aspUrl string) (client.ASPClient, error) {
}

svc := arkv1.NewArkServiceClient(conn)
eventsCh := make(chan client.RoundEventChannel)
treeCache := utils.NewCache[tree.CongestionTree]()

return &grpcClient{conn, svc, eventsCh, treeCache}, nil
return &grpcClient{conn, svc, treeCache}, nil
}

func (c *grpcClient) Close() {
Expand All @@ -61,34 +60,47 @@ func (c *grpcClient) Close() {

func (a *grpcClient) GetEventStream(
ctx context.Context, paymentID string,
) (<-chan client.RoundEventChannel, error) {
) (<-chan client.RoundEventChannel, func(), error) {
req := &arkv1.GetEventStreamRequest{}
stream, err := a.svc.GetEventStream(ctx, req)
if err != nil {
return nil, err
return nil, nil, err
}

eventsCh := make(chan client.RoundEventChannel)

go func() {
defer close(a.eventsCh)
defer close(eventsCh)

for {
resp, err := stream.Recv()
if err != nil {
a.eventsCh <- client.RoundEventChannel{Err: err}
select {
case <-stream.Context().Done():
return
}
default:
resp, err := stream.Recv()
if err != nil {
eventsCh <- client.RoundEventChannel{Err: err}
return
}

ev, err := event{resp}.toRoundEvent()
if err != nil {
a.eventsCh <- client.RoundEventChannel{Err: err}
return
}
ev, err := event{resp}.toRoundEvent()
if err != nil {
eventsCh <- client.RoundEventChannel{Err: err}
return
}

a.eventsCh <- client.RoundEventChannel{Event: ev}
eventsCh <- client.RoundEventChannel{Event: ev}
}
}
}()

return a.eventsCh, nil
closeFn := func() {
if err := stream.CloseSend(); err != nil {
logrus.Warnf("failed to close stream: %v", err)
}
}

return eventsCh, closeFn, nil
}

func (a *grpcClient) GetInfo(ctx context.Context) (*client.Info, error) {
Expand Down Expand Up @@ -184,6 +196,10 @@ func (a *grpcClient) Ping(
return nil, err
}

if resp.GetEvent() == nil {
return nil, nil
}

return event{resp}.toRoundEvent()
}

Expand Down
26 changes: 17 additions & 9 deletions pkg/client-sdk/client/rest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (

type restClient struct {
svc ark_service.ClientService
eventsCh chan client.RoundEventChannel
requestTimeout time.Duration
treeCache *utils.Cache[tree.CongestionTree]
}
Expand All @@ -39,41 +38,46 @@ func NewClient(aspUrl string) (client.ASPClient, error) {
if err != nil {
return nil, err
}
eventsCh := make(chan client.RoundEventChannel)
reqTimeout := 15 * time.Second
treeCache := utils.NewCache[tree.CongestionTree]()

return &restClient{svc, eventsCh, reqTimeout, treeCache}, nil
return &restClient{svc, reqTimeout, treeCache}, nil
}

func (c *restClient) Close() {}

func (a *restClient) GetEventStream(
ctx context.Context, paymentID string,
) (<-chan client.RoundEventChannel, error) {
) (<-chan client.RoundEventChannel, func(), error) {
eventsCh := make(chan client.RoundEventChannel)
stopCh := make(chan struct{})

go func(payID string) {
defer close(a.eventsCh)
defer close(eventsCh)
defer close(stopCh)

timeout := time.After(a.requestTimeout)

for {
select {
case <-stopCh:
return
case <-timeout:
a.eventsCh <- client.RoundEventChannel{
eventsCh <- client.RoundEventChannel{
Err: fmt.Errorf("timeout reached"),
}
return
default:
event, err := a.Ping(ctx, payID)
if err != nil {
a.eventsCh <- client.RoundEventChannel{
eventsCh <- client.RoundEventChannel{
Err: err,
}
return
}

if event != nil {
a.eventsCh <- client.RoundEventChannel{
eventsCh <- client.RoundEventChannel{
Event: event,
}
}
Expand All @@ -83,7 +87,11 @@ func (a *restClient) GetEventStream(
}
}(paymentID)

return a.eventsCh, nil
close := func() {
stopCh <- struct{}{}
}

return eventsCh, close, nil
}

func (a *restClient) GetInfo(
Expand Down
7 changes: 5 additions & 2 deletions pkg/client-sdk/covenant_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ func (a *covenantArkClient) handleRoundStream(
boardingDescriptor string,
receivers []client.Output,
) (string, error) {
eventsCh, err := a.client.GetEventStream(ctx, paymentID)
eventsCh, close, err := a.client.GetEventStream(ctx, paymentID)
if err != nil {
return "", err
}
Expand All @@ -1021,7 +1021,10 @@ func (a *covenantArkClient) handleRoundStream(
pingStop = a.ping(ctx, paymentID)
}

defer pingStop()
defer func() {
pingStop()
close()
}()

for {
select {
Expand Down
15 changes: 10 additions & 5 deletions pkg/client-sdk/covenantless_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ func (a *covenantlessArkClient) handleRoundStream(
receivers []client.Output,
roundEphemeralKey *secp256k1.PrivateKey,
) (string, error) {
eventsCh, err := a.client.GetEventStream(ctx, paymentID)
eventsCh, close, err := a.client.GetEventStream(ctx, paymentID)
if err != nil {
return "", err
}
Expand All @@ -1104,7 +1104,10 @@ func (a *covenantlessArkClient) handleRoundStream(
pingStop = a.ping(ctx, paymentID)
}

defer pingStop()
defer func() {
pingStop()
close()
}()

var signerSession bitcointree.SignerSession

Expand All @@ -1120,14 +1123,16 @@ func (a *covenantlessArkClient) handleRoundStream(
for {
select {
case <-ctx.Done():
return "", ctx.Err()
return "", fmt.Errorf("context done %s", ctx.Err())
case notify := <-eventsCh:
if notify.Err != nil {
return "", err
return "", notify.Err
}

switch event := notify.Event; event.(type) {
case client.RoundFinalizedEvent:
if step != roundFinalization {
continue
}
return event.(client.RoundFinalizedEvent).Txid, nil
case client.RoundFailedEvent:
return "", fmt.Errorf("round failed: %s", event.(client.RoundFailedEvent).Reason)
Expand Down
10 changes: 6 additions & 4 deletions server/internal/core/application/covenantless.go
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,6 @@ func (s *covenantlessService) RegisterCosignerNonces(
if err != nil {
return fmt.Errorf("failed to decode nonces: %s", err)
}

session.lock.Lock()
defer session.lock.Unlock()

Expand All @@ -654,7 +653,9 @@ func (s *covenantlessService) RegisterCosignerNonces(
session.nonces[pubkey] = nonces

if len(session.nonces) == session.nbCosigners-1 { // exclude the ASP
session.nonceDoneC <- struct{}{}
go func() {
session.nonceDoneC <- struct{}{}
}()
}

return nil
Expand Down Expand Up @@ -683,7 +684,9 @@ func (s *covenantlessService) RegisterCosignerSignatures(
session.signatures[pubkey] = signatures

if len(session.signatures) == session.nbCosigners-1 { // exclude the ASP
session.sigDoneC <- struct{}{}
go func() {
session.sigDoneC <- struct{}{}
}()
}

return nil
Expand Down Expand Up @@ -1078,7 +1081,6 @@ func (s *covenantlessService) finalizeRound() {
txid, err := s.wallet.BroadcastTransaction(ctx, signedRoundTx)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err))
log.WithError(err).Warn("failed to broadcast pool tx")
return
}

Expand Down
22 changes: 21 additions & 1 deletion server/internal/core/application/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,27 @@ func (m *paymentsMap) push(payment domain.Payment, boardingInputs []ports.Boardi
defer m.lock.Unlock()

if _, ok := m.payments[payment.Id]; ok {
return fmt.Errorf("duplicated inputs")
return fmt.Errorf("duplicated payment %s", payment.Id)
}

for _, input := range payment.Inputs {
for _, pay := range m.payments {
for _, pInput := range pay.Inputs {
if input.VtxoKey.Txid == pInput.VtxoKey.Txid && input.VtxoKey.VOut == pInput.VtxoKey.VOut {
return fmt.Errorf("duplicated input, %s:%d already used by payment %s", input.VtxoKey.Txid, input.VtxoKey.VOut, pay.Id)
}
}
}
}

for _, input := range boardingInputs {
for _, pay := range m.payments {
for _, pBoardingInput := range pay.boardingInputs {
if input.Txid == pBoardingInput.Txid && input.VOut == pBoardingInput.VOut {
return fmt.Errorf("duplicated boarding input, %s:%d already used by payment %s", input.Txid, input.VOut, pay.Id)
}
}
}
}

m.payments[payment.Id] = &timedPayment{payment, boardingInputs, time.Now(), time.Time{}}
Expand Down
Loading

0 comments on commit 9e3d667

Please sign in to comment.