diff --git a/pipe/fork/fork.go b/pipe/fork/fork.go index 44508c7..e93bf61 100644 --- a/pipe/fork/fork.go +++ b/pipe/fork/fork.go @@ -92,7 +92,7 @@ func ForEach[A any](ctx context.Context, par int, in <-chan A, f func(A)) <-chan // FMap applies function over channel messages, flatten the output channel and // emits it result to new channel. -func FMap[A, B any](ctx context.Context, par int, in <-chan A, fmap func(A) (<-chan B, error)) (<-chan B, <-chan error) { +func FMap[A, B any](ctx context.Context, par int, in <-chan A, fmap func(context.Context, A, chan<- B) error) (<-chan B, <-chan error) { var wg sync.WaitGroup out := make(chan B, par) exx := make(chan error, par) @@ -100,25 +100,17 @@ func FMap[A, B any](ctx context.Context, par int, in <-chan A, fmap func(A) (<-c pmap := func() { defer wg.Done() - var ( - a A - ch <-chan B - err error - ) - + var a A for a = range in { - ch, err = fmap(a) - if err != nil { + if err := fmap(ctx, a, out); err != nil { exx <- err return } - for x := range ch { - select { - case out <- x: - case <-ctx.Done(): - return - } + select { + case <-ctx.Done(): + return + default: } } } diff --git a/pipe/fork/fork_test.go b/pipe/fork/fork_test.go index fd6a482..bf85e61 100644 --- a/pipe/fork/fork_test.go +++ b/pipe/fork/fork_test.go @@ -83,8 +83,9 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := fork.Seq(1, 2, 3, 4, 5) out := fork.StdErr(fork.FMap(ctx, par, seq, - func(x int) (<-chan string, error) { - return fork.Seq(strconv.Itoa(x)), nil + func(ctx context.Context, x int, ch chan<- string) error { + ch <- strconv.Itoa(x) + return nil }), ) @@ -99,8 +100,8 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := fork.Seq(1, 2, 3, 4, 5) _, exx := fork.FMap(ctx, par, seq, - func(x int) (<-chan string, error) { - return nil, fmt.Errorf("fail") + func(ctx context.Context, x int, ch chan<- string) error { + return fmt.Errorf("fail") }, ) @@ -110,6 +111,30 @@ func TestFMap(t *testing.T) { close() }) + + t.Run("Cancel", func(t *testing.T) { + acc := 0 + emit := func() (int, error) { + acc++ + return acc, nil + } + + ctx, close := context.WithCancel(context.Background()) + seq := fork.StdErr(fork.Emit(ctx, 1000, 10*time.Microsecond, emit)) + out := fork.StdErr(fork.FMap(ctx, par, seq, + func(ctx context.Context, x int, ch chan<- int) error { + ch <- x + return nil + }), + ) + + vals := fork.ToSeq(fork.Take(ctx, out, 10)) + close() + + it.Then(t).Should( + it.Seq(vals).Contain().AllOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + ) + }) } func TestFold(t *testing.T) { diff --git a/pipe/pipe.go b/pipe/pipe.go index 1a52e5d..eca8a58 100644 --- a/pipe/pipe.go +++ b/pipe/pipe.go @@ -99,7 +99,7 @@ func ForEach[A any](ctx context.Context, in <-chan A, f func(A)) <-chan struct{} // FMap applies function over channel messages, flatten the output channel and // emits it result to new channel. -func FMap[A, B any](ctx context.Context, in <-chan A, fmap func(A) (<-chan B, error)) (<-chan B, <-chan error) { +func FMap[A, B any](ctx context.Context, in <-chan A, fmap func(context.Context, A, chan<- B) error) (<-chan B, <-chan error) { out := make(chan B, cap(in)) exx := make(chan error, 1) @@ -107,25 +107,17 @@ func FMap[A, B any](ctx context.Context, in <-chan A, fmap func(A) (<-chan B, er defer close(out) defer close(exx) - var ( - a A - ch <-chan B - err error - ) - + var a A for a = range in { - ch, err = fmap(a) - if err != nil { + if err := fmap(ctx, a, out); err != nil { exx <- err return } - for x := range ch { - select { - case out <- x: - case <-ctx.Done(): - return - } + select { + case <-ctx.Done(): + return + default: } } }() diff --git a/pipe/pipe_test.go b/pipe/pipe_test.go index ac52d27..335b1b6 100644 --- a/pipe/pipe_test.go +++ b/pipe/pipe_test.go @@ -88,8 +88,9 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := pipe.Seq(1, 2, 3, 4, 5) out := pipe.StdErr(pipe.FMap(ctx, seq, - func(x int) (<-chan string, error) { - return pipe.Seq(strconv.Itoa(x)), nil + func(ctx context.Context, x int, ch chan<- string) error { + ch <- strconv.Itoa(x) + return nil }), ) @@ -104,8 +105,8 @@ func TestFMap(t *testing.T) { ctx, close := context.WithCancel(context.Background()) seq := pipe.Seq(1, 2, 3, 4, 5) _, exx := pipe.FMap(ctx, seq, - func(x int) (<-chan string, error) { - return nil, fmt.Errorf("fail") + func(ctx context.Context, x int, ch chan<- string) error { + return fmt.Errorf("fail") }, ) @@ -115,6 +116,30 @@ func TestFMap(t *testing.T) { close() }) + + t.Run("Cancel", func(t *testing.T) { + acc := 0 + emit := func() (int, error) { + acc++ + return acc, nil + } + + ctx, close := context.WithCancel(context.Background()) + seq := pipe.StdErr(pipe.Emit(ctx, 1000, 10*time.Microsecond, emit)) + out := pipe.StdErr(pipe.FMap(ctx, seq, + func(ctx context.Context, x int, ch chan<- int) error { + ch <- x + return nil + }), + ) + + vals := pipe.ToSeq(pipe.Take(ctx, out, 10)) + close() + + it.Then(t).Should( + it.Seq(vals).Contain().AllOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + ) + }) } func TestFold(t *testing.T) { diff --git a/pipe/version.go b/pipe/version.go index 4130402..01f5c67 100644 --- a/pipe/version.go +++ b/pipe/version.go @@ -8,4 +8,4 @@ package pipe -const Version = "pipe/v1.1.0" +const Version = "pipe/v1.1.1"