Skip to content

Commit

Permalink
fix(wsl-pro-service): Validate port read from the address file (#632)
Browse files Browse the repository at this point in the history
The port in the address file can be negative which would cause errors
down the line. We also want to avoid port 0 as that usually means "any
port" in go land.


Closes #629 

UDENG-2426
  • Loading branch information
EduardGomezEscandell authored Mar 7, 2024
2 parents e1b0699 + d4e4c83 commit a6784f5
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 28 deletions.
39 changes: 33 additions & 6 deletions wsl-pro-service/internal/controlstream/controlstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net"
"os"
"path/filepath"
"strconv"

agentapi "github.com/canonical/ubuntu-pro-for-wsl/agentapi/go"
"github.com/canonical/ubuntu-pro-for-wsl/common"
Expand Down Expand Up @@ -125,20 +126,46 @@ func (cs ControlStream) address(ctx context.Context) (string, error) {
return "", SystemError{err}
}

/*
We parse the port from the file written by the windows agent.
*/
// Parse the port from the file written by the windows agent.
addr, err := os.ReadFile(cs.addrPath)
if err != nil {
return "", fmt.Errorf("could not read agent port file %q: %v", cs.addrPath, err)
}

_, port, err := net.SplitHostPort(string(addr))
port, err := splitPort(string(addr))
if err != nil {
return "", fmt.Errorf("could not parse port from %q: %v", addr, err)
return "", err
}

return net.JoinHostPort(windowsLocalhost.String(), port), nil
// Join the address and port, and validate it.
address := net.JoinHostPort(windowsLocalhost.String(), fmt.Sprint(port))

return address, nil
}

// splitPort splits the port from the address, and validates that the port is a strictly positive integer.
func splitPort(addr string) (p int, err error) {
defer decorate.OnError(&err, "could not parse port from %q", addr)

_, port, err := net.SplitHostPort(addr)
if err != nil {
return 0, fmt.Errorf("could not split address: %v", err)
}

p, err = strconv.Atoi(port)
if err != nil {
return 0, fmt.Errorf("could not parse port as an integer: %v", err)
}

if p == 0 {
return 0, errors.New("port cannot be zero")
}

if p < 0 {
return 0, errors.New("port cannot be negative")
}

return p, nil
}

// ReservedPort returns the port assigned to this distro.
Expand Down
78 changes: 56 additions & 22 deletions wsl-pro-service/internal/controlstream/controlstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,37 +217,71 @@ func TestConnect(t *testing.T) {
func TestSend(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
testCases := map[string]struct {
forcePort string

wantErr bool
}{
"Success": {},

"Error when port is wrong": {forcePort: "123", wantErr: true},
"Error when port is empty": {forcePort: "-", wantErr: true},
"Error when port is not a number": {forcePort: "abc", wantErr: true},
"Error when port is negative": {forcePort: "-1", wantErr: true},
"Error when port is zero": {forcePort: "0", wantErr: true},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

system, mock := testutils.MockSystem(t)
system, mock := testutils.MockSystem(t)

portFile := mock.DefaultAddrFile()
_, agentMetaData := testutils.MockWindowsAgent(t, ctx, portFile)
portFile := mock.DefaultAddrFile()
_, agentMetaData := testutils.MockWindowsAgent(t, ctx, portFile)

// Override the port file with a different port
if tc.forcePort != "" {
newAddr := "127.0.0.1"
if tc.forcePort != "-" {
newAddr = fmt.Sprintf("%s:%s", newAddr, tc.forcePort)
}
err := os.WriteFile(portFile, []byte(newAddr), 0600)
require.NoError(t, err, "Setup: could not overwrite new address file")
}

cs, err := controlstream.New(ctx, system)
require.NoError(t, err, "New should return no error")
cs, err := controlstream.New(ctx, system)
require.NoError(t, err, "New should return no error")

err = cs.Connect(ctx)
require.NoError(t, err, "Connect should have returned no error")
defer cs.Disconnect()
err = cs.Connect(ctx)
if tc.wantErr {
require.Error(t, err, "Connect should have returned an error")
return
}
require.NoError(t, err, "Connect should have returned no error")
defer cs.Disconnect()

require.Equal(t, int32(1), agentMetaData.ConnectionCount.Load(), "The agent should have received one connection via the control stream")
require.Equal(t, int32(1), agentMetaData.RecvCount.Load(), "The agent should have received one message via the control stream")
require.Equal(t, int32(1), agentMetaData.ConnectionCount.Load(), "The agent should have received one connection via the control stream")
require.Equal(t, int32(1), agentMetaData.RecvCount.Load(), "The agent should have received one message via the control stream")

var c net.ListenConfig
l, err := c.Listen(ctx, "tcp4", fmt.Sprintf("localhost:%d", cs.ReservedPort()))
require.NoError(t, err, "could not serve assigned port")
defer l.Close()
var c net.ListenConfig
l, err := c.Listen(ctx, "tcp4", fmt.Sprintf("localhost:%d", cs.ReservedPort()))
require.NoError(t, err, "could not serve assigned port")
defer l.Close()

err = cs.Send(&agentapi.DistroInfo{WslName: "HELLO"})
require.NoError(t, err, "Send should return no error")
err = cs.Send(&agentapi.DistroInfo{WslName: "HELLO"})
require.NoError(t, err, "Send should return no error")

require.Eventually(t, func() bool {
return agentMetaData.RecvCount.Load() > 1
}, 20*time.Second, time.Second, "The agent should have received another message via the control stream")
require.Eventually(t, func() bool {
return agentMetaData.RecvCount.Load() > 1
}, 20*time.Second, time.Second, "The agent should have received another message via the control stream")

require.Equal(t, int32(2), agentMetaData.RecvCount.Load(), "The agent should have received exactly two messages via the control stream")
require.Equal(t, int32(2), agentMetaData.RecvCount.Load(), "The agent should have received exactly two messages via the control stream")
})
}
}

func TestReconnection(t *testing.T) {
Expand Down

0 comments on commit a6784f5

Please sign in to comment.