From 4cca41df7d5839401eec604cb1ec9dc946b7a304 Mon Sep 17 00:00:00 2001 From: Dmitry Kolesnikov Date: Tue, 13 Aug 2024 20:07:30 +0300 Subject: [PATCH 1/3] (fix) allow fmap stream out data and produce errors effectively --- pipe/fork/fork.go | 22 +++++++--------------- pipe/fork/fork_test.go | 9 +++++---- pipe/pipe.go | 22 +++++++--------------- pipe/pipe_test.go | 9 +++++---- pipe/version.go | 2 +- 5 files changed, 25 insertions(+), 39 deletions(-) diff --git a/pipe/fork/fork.go b/pipe/fork/fork.go index 44508c7..e93bf61 100644 --- a/pipe/fork/fork.go +++ b/pipe/fork/fork.go @@ -92,7 +92,7 @@ func ForEach[A any](ctx context.Context, par int, in <-chan A, f func(A)) <-chan // FMap applies function over channel messages, flatten the output channel and // emits it result to new channel. -func FMap[A, B any](ctx context.Context, par int, in <-chan A, fmap func(A) (<-chan B, error)) (<-chan B, <-chan error) { +func FMap[A, B any](ctx context.Context, par int, in <-chan A, fmap func(context.Context, A, chan<- B) error) (<-chan B, <-chan error) { var wg sync.WaitGroup out := make(chan B, par) exx := make(chan error, par) @@ -100,25 +100,17 @@ func FMap[A, B any](ctx context.Context, par int, in <-chan A, fmap func(A) (<-c pmap := func() { defer wg.Done() - var ( - a A - ch <-chan B - err error - ) - + var a A for a = range in { - ch, err = fmap(a) - if err != nil { + if err := fmap(ctx, a, out); err != nil { exx <- err return } - for x := range ch { - select { - case out <- x: - case <-ctx.Done(): - return - } + select { + case <-ctx.Done(): + return + default: } } } diff --git a/pipe/fork/fork_test.go b/pipe/fork/fork_test.go index fd6a482..15c14c2 100644 --- a/pipe/fork/fork_test.go +++ b/pipe/fork/fork_test.go @@ -83,8 +83,9 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := fork.Seq(1, 2, 3, 4, 5) out := fork.StdErr(fork.FMap(ctx, par, seq, - func(x int) (<-chan string, error) { - return fork.Seq(strconv.Itoa(x)), nil + func(ctx context.Context, x int, ch chan<- string) error { + ch <- strconv.Itoa(x) + return nil }), ) @@ -99,8 +100,8 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := fork.Seq(1, 2, 3, 4, 5) _, exx := fork.FMap(ctx, par, seq, - func(x int) (<-chan string, error) { - return nil, fmt.Errorf("fail") + func(ctx context.Context, x int, ch chan<- string) error { + return fmt.Errorf("fail") }, ) diff --git a/pipe/pipe.go b/pipe/pipe.go index 1a52e5d..eca8a58 100644 --- a/pipe/pipe.go +++ b/pipe/pipe.go @@ -99,7 +99,7 @@ func ForEach[A any](ctx context.Context, in <-chan A, f func(A)) <-chan struct{} // FMap applies function over channel messages, flatten the output channel and // emits it result to new channel. -func FMap[A, B any](ctx context.Context, in <-chan A, fmap func(A) (<-chan B, error)) (<-chan B, <-chan error) { +func FMap[A, B any](ctx context.Context, in <-chan A, fmap func(context.Context, A, chan<- B) error) (<-chan B, <-chan error) { out := make(chan B, cap(in)) exx := make(chan error, 1) @@ -107,25 +107,17 @@ func FMap[A, B any](ctx context.Context, in <-chan A, fmap func(A) (<-chan B, er defer close(out) defer close(exx) - var ( - a A - ch <-chan B - err error - ) - + var a A for a = range in { - ch, err = fmap(a) - if err != nil { + if err := fmap(ctx, a, out); err != nil { exx <- err return } - for x := range ch { - select { - case out <- x: - case <-ctx.Done(): - return - } + select { + case <-ctx.Done(): + return + default: } } }() diff --git a/pipe/pipe_test.go b/pipe/pipe_test.go index ac52d27..942b5d7 100644 --- a/pipe/pipe_test.go +++ b/pipe/pipe_test.go @@ -88,8 +88,9 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := pipe.Seq(1, 2, 3, 4, 5) out := pipe.StdErr(pipe.FMap(ctx, seq, - func(x int) (<-chan string, error) { - return pipe.Seq(strconv.Itoa(x)), nil + func(ctx context.Context, x int, ch chan<- string) error { + ch <- strconv.Itoa(x) + return nil }), ) @@ -104,8 +105,8 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := pipe.Seq(1, 2, 3, 4, 5) _, exx := pipe.FMap(ctx, seq, - func(x int) (<-chan string, error) { - return nil, fmt.Errorf("fail") + func(ctx context.Context, x int, ch chan<- string) error { + return fmt.Errorf("fail") }, ) diff --git a/pipe/version.go b/pipe/version.go index 4130402..01f5c67 100644 --- a/pipe/version.go +++ b/pipe/version.go @@ -8,4 +8,4 @@ package pipe -const Version = "pipe/v1.1.0" +const Version = "pipe/v1.1.1" From 25de7d59a5b3ffc24b8f380cd69e9037d56026dd Mon Sep 17 00:00:00 2001 From: Dmitry Kolesnikov Date: Tue, 13 Aug 2024 22:05:48 +0300 Subject: [PATCH 2/3] add cancel test --- pipe/fork/fork_test.go | 24 ++++++++++++++++++++++++ pipe/pipe_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/pipe/fork/fork_test.go b/pipe/fork/fork_test.go index 15c14c2..760acea 100644 --- a/pipe/fork/fork_test.go +++ b/pipe/fork/fork_test.go @@ -111,6 +111,30 @@ func TestFMap(t *testing.T) { close() }) + + t.Run("Cancel", func(t *testing.T) { + acc := 0 + emit := func() (int, error) { + acc++ + return acc, nil + } + + ctx, close := context.WithCancel(context.Background()) + seq := fork.StdErr(fork.Emit(ctx, 1000, 10*time.Microsecond, emit)) + out := fork.StdErr(fork.FMap(ctx, par, seq, + func(ctx context.Context, x int, ch chan<- int) error { + ch <- x + return nil + }), + ) + + time.Sleep(100 * time.Microsecond) + close() + + it.Then(t).Should( + it.Seq(fork.ToSeq(out)).Contain().AllOf(1, 2, 3, 4), + ) + }) } func TestFold(t *testing.T) { diff --git a/pipe/pipe_test.go b/pipe/pipe_test.go index 942b5d7..2541e6d 100644 --- a/pipe/pipe_test.go +++ b/pipe/pipe_test.go @@ -116,6 +116,30 @@ func TestFMap(t *testing.T) { close() }) + + t.Run("Cancel", func(t *testing.T) { + acc := 0 + emit := func() (int, error) { + acc++ + return acc, nil + } + + ctx, close := context.WithCancel(context.Background()) + seq := pipe.StdErr(pipe.Emit(ctx, 1000, 10*time.Microsecond, emit)) + out := pipe.StdErr(pipe.FMap(ctx, seq, + func(ctx context.Context, x int, ch chan<- int) error { + ch <- x + return nil + }), + ) + + time.Sleep(100 * time.Microsecond) + close() + + it.Then(t).Should( + it.Seq(pipe.ToSeq(out)).Contain().AllOf(1, 2, 3, 4), + ) + }) } func TestFold(t *testing.T) { From 2a0ad0f3b7a17a4438c3d7c259887ad9d2f4fc82 Mon Sep 17 00:00:00 2001 From: Dmitry Kolesnikov Date: Tue, 13 Aug 2024 22:12:10 +0300 Subject: [PATCH 3/3] increase test coverage --- pipe/fork/fork_test.go | 4 ++-- pipe/pipe_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pipe/fork/fork_test.go b/pipe/fork/fork_test.go index 760acea..bf85e61 100644 --- a/pipe/fork/fork_test.go +++ b/pipe/fork/fork_test.go @@ -128,11 +128,11 @@ func TestFMap(t *testing.T) { }), ) - time.Sleep(100 * time.Microsecond) + vals := fork.ToSeq(fork.Take(ctx, out, 10)) close() it.Then(t).Should( - it.Seq(fork.ToSeq(out)).Contain().AllOf(1, 2, 3, 4), + it.Seq(vals).Contain().AllOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), ) }) } diff --git a/pipe/pipe_test.go b/pipe/pipe_test.go index 2541e6d..335b1b6 100644 --- a/pipe/pipe_test.go +++ b/pipe/pipe_test.go @@ -133,11 +133,11 @@ func TestFMap(t *testing.T) { }), ) - time.Sleep(100 * time.Microsecond) + vals := pipe.ToSeq(pipe.Take(ctx, out, 10)) close() it.Then(t).Should( - it.Seq(pipe.ToSeq(out)).Contain().AllOf(1, 2, 3, 4), + it.Seq(vals).Contain().AllOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), ) }) }