Skip to content

Commit

Permalink
cleanup: refactor and use bicopy everywhere
Browse files Browse the repository at this point in the history
There are 3 instances of this pattern and each does this slightly
differently. Clean up the implementation to return errors using
`errors.Join` (which wasn't available when the original was written) and
use it everywhere.

This doesn't change behavior because the error return is always just
logged (see the only called of `(*pseudoLoopbackForwarder).forward`.

Note that the removal of the special handling of `io.EOF` returned from
`io.Copy` doesn't change behavior because it can never happen per the
latter's documentation.

Signed-off-by: Tamir Duberstein <[email protected]>
  • Loading branch information
tamird committed Nov 28, 2024
1 parent 10280b6 commit a2a75a4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 108 deletions.
115 changes: 53 additions & 62 deletions pkg/bicopy/bicopy.go
Original file line number Diff line number Diff line change
@@ -1,81 +1,72 @@
// From https://raw.githubusercontent.com/norouter/norouter/v0.6.5/pkg/agent/bicopy/bicopy.go
/*
Copyright (C) NoRouter authors.
Copyright (C) libnetwork authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package bicopy

import (
"errors"
"fmt"
"io"
"sync"

"github.com/sirupsen/logrus"
)

// Bicopy is from https://github.com/rootless-containers/rootlesskit/blob/v0.10.1/pkg/port/builtin/parent/tcp/tcp.go#L73-L104
// (originally from libnetwork, Apache License 2.0).
func Bicopy(x, y io.ReadWriter, quit <-chan struct{}) {
type closeReader interface {
func closeRead(c io.ReadWriter) error {
if c, ok := c.(interface {
CloseRead() error
}); ok {
return c.CloseRead()
}
type closeWriter interface {
return nil
}

func closeWrite(c io.ReadWriter) error {
if c, ok := c.(interface {
CloseWrite() error
}); ok {
return c.CloseWrite()
}
var wg sync.WaitGroup
broker := func(to, from io.ReadWriter) {
if _, err := io.Copy(to, from); err != nil {
logrus.WithError(err).Debug("failed to call io.Copy")
}
if fromCR, ok := from.(closeReader); ok {
if err := fromCR.CloseRead(); err != nil {
logrus.WithError(err).Debug("failed to call CloseRead")
}
}
if toCW, ok := to.(closeWriter); ok {
if err := toCW.CloseWrite(); err != nil {
logrus.WithError(err).Debug("failed to call CloseWrite")
}
}
wg.Done()
return nil
}

// Avoid shadowing the built-in `close`.
func ioClose(c io.ReadWriter) error {
if c, ok := c.(io.Closer); ok {
return c.Close()
}
return nil
}

func broker(dst, src io.ReadWriter, dstName, srcName string) error {
_, errCopy := io.Copy(dst, src)
if errCopy != nil {
errCopy = fmt.Errorf("io.Copy(%s, %s): %w", srcName, dstName, errCopy)
}
errCloseRead := closeRead(src)
if errCloseRead != nil {
errCloseRead = fmt.Errorf("closeRead(%s): %w", srcName, errCloseRead)
}
errCloseWrite := closeWrite(dst)
if errCloseWrite != nil {
errCloseWrite = fmt.Errorf("closeWrite(%s): %w", dstName, errCloseWrite)
}
return errors.Join(errCopy, errCloseRead, errCloseWrite)
}

func Bicopy(left, right io.ReadWriter, leftName, rightName string) error {
var wg sync.WaitGroup
wg.Add(2)
go broker(x, y)
go broker(y, x)
finish := make(chan struct{})
var errLeft, errRight error
go func() {
wg.Wait()
close(finish)
errLeft = broker(left, right, leftName, rightName)
}()

select {
case <-quit:
case <-finish:
}
if xCloser, ok := x.(io.Closer); ok {
if err := xCloser.Close(); err != nil {
logrus.WithError(err).Debug("failed to call xCloser.Close")
}
go func() {
errRight = broker(right, left, rightName, leftName)
}()
wg.Wait()
errCloseLeft := ioClose(left)
if errCloseLeft != nil {
errCloseLeft = fmt.Errorf("ioClose(%s): %w", leftName, errCloseLeft)
}
if yCloser, ok := y.(io.Closer); ok {
if err := yCloser.Close(); err != nil {
logrus.WithError(err).Debug("failed to call yCloser.Close")
}
errCloseRight := ioClose(right)
if errCloseRight != nil {
errCloseRight = fmt.Errorf("ioClose(%s): %w", rightName, errCloseRight)
}
<-finish
// TODO: return copied bytes
return errors.Join(errLeft, errRight, errCloseLeft, errCloseRight)
}
17 changes: 8 additions & 9 deletions pkg/hostagent/port_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,22 @@ func (plf *pseudoLoopbackForwarder) Serve() error {
ac.Close()
continue
}
go func(ac *net.TCPConn) {
if fErr := plf.forward(ac); fErr != nil {
logrus.Error(fErr)
}
}(ac)
go plf.forward(ac)
}
}

func (plf *pseudoLoopbackForwarder) forward(ac *net.TCPConn) error {
func (plf *pseudoLoopbackForwarder) forward(ac *net.TCPConn) {
defer ac.Close()
unixConn, err := net.DialUnix("unix", nil, plf.unixAddr)
if err != nil {
return err
logrus.WithError(err).Errorf("pseudoloopback forwarder: failed to dial %q", plf.unixAddr)
return
}
defer unixConn.Close()
bicopy.Bicopy(ac, unixConn, nil)
return nil

if err := bicopy.Bicopy(ac, unixConn, "tcp", "unix"); err != nil {
logrus.WithError(err).Error("pseudoloopback forwarder: failed to forward")
}
}

func (plf *pseudoLoopbackForwarder) Close() error {
Expand Down
21 changes: 3 additions & 18 deletions pkg/portfwd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net"

"github.com/lima-vm/lima/pkg/bicopy"
"github.com/lima-vm/lima/pkg/guestagent/api"
guestagentclient "github.com/lima-vm/lima/pkg/guestagent/api/client"
"github.com/sirupsen/logrus"
Expand All @@ -24,26 +25,10 @@ func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgen
return
}

g, _ := errgroup.WithContext(ctx)

rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr}
g.Go(func() error {
_, err := io.Copy(rw, conn)
if errors.Is(err, io.EOF) {
return nil
}
return err
})
g.Go(func() error {
_, err := io.Copy(conn, rw)
if errors.Is(err, io.EOF) {
return nil
}
return err
})

if err := g.Wait(); err != nil {
logrus.Debugf("error in tcp tunnel for id: %s error:%v", id, err)
if err := bicopy.Bicopy(rw, conn, guestAddr, id); err != nil {
logrus.WithError(err).Error("failed to forward packets")
}
}

Expand Down
23 changes: 4 additions & 19 deletions pkg/vz/network_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import (
"io"
"net"
"os"
"sync"
"syscall"
"time"

"github.com/balajiv113/fd"
"github.com/lima-vm/lima/pkg/bicopy"

"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -66,24 +66,9 @@ func forwardPackets(qemuConn *qemuPacketConn, vzConn *packetConn) {
defer qemuConn.Close()
defer vzConn.Close()

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
if _, err := io.Copy(qemuConn, vzConn); err != nil {
logrus.Errorf("Failed to forward packets from VZ to VMNET: %s", err)
}
}()

go func() {
defer wg.Done()
if _, err := io.Copy(vzConn, qemuConn); err != nil {
logrus.Errorf("Failed to forward packets from VMNET to VZ: %s", err)
}
}()

wg.Wait()
if err := bicopy.Bicopy(qemuConn, vzConn, "VMNET", "VZ"); err != nil {
logrus.WithError(err).Error("failed to forward packets")
}
}

// qemuPacketConn converts raw network packet to a QEMU supported network packet.
Expand Down

0 comments on commit a2a75a4

Please sign in to comment.