From 9d213e0b54dc3c9005590b376d8c6d465ae42f09 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 12 Mar 2024 18:05:41 +0100 Subject: [PATCH 01/89] Add fallback retry to daemon (#1690) This change adds a fallback retry to the daemon service. this retry has a larger interval with a shorter max retry run time then others retries --- client/server/server.go | 127 +++++++++++++++++++++++++--- client/server/server_test.go | 157 +++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 12 deletions(-) create mode 100644 client/server/server_test.go diff --git a/client/server/server.go b/client/server/server.go index fc1e4cc2642..90b5bcb642c 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -3,11 +3,15 @@ package server import ( "context" "fmt" + "os" "os/exec" "runtime" + "strconv" "sync" "time" + "github.com/cenkalti/backoff/v4" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/system" @@ -23,7 +27,17 @@ import ( "github.com/netbirdio/netbird/version" ) -const probeThreshold = time.Second * 5 +const ( + probeThreshold = time.Second * 5 + retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME" + maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME" + maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME" + retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER" + defaultInitialRetryTime = 14 * 24 * time.Hour + defaultMaxRetryInterval = 60 * time.Minute + defaultMaxRetryTime = 14 * 24 * time.Hour + defaultRetryMultiplier = 1.7 +) // Server for service control. type Server struct { @@ -125,16 +139,110 @@ func (s *Server) Start() error { } if !config.DisableAutoConnect { - go func() { - if err := internal.RunClientWithProbes(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil { - log.Errorf("init connections: %v", err) - } - }() + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) } return nil } +// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional +// mechanism to keep the client connected even when the connection is lost. +// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. +func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, + mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe) { + backOff := getConnectWithBackoff(ctx) + retryStarted := false + + go func() { + t := time.NewTicker(24 * time.Hour) + for { + select { + case <-ctx.Done(): + t.Stop() + return + case <-t.C: + if retryStarted { + + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + retryStarted = false + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) + } + } + } + } + }() + + runOperation := func() error { + log.Tracef("running client connection") + err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) + if err != nil { + log.Debugf("run client connection exited with error: %v. Will retry in the background", err) + } + + if config.DisableAutoConnect { + return backoff.Permanent(err) + } + + if !retryStarted { + retryStarted = true + backOff.Reset() + } + + log.Tracef("client connection exited") + return fmt.Errorf("client connection exited") + } + + err := backoff.Retry(runOperation, backOff) + if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { + log.Errorf("received an error when trying to connect: %v", err) + } else { + log.Tracef("retry canceled") + } +} + +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } + } + + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + } + return defaultDuration +} + // loginAttempt attempts to login using the provided information. it returns a status in case something fails func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { var status internal.StatusType @@ -445,12 +553,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) - go func() { - if err := internal.RunClientWithProbes(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil { - log.Errorf("run client connection: %v", err) - return - } - }() + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) return &proto.UpResponse{}, nil } diff --git a/client/server/server_test.go b/client/server/server_test.go new file mode 100644 index 00000000000..79a22002311 --- /dev/null +++ b/client/server/server_test.go @@ -0,0 +1,157 @@ +package server + +import ( + "context" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/signal/proto" + signalServer "github.com/netbirdio/netbird/signal/server" +) + +var ( + kaep = keepalive.EnforcementPolicy{ + MinTime: 15 * time.Second, + PermitWithoutStream: true, + } + + kasp = keepalive.ServerParameters{ + MaxConnectionIdle: 15 * time.Second, + MaxConnectionAgeGrace: 5 * time.Second, + Time: 5 * time.Second, + Timeout: 2 * time.Second, + } +) + +// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables +// we will use a management server started via to simulate the server and capture the number of retries +func TestConnectWithRetryRuns(t *testing.T) { + // start the signal server + _, signalAddr, err := startSignal() + if err != nil { + t.Fatalf("failed to start signal server: %v", err) + } + + counter := 0 + // start the management server + _, mgmtAddr, err := startManagement(t, signalAddr, &counter) + if err != nil { + t.Fatalf("failed to start management server: %v", err) + } + + ctx := internal.CtxInitState(context.Background()) + + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second)) + defer cancel() + // create new server + s := New(ctx, t.TempDir()+"/config.json", "debug") + s.latestConfigInput.ManagementURL = "http://" + mgmtAddr + config, err := internal.UpdateOrCreateConfig(s.latestConfigInput) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } + s.config = config + + s.statusRecorder = peer.NewRecorder(config.ManagementURL.String()) + t.Setenv(retryInitialIntervalVar, "1s") + t.Setenv(maxRetryIntervalVar, "2s") + t.Setenv(maxRetryTimeVar, "5s") + t.Setenv(retryMultiplierVar, "1") + + s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + if counter < 3 { + t.Fatalf("expected counter > 2, got %d", counter) + } +} + +type mockServer struct { + mgmtProto.ManagementServiceServer + counter *int +} + +func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { + *m.counter++ + return m.ManagementServiceServer.Login(ctx, req) +} + +func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) { + t.Helper() + dataDir := t.TempDir() + + config := &server.Config{ + Stuns: []*server.Host{}, + TURNConfig: &server.TURNConfig{}, + Signal: &server.Host{ + Proto: "http", + URI: signalAddr, + }, + Datadir: dataDir, + HttpConfig: nil, + } + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, "", err + } + s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) + store, err := server.NewStoreFromJson(config.Datadir, nil) + if err != nil { + return nil, "", err + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + eventStore := &activity.InMemoryEventStore{} + if err != nil { + return nil, "", err + } + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false) + if err != nil { + return nil, "", err + } + turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) + mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + if err != nil { + return nil, "", err + } + mock := &mockServer{ + ManagementServiceServer: mgmtServer, + counter: counter, + } + mgmtProto.RegisterManagementServiceServer(s, mock) + go func() { + if err = s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } + }() + + return s, lis.Addr().String(), nil +} + +func startSignal() (*grpc.Server, string, error) { + s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + + proto.RegisterSignalExchangeServer(s, signalServer.NewServer()) + + go func() { + if err = s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } + }() + + return s, lis.Addr().String(), nil +} From ba33572ec9bb761491b81470f244ccdeea69bb57 Mon Sep 17 00:00:00 2001 From: Krzysztof Nazarewski Date: Tue, 12 Mar 2024 18:29:19 +0100 Subject: [PATCH 02/89] add --service/-s flag for specifying system service name (#1691) --- client/cmd/root.go | 8 ++++++++ client/cmd/service.go | 8 +------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index b3a924016f9..c3ff0a3c876 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -61,6 +61,7 @@ var ( serverSSHAllowed bool interfaceName string wireguardPort uint16 + serviceName string autoConnectDisabled bool rootCmd = &cobra.Command{ Use: "netbird", @@ -100,9 +101,16 @@ func init() { if runtime.GOOS == "windows" { defaultDaemonAddr = "tcp://127.0.0.1:41731" } + + defaultServiceName := "netbird" + if runtime.GOOS == "windows" { + defaultServiceName = "Netbird" + } + rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL)) rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL)) + rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout") diff --git a/client/cmd/service.go b/client/cmd/service.go index 18fe5d6212b..5c60744f96c 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -2,8 +2,6 @@ package cmd import ( "context" - "runtime" - "github.com/kardianos/service" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -24,12 +22,8 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program { } func newSVCConfig() *service.Config { - name := "netbird" - if runtime.GOOS == "windows" { - name = "Netbird" - } return &service.Config{ - Name: name, + Name: serviceName, DisplayName: "Netbird", Description: "A WireGuard-based mesh network that connects your devices into a single private network.", Option: make(service.KeyValue), From 4a1aee1ae0188191fc97bb732727820260770d7c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 12 Mar 2024 19:06:16 +0100 Subject: [PATCH 03/89] Add routes and dns servers to status command (#1680) * Add routes (client and server) to status command * Add DNS servers to status output --- .github/workflows/mobile-build-validation.yml | 6 +- client/cmd/status.go | 126 ++++++-- client/cmd/status_test.go | 122 ++++++- client/internal/dns/server.go | 114 ++++++- client/internal/dns/server_test.go | 16 +- client/internal/dns/upstream.go | 36 ++- client/internal/dns/upstream_ios.go | 12 +- client/internal/dns/upstream_nonios.go | 12 +- client/internal/dns/upstream_test.go | 4 +- client/internal/engine.go | 13 +- client/internal/peer/status.go | 36 +++ client/internal/routemanager/client.go | 20 ++ client/internal/routemanager/manager.go | 2 +- .../internal/routemanager/server_android.go | 3 +- .../routemanager/server_nonandroid.go | 40 ++- client/proto/daemon.pb.go | 305 ++++++++++++------ client/proto/daemon.proto | 10 + client/server/server.go | 18 +- go.mod | 2 + go.sum | 4 + 20 files changed, 722 insertions(+), 179 deletions(-) diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index 8cb0eed74c3..85296484229 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -11,7 +11,7 @@ concurrency: cancel-in-progress: true jobs: - andrloid_build: + android_build: runs-on: ubuntu-latest steps: - name: Checkout repository @@ -41,7 +41,7 @@ jobs: run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda - name: gomobile init run: gomobile init - - name: build android nebtird lib + - name: build android netbird lib run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android env: CGO_ENABLED: 0 @@ -59,7 +59,7 @@ jobs: run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda - name: gomobile init run: gomobile init - - name: build iOS nebtird lib + - name: build iOS netbird lib run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK env: CGO_ENABLED: 0 \ No newline at end of file diff --git a/client/cmd/status.go b/client/cmd/status.go index fded7dff8fe..4c7218fde94 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -35,6 +35,7 @@ type peerStateDetailOutput struct { TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` TransferSent int64 `json:"transferSent" yaml:"transferSent"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` + Routes []string `json:"routes" yaml:"routes"` } type peersStateOutput struct { @@ -72,19 +73,28 @@ type iceCandidateType struct { Remote string `json:"remote" yaml:"remote"` } +type nsServerGroupStateOutput struct { + Servers []string `json:"servers" yaml:"servers"` + Domains []string `json:"domains" yaml:"domains"` + Enabled bool `json:"enabled" yaml:"enabled"` + Error string `json:"error" yaml:"error"` +} + type statusOutputOverview struct { - Peers peersStateOutput `json:"peers" yaml:"peers"` - CliVersion string `json:"cliVersion" yaml:"cliVersion"` - DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"` - ManagementState managementStateOutput `json:"management" yaml:"management"` - SignalState signalStateOutput `json:"signal" yaml:"signal"` - Relays relayStateOutput `json:"relays" yaml:"relays"` - IP string `json:"netbirdIp" yaml:"netbirdIp"` - PubKey string `json:"publicKey" yaml:"publicKey"` - KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"` - FQDN string `json:"fqdn" yaml:"fqdn"` - RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` - RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` + Peers peersStateOutput `json:"peers" yaml:"peers"` + CliVersion string `json:"cliVersion" yaml:"cliVersion"` + DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"` + ManagementState managementStateOutput `json:"management" yaml:"management"` + SignalState signalStateOutput `json:"signal" yaml:"signal"` + Relays relayStateOutput `json:"relays" yaml:"relays"` + IP string `json:"netbirdIp" yaml:"netbirdIp"` + PubKey string `json:"publicKey" yaml:"publicKey"` + KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"` + FQDN string `json:"fqdn" yaml:"fqdn"` + RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` + RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` + Routes []string `json:"routes" yaml:"routes"` + NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` } var ( @@ -168,7 +178,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { case yamlFlag: statusOutputString, err = parseToYAML(outputInformationHolder) default: - statusOutputString = parseGeneralSummary(outputInformationHolder, false, false) + statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false) } if err != nil { @@ -268,6 +278,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), + Routes: pbFullStatus.GetLocalPeerState().GetRoutes(), + NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), } return overview @@ -299,6 +311,19 @@ func mapRelays(relays []*proto.RelayState) relayStateOutput { } } +func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput { + mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers)) + for _, pbNsGroupServer := range servers { + mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{ + Servers: pbNsGroupServer.GetServers(), + Domains: pbNsGroupServer.GetDomains(), + Enabled: pbNsGroupServer.GetEnabled(), + Error: pbNsGroupServer.GetError(), + }) + } + return mappedNSGroups +} + func mapPeers(peers []*proto.PeerState) peersStateOutput { var peersStateDetail []peerStateDetailOutput localICE := "" @@ -352,6 +377,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { TransferReceived: transferReceived, TransferSent: transferSent, RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), + Routes: pbPeerState.GetRoutes(), } peersStateDetail = append(peersStateDetail, peerState) @@ -401,8 +427,7 @@ func parseToYAML(overview statusOutputOverview) (string, error) { return string(yamlBytes), nil } -func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool) string { - +func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string { var managementConnString string if overview.ManagementState.Connected { managementConnString = "Connected" @@ -438,7 +463,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays interfaceIP = "N/A" } - var relayAvailableString string + var relaysString string if showRelays { for _, relay := range overview.Relays.Details { available := "Available" @@ -447,15 +472,46 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays available = "Unavailable" reason = fmt.Sprintf(", reason: %s", relay.Error) } - relayAvailableString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) - + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { + relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) + } - relayAvailableString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) + routes := "-" + if len(overview.Routes) > 0 { + sort.Strings(overview.Routes) + routes = strings.Join(overview.Routes, ", ") } - peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) + var dnsServersString string + if showNameServers { + for _, nsServerGroup := range overview.NSServerGroups { + enabled := "Available" + if !nsServerGroup.Enabled { + enabled = "Unavailable" + } + errorString := "" + if nsServerGroup.Error != "" { + errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error) + errorString = strings.TrimSpace(errorString) + } + + domainsString := strings.Join(nsServerGroup.Domains, ", ") + if domainsString == "" { + domainsString = "." // Show "." for the default zone + } + dnsServersString += fmt.Sprintf( + "\n [%s] for [%s] is %s%s", + strings.Join(nsServerGroup.Servers, ", "), + domainsString, + enabled, + errorString, + ) + } + } else { + dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups)) + } rosenpassEnabledStatus := "false" if overview.RosenpassEnabled { @@ -465,26 +521,32 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays } } + peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) + summary := fmt.Sprintf( "Daemon version: %s\n"+ "CLI version: %s\n"+ "Management: %s\n"+ "Signal: %s\n"+ "Relays: %s\n"+ + "Nameservers: %s\n"+ "FQDN: %s\n"+ "NetBird IP: %s\n"+ "Interface type: %s\n"+ "Quantum resistance: %s\n"+ + "Routes: %s\n"+ "Peers count: %s\n", overview.DaemonVersion, version.NetbirdVersion(), managementConnString, signalConnString, - relayAvailableString, + relaysString, + dnsServersString, overview.FQDN, interfaceIP, interfaceTypeString, rosenpassEnabledStatus, + routes, peersCountString, ) return summary @@ -492,7 +554,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays func parseToFullDetailSummary(overview statusOutputOverview) string { parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) - summary := parseGeneralSummary(overview, true, true) + summary := parseGeneralSummary(overview, true, true, true) return fmt.Sprintf( "Peers detail:"+ @@ -556,6 +618,12 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo } } + routes := "-" + if len(peerState.Routes) > 0 { + sort.Strings(peerState.Routes) + routes = strings.Join(peerState.Routes, ", ") + } + peerString := fmt.Sprintf( "\n %s:\n"+ " NetBird IP: %s\n"+ @@ -569,7 +637,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Last connection update: %s\n"+ " Last WireGuard handshake: %s\n"+ " Transfer status (received/sent) %s/%s\n"+ - " Quantum resistance: %s\n", + " Quantum resistance: %s\n"+ + " Routes: %s\n", peerState.FQDN, peerState.IP, peerState.PubKey, @@ -585,6 +654,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo toIEC(peerState.TransferReceived), toIEC(peerState.TransferSent), rosenpassEnabledStatus, + routes, ) peersString += peerString @@ -638,3 +708,13 @@ func toIEC(b int64) string { return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp]) } + +func countEnabled(dnsServers []nsServerGroupStateOutput) int { + count := 0 + for _, server := range dnsServers { + if server.Enabled { + count++ + } + } + return count +} diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index b5db576e47b..ea6980c3df7 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -42,6 +42,9 @@ var resp = &proto.StatusResponse{ LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), BytesRx: 200, BytesTx: 100, + Routes: []string{ + "10.1.0.0/24", + }, }, { IP: "192.168.178.102", @@ -87,6 +90,31 @@ var resp = &proto.StatusResponse{ PubKey: "Some-Pub-Key", KernelInterface: true, Fqdn: "some-localhost.awesome-domain.com", + Routes: []string{ + "10.10.0.0/24", + }, + }, + DnsServers: []*proto.NSGroupState{ + { + Servers: []string{ + "8.8.8.8:53", + }, + Domains: nil, + Enabled: true, + Error: "", + }, + { + Servers: []string{ + "1.1.1.1:53", + "2.2.2.2:53", + }, + Domains: []string{ + "example.com", + "example.net", + }, + Enabled: false, + Error: "timeout", + }, }, }, DaemonVersion: "0.14.1", @@ -116,6 +144,9 @@ var overview = statusOutputOverview{ LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC), TransferReceived: 200, TransferSent: 100, + Routes: []string{ + "10.1.0.0/24", + }, }, { IP: "192.168.178.102", @@ -171,6 +202,31 @@ var overview = statusOutputOverview{ PubKey: "Some-Pub-Key", KernelInterface: true, FQDN: "some-localhost.awesome-domain.com", + NSServerGroups: []nsServerGroupStateOutput{ + { + Servers: []string{ + "8.8.8.8:53", + }, + Domains: nil, + Enabled: true, + Error: "", + }, + { + Servers: []string{ + "1.1.1.1:53", + "2.2.2.2:53", + }, + Domains: []string{ + "example.com", + "example.net", + }, + Enabled: false, + Error: "timeout", + }, + }, + Routes: []string{ + "10.10.0.0/24", + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -232,7 +288,10 @@ func TestParsingToJSON(t *testing.T) { "lastWireguardHandshake": "2001-01-01T01:01:02Z", "transferReceived": 200, "transferSent": 100, - "quantumResistance":false + "quantumResistance": false, + "routes": [ + "10.1.0.0/24" + ] }, { "fqdn": "peer-2.awesome-domain.com", @@ -253,7 +312,8 @@ func TestParsingToJSON(t *testing.T) { "lastWireguardHandshake": "2002-02-02T02:02:03Z", "transferReceived": 2000, "transferSent": 1000, - "quantumResistance":false + "quantumResistance": false, + "routes": null } ] }, @@ -289,8 +349,33 @@ func TestParsingToJSON(t *testing.T) { "publicKey": "Some-Pub-Key", "usesKernelInterface": true, "fqdn": "some-localhost.awesome-domain.com", - "quantumResistance":false, - "quantumResistancePermissive":false + "quantumResistance": false, + "quantumResistancePermissive": false, + "routes": [ + "10.10.0.0/24" + ], + "dnsServers": [ + { + "servers": [ + "8.8.8.8:53" + ], + "domains": null, + "enabled": true, + "error": "" + }, + { + "servers": [ + "1.1.1.1:53", + "2.2.2.2:53" + ], + "domains": [ + "example.com", + "example.net" + ], + "enabled": false, + "error": "timeout" + } + ] }` // @formatter:on @@ -325,6 +410,8 @@ func TestParsingToYAML(t *testing.T) { transferReceived: 200 transferSent: 100 quantumResistance: false + routes: + - 10.1.0.0/24 - fqdn: peer-2.awesome-domain.com netbirdIp: 192.168.178.102 publicKey: Pubkey2 @@ -342,6 +429,7 @@ func TestParsingToYAML(t *testing.T) { transferReceived: 2000 transferSent: 1000 quantumResistance: false + routes: [] cliVersion: development daemonVersion: 0.14.1 management: @@ -368,6 +456,22 @@ usesKernelInterface: true fqdn: some-localhost.awesome-domain.com quantumResistance: false quantumResistancePermissive: false +routes: + - 10.10.0.0/24 +dnsServers: + - servers: + - 8.8.8.8:53 + domains: [] + enabled: true + error: "" + - servers: + - 1.1.1.1:53 + - 2.2.2.2:53 + domains: + - example.com + - example.net + enabled: false + error: timeout ` assert.Equal(t, expectedYAML, yaml) @@ -391,6 +495,7 @@ func TestParsingToDetail(t *testing.T) { Last WireGuard handshake: 2001-01-01 01:01:02 Transfer status (received/sent) 200 B/100 B Quantum resistance: false + Routes: 10.1.0.0/24 peer-2.awesome-domain.com: NetBird IP: 192.168.178.102 @@ -405,6 +510,7 @@ func TestParsingToDetail(t *testing.T) { Last WireGuard handshake: 2002-02-02 02:02:03 Transfer status (received/sent) 2.0 KiB/1000 B Quantum resistance: false + Routes: - Daemon version: 0.14.1 CLI version: development @@ -413,10 +519,14 @@ Signal: Connected to my-awesome-signal.com:443 Relays: [stun:my-awesome-stun.com:3478] is Available [turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded +Nameservers: + [8.8.8.8:53] for [.] is Available + [1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout FQDN: some-localhost.awesome-domain.com NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false +Routes: 10.10.0.0/24 Peers count: 2/2 Connected ` @@ -424,7 +534,7 @@ Peers count: 2/2 Connected } func TestParsingToShortVersion(t *testing.T) { - shortVersion := parseGeneralSummary(overview, false, false) + shortVersion := parseGeneralSummary(overview, false, false, false) expectedString := `Daemon version: 0.14.1 @@ -432,10 +542,12 @@ CLI version: development Management: Connected Signal: Connected Relays: 1/2 Available +Nameservers: 1/2 Available FQDN: some-localhost.awesome-domain.com NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false +Routes: 10.10.0.0/24 Peers count: 2/2 Connected ` diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 9986f632ec5..dff44f01d1d 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "strings" "sync" "github.com/miekg/dns" @@ -11,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" ) @@ -59,6 +61,8 @@ type DefaultServer struct { // make sense on mobile only searchDomainNotifier *notifier iosDnsManager IosDnsManager + + statusRecorder *peer.Status } type handlerWithStop interface { @@ -73,7 +77,12 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) { +func NewDefaultServer( + ctx context.Context, + wgInterface WGIface, + customAddress string, + statusRecorder *peer.Status, +) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(customAddress) @@ -90,13 +99,20 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService), nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems -func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer { +func NewDefaultServerPermanentUpstream( + ctx context.Context, + wgInterface WGIface, + hostsDnsList []string, + config nbdns.Config, + listener listener.NetworkChangeListener, + statusRecorder *peer.Status, +) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface)) + ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) ds.permanent = true ds.hostsDnsList = hostsDnsList ds.addHostRootZone() @@ -108,13 +124,18 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, } // NewDefaultServerIos returns a new dns server. It optimized for ios -func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface)) +func NewDefaultServerIos( + ctx context.Context, + wgInterface WGIface, + iosDnsManager IosDnsManager, + statusRecorder *peer.Status, +) *DefaultServer { + ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) ds.iosDnsManager = iosDnsManager return ds } -func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer { +func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ ctx: ctx, @@ -124,7 +145,8 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi localResolver: &localResolver{ registeredMap: make(registrationMap), }, - wgInterface: wgInterface, + wgInterface: wgInterface, + statusRecorder: statusRecorder, } return defaultServer @@ -299,6 +321,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) } + s.updateNSGroupStates(update.NameServerGroups) + return nil } @@ -338,7 +362,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam continue } - handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network) + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface.Name(), + s.wgInterface.Address().IP, + s.wgInterface.Address().Network, + s.statusRecorder, + ) if err != nil { return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) } @@ -460,14 +490,14 @@ func getNSHostPort(ns nbdns.NameServer) string { func (s *DefaultServer) upstreamCallbacks( nsGroup *nbdns.NameServerGroup, handler dns.Handler, -) (deactivate func(), reactivate func()) { +) (deactivate func(error), reactivate func()) { var removeIndex map[string]int - deactivate = func() { + deactivate = func(err error) { s.mux.Lock() defer s.mux.Unlock() l := log.WithField("nameservers", nsGroup.NameServers) - l.Info("temporary deactivate nameservers group due timeout") + l.Info("Temporarily deactivating nameservers group due to timeout") removeIndex = make(map[string]int) for _, domain := range nsGroup.Domains { @@ -486,8 +516,11 @@ func (s *DefaultServer) upstreamCallbacks( } } if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - l.WithError(err).Error("fail to apply nameserver deactivation on the host") + l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } + + s.updateNSState(nsGroup, err, false) + } reactivate = func() { s.mux.Lock() @@ -510,12 +543,20 @@ func (s *DefaultServer) upstreamCallbacks( if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") } + + s.updateNSState(nsGroup, nil, true) } return } func (s *DefaultServer) addHostRootZone() { - handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network) + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface.Name(), + s.wgInterface.Address().IP, + s.wgInterface.Address().Network, + s.statusRecorder, + ) if err != nil { log.Errorf("unable to create a new upstream resolver, error: %v", err) return @@ -535,7 +576,50 @@ func (s *DefaultServer) addHostRootZone() { handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString) } - handler.deactivate = func() {} + handler.deactivate = func(error) {} handler.reactivate = func() {} s.service.RegisterMux(nbdns.RootZone, handler) } + +func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { + var states []peer.NSGroupState + + for _, group := range groups { + var servers []string + for _, ns := range group.NameServers { + servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) + } + + state := peer.NSGroupState{ + ID: generateGroupKey(group), + Servers: servers, + Domains: group.Domains, + // The probe will determine the state, default enabled + Enabled: true, + Error: nil, + } + states = append(states, state) + } + s.statusRecorder.UpdateDNSStates(states) +} + +func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) { + states := s.statusRecorder.GetDNSStates() + id := generateGroupKey(nsGroup) + for i, state := range states { + if state.ID == id { + states[i].Enabled = enabled + states[i].Error = err + break + } + } + s.statusRecorder.UpdateDNSStates(states) +} + +func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { + var servers []string + for _, ns := range nsGroup.NameServers { + servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) + } + return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ",")) +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 68c4992d816..f3282f1f49c 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -15,6 +15,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" @@ -274,7 +275,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "") + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) if err != nil { t.Fatal(err) } @@ -375,7 +376,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "") + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -470,7 +471,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}) if err != nil { t.Fatalf("%v", err) } @@ -541,6 +542,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { {false, "domain2", false}, }, }, + statusRecorder: &peer.Status{}, } var domainsUpdate string @@ -563,7 +565,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { }, }, nil) - deactivate() + deactivate(nil) expected := "domain0,domain2" domains := []string{} for _, item := range server.currentConfig.Domains { @@ -601,7 +603,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { var dnsList []string dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -625,7 +627,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -717,7 +719,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 9fd524700a5..cc31559fab1 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -11,8 +11,11 @@ import ( "time" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" ) const ( @@ -45,12 +48,13 @@ type upstreamResolverBase struct { reactivatePeriod time.Duration upstreamTimeout time.Duration - deactivate func() - reactivate func() + deactivate func(error) + reactivate func() + statusRecorder *peer.Status } -func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase { - ctx, cancel := context.WithCancel(parentCTX) +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase { + ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ ctx: ctx, @@ -58,6 +62,7 @@ func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase { upstreamTimeout: upstreamTimeout, reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, + statusRecorder: statusRecorder, } } @@ -68,7 +73,10 @@ func (u *upstreamResolverBase) stop() { // ServeDNS handles a DNS request func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - defer u.checkUpstreamFails() + var err error + defer func() { + u.checkUpstreamFails(err) + }() log.WithField("question", r.Question[0]).Trace("received an upstream question") @@ -81,7 +89,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { for _, upstream := range u.upstreamServers { var rm *dns.Msg var t time.Duration - var err error func() { ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) @@ -132,7 +139,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If fails count is greater that failsTillDeact, upstream resolving // will be disabled for reactivatePeriod, after that time period fails counter // will be reset and upstream will be reactivated. -func (u *upstreamResolverBase) checkUpstreamFails() { +func (u *upstreamResolverBase) checkUpstreamFails(err error) { u.mutex.Lock() defer u.mutex.Unlock() @@ -146,7 +153,7 @@ func (u *upstreamResolverBase) checkUpstreamFails() { default: } - u.disable() + u.disable(err) } // probeAvailability tests all upstream servers simultaneously and @@ -165,13 +172,16 @@ func (u *upstreamResolverBase) probeAvailability() { var mu sync.Mutex var wg sync.WaitGroup + var errors *multierror.Error for _, upstream := range u.upstreamServers { upstream := upstream wg.Add(1) go func() { defer wg.Done() - if err := u.testNameserver(upstream); err != nil { + err := u.testNameserver(upstream) + if err != nil { + errors = multierror.Append(errors, err) log.Warnf("probing upstream nameserver %s: %s", upstream, err) return } @@ -186,7 +196,7 @@ func (u *upstreamResolverBase) probeAvailability() { // didn't find a working upstream server, let's disable and try later if !success { - u.disable() + u.disable(errors.ErrorOrNil()) } } @@ -245,15 +255,15 @@ func isTimeout(err error) bool { return false } -func (u *upstreamResolverBase) disable() { +func (u *upstreamResolverBase) disable(err error) { if u.disabled { return } // todo test the deactivation logic, it seems to affect the client if runtime.GOOS != "ios" { - log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod) - u.deactivate() + log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) + u.deactivate(err) u.disabled = true go u.waitUntilResponse() } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 33937d8d88d..c9d3bb942b4 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -11,6 +11,8 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/peer" ) type upstreamResolverIOS struct { @@ -20,8 +22,14 @@ type upstreamResolverIOS struct { iIndex int } -func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(parentCTX) +func newUpstreamResolver( + ctx context.Context, + interfaceName string, + ip net.IP, + net *net.IPNet, + statusRecorder *peer.Status, +) (*upstreamResolverIOS, error) { + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) index, err := getInterfaceIndex(interfaceName) if err != nil { diff --git a/client/internal/dns/upstream_nonios.go b/client/internal/dns/upstream_nonios.go index 93e523c4ef3..22bd24ca989 100644 --- a/client/internal/dns/upstream_nonios.go +++ b/client/internal/dns/upstream_nonios.go @@ -8,14 +8,22 @@ import ( "time" "github.com/miekg/dns" + + "github.com/netbirdio/netbird/client/internal/peer" ) type upstreamResolverNonIOS struct { *upstreamResolverBase } -func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(parentCTX) +func newUpstreamResolver( + ctx context.Context, + _ string, + _ net.IP, + _ *net.IPNet, + statusRecorder *peer.Status, +) (*upstreamResolverNonIOS, error) { + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) nonIOS := &upstreamResolverNonIOS{ upstreamResolverBase: upstreamResolverBase, } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 13610df4188..77851dd9d64 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}) + resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil) resolver.upstreamServers = testCase.InputServers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { @@ -131,7 +131,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { } failed := false - resolver.deactivate = func() { + resolver.deactivate = func(error) { failed = true } diff --git a/client/internal/engine.go b/client/internal/engine.go index e4f0f236d4c..9e52cea4446 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1188,14 +1188,21 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { if err != nil { return nil, nil, err } - dnsServer := dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener) + dnsServer := dns.NewDefaultServerPermanentUpstream( + e.ctx, + e.wgInterface, + e.mobileDep.HostDNSAddresses, + *dnsConfig, + e.mobileDep.NetworkChangeListener, + e.statusRecorder, + ) go e.mobileDep.DnsReadyListener.OnReady() return routes, dnsServer, nil case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) return nil, dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder) if err != nil { return nil, nil, err } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 87338f646d7..1e252c5dd48 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -29,6 +29,7 @@ type State struct { BytesTx int64 BytesRx int64 RosenpassEnabled bool + Routes map[string]struct{} } // LocalPeerState contains the latest state of the local peer @@ -37,6 +38,7 @@ type LocalPeerState struct { PubKey string KernelInterface bool FQDN string + Routes map[string]struct{} } // SignalState contains the latest state of a signal connection @@ -59,6 +61,16 @@ type RosenpassState struct { Permissive bool } +// NSGroupState represents the status of a DNS server group, including associated domains, +// whether it's enabled, and the last error message encountered during probing. +type NSGroupState struct { + ID string + Servers []string + Domains []string + Enabled bool + Error error +} + // FullStatus contains the full state held by the Status instance type FullStatus struct { Peers []State @@ -67,6 +79,7 @@ type FullStatus struct { LocalPeerState LocalPeerState RosenpassState RosenpassState Relays []relay.ProbeResult + NSGroupStates []NSGroupState } // Status holds a state of peers, signal, management connections and relays @@ -86,6 +99,7 @@ type Status struct { notifier *notifier rosenpassEnabled bool rosenpassPermissive bool + nsGroupStates []NSGroupState // To reduce the number of notification invocation this bool will be true when need to call the notification // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events @@ -174,6 +188,10 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } + if receivedState.Routes != nil { + peerState.Routes = receivedState.Routes + } + skipNotification := shouldSkipNotify(receivedState, peerState) if receivedState.ConnStatus != peerState.ConnStatus { @@ -278,6 +296,13 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { return ch } +// GetLocalPeerState returns the local peer state +func (d *Status) GetLocalPeerState() LocalPeerState { + d.mux.Lock() + defer d.mux.Unlock() + return d.localPeer +} + // UpdateLocalPeerState updates local peer status func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.mux.Lock() @@ -364,6 +389,12 @@ func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) { d.relayStates = relayResults } +func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { + d.mux.Lock() + defer d.mux.Unlock() + d.nsGroupStates = dnsStates +} + func (d *Status) GetRosenpassState() RosenpassState { return RosenpassState{ d.rosenpassEnabled, @@ -409,6 +440,10 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { return d.relayStates } +func (d *Status) GetDNSStates() []NSGroupState { + return d.nsGroupStates +} + // GetFullStatus gets full status func (d *Status) GetFullStatus() FullStatus { d.mux.Lock() @@ -420,6 +455,7 @@ func (d *Status) GetFullStatus() FullStatus { LocalPeerState: d.localPeer, Relays: d.GetRelayStates(), RosenpassState: d.GetRosenpassState(), + NSGroupStates: d.GetDNSStates(), } for _, status := range d.peers { diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index ee98d503de5..f7ead582720 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -160,6 +160,12 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { if err != nil { return err } + + delete(state.Routes, c.network.String()) + if err := c.statusRecorder.UpdatePeerState(state); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } + if state.ConnStatus != peer.StatusConnected { return nil } @@ -225,6 +231,20 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } c.chosenRoute = c.routes[chosen] + + state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer) + if err != nil { + log.Errorf("Failed to get peer state: %v", err) + } else { + if state.Routes == nil { + state.Routes = map[string]struct{}{} + } + state.Routes[c.network.String()] = struct{}{} + if err := c.statusRecorder.UpdatePeerState(state); err != nil { + log.Warnf("Failed to update peer state: %v", err) + } + } + err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) if err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e8a4bd1341f..fde943757f2 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -58,7 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error - m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall) + m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) if err != nil { return err } diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 1918c7f6f12..b4065bca69d 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -7,9 +7,10 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" ) -func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) { +func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) { return nil, fmt.Errorf("server route not supported on this os") } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 20e500e7944..19236787772 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -10,24 +10,27 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) type defaultServerRouter struct { - mux sync.Mutex - ctx context.Context - routes map[string]*route.Route - firewall firewall.Manager - wgInterface *iface.WGIface + mux sync.Mutex + ctx context.Context + routes map[string]*route.Route + firewall firewall.Manager + wgInterface *iface.WGIface + statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) { +func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { return &defaultServerRouter{ - ctx: ctx, - routes: make(map[string]*route.Route), - firewall: firewall, - wgInterface: wgInterface, + ctx: ctx, + routes: make(map[string]*route.Route), + firewall: firewall, + wgInterface: wgInterface, + statusRecorder: statusRecorder, }, nil } @@ -88,6 +91,11 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error return err } delete(m.routes, route.ID) + + state := m.statusRecorder.GetLocalPeerState() + delete(state.Routes, route.Network.String()) + m.statusRecorder.UpdateLocalPeerState(state) + return nil } } @@ -105,6 +113,14 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { return err } m.routes[route.ID] = route + + state := m.statusRecorder.GetLocalPeerState() + if state.Routes == nil { + state.Routes = map[string]struct{}{} + } + state.Routes[route.Network.String()] = struct{}{} + m.statusRecorder.UpdateLocalPeerState(state) + return nil } } @@ -117,6 +133,10 @@ func (m *defaultServerRouter) cleanUp() { if err != nil { log.Warnf("failed to remove clean up route: %s", r.ID) } + + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index a1c3aef1164..869eceee550 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -772,6 +772,7 @@ type PeerState struct { BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` } func (x *PeerState) Reset() { @@ -911,18 +912,26 @@ func (x *PeerState) GetRosenpassEnabled() bool { return false } +func (x *PeerState) GetRoutes() []string { + if x != nil { + return x.Routes + } + return nil +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` - PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` - KernelInterface bool `protobuf:"varint,3,opt,name=kernelInterface,proto3" json:"kernelInterface,omitempty"` - Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` - RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` + PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` + KernelInterface bool `protobuf:"varint,3,opt,name=kernelInterface,proto3" json:"kernelInterface,omitempty"` + Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + Routes []string `protobuf:"bytes,7,rep,name=routes,proto3" json:"routes,omitempty"` } func (x *LocalPeerState) Reset() { @@ -999,6 +1008,13 @@ func (x *LocalPeerState) GetRosenpassPermissive() bool { return false } +func (x *LocalPeerState) GetRoutes() []string { + if x != nil { + return x.Routes + } + return nil +} + // SignalState contains the latest state of a signal connection type SignalState struct { state protoimpl.MessageState @@ -1191,6 +1207,77 @@ func (x *RelayState) GetError() string { return "" } +type NSGroupState struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Servers []string `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"` + Domains []string `protobuf:"bytes,2,rep,name=domains,proto3" json:"domains,omitempty"` + Enabled bool `protobuf:"varint,3,opt,name=enabled,proto3" json:"enabled,omitempty"` + Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` +} + +func (x *NSGroupState) Reset() { + *x = NSGroupState{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *NSGroupState) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NSGroupState) ProtoMessage() {} + +func (x *NSGroupState) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[17] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NSGroupState.ProtoReflect.Descriptor instead. +func (*NSGroupState) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{17} +} + +func (x *NSGroupState) GetServers() []string { + if x != nil { + return x.Servers + } + return nil +} + +func (x *NSGroupState) GetDomains() []string { + if x != nil { + return x.Domains + } + return nil +} + +func (x *NSGroupState) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +func (x *NSGroupState) GetError() string { + if x != nil { + return x.Error + } + return "" +} + // FullStatus contains the full state held by the Status instance type FullStatus struct { state protoimpl.MessageState @@ -1202,12 +1289,13 @@ type FullStatus struct { LocalPeerState *LocalPeerState `protobuf:"bytes,3,opt,name=localPeerState,proto3" json:"localPeerState,omitempty"` Peers []*PeerState `protobuf:"bytes,4,rep,name=peers,proto3" json:"peers,omitempty"` Relays []*RelayState `protobuf:"bytes,5,rep,name=relays,proto3" json:"relays,omitempty"` + DnsServers []*NSGroupState `protobuf:"bytes,6,rep,name=dns_servers,json=dnsServers,proto3" json:"dns_servers,omitempty"` } func (x *FullStatus) Reset() { *x = FullStatus{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1220,7 +1308,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1233,7 +1321,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{17} + return file_daemon_proto_rawDescGZIP(), []int{18} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -1271,6 +1359,13 @@ func (x *FullStatus) GetRelays() []*RelayState { return nil } +func (x *FullStatus) GetDnsServers() []*NSGroupState { + if x != nil { + return x.DnsServers + } + return nil +} + var File_daemon_proto protoreflect.FileDescriptor var file_daemon_proto_rawDesc = []byte{ @@ -1380,7 +1475,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x22, 0x81, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, + 0x22, 0x99, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, @@ -1420,20 +1515,23 @@ var file_daemon_proto_rawDesc = []byte{ 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x22, 0xd4, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, - 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, - 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, - 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, - 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, + 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0xec, 0x01, 0x0a, + 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, + 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, + 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, + 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, - 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, - 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0x53, 0x0a, 0x0b, 0x53, + 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, + 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, + 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, @@ -1449,50 +1547,61 @@ var file_daemon_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x9b, 0x02, - 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, - 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, - 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, - 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, - 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, - 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, - 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, - 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, - 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, - 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, - 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, - 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, - 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, + 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, + 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, + 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, + 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, + 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, + 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, + 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, + 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, + 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, + 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, + 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, + 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, + 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, + 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, + 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, + 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, + 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( @@ -1507,7 +1616,7 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 18) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 19) var file_daemon_proto_goTypes = []interface{}{ (*LoginRequest)(nil), // 0: daemon.LoginRequest (*LoginResponse)(nil), // 1: daemon.LoginResponse @@ -1526,35 +1635,37 @@ var file_daemon_proto_goTypes = []interface{}{ (*SignalState)(nil), // 14: daemon.SignalState (*ManagementState)(nil), // 15: daemon.ManagementState (*RelayState)(nil), // 16: daemon.RelayState - (*FullStatus)(nil), // 17: daemon.FullStatus - (*timestamp.Timestamp)(nil), // 18: google.protobuf.Timestamp + (*NSGroupState)(nil), // 17: daemon.NSGroupState + (*FullStatus)(nil), // 18: daemon.FullStatus + (*timestamp.Timestamp)(nil), // 19: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 17, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 18, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 18, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 18, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 19, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 19, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 15, // 3: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 14, // 4: daemon.FullStatus.signalState:type_name -> daemon.SignalState 13, // 5: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 12, // 6: daemon.FullStatus.peers:type_name -> daemon.PeerState 16, // 7: daemon.FullStatus.relays:type_name -> daemon.RelayState - 0, // 8: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 2, // 9: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 4, // 10: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 6, // 11: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 8, // 12: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 10, // 13: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 1, // 14: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 3, // 15: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 5, // 16: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 7, // 17: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 9, // 18: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 11, // 19: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 14, // [14:20] is the sub-list for method output_type - 8, // [8:14] is the sub-list for method input_type - 8, // [8:8] is the sub-list for extension type_name - 8, // [8:8] is the sub-list for extension extendee - 0, // [0:8] is the sub-list for field type_name + 17, // 8: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 0, // 9: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 2, // 10: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 4, // 11: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 6, // 12: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 8, // 13: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 10, // 14: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 1, // 15: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 3, // 16: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 5, // 17: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 7, // 18: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 9, // 19: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 11, // 20: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 15, // [15:21] is the sub-list for method output_type + 9, // [9:15] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -1768,6 +1879,18 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*NSGroupState); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*FullStatus); i { case 0: return &v.state @@ -1787,7 +1910,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 0, - NumMessages: 18, + NumMessages: 19, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 2858ba2e050..bdb1cb83eea 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -141,6 +141,7 @@ message PeerState { int64 bytesRx = 13; int64 bytesTx = 14; bool rosenpassEnabled = 15; + repeated string routes = 16; } // LocalPeerState contains the latest state of the local peer @@ -151,6 +152,7 @@ message LocalPeerState { string fqdn = 4; bool rosenpassEnabled = 5; bool rosenpassPermissive = 6; + repeated string routes = 7; } // SignalState contains the latest state of a signal connection @@ -174,6 +176,13 @@ message RelayState { string error = 3; } +message NSGroupState { + repeated string servers = 1; + repeated string domains = 2; + bool enabled = 3; + string error = 4; +} + // FullStatus contains the full state held by the Status instance message FullStatus { ManagementState managementState = 1; @@ -181,4 +190,5 @@ message FullStatus { LocalPeerState localPeerState = 3; repeated PeerState peers = 4; repeated RelayState relays = 5; + repeated NSGroupState dns_servers = 6; } \ No newline at end of file diff --git a/client/server/server.go b/client/server/server.go index 90b5bcb642c..5f1bf0100a4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,6 +11,7 @@ import ( "time" "github.com/cenkalti/backoff/v4" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/system" @@ -670,7 +671,6 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { SignalState: &proto.SignalState{}, LocalPeerState: &proto.LocalPeerState{}, Peers: []*proto.PeerState{}, - Relays: []*proto.RelayState{}, } pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL @@ -691,6 +691,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled + pbFullStatus.LocalPeerState.Routes = maps.Keys(fullStatus.LocalPeerState.Routes) for _, peerState := range fullStatus.Peers { pbPeerState := &proto.PeerState{ @@ -709,6 +710,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, + Routes: maps.Keys(peerState.Routes), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) } @@ -724,6 +726,20 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState) } + for _, dnsState := range fullStatus.NSGroupStates { + var err string + if dnsState.Error != nil { + err = dnsState.Error.Error() + } + pbDnsState := &proto.NSGroupState{ + Servers: dnsState.Servers, + Domains: dnsState.Domains, + Enabled: dnsState.Enabled, + Error: err, + } + pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState) + } + return &pbFullStatus } diff --git a/go.mod b/go.mod index d435e4eb8c3..6aba599f810 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/google/gopacket v1.1.19 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 + github.com/hashicorp/go-multierror v1.1.0 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 github.com/libp2p/go-netroute v0.2.0 @@ -123,6 +124,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect github.com/gopacket/gopacket v1.1.1 // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/go.sum b/go.sum index cc7a52ed636..ca10cd55367 100644 --- a/go.sum +++ b/go.sum @@ -289,6 +289,10 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= +github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= From 042141db061a812baa8e12d93901ca6c948945dc Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 14 Mar 2024 14:17:22 +0100 Subject: [PATCH 04/89] Update account attributes only when there is a domain (#1701) add log for when a domain is not present --- management/server/account.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 9450c95b47c..f307ef27429 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1361,16 +1361,21 @@ func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) e func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { - account.IsDomainPrimaryAccount = primaryDomain - lowerDomain := strings.ToLower(claims.Domain) - userObj := account.Users[claims.UserId] - if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { - account.Domain = lowerDomain - } - // prevent updating category for different domain until admin logs in - if account.Domain == lowerDomain { - account.DomainCategory = claims.DomainCategory + if claims.Domain != "" { + account.IsDomainPrimaryAccount = primaryDomain + + lowerDomain := strings.ToLower(claims.Domain) + userObj := account.Users[claims.UserId] + if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { + account.Domain = lowerDomain + } + // prevent updating category for different domain until admin logs in + if account.Domain == lowerDomain { + account.DomainCategory = claims.DomainCategory + } + } else { + log.Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) } err := am.Store.SaveAccount(account) From 0b3b50c7051c27c3d8edcfe98d3acc7e331e5b67 Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Thu, 14 Mar 2024 21:31:21 +0100 Subject: [PATCH 05/89] Remove deprecated Rules API endpoints (#1523) --- management/server/account.go | 3 - management/server/file_store_test.go | 13 - management/server/http/api/openapi.yml | 210 ------------ management/server/http/api/types.gen.go | 66 ---- management/server/http/handler.go | 10 - .../server/http/policies_handler_test.go | 19 -- management/server/http/rules_handler.go | 305 ------------------ management/server/http/rules_handler_test.go | 265 --------------- management/server/mock_server/account_mock.go | 27 -- management/server/policy.go | 24 +- management/server/rule.go | 100 ------ management/server/sqlite_store.go | 2 +- 12 files changed, 12 insertions(+), 1032 deletions(-) delete mode 100644 management/server/http/rules_handler.go delete mode 100644 management/server/http/rules_handler_test.go delete mode 100644 management/server/rule.go diff --git a/management/server/account.go b/management/server/account.go index f307ef27429..a800ea97d32 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -227,9 +227,6 @@ type Account struct { PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` - // deprecated on store and api level - Rules map[string]*Rule `json:"-" gorm:"-"` - RulesG []Rule `json:"-" gorm:"-"` } type UserInfo struct { diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index e0868fb49b6..d8575a3bfed 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -258,18 +258,6 @@ func TestStore(t *testing.T) { t.Errorf("failed to restore a FileStore file - missing Group all") } - if restoredAccount.Rules["all"] == nil { - t.Errorf("failed to restore a FileStore file - missing Rule all") - return - } - - if restoredAccount.Rules["dmz"] == nil { - t.Errorf("failed to restore a FileStore file - missing Rule dmz") - return - } - assert.Equal(t, account.Rules["all"], restoredAccount.Rules["all"], "failed to restore a FileStore file - missing Rule all") - assert.Equal(t, account.Rules["dmz"], restoredAccount.Rules["dmz"], "failed to restore a FileStore file - missing Rule dmz") - if len(restoredAccount.Policies) != 2 { t.Errorf("failed to restore a FileStore file - missing Policies") return @@ -411,7 +399,6 @@ func TestFileStore_GetAccount(t *testing.T) { assert.Len(t, account.Peers, len(expected.Peers)) assert.Len(t, account.Users, len(expected.Users)) assert.Len(t, account.SetupKeys, len(expected.SetupKeys)) - assert.Len(t, account.Rules, len(expected.Rules)) assert.Len(t, account.Routes, len(expected.Routes)) assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups)) } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index b2ddfd5ccc9..f4c8910bc7a 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -17,8 +17,6 @@ tags: description: Interact with and view information about setup keys. - name: Groups description: Interact with and view information about groups. - - name: Rules - description: Interact with and view information about rules. - name: Policies description: Interact with and view information about policies. - name: Posture Checks @@ -621,73 +619,6 @@ components: $ref: '#/components/schemas/PeerMinimum' required: - peers - RuleMinimum: - type: object - properties: - name: - description: Rule name identifier - type: string - example: Default - description: - description: Rule friendly description - type: string - example: This is a default rule that allows connections between all the resources - disabled: - description: Rules status - type: boolean - example: false - flow: - description: Rule flow, currently, only "bidirect" for bi-directional traffic is accepted - type: string - example: bidirect - required: - - name - - description - - disabled - - flow - RuleRequest: - allOf: - - $ref: '#/components/schemas/RuleMinimum' - - type: object - properties: - sources: - type: array - description: List of source group IDs - items: - type: string - example: "ch8i4ug6lnn4g9hqv7m1" - destinations: - type: array - description: List of destination group IDs - items: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" - Rule: - allOf: - - type: object - properties: - id: - description: Rule ID - type: string - example: ch8i4ug6lnn4g9hqv7mg - required: - - id - - $ref: '#/components/schemas/RuleMinimum' - - type: object - properties: - sources: - description: Rule source group IDs - type: array - items: - $ref: '#/components/schemas/GroupMinimum' - destinations: - description: Rule destination group IDs - type: array - items: - $ref: '#/components/schemas/GroupMinimum' - required: - - sources - - destinations PolicyRuleMinimum: type: object properties: @@ -2035,147 +1966,6 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" - /api/rules: - get: - summary: List all Rules - description: Returns a list of all rules. This will be deprecated in favour of `/api/policies`. - tags: [ Rules ] - deprecated: true - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - responses: - '200': - description: A JSON Array of Rules - content: - application/json: - schema: - type: array - items: - $ref: '#/components/schemas/Rule' - '400': - "$ref": "#/components/responses/bad_request" - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - post: - summary: Create a Rule - description: Creates a rule. This will be deprecated in favour of `/api/policies`. - deprecated: true - tags: [ Rules ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - requestBody: - description: New Rule request - content: - 'application/json': - schema: - $ref: '#/components/schemas/RuleRequest' - responses: - '200': - description: A Rule Object - content: - application/json: - schema: - $ref: '#/components/schemas/Rule' - /api/rules/{ruleId}: - get: - summary: Retrieve a Rule - description: Get information about a rules. This will be deprecated in favour of `/api/policies/{policyID}`. - deprecated: true - tags: [ Rules ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - parameters: - - in: path - name: ruleId - required: true - schema: - type: string - description: The unique identifier of a rule - responses: - '200': - description: A Rule object - content: - application/json: - schema: - $ref: '#/components/schemas/Rule' - '400': - "$ref": "#/components/responses/bad_request" - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - put: - summary: Update a Rule - description: Update/Replace a rule. This will be deprecated in favour of `/api/policies/{policyID}`. - deprecated: true - tags: [ Rules ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - parameters: - - in: path - name: ruleId - required: true - schema: - type: string - description: The unique identifier of a rule - requestBody: - description: Update Rule request - content: - 'application/json': - schema: - $ref: '#/components/schemas/RuleRequest' - responses: - '200': - description: A Rule object - content: - application/json: - schema: - $ref: '#/components/schemas/Rule' - '400': - "$ref": "#/components/responses/bad_request" - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - delete: - summary: Delete a Rule - description: Delete a rule. This will be deprecated in favour of `/api/policies/{policyID}`. - deprecated: true - tags: [ Rules ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - parameters: - - in: path - name: ruleId - required: true - schema: - type: string - description: The unique identifier of a rule - responses: - '200': - description: Delete status code - content: { } - '400': - "$ref": "#/components/responses/bad_request" - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" /api/policies: get: summary: List all Policies diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index c007663a476..a4c492bb870 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -976,66 +976,6 @@ type RouteRequest struct { PeerGroups *[]string `json:"peer_groups,omitempty"` } -// Rule defines model for Rule. -type Rule struct { - // Description Rule friendly description - Description string `json:"description"` - - // Destinations Rule destination group IDs - Destinations []GroupMinimum `json:"destinations"` - - // Disabled Rules status - Disabled bool `json:"disabled"` - - // Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted - Flow string `json:"flow"` - - // Id Rule ID - Id string `json:"id"` - - // Name Rule name identifier - Name string `json:"name"` - - // Sources Rule source group IDs - Sources []GroupMinimum `json:"sources"` -} - -// RuleMinimum defines model for RuleMinimum. -type RuleMinimum struct { - // Description Rule friendly description - Description string `json:"description"` - - // Disabled Rules status - Disabled bool `json:"disabled"` - - // Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted - Flow string `json:"flow"` - - // Name Rule name identifier - Name string `json:"name"` -} - -// RuleRequest defines model for RuleRequest. -type RuleRequest struct { - // Description Rule friendly description - Description string `json:"description"` - - // Destinations List of destination group IDs - Destinations *[]string `json:"destinations,omitempty"` - - // Disabled Rules status - Disabled bool `json:"disabled"` - - // Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted - Flow string `json:"flow"` - - // Name Rule name identifier - Name string `json:"name"` - - // Sources List of source group IDs - Sources *[]string `json:"sources,omitempty"` -} - // SetupKey defines model for SetupKey. type SetupKey struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key @@ -1219,12 +1159,6 @@ type PostApiRoutesJSONRequestBody = RouteRequest // PutApiRoutesRouteIdJSONRequestBody defines body for PutApiRoutesRouteId for application/json ContentType. type PutApiRoutesRouteIdJSONRequestBody = RouteRequest -// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType. -type PostApiRulesJSONRequestBody = RuleRequest - -// PutApiRulesRuleIdJSONRequestBody defines body for PutApiRulesRuleId for application/json ContentType. -type PutApiRulesRuleIdJSONRequestBody = RuleRequest - // PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType. type PostApiSetupKeysJSONRequestBody = SetupKeyRequest diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 4aab513a7a1..d035ae0b750 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -85,7 +85,6 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa api.addUsersEndpoint() api.addUsersTokensEndpoint() api.addSetupKeysEndpoint() - api.addRulesEndpoint() api.addPoliciesEndpoint() api.addGroupsEndpoint() api.addRoutesEndpoint() @@ -158,15 +157,6 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() { apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") } -func (apiHandler *apiHandler) addRulesEndpoint() { - rulesHandler := NewRulesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/rules", rulesHandler.GetAllRules).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/rules", rulesHandler.CreateRule).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.UpdateRule).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.GetRule).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.DeleteRule).Methods("DELETE", "OPTIONS") -} - func (apiHandler *apiHandler) addPoliciesEndpoint() { policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS") diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 86665848b53..e6b858036b7 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -3,7 +3,6 @@ package http import ( "bytes" "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" @@ -44,24 +43,6 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return nil }, - SaveRuleFunc: func(_, _ string, rule *server.Rule) error { - if !strings.HasPrefix(rule.ID, "id-") { - rule.ID = "id-was-set" - } - return nil - }, - GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) { - if ruleID != "idoftherule" { - return nil, fmt.Errorf("not found") - } - return &server.Rule{ - ID: "idoftherule", - Name: "Rule", - Source: []string{"idofsrcrule"}, - Destination: []string{"idofdestrule"}, - Flow: server.TrafficFlowBidirect, - }, nil - }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ diff --git a/management/server/http/rules_handler.go b/management/server/http/rules_handler.go deleted file mode 100644 index bd501acf998..00000000000 --- a/management/server/http/rules_handler.go +++ /dev/null @@ -1,305 +0,0 @@ -package http - -import ( - "encoding/json" - "net/http" - - "github.com/gorilla/mux" - "github.com/rs/xid" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/netbirdio/netbird/management/server/status" -) - -// RulesHandler is a handler that returns rules of the account -type RulesHandler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor -} - -// NewRulesHandler creates a new RulesHandler HTTP handler -func NewRulesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RulesHandler { - return &RulesHandler{ - accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), - } -} - -// GetAllRules list for the account -func (h *RulesHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - - accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - rules := []*api.Rule{} - for _, policy := range accountPolicies { - for _, r := range policy.Rules { - rules = append(rules, toRuleResponse(account, r.ToRule())) - } - } - - util.WriteJSONObject(w, rules) -} - -// UpdateRule handles update to a rule identified by a given ID -func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - - vars := mux.Vars(r) - ruleID := vars["ruleId"] - if len(ruleID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w) - return - } - - policy, err := h.accountManager.GetPolicy(account.Id, ruleID, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - - var req api.PutApiRulesRuleIdJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) - if err != nil { - util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) - } - - if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w) - return - } - - var reqSources []string - if req.Sources != nil { - reqSources = *req.Sources - } - - var reqDestinations []string - if req.Destinations != nil { - reqDestinations = *req.Destinations - } - - if len(policy.Rules) != 1 { - util.WriteError(status.Errorf(status.Internal, "policy should contain exactly one rule"), w) - return - } - - policy.Name = req.Name - policy.Description = req.Description - policy.Enabled = !req.Disabled - policy.Rules[0].ID = ruleID - policy.Rules[0].Name = req.Name - policy.Rules[0].Sources = reqSources - policy.Rules[0].Destinations = reqDestinations - policy.Rules[0].Enabled = !req.Disabled - policy.Rules[0].Description = req.Description - - switch req.Flow { - case server.TrafficFlowBidirectString: - policy.Rules[0].Action = server.PolicyTrafficActionAccept - default: - util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w) - return - } - - err = h.accountManager.SavePolicy(account.Id, user.Id, policy) - if err != nil { - util.WriteError(err, w) - return - } - - resp := toRuleResponse(account, policy.Rules[0].ToRule()) - - util.WriteJSONObject(w, &resp) -} - -// CreateRule handles rule creation request -func (h *RulesHandler) CreateRule(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - - var req api.PostApiRulesJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) - if err != nil { - util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) - return - } - - if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w) - return - } - - var reqSources []string - if req.Sources != nil { - reqSources = *req.Sources - } - - var reqDestinations []string - if req.Destinations != nil { - reqDestinations = *req.Destinations - } - - rule := server.Rule{ - ID: xid.New().String(), - Name: req.Name, - Source: reqSources, - Destination: reqDestinations, - Disabled: req.Disabled, - Description: req.Description, - } - - switch req.Flow { - case server.TrafficFlowBidirectString: - rule.Flow = server.TrafficFlowBidirect - default: - util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w) - return - } - - policy, err := server.RuleToPolicy(&rule) - if err != nil { - util.WriteError(err, w) - return - } - err = h.accountManager.SavePolicy(account.Id, user.Id, policy) - if err != nil { - util.WriteError(err, w) - return - } - - resp := toRuleResponse(account, &rule) - - util.WriteJSONObject(w, &resp) -} - -// DeleteRule handles rule deletion request -func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - aID := account.Id - - rID := mux.Vars(r)["ruleId"] - if len(rID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w) - return - } - - err = h.accountManager.DeletePolicy(aID, rID, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - - util.WriteJSONObject(w, emptyObject{}) -} - -// GetRule handles a group Get request identified by ID -func (h *RulesHandler) GetRule(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - - switch r.Method { - case http.MethodGet: - ruleID := mux.Vars(r)["ruleId"] - if len(ruleID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w) - return - } - - policy, err := h.accountManager.GetPolicy(account.Id, ruleID, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - - util.WriteJSONObject(w, toRuleResponse(account, policy.Rules[0].ToRule())) - default: - util.WriteError(status.Errorf(status.NotFound, "method not found"), w) - } -} - -func toRuleResponse(account *server.Account, rule *server.Rule) *api.Rule { - cache := make(map[string]api.GroupMinimum) - gr := api.Rule{ - Id: rule.ID, - Name: rule.Name, - Description: rule.Description, - Disabled: rule.Disabled, - } - - switch rule.Flow { - case server.TrafficFlowBidirect: - gr.Flow = server.TrafficFlowBidirectString - default: - gr.Flow = "unknown" - } - - for _, gid := range rule.Source { - _, ok := cache[gid] - if ok { - continue - } - - if group, ok := account.Groups[gid]; ok { - minimum := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - - gr.Sources = append(gr.Sources, minimum) - cache[gid] = minimum - } - } - - for _, gid := range rule.Destination { - cachedMinimum, ok := cache[gid] - if ok { - gr.Destinations = append(gr.Destinations, cachedMinimum) - continue - } - if group, ok := account.Groups[gid]; ok { - minimum := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - gr.Destinations = append(gr.Destinations, minimum) - cache[gid] = minimum - } - } - - return &gr -} diff --git a/management/server/http/rules_handler_test.go b/management/server/http/rules_handler_test.go deleted file mode 100644 index 27a308a0a56..00000000000 --- a/management/server/http/rules_handler_test.go +++ /dev/null @@ -1,265 +0,0 @@ -package http - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - - "github.com/gorilla/mux" - - "github.com/netbirdio/netbird/management/server/jwtclaims" - - "github.com/magiconair/properties/assert" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/mock_server" -) - -func initRulesTestData(rules ...*server.Rule) *RulesHandler { - testPolicies := make(map[string]*server.Policy, len(rules)) - for _, rule := range rules { - policy, err := server.RuleToPolicy(rule) - if err != nil { - panic(err) - } - testPolicies[policy.ID] = policy - } - return &RulesHandler{ - accountManager: &mock_server.MockAccountManager{ - GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) { - policy, ok := testPolicies[policyID] - if !ok { - return nil, status.Errorf(status.NotFound, "policy not found") - } - return policy, nil - }, - SavePolicyFunc: func(_, _ string, policy *server.Policy) error { - if !strings.HasPrefix(policy.ID, "id-") { - policy.ID = "id-was-set" - } - return nil - }, - SaveRuleFunc: func(_, _ string, rule *server.Rule) error { - if !strings.HasPrefix(rule.ID, "id-") { - rule.ID = "id-was-set" - } - return nil - }, - GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) { - if ruleID != "idoftherule" { - return nil, fmt.Errorf("not found") - } - return &server.Rule{ - ID: "idoftherule", - Name: "Rule", - Source: []string{"idofsrcrule"}, - Destination: []string{"idofdestrule"}, - Flow: server.TrafficFlowBidirect, - }, nil - }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Rules: map[string]*server.Rule{"id-existed": {ID: "id-existed"}}, - Groups: map[string]*server.Group{ - "F": {ID: "F"}, - "G": {ID: "G"}, - }, - Users: map[string]*server.User{ - "test_user": user, - }, - }, user, nil - }, - }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), - } -} - -func TestRulesGetRule(t *testing.T) { - tt := []struct { - name string - expectedStatus int - expectedBody bool - requestType string - requestPath string - requestBody io.Reader - }{ - { - name: "GetRule OK", - expectedBody: true, - requestType: http.MethodGet, - requestPath: "/api/rules/idoftherule", - expectedStatus: http.StatusOK, - }, - { - name: "GetRule not found", - requestType: http.MethodGet, - requestPath: "/api/rules/notexists", - expectedStatus: http.StatusNotFound, - }, - } - - rule := &server.Rule{ - ID: "idoftherule", - Name: "Rule", - } - - p := initRulesTestData(rule) - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - recorder := httptest.NewRecorder() - req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - - router := mux.NewRouter() - router.HandleFunc("/api/rules/{ruleId}", p.GetRule).Methods("GET") - router.ServeHTTP(recorder, req) - - res := recorder.Result() - defer res.Body.Close() - - if status := recorder.Code; status != tc.expectedStatus { - t.Errorf("handler returned wrong status code: got %v want %v", - status, tc.expectedStatus) - return - } - - if !tc.expectedBody { - return - } - - content, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("I don't know what I expected; %v", err) - } - - var got api.Rule - if err = json.Unmarshal(content, &got); err != nil { - t.Fatalf("Sent content is not in correct json format; %v", err) - } - - assert.Equal(t, got.Id, rule.ID) - assert.Equal(t, got.Name, rule.Name) - }) - } -} - -func TestRulesWriteRule(t *testing.T) { - tt := []struct { - name string - expectedStatus int - expectedBody bool - expectedRule *api.Rule - requestType string - requestPath string - requestBody io.Reader - }{ - { - name: "WriteRule POST OK", - requestType: http.MethodPost, - requestPath: "/api/rules", - requestBody: bytes.NewBuffer( - []byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)), - expectedStatus: http.StatusOK, - expectedBody: true, - expectedRule: &api.Rule{ - Id: "id-was-set", - Name: "Default POSTed Rule", - Flow: server.TrafficFlowBidirectString, - }, - }, - { - name: "WriteRule POST Invalid Name", - requestType: http.MethodPost, - requestPath: "/api/rules", - requestBody: bytes.NewBuffer( - []byte(`{"Name":"","Flow":"bidirect"}`)), - expectedStatus: http.StatusUnprocessableEntity, - expectedBody: false, - }, - { - name: "WriteRule PUT OK", - requestType: http.MethodPut, - requestPath: "/api/rules/id-existed", - requestBody: bytes.NewBuffer( - []byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)), - expectedStatus: http.StatusOK, - expectedBody: true, - expectedRule: &api.Rule{ - Id: "id-existed", - Name: "Default POSTed Rule", - Flow: server.TrafficFlowBidirectString, - }, - }, - { - name: "WriteRule PUT Invalid Name", - requestType: http.MethodPut, - requestPath: "/api/rules/id-existed", - requestBody: bytes.NewBuffer( - []byte(`{"Name":"","Flow":"bidirect"}`)), - expectedStatus: http.StatusUnprocessableEntity, - }, - } - - p := initRulesTestData(&server.Rule{ - ID: "id-existed", - Name: "Default POSTed Rule", - Flow: server.TrafficFlowBidirect, - }) - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - recorder := httptest.NewRecorder() - req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - - router := mux.NewRouter() - router.HandleFunc("/api/rules", p.CreateRule).Methods("POST") - router.HandleFunc("/api/rules/{ruleId}", p.UpdateRule).Methods("PUT") - router.ServeHTTP(recorder, req) - - res := recorder.Result() - defer res.Body.Close() - - content, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("I don't know what I expected; %v", err) - } - - if status := recorder.Code; status != tc.expectedStatus { - t.Errorf("handler returned wrong status code: got %v want %v, content: %s", - status, tc.expectedStatus, string(content)) - return - } - - if !tc.expectedBody { - return - } - - got := &api.Rule{} - if err = json.Unmarshal(content, &got); err != nil { - t.Fatalf("Sent content is not in correct json format; %v", err) - } - tc.expectedRule.Id = got.Id - - assert.Equal(t, got, tc.expectedRule) - }) - } -} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2df5ef08685..f518372ed95 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -38,10 +38,7 @@ type MockAccountManager struct { ListGroupsFunc func(accountID string) ([]*server.Group, error) GroupAddPeerFunc func(accountID, groupID, peerID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error - GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error) - SaveRuleFunc func(accountID, userID string, rule *server.Rule) error DeleteRuleFunc func(accountID, ruleID, userID string) error - ListRulesFunc func(accountID, userID string) ([]*server.Rule, error) GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) SavePolicyFunc func(accountID, userID string, policy *server.Policy) error DeletePolicyFunc func(accountID, policyID, userID string) error @@ -302,22 +299,6 @@ func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string) return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented") } -// GetRule mock implementation of GetRule from server.AccountManager interface -func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) { - if am.GetRuleFunc != nil { - return am.GetRuleFunc(accountID, ruleID, userID) - } - return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented") -} - -// SaveRule mock implementation of SaveRule from server.AccountManager interface -func (am *MockAccountManager) SaveRule(accountID, userID string, rule *server.Rule) error { - if am.SaveRuleFunc != nil { - return am.SaveRuleFunc(accountID, userID, rule) - } - return status.Errorf(codes.Unimplemented, "method SaveRule is not implemented") -} - // DeleteRule mock implementation of DeleteRule from server.AccountManager interface func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error { if am.DeleteRuleFunc != nil { @@ -326,14 +307,6 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented") } -// ListRules mock implementation of ListRules from server.AccountManager interface -func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) { - if am.ListRulesFunc != nil { - return am.ListRulesFunc(accountID, userID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented") -} - // GetPolicy mock implementation of GetPolicy from server.AccountManager interface func (am *MockAccountManager) GetPolicy(accountID, policyID, userID string) (*server.Policy, error) { if am.GetPolicyFunc != nil { diff --git a/management/server/policy.go b/management/server/policy.go index 291a4f1f766..8265dabb51c 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -52,6 +52,17 @@ const ( PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") ) +const ( + // DefaultRuleName is a name for the Default rule that is created for every account + DefaultRuleName = "Default" + // DefaultRuleDescription is a description for the Default rule that is created for every account + DefaultRuleDescription = "This is a default rule that allows connections between all the resources" + // DefaultPolicyName is a name for the Default policy that is created for every account + DefaultPolicyName = "Default" + // DefaultPolicyDescription is a description for the Default policy that is created for every account + DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" +) + const ( firewallRuleDirectionIN = 0 firewallRuleDirectionOUT = 1 @@ -119,19 +130,6 @@ func (pm *PolicyRule) Copy() *PolicyRule { return rule } -// ToRule converts the PolicyRule to a legacy representation of the Rule (for backwards compatibility) -func (pm *PolicyRule) ToRule() *Rule { - return &Rule{ - ID: pm.ID, - Name: pm.Name, - Description: pm.Description, - Disabled: !pm.Enabled, - Flow: TrafficFlowBidirect, - Destination: pm.Destinations, - Source: pm.Sources, - } -} - // Policy of the Rego query type Policy struct { // ID of the policy' diff --git a/management/server/rule.go b/management/server/rule.go deleted file mode 100644 index 19085840cc7..00000000000 --- a/management/server/rule.go +++ /dev/null @@ -1,100 +0,0 @@ -package server - -import "fmt" - -// TrafficFlowType defines allowed direction of the traffic in the rule -type TrafficFlowType int - -const ( - // TrafficFlowBidirect allows traffic to both direction - TrafficFlowBidirect TrafficFlowType = iota - // TrafficFlowBidirectString allows traffic to both direction - TrafficFlowBidirectString = "bidirect" - // DefaultRuleName is a name for the Default rule that is created for every account - DefaultRuleName = "Default" - // DefaultRuleDescription is a description for the Default rule that is created for every account - DefaultRuleDescription = "This is a default rule that allows connections between all the resources" - // DefaultPolicyName is a name for the Default policy that is created for every account - DefaultPolicyName = "Default" - // DefaultPolicyDescription is a description for the Default policy that is created for every account - DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" -) - -// Rule of ACL for groups -type Rule struct { - // ID of the rule - ID string - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name of the rule visible in the UI - Name string - - // Description of the rule visible in the UI - Description string - - // Disabled status of rule in the system - Disabled bool - - // Source list of groups IDs of peers - Source []string `gorm:"serializer:json"` - - // Destination list of groups IDs of peers - Destination []string `gorm:"serializer:json"` - - // Flow of the traffic allowed by the rule - Flow TrafficFlowType -} - -func (r *Rule) Copy() *Rule { - rule := &Rule{ - ID: r.ID, - Name: r.Name, - Description: r.Description, - Disabled: r.Disabled, - Source: make([]string, len(r.Source)), - Destination: make([]string, len(r.Destination)), - Flow: r.Flow, - } - copy(rule.Source, r.Source) - copy(rule.Destination, r.Destination) - return rule -} - -// EventMeta returns activity event meta related to this rule -func (r *Rule) EventMeta() map[string]any { - return map[string]any{"name": r.Name} -} - -// ToPolicyRule converts a Rule to a PolicyRule object -func (r *Rule) ToPolicyRule() *PolicyRule { - if r == nil { - return nil - } - return &PolicyRule{ - ID: r.ID, - Name: r.Name, - Enabled: !r.Disabled, - Description: r.Description, - Destinations: r.Destination, - Sources: r.Source, - Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, - } -} - -// RuleToPolicy converts a Rule to a Policy query object -func RuleToPolicy(rule *Rule) (*Policy, error) { - if rule == nil { - return nil, fmt.Errorf("rule is empty") - } - return &Policy{ - ID: rule.ID, - Name: rule.Name, - Description: rule.Description, - Enabled: !rule.Disabled, - Rules: []*PolicyRule{rule.ToPolicyRule()}, - }, nil -} diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index eff43a31b6f..f6a6f92a726 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -64,7 +64,7 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, sql.SetMaxOpenConns(conns) // TODO: make it configurable err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{}, + &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, ) From 52a3ac6b06bf26f50040a6284a67d6b89edf7c39 Mon Sep 17 00:00:00 2001 From: Aaron Turner Date: Fri, 15 Mar 2024 02:32:51 -0700 Subject: [PATCH 06/89] Add support for inviting/deleting users via Zitadel (#1572) This fixes the "Invite User" button in Dashboard v2.0.0 and enables the usage of the --user-delete-from-idp flag for Zitadel. Unlike the NetBird SaaS solution, we rely on Zitadel to send the emails on our behalf. --- management/server/idp/zitadel.go | 127 +++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 6 deletions(-) diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 926f078b208..c09d362d8a7 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -75,6 +75,27 @@ type zitadelProfile struct { Human *zitadelUser `json:"human"` } +// zitadelUserDetails represents the metadata for the new user that was created +type zitadelUserDetails struct { + Sequence string `json:"sequence"` // uint64 as a string + CreationDate string `json:"creationDate"` // ISO format + ChangeDate string `json:"changeDate"` // ISO format + ResourceOwner string +} + +// zitadelPasswordlessRegistration represents the information for the user to complete signup +type zitadelPasswordlessRegistration struct { + Link string `json:"link"` + Expiration string `json:"expiration"` // ex: 3600s +} + +// zitadelUser represents an zitadel create user response +type zitadelUserResponse struct { + UserId string `json:"userId"` + Details zitadelUserDetails `json:"details"` + PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"` +} + // NewZitadelManager creates a new instance of the ZitadelManager. func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() @@ -224,9 +245,57 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { return zc.jwtToken, nil } -// CreateUser creates a new user in zitadel Idp and sends an invite. -func (zm *ZitadelManager) CreateUser(_, _, _, _ string) (*UserData, error) { - return nil, fmt.Errorf("method CreateUser not implemented") +// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel. +func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { + firstLast := strings.SplitN(name, " ", 2) + + var addUser = map[string]any{ + "userName": email, + "profile": map[string]string{ + "firstName": firstLast[0], + "lastName": firstLast[0], + "displayName": name, + }, + "email": map[string]any{ + "email": email, + "isEmailVerified": false, + }, + "passwordChangeRequired": true, + "requestPasswordlessRegistration": false, // let Zitadel send the invite for us + } + + payload, err := zm.helper.Marshal(addUser) + if err != nil { + return nil, err + } + + body, err := zm.post("users/human/_import", string(payload)) + if err != nil { + return nil, err + } + + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountCreateUser() + } + + var newUser zitadelUserResponse + err = zm.helper.Unmarshal(body, &newUser) + if err != nil { + return nil, err + } + + var pending bool = true + ret := &UserData{ + Email: email, + Name: name, + ID: newUser.UserId, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pending, + WTInvitedBy: invitedByEmail, + }, + } + return ret, nil } // GetUserByEmail searches users with a given email. @@ -354,10 +423,25 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { return nil } +type inviteUserRequest struct { + Email string `json:"email"` +} + // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (zm *ZitadelManager) InviteUserByID(_ string) error { - return fmt.Errorf("method InviteUserByID not implemented") +func (zm *ZitadelManager) InviteUserByID(userID string) error { + inviteUser := inviteUserRequest{ + Email: userID, + } + + payload, err := zm.helper.Marshal(inviteUser) + if err != nil { + return err + } + + // don't care about the body in the response + _, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload)) + return err } // DeleteUser from Zitadel @@ -411,7 +495,38 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { } // delete perform Delete requests. -func (zm *ZitadelManager) delete(_ string) error { +func (zm *ZitadelManager) delete(resource string) error { + jwtToken, err := zm.credentials.Authenticate() + if err != nil { + return err + } + + reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource) + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + resp, err := zm.httpClient.Do(req) + if err != nil { + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountRequestError() + } + + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode) + } + return nil } From fc7c1e397f56753feaabb445fe0e84effe3266ad Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 15 Mar 2024 10:50:02 +0100 Subject: [PATCH 07/89] Disable force jsonfile variable (#1611) This enables windows management tests Added another DNS server to the dns server tests --- .github/workflows/golang-test-windows.yml | 1 - client/internal/dns/server_test.go | 5 +++++ management/server/store.go | 9 ++++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index a50c8191891..6027d36269f 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -44,7 +44,6 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - - run: "[Environment]::SetEnvironmentVariable('NETBIRD_STORE_ENGINE', 'jsonfile', 'Machine')" - name: test run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index f3282f1f49c..22966d89ca2 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -750,6 +750,11 @@ func TestDNSPermanent_matchOnly(t *testing.T) { NSType: nbdns.UDPNameServerType, Port: 53, }, + { + IP: netip.MustParseAddr("9.9.9.9"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, }, Domains: []string{"customdomain.com"}, Primary: false, diff --git a/management/server/store.go b/management/server/store.go index 7ef090a67c4..77b8d0dadbb 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -103,7 +103,14 @@ func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, erro return nil, err } - switch kind := getStoreEngineFromEnv(); kind { + // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE + kind := getStoreEngineFromEnv() + if kind == "" { + // NETBIRD_STORE_ENGINE is not set we evaluate default based on dataDir + kind = getStoreEngineFromDatadir(dataDir) + } + + switch kind { case FileStoreEngine: return fstore, nil case SqliteStoreEngine: From 416f04c27a10309f70f1f376a0e8db0ed57f1aad Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 15 Mar 2024 18:57:18 +0100 Subject: [PATCH 08/89] Unblock ACL apply filtering because of dns probes (#1711) moved the e.dnsServer.ProbeAvailability() to run after ACL apply filtering run the probes in parallel --- client/internal/dns/server.go | 8 +++++++- client/internal/engine.go | 9 +++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index dff44f01d1d..b9608b6f288 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -278,9 +278,15 @@ func (s *DefaultServer) SearchDomains() []string { // ProbeAvailability tests each upstream group's servers for availability // and deactivates the group if no server responds func (s *DefaultServer) ProbeAvailability() { + var wg sync.WaitGroup for _, mux := range s.dnsMuxMap { - mux.probeAvailability() + wg.Add(1) + go func(mux handlerWithStop) { + defer wg.Done() + mux.probeAvailability() + }(mux) } + wg.Wait() } func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { diff --git a/client/internal/engine.go b/client/internal/engine.go index 9e52cea4446..78d26f0b8fb 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -698,15 +698,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } - // Test received (upstream) servers for availability right away instead of upon usage. - // If no server of a server group responds this will disable the respective handler and retry later. - e.dnsServer.ProbeAvailability() - if e.acl != nil { e.acl.ApplyFiltering(networkMap) } + e.networkSerial = serial + // Test received (upstream) servers for availability right away instead of upon usage. + // If no server of a server group responds this will disable the respective handler and retry later. + e.dnsServer.ProbeAvailability() + return nil } From abd57d11914031ab559c4c8a60490518031e9fd8 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sun, 17 Mar 2024 11:13:39 +0100 Subject: [PATCH 09/89] Avoid creating duplicate groups with the same name (#1579) Avoid creating groups with the same name via API calls. JWT and integrations still allowed to register groups with duplicated names --- management/server/account.go | 13 +++++++-- management/server/group.go | 32 +++++++++++++++++++- management/server/group_test.go | 37 +++++++++++++++++++++++- management/server/http/api/openapi.yml | 7 +++-- management/server/http/groups_handler.go | 3 -- 5 files changed, 82 insertions(+), 10 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index a800ea97d32..8b326d93a60 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -40,9 +40,6 @@ const ( PublicCategory = "public" PrivateCategory = "private" UnknownCategory = "unknown" - GroupIssuedAPI = "api" - GroupIssuedJWT = "jwt" - GroupIssuedIntegration = "integration" CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour @@ -556,6 +553,16 @@ func (a *Account) FindUser(userID string) (*User, error) { return user, nil } +// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. +func (a *Account) FindGroupByName(groupName string) (*Group, error) { + for _, group := range a.Groups { + if group.Name == groupName { + return group, nil + } + } + return nil, status.Errorf(status.NotFound, "group %s not found", groupName) +} + // FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { key := a.SetupKeys[setupKey] diff --git a/management/server/group.go b/management/server/group.go index be8d3fb0e2d..43d48e6227f 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -3,6 +3,7 @@ package server import ( "fmt" + "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" @@ -18,6 +19,12 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } +const ( + GroupIssuedAPI = "api" + GroupIssuedJWT = "jwt" + GroupIssuedIntegration = "integration" +) + // Group of the peers for ACL type Group struct { // ID of the group @@ -29,7 +36,7 @@ type Group struct { // Name visible in the UI Name string - // Issued of the group + // Issued defines how this group was created (enum of "api", "integration" or "jwt") Issued string // Peers list of the group @@ -116,6 +123,29 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G return err } + if newGroup.ID == "" && newGroup.Issued != GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } + + if newGroup.ID == "" && newGroup.Issued == GroupIssuedAPI { + + existingGroup, err := account.FindGroupByName(newGroup.Name) + if err != nil { + s, ok := status.FromError(err) + if !ok || s.ErrorType != status.NotFound { + return err + } + } + + // avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are + // coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() + } + for _, peerID := range newGroup.Peers { if account.Peers[peerID] == nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) diff --git a/management/server/group_test.go b/management/server/group_test.go index e2051a65611..3a2195c889d 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -13,6 +13,41 @@ const ( groupAdminUserID = "testingAdminUser" ) +func TestDefaultAccountManager_CreateGroup(t *testing.T) { + am, err := createManager(t) + if err != nil { + t.Error("failed to create account manager") + } + + account, err := initTestGroupAccount(am) + if err != nil { + t.Error("failed to init testing account") + } + for _, group := range account.Groups { + group.Issued = GroupIssuedIntegration + err = am.SaveGroup(account.Id, groupAdminUserID, group) + if err != nil { + t.Errorf("should allow to create %s groups", GroupIssuedIntegration) + } + } + + for _, group := range account.Groups { + group.Issued = GroupIssuedJWT + err = am.SaveGroup(account.Id, groupAdminUserID, group) + if err != nil { + t.Errorf("should allow to create %s groups", GroupIssuedJWT) + } + } + for _, group := range account.Groups { + group.Issued = GroupIssuedAPI + group.ID = "" + err = am.SaveGroup(account.Id, groupAdminUserID, group) + if err == nil { + t.Errorf("should not create api group with the same name, %s", group.Name) + } + } +} + func TestDefaultAccountManager_DeleteGroup(t *testing.T) { am, err := createManager(t) if err != nil { @@ -137,7 +172,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { groupForIntegration := &Group{ ID: "grp-for-integration", AccountID: "account-id", - Name: "Group for users", + Name: "Group for users integration", Issued: GroupIssuedIntegration, Peers: make([]string, 0), } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index f4c8910bc7a..7ec2310afe6 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -585,7 +585,10 @@ components: type: integer example: 2 issued: - description: How group was issued by API or from JWT token + description: How the group was issued (api, integration, jwt) + type: string + enum: ["api", "integration", "jwt"] + example: api type: string example: api required: @@ -1246,7 +1249,7 @@ paths: /api/accounts/{accountId}: delete: summary: Delete an Account - description: Deletes an account and all its resources. Only administrators and account owners can delete accounts. + description: Deletes an account and all its resources. Only account owners can delete accounts. tags: [ Accounts ] security: - BearerAuth: [ ] diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index c06445690db..b37f4fd2f46 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -8,8 +8,6 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -151,7 +149,6 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { peers = *req.Peers } group := server.Group{ - ID: xid.New().String(), Name: req.Name, Peers: peers, Issued: server.GroupIssuedAPI, From 9b0fe2c8e581319cc0ba391dd2de6f82d260d3b6 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Mar 2024 10:12:46 +0100 Subject: [PATCH 10/89] Validate authentik issuer url (#1723) * Validate authentik issuer url * test issuer * adjust test times on windows --- management/server/idp/authentik.go | 4 ++++ management/server/idp/authentik_test.go | 16 ++++++++++++-- management/server/scheduler_test.go | 28 ++++++++++++++++++------- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 4bbf094045a..b39f2b5cbf1 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -76,6 +76,10 @@ func NewAuthentikManager(config AuthentikClientConfig, return nil, fmt.Errorf("authentik IdP configuration is incomplete, TokenEndpoint is missing") } + if config.Issuer == "" { + return nil, fmt.Errorf("authentik IdP configuration is incomplete, Issuer is missing") + } + if config.GrantType == "" { return nil, fmt.Errorf("authentik IdP configuration is incomplete, GrantType is missing") } diff --git a/management/server/idp/authentik_test.go b/management/server/idp/authentik_test.go index c70a84efd07..342e16384db 100644 --- a/management/server/idp/authentik_test.go +++ b/management/server/idp/authentik_test.go @@ -7,9 +7,10 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" ) func TestNewAuthentikManager(t *testing.T) { @@ -25,6 +26,7 @@ func TestNewAuthentikManager(t *testing.T) { Username: "username", Password: "password", TokenEndpoint: "https://localhost:8080/application/o/token/", + Issuer: "https://localhost:8080/application/o/netbird/", GrantType: "client_credentials", } @@ -75,7 +77,17 @@ func TestNewAuthentikManager(t *testing.T) { assertErrFuncMessage: "should return error when field empty", } - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { + testCase6Config := defaultTestConfig + testCase6Config.Issuer = "" + + testCase6 := test{ + name: "Missing Issuer Configuration", + inputConfig: testCase6Config, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + } + + for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} { t.Run(testCase.name, func(t *testing.T) { _, err := NewAuthentikManager(testCase.inputConfig, &telemetry.MockAppMetrics{}) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go index 4b2c2e30d4e..9dd73e2695f 100644 --- a/management/server/scheduler_test.go +++ b/management/server/scheduler_test.go @@ -3,6 +3,7 @@ package server import ( "fmt" "math/rand" + "runtime" "sync" "testing" "time" @@ -25,7 +26,13 @@ func TestScheduler_Performance(t *testing.T) { return 0, false }) } - failed := waitTimeout(wg, 3*time.Second) + timeout := 3 * time.Second + if runtime.GOOS == "windows" { + // sleep and ticker are slower on windows see https://github.com/golang/go/issues/44343 + timeout = 5 * time.Second + } + + failed := waitTimeout(wg, timeout) if failed { t.Fatal("timed out while waiting for test to finish") return @@ -39,22 +46,29 @@ func TestScheduler_Cancel(t *testing.T) { scheduler := NewDefaultScheduler() tChan := make(chan struct{}) p := []string{jobID1, jobID2} - scheduler.Schedule(2*time.Millisecond, jobID1, func() (nextRunIn time.Duration, reschedule bool) { + scheduletime := 2 * time.Millisecond + sleepTime := 4 * time.Millisecond + if runtime.GOOS == "windows" { + // sleep and ticker are slower on windows see https://github.com/golang/go/issues/44343 + sleepTime = 20 * time.Millisecond + } + + scheduler.Schedule(scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) { tt := p[0] <-tChan t.Logf("job %s", tt) - return 2 * time.Millisecond, true + return scheduletime, true }) - scheduler.Schedule(2*time.Millisecond, jobID2, func() (nextRunIn time.Duration, reschedule bool) { - return 2 * time.Millisecond, true + scheduler.Schedule(scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) { + return scheduletime, true }) - time.Sleep(4 * time.Millisecond) + time.Sleep(sleepTime) assert.Len(t, scheduler.jobs, 2) scheduler.Cancel([]string{jobID1}) close(tChan) p = []string{} - time.Sleep(4 * time.Millisecond) + time.Sleep(sleepTime) assert.Len(t, scheduler.jobs, 1) assert.NotNil(t, scheduler.jobs[jobID2]) } From f0672b87bc7440d921c8a66f24bfc3a2c6fe3e16 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Mar 2024 12:25:01 +0100 Subject: [PATCH 11/89] Add missing dns domain to tests to avoid verbose test logs (#1724) --- client/cmd/testutil.go | 2 +- client/internal/engine_test.go | 10 +++++----- client/server/server_test.go | 2 +- management/client/client_test.go | 2 +- management/server/http/peers_handler_test.go | 3 +++ management/server/management_proto_test.go | 2 +- management/server/management_test.go | 2 +- management/server/nameserver_test.go | 2 +- management/server/route_test.go | 2 +- 9 files changed, 15 insertions(+), 12 deletions(-) diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index cba47326f84..2cfc934159e 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -78,7 +78,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { return nil, nil } - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index ee0380db7f5..952b3c90cfb 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -70,10 +70,10 @@ func TestEngine_SSH(t *testing.T) { defer cancel() engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ - WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, + WgIfaceName: "utun101", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, ServerSSHAllowed: true, }, MobileDependency{}, peer.NewRecorder("https://mgm")) @@ -1050,7 +1050,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 79a22002311..7f8310c903b 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -114,7 +114,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) if err != nil { return nil, "", err } diff --git a/management/client/client_test.go b/management/client/client_test.go index 0a57fda72a7..f30ae0cfd66 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -60,7 +60,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) if err != nil { t.Fatal(err) } diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 27978c48754..e43c4375e92 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -55,6 +55,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, + GetDNSDomainFunc: func() string { + return "netbird.selfhosted" + }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index e5457db0211..6ea902003de 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) } peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "", + accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) if err != nil { return nil, "", err diff --git a/management/server/management_test.go b/management/server/management_test.go index f4535487764..fb3f74cb9fa 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { } peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) if err != nil { log.Fatalf("failed creating a manager: %v", err) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 3327869b444..d04ac1a20a1 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -759,7 +759,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) } func createNSStore(t *testing.T) (Store, error) { diff --git a/management/server/route_test.go b/management/server/route_test.go index a5db2ca07b4..5a56eaa8bd9 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1014,7 +1014,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) } func createRouterStore(t *testing.T) (Store, error) { From 6cba9c0818c718d678ee67db5be47b010c54d8f1 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 12:32:07 +0100 Subject: [PATCH 12/89] Remove context niling (#1729) --- client/internal/routemanager/manager.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index fde943757f2..b624d8c34ce 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -71,7 +71,6 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } - m.ctx = nil } // UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps From 846871913dccc7d80ddfa9b0fa90c2c225b94b0b Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 20 Mar 2024 11:18:34 +0100 Subject: [PATCH 13/89] Add latency checks to peer connection and status output (#1725) * adding peer healthcheck * generate proto file * fix return in udp mux and replace with continue * use ice agent for latency checks * fix status output * remove some logs * fix status test * revert bind and ebpf code * fix error handling on binding response callback * extend error handling on binding response callback --------- Co-authored-by: Maycon Santos --- client/cmd/status.go | 6 +- client/cmd/status_test.go | 11 ++ client/internal/peer/conn.go | 18 ++ client/internal/peer/status.go | 17 ++ client/proto/daemon.pb.go | 320 +++++++++++++++++---------------- client/proto/daemon.proto | 2 + client/server/server.go | 3 + go.mod | 18 +- go.sum | 42 ++--- 9 files changed, 252 insertions(+), 185 deletions(-) diff --git a/client/cmd/status.go b/client/cmd/status.go index 4c7218fde94..2840cc6c941 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -34,6 +34,7 @@ type peerStateDetailOutput struct { LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"` TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` TransferSent int64 `json:"transferSent" yaml:"transferSent"` + Latency time.Duration `json:"latency" yaml:"latency"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` Routes []string `json:"routes" yaml:"routes"` } @@ -376,6 +377,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { LastWireguardHandshake: lastHandshake, TransferReceived: transferReceived, TransferSent: transferSent, + Latency: pbPeerState.GetLatency().AsDuration(), RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), Routes: pbPeerState.GetRoutes(), } @@ -638,7 +640,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Last WireGuard handshake: %s\n"+ " Transfer status (received/sent) %s/%s\n"+ " Quantum resistance: %s\n"+ - " Routes: %s\n", + " Routes: %s\n"+ + " Latency: %s\n", peerState.FQDN, peerState.IP, peerState.PubKey, @@ -655,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo toIEC(peerState.TransferSent), rosenpassEnabledStatus, routes, + peerState.Latency.String(), ) peersString += peerString diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index ea6980c3df7..cc0cce134e9 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/client/proto" @@ -45,6 +46,7 @@ var resp = &proto.StatusResponse{ Routes: []string{ "10.1.0.0/24", }, + Latency: durationpb.New(time.Duration(10000000)), }, { IP: "192.168.178.102", @@ -61,6 +63,7 @@ var resp = &proto.StatusResponse{ LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)), BytesRx: 2000, BytesTx: 1000, + Latency: durationpb.New(time.Duration(10000000)), }, }, ManagementState: &proto.ManagementState{ @@ -147,6 +150,7 @@ var overview = statusOutputOverview{ Routes: []string{ "10.1.0.0/24", }, + Latency: time.Duration(10000000), }, { IP: "192.168.178.102", @@ -167,6 +171,7 @@ var overview = statusOutputOverview{ LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC), TransferReceived: 2000, TransferSent: 1000, + Latency: time.Duration(10000000), }, }, }, @@ -288,6 +293,7 @@ func TestParsingToJSON(t *testing.T) { "lastWireguardHandshake": "2001-01-01T01:01:02Z", "transferReceived": 200, "transferSent": 100, + "latency": 10000000, "quantumResistance": false, "routes": [ "10.1.0.0/24" @@ -312,6 +318,7 @@ func TestParsingToJSON(t *testing.T) { "lastWireguardHandshake": "2002-02-02T02:02:03Z", "transferReceived": 2000, "transferSent": 1000, + "latency": 10000000, "quantumResistance": false, "routes": null } @@ -409,6 +416,7 @@ func TestParsingToYAML(t *testing.T) { lastWireguardHandshake: 2001-01-01T01:01:02Z transferReceived: 200 transferSent: 100 + latency: 10ms quantumResistance: false routes: - 10.1.0.0/24 @@ -428,6 +436,7 @@ func TestParsingToYAML(t *testing.T) { lastWireguardHandshake: 2002-02-02T02:02:03Z transferReceived: 2000 transferSent: 1000 + latency: 10ms quantumResistance: false routes: [] cliVersion: development @@ -496,6 +505,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 200 B/100 B Quantum resistance: false Routes: 10.1.0.0/24 + Latency: 10ms peer-2.awesome-domain.com: NetBird IP: 192.168.178.102 @@ -511,6 +521,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 2.0 KiB/1000 B Quantum resistance: false Routes: - + Latency: 10ms Daemon version: 0.14.1 CLI version: development diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index b7db310e6e1..c180e8f032b 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -133,6 +133,9 @@ type Conn struct { adapter iface.TunAdapter iFaceDiscover stdnet.ExternalIFaceDiscover sentExtraSrflx bool + + remoteEndpoint *net.UDPAddr + remoteConn *ice.Conn } // meta holds meta information about a connection @@ -234,6 +237,17 @@ func (conn *Conn) reCreateAgent() error { return err } + err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) { + err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency()) + if err != nil { + log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err) + return + } + }) + if err != nil { + return fmt.Errorf("failed setting binding response callback: %w", err) + } + return nil } @@ -348,6 +362,9 @@ func (conn *Conn) Open() error { if remoteOfferAnswer.WgListenPort != 0 { remoteWgPort = remoteOfferAnswer.WgListenPort } + + conn.remoteConn = remoteConn + // the ice connection has been established successfully so we are ready to start the proxy remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey, remoteOfferAnswer.RosenpassAddr) @@ -397,6 +414,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem } endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) + conn.remoteEndpoint = endpointUdpAddr err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 1e252c5dd48..ca97c3ea497 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -28,6 +28,7 @@ type State struct { LastWireguardHandshake time.Time BytesTx int64 BytesRx int64 + Latency time.Duration RosenpassEnabled bool Routes map[string]struct{} } @@ -410,6 +411,22 @@ func (d *Status) GetManagementState() ManagementState { } } +func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error { + if latency <= 0 { + return nil + } + + d.mux.Lock() + defer d.mux.Unlock() + peerState, ok := d.peers[pubKey] + if !ok { + return errors.New("peer doesn't exist") + } + peerState.Latency = latency + d.peers[pubKey] = peerState + return nil +} + // IsLoginRequired determines if a peer's login has expired. func (d *Status) IsLoginRequired() bool { d.mux.Lock() diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 869eceee550..81998b115d3 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,16 +1,17 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.12.4 +// protoc v4.24.3 // source: daemon.proto package proto import ( - _ "github.com/golang/protobuf/protoc-gen-go/descriptor" - timestamp "github.com/golang/protobuf/ptypes/timestamp" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + _ "google.golang.org/protobuf/types/descriptorpb" + durationpb "google.golang.org/protobuf/types/known/durationpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" ) @@ -757,22 +758,23 @@ type PeerState struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` - PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` - ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` - ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` - Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` - Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` - LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` - RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` - Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` - LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"` - RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"` - LastWireguardHandshake *timestamp.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"` - BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` - BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` - RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` + PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` + ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` + ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` + Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` + Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` + LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` + RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` + Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"` + RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"` + LastWireguardHandshake *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"` + BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` + BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` + RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` } func (x *PeerState) Reset() { @@ -828,7 +830,7 @@ func (x *PeerState) GetConnStatus() string { return "" } -func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp { +func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp { if x != nil { return x.ConnStatusUpdate } @@ -884,7 +886,7 @@ func (x *PeerState) GetRemoteIceCandidateEndpoint() string { return "" } -func (x *PeerState) GetLastWireguardHandshake() *timestamp.Timestamp { +func (x *PeerState) GetLastWireguardHandshake() *timestamppb.Timestamp { if x != nil { return x.LastWireguardHandshake } @@ -919,6 +921,13 @@ func (x *PeerState) GetRoutes() []string { return nil } +func (x *PeerState) GetLatency() *durationpb.Duration { + if x != nil { + return x.Latency + } + return nil +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState @@ -1374,7 +1383,9 @@ var file_daemon_proto_rawDesc = []byte{ 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xdd, 0x06, 0x0a, 0x0c, 0x4c, 0x6f, + 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xdd, 0x06, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -1475,7 +1486,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x22, 0x99, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, + 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, @@ -1516,92 +1527,95 @@ var file_daemon_proto_rawDesc = []byte{ 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0xec, 0x01, 0x0a, - 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, - 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, - 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, - 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, - 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, - 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, - 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, - 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, - 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, - 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, - 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, - 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, - 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, - 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, - 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, - 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, - 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, - 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, - 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, - 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, - 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, - 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, - 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, - 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, - 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, - 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, - 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, - 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, - 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, - 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, - 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, - 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, - 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, + 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, + 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, + 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, + 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, + 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, + 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, + 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, + 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, + 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, + 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, + 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, + 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, + 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, + 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, + 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, + 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, + 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, + 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, + 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -1618,54 +1632,56 @@ func file_daemon_proto_rawDescGZIP() []byte { var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 19) var file_daemon_proto_goTypes = []interface{}{ - (*LoginRequest)(nil), // 0: daemon.LoginRequest - (*LoginResponse)(nil), // 1: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 4: daemon.UpRequest - (*UpResponse)(nil), // 5: daemon.UpResponse - (*StatusRequest)(nil), // 6: daemon.StatusRequest - (*StatusResponse)(nil), // 7: daemon.StatusResponse - (*DownRequest)(nil), // 8: daemon.DownRequest - (*DownResponse)(nil), // 9: daemon.DownResponse - (*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse - (*PeerState)(nil), // 12: daemon.PeerState - (*LocalPeerState)(nil), // 13: daemon.LocalPeerState - (*SignalState)(nil), // 14: daemon.SignalState - (*ManagementState)(nil), // 15: daemon.ManagementState - (*RelayState)(nil), // 16: daemon.RelayState - (*NSGroupState)(nil), // 17: daemon.NSGroupState - (*FullStatus)(nil), // 18: daemon.FullStatus - (*timestamp.Timestamp)(nil), // 19: google.protobuf.Timestamp + (*LoginRequest)(nil), // 0: daemon.LoginRequest + (*LoginResponse)(nil), // 1: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 4: daemon.UpRequest + (*UpResponse)(nil), // 5: daemon.UpResponse + (*StatusRequest)(nil), // 6: daemon.StatusRequest + (*StatusResponse)(nil), // 7: daemon.StatusResponse + (*DownRequest)(nil), // 8: daemon.DownRequest + (*DownResponse)(nil), // 9: daemon.DownResponse + (*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse + (*PeerState)(nil), // 12: daemon.PeerState + (*LocalPeerState)(nil), // 13: daemon.LocalPeerState + (*SignalState)(nil), // 14: daemon.SignalState + (*ManagementState)(nil), // 15: daemon.ManagementState + (*RelayState)(nil), // 16: daemon.RelayState + (*NSGroupState)(nil), // 17: daemon.NSGroupState + (*FullStatus)(nil), // 18: daemon.FullStatus + (*timestamppb.Timestamp)(nil), // 19: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 20: google.protobuf.Duration } var file_daemon_proto_depIdxs = []int32{ 18, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 19, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 19, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 15, // 3: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 14, // 4: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 13, // 5: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 12, // 6: daemon.FullStatus.peers:type_name -> daemon.PeerState - 16, // 7: daemon.FullStatus.relays:type_name -> daemon.RelayState - 17, // 8: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 0, // 9: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 2, // 10: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 4, // 11: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 6, // 12: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 8, // 13: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 10, // 14: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 1, // 15: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 3, // 16: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 5, // 17: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 7, // 18: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 9, // 19: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 11, // 20: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 15, // [15:21] is the sub-list for method output_type - 9, // [9:15] is the sub-list for method input_type - 9, // [9:9] is the sub-list for extension type_name - 9, // [9:9] is the sub-list for extension extendee - 0, // [0:9] is the sub-list for field type_name + 20, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 15, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 14, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 13, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 12, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState + 16, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState + 17, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 0, // 10: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 2, // 11: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 4, // 12: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 6, // 13: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 8, // 14: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 10, // 15: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 1, // 16: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 3, // 17: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 5, // 18: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 7, // 19: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 9, // 20: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 11, // 21: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 16, // [16:22] is the sub-list for method output_type + 10, // [10:16] is the sub-list for method input_type + 10, // [10:10] is the sub-list for extension type_name + 10, // [10:10] is the sub-list for extension extendee + 0, // [0:10] is the sub-list for field type_name } func init() { file_daemon_proto_init() } diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index bdb1cb83eea..8f9148d68af 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -2,6 +2,7 @@ syntax = "proto3"; import "google/protobuf/descriptor.proto"; import "google/protobuf/timestamp.proto"; +import "google/protobuf/duration.proto"; option go_package = "/proto"; @@ -142,6 +143,7 @@ message PeerState { int64 bytesTx = 14; bool rosenpassEnabled = 15; repeated string routes = 16; + google.protobuf.Duration latency = 17; } // LocalPeerState contains the latest state of the local peer diff --git a/client/server/server.go b/client/server/server.go index 5f1bf0100a4..481ef0f7cc6 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -13,6 +13,8 @@ import ( "github.com/cenkalti/backoff/v4" "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/durationpb" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/system" @@ -711,6 +713,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, Routes: maps.Keys(peerState.Routes), + Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) } diff --git a/go.mod b/go.mod index 6aba599f810..ce3da619e5a 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cloudflare/circl v1.3.3 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.5.3 - github.com/google/uuid v1.3.1 + github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.0 github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 github.com/onsi/ginkgo v1.16.5 @@ -21,8 +21,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 - golang.org/x/crypto v0.17.0 - golang.org/x/sys v0.15.0 + golang.org/x/crypto v0.18.0 + golang.org/x/sys v0.16.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -81,10 +81,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 - golang.org/x/net v0.17.0 + golang.org/x/net v0.20.0 golang.org/x/oauth2 v0.8.0 golang.org/x/sync v0.3.0 - golang.org/x/term v0.15.0 + golang.org/x/term v0.16.0 google.golang.org/api v0.126.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/sqlite v1.5.3 @@ -137,10 +137,10 @@ require ( github.com/nxadm/tail v1.4.8 // indirect github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect - github.com/pion/dtls/v2 v2.2.7 // indirect - github.com/pion/mdns v0.0.9 // indirect + github.com/pion/dtls/v2 v2.2.10 // indirect + github.com/pion/mdns v0.0.12 // indirect github.com/pion/randutil v0.1.0 // indirect - github.com/pion/transport/v2 v2.2.1 // indirect + github.com/pion/transport/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/common v0.37.0 // indirect @@ -175,3 +175,5 @@ replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-202 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 + +replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e diff --git a/go.sum b/go.sum index ca10cd55367..e304e3191dd 100644 --- a/go.sum +++ b/go.sum @@ -271,8 +271,8 @@ github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc= github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k= github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= @@ -380,6 +380,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= +github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= +github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4= github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA= @@ -423,20 +425,20 @@ github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY github.com/pegasus-kv/thrift v0.13.0/go.mod h1:Gl9NT/WHG6ABm6NsrbfE8LiJN0sAyneCrvB4qN4NPqQ= github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= -github.com/pion/dtls/v2 v2.2.7 h1:cSUBsETxepsCSFSxC3mc/aDo14qQLMSL+O6IjG28yV8= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= -github.com/pion/ice/v3 v3.0.2 h1:dNQnKsjLvOWz+PaI4tw1VnLYTp9adihC1HIASFGajmI= -github.com/pion/ice/v3 v3.0.2/go.mod h1:q3BDzTsxbqP0ySMSHrFuw2MYGUx/AC3WQfRGC5F/0Is= +github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= +github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns v0.0.9 h1:7Ue5KZsqq8EuqStnpPWV33vYYEH0+skdDN5L7EiEsI4= -github.com/pion/mdns v0.0.9/go.mod h1:2JA5exfxwzXiCihmxpTKgFUpiQws2MnipoPK09vecIc= +github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= +github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= -github.com/pion/transport/v2 v2.2.1 h1:7qYnCBlpgSJNYMbLCKuSY9KbQdBFoETvPNETv0y4N7c= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= +github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= +github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v3 v3.0.1 h1:gDTlPJwROfSfz6QfSi0ZmeCSkFcnWWiiR9ES0ouANiM= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= @@ -580,10 +582,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -671,9 +671,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -765,20 +764,16 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -792,7 +787,6 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 2475473227f02bf9a362e23ed9a289c6ca97bc7b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 21 Mar 2024 16:49:28 +0100 Subject: [PATCH 14/89] Support client default routes for Linux (#1667) All routes are now installed in a custom netbird routing table. Management and wireguard traffic is now marked with a custom fwmark. When the mark is present the traffic is routed via the main routing table, bypassing the VPN. When the mark is absent the traffic is routed via the netbird routing table, if: - there's no match in the main routing table - it would match the default route in the routing table IPv6 traffic is blocked when a default route IPv4 route is configured to avoid leakage. --- .github/workflows/golang-test-linux.yml | 16 +- .github/workflows/golangci-lint.yml | 2 +- client/internal/engine.go | 22 +- client/internal/relay/relay.go | 24 +- client/internal/routemanager/client.go | 63 ++- client/internal/routemanager/manager.go | 52 +- client/internal/routemanager/manager_test.go | 28 +- client/internal/routemanager/mock.go | 4 + .../routemanager/server_nonandroid.go | 51 +- .../routemanager/systemops_android.go | 4 +- client/internal/routemanager/systemops_bsd.go | 1 - .../routemanager/systemops_bsd_nonios.go | 13 + client/internal/routemanager/systemops_ios.go | 6 +- .../internal/routemanager/systemops_linux.go | 447 ++++++++++++++--- .../routemanager/systemops_linux_test.go | 469 ++++++++++++++++++ .../routemanager/systemops_nonandroid.go | 138 ++++-- .../routemanager/systemops_nonandroid_test.go | 142 +++--- .../routemanager/systemops_nonlinux.go | 27 +- .../routemanager/systemops_nonlinux_test.go | 80 +++ .../routemanager/systemops_windows.go | 15 +- client/internal/stdnet/dialer.go | 24 + client/internal/stdnet/listener.go | 20 + client/internal/wgproxy/portlookup.go | 6 +- client/internal/wgproxy/proxy_ebpf.go | 32 +- client/internal/wgproxy/proxy_userspace.go | 4 +- go.mod | 4 +- go.sum | 4 +- iface/address.go | 18 + iface/wg_configurer_kernel.go | 4 +- iface/wg_configurer_usp.go | 11 +- management/client/grpc.go | 2 + sharedsock/sock_linux.go | 59 ++- signal/client/grpc.go | 2 + util/grpc/dialer_generic.go | 9 + util/grpc/dialer_linux.go | 18 + util/net/dialer_generic.go | 19 + util/net/dialer_linux.go | 60 +++ util/net/listener_generic.go | 13 + util/net/listener_linux.go | 30 ++ util/net/net.go | 6 + util/net/net_linux.go | 35 ++ 41 files changed, 1632 insertions(+), 352 deletions(-) create mode 100644 client/internal/routemanager/systemops_bsd_nonios.go create mode 100644 client/internal/routemanager/systemops_linux_test.go create mode 100644 client/internal/routemanager/systemops_nonlinux_test.go create mode 100644 client/internal/stdnet/dialer.go create mode 100644 client/internal/stdnet/listener.go create mode 100644 util/grpc/dialer_generic.go create mode 100644 util/grpc/dialer_linux.go create mode 100644 util/net/dialer_generic.go create mode 100644 util/net/dialer_linux.go create mode 100644 util/net/listener_generic.go create mode 100644 util/net/listener_linux.go create mode 100644 util/net/net.go create mode 100644 util/net/net_linux.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index f6fab80c527..42f740e9b54 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -14,8 +14,8 @@ jobs: test: strategy: matrix: - arch: ['386','amd64'] - store: ['jsonfile', 'sqlite'] + arch: [ '386','amd64' ] + store: [ 'jsonfile', 'sqlite' ] runs-on: ubuntu-latest steps: - name: Install Go @@ -36,7 +36,11 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - name: Install modules run: go mod tidy @@ -67,7 +71,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - name: Install modules run: go mod tidy @@ -82,7 +86,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... + run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... @@ -109,7 +113,7 @@ jobs: - name: Run Engine tests in docker with file store run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 - + - name: Run Engine tests in docker with sqlite store run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9f543c74c45..13228250d59 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -40,7 +40,7 @@ jobs: cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index 78d26f0b8fb..7f7b5ef55ba 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -230,8 +230,8 @@ func (e *Engine) Start() error { wgIface, err := e.newWgIface() if err != nil { - log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error()) - return err + log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err) + return fmt.Errorf("new wg interface: %w", err) } e.wgInterface = wgIface @@ -244,29 +244,33 @@ func (e *Engine) Start() error { } e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName) if err != nil { - return err + return fmt.Errorf("create rosenpass manager: %w", err) } err := e.rpManager.Run() if err != nil { - return err + return fmt.Errorf("run rosenpass manager: %w", err) } } initialRoutes, dnsServer, err := e.newDnsServer() if err != nil { e.close() - return err + return fmt.Errorf("create dns server: %w", err) } e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) + if err := e.routeManager.Init(); err != nil { + e.close() + return fmt.Errorf("init route manager: %w", err) + } e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() if err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) e.close() - return err + return fmt.Errorf("create wg interface: %w", err) } e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) @@ -278,7 +282,7 @@ func (e *Engine) Start() error { err = e.routeManager.EnableServerRouter(e.firewall) if err != nil { e.close() - return err + return fmt.Errorf("enable server router: %w", err) } } @@ -286,7 +290,7 @@ func (e *Engine) Start() error { if err != nil { log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) e.close() - return err + return fmt.Errorf("up wg interface: %w", err) } if e.firewall != nil { @@ -296,7 +300,7 @@ func (e *Engine) Start() error { err = e.dnsServer.Initialize() if err != nil { e.close() - return err + return fmt.Errorf("initialize dns server: %w", err) } e.receiveSignalEvents() diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 1d8e6846d4e..84fd72e49c9 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -10,6 +10,9 @@ import ( "github.com/pion/stun/v2" "github.com/pion/turn/v3" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" + nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request @@ -27,7 +30,15 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } }() - client, err := stun.DialURI(uri, &stun.DialConfig{}) + net, err := stdnet.NewNet(nil) + if err != nil { + probeErr = fmt.Errorf("new net: %w", err) + return + } + + client, err := stun.DialURI(uri, &stun.DialConfig{ + Net: net, + }) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return @@ -85,14 +96,13 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) switch uri.Proto { case stun.ProtoTypeUDP: var err error - conn, err = net.ListenPacket("udp", "") + conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "") if err != nil { probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: - dialer := net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) + tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return @@ -109,12 +119,18 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } }() + net, err := stdnet.NewNet(nil) + if err != nil { + probeErr = fmt.Errorf("new net: %w", err) + return + } cfg := &turn.ClientConfig{ STUNServerAddr: turnServerAddr, TURNServerAddr: turnServerAddr, Conn: conn, Username: uri.Username, Password: uri.Password, + Net: net, } client, err := turn.NewClient(cfg) if err != nil { diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index f7ead582720..b2dff7f08cf 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -41,6 +41,7 @@ type clientNetwork struct { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { ctx, cancel := context.WithCancel(ctx) + client := &clientNetwork{ ctx: ctx, stop: cancel, @@ -72,6 +73,18 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { return routePeerStatuses } +// getBestRouteFromStatuses determines the most optimal route from the available routes +// within a clientNetwork, taking into account peer connection status, route metrics, and +// preference for non-relayed and direct connections. +// +// It follows these prioritization rules: +// * Connected peers: Only routes with connected peers are considered. +// * Metric: Routes with lower metrics (better) are prioritized. +// * Non-relayed: Routes without relays are preferred. +// * Direct connections: Routes with direct peer connections are favored. +// * Stability: In case of equal scores, the currently active route (if any) is maintained. +// +// It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" chosenScore := 0 @@ -158,7 +171,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) if err != nil { - return err + return fmt.Errorf("get peer state: %v", err) } delete(state.Routes, c.network.String()) @@ -172,7 +185,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) if err != nil { - return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", + return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } return nil @@ -180,30 +193,26 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } - err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) - if err != nil { - return fmt.Errorf("couldn't remove route %s from system, err: %v", - c.network, err) + + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route: %v", err) } } return nil } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { - - var err error - routerPeerStatuses := c.getRouterPeerStatuses() chosen := c.getBestRouteFromStatuses(routerPeerStatuses) + + // If no route is chosen, remove the route from the peer and system if chosen == "" { - err = c.removeRouteFromPeerAndSystem() - if err != nil { - return err + if err := c.removeRouteFromPeerAndSystem(); err != nil { + return fmt.Errorf("remove route from peer and system: %v", err) } c.chosenRoute = nil @@ -211,6 +220,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } + // If the chosen route is the same as the current route, do nothing if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute.IsEqual(c.routes[chosen]) { return nil @@ -218,13 +228,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + // If a previous route exists, remove it from the peer + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route from peer: %v", err) } } else { - err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) - if err != nil { + // otherwise add the route to the system + if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -245,8 +255,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } - err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) - if err != nil { + if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } @@ -287,21 +296,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system: %v", err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("received a routes update with smaller serial number, ignoring it") + log.Warnf("Received a routes update with smaller serial number, ignoring it") continue } - log.Debugf("received a new client network route update for %s", c.network) + log.Debugf("Received a new client network route update for %s", c.network) c.handleUpdate(update) @@ -309,7 +318,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index b624d8c34ce..6a0d954da09 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,6 +2,8 @@ package routemanager import ( "context" + "fmt" + "net/netip" "runtime" "sync" @@ -15,8 +17,14 @@ import ( "github.com/netbirdio/netbird/version" ) +var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + +// nolint:unused +var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + // Manager is a route manager interface type Manager interface { + Init() error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -56,6 +64,19 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } +// Init sets up the routing +func (m *DefaultManager) Init() error { + if err := cleanupRouting(); err != nil { + log.Warnf("Failed cleaning up routing: %v", err) + } + + if err := setupRouting(); err != nil { + return fmt.Errorf("setup routing: %w", err) + } + log.Info("Routing setup complete") + return nil +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -71,9 +92,15 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } + if err := cleanupRouting(); err != nil { + log.Errorf("Error cleaning up routing: %v", err) + } else { + log.Info("Routing cleanup complete") + } + m.ctx = nil } -// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps +// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -91,7 +118,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return err + return fmt.Errorf("update routes: %w", err) } } @@ -156,11 +183,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] for _, newRoute := range newRoutes { networkID := route.GetHAUniqueID(newRoute) if !ownNetworkIDs[networkID] { - // if prefix is too small, lets assume is a possible default route which is not yet supported - // we skip this route management - if newRoute.Network.Bits() < minRangeBits { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", - version.NetbirdVersion(), newRoute.Network) + if !isPrefixSupported(newRoute.Network) { continue } newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) @@ -178,3 +201,18 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } + +func isPrefixSupported(prefix netip.Prefix) bool { + if runtime.GOOS == "linux" { + return true + } + + // If prefix is too small, lets assume it is a possible default prefix which is not yet supported + // we skip this prefix management + if prefix.Bits() < minRangeBits { + log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", + version.NetbirdVersion(), prefix) + return false + } + return true +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2e5cf6649d8..9d92bf90d2f 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedLinux int }{ { name: "Should create 2 client networks", @@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedLinux: 1, }, { name: "Remove 1 Client Route", @@ -415,6 +417,8 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) + err = routeManager.Init() + require.NoError(t, err, "should init route manager") defer routeManager.Stop() if testCase.removeSrvRouter { @@ -429,7 +433,11 @@ func TestManagerUpdateRoutes(t *testing.T) { err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") - require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") + expectedWatchers := testCase.clientNetworkWatchersExpected + if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedLinux + } + require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index a1214cbb9ec..e812b3a85b6 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -16,6 +16,10 @@ type MockManager struct { StopFunc func() } +func (m *MockManager) Init() error { + return nil +} + // InitialRouteRange mock implementation of InitialRouteRange from Manager interface func (m *MockManager) InitialRouteRange() []string { return nil diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 19236787772..00df735fb8a 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,6 +4,7 @@ package routemanager import ( "context" + "fmt" "net/netip" "sync" @@ -48,7 +49,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er oldRoute := m.routes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } delete(m.routes, routeID) @@ -62,7 +63,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er err := m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue } m.routes[id] = newRoute @@ -81,15 +82,22 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not removing from server network because context is done") + log.Infof("Not removing from server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.RemoveRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("remove routing rules: %w", err) } + delete(m.routes, route.ID) state := m.statusRecorder.GetLocalPeerState() @@ -103,15 +111,22 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not adding to server network because context is done") + log.Infof("Not adding to server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.InsertRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("insert routing rules: %w", err) } + m.routes[route.ID] = route state := m.statusRecorder.GetLocalPeerState() @@ -129,9 +144,15 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) + if err != nil { + log.Errorf("Failed to convert route to router pair: %v", err) + continue + } + + err = m.firewall.RemoveRoutingRules(routerPair) if err != nil { - log.Warnf("failed to remove clean up route: %s", r.ID) + log.Errorf("Failed to remove cleanup route: %v", err) } state := m.statusRecorder.GetLocalPeerState() @@ -139,13 +160,15 @@ func (m *defaultServerRouter) cleanUp() { m.statusRecorder.UpdateLocalPeerState(state) } } - -func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { - parsed := netip.MustParsePrefix(source).Masked() +func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { + parsed, err := netip.ParsePrefix(source) + if err != nil { + return firewall.RouterPair{}, err + } return firewall.RouterPair{ ID: route.ID, Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, - } + }, nil } diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 950a268434c..291826780af 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -4,10 +4,10 @@ import ( "net/netip" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index b2da8075cfa..173e7c0e847 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -1,5 +1,4 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package routemanager diff --git a/client/internal/routemanager/systemops_bsd_nonios.go b/client/internal/routemanager/systemops_bsd_nonios.go new file mode 100644 index 00000000000..f60c7afc3a0 --- /dev/null +++ b/client/internal/routemanager/systemops_bsd_nonios.go @@ -0,0 +1,13 @@ +//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios + +package routemanager + +import "net/netip" + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { + return genericAddToRouteTableIfNoExists(prefix, addr, intf) +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { + return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index aae0f8dc8f2..291826780af 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,15 +1,13 @@ -//go:build ios - package routemanager import ( "net/netip" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 0562826a55d..192509992c7 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -3,142 +3,298 @@ package routemanager import ( + "bufio" + "errors" + "fmt" "net" "net/netip" "os" "syscall" - "unsafe" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + + nbnet "github.com/netbirdio/netbird/util/net" ) -// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html -// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. -type routeInfoInMemory struct { - Family byte - DstLen byte - SrcLen byte - TOS byte +const ( + // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. + NetbirdVPNTableID = 0x1BD0 + // NetbirdVPNTableName is the name of the custom routing table used by Netbird. + NetbirdVPNTableName = "netbird" + + // rtTablesPath is the path to the file containing the routing table names. + rtTablesPath = "/etc/iproute2/rt_tables" - Table byte - Protocol byte - Scope byte - Type byte + // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. + ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +) - Flags uint32 +var ErrTableIDExists = errors.New("ID exists with different name") + +type ruleParams struct { + fwmark int + tableID int + family int + priority int + invert bool + suppressPrefix int + description string } -const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +func getSetupRules() []ruleParams { + return []ruleParams{ + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"}, + } +} -func addToRouteTable(prefix netip.Prefix, addr string) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return err +// setupRouting establishes the routing configuration for the VPN, including essential rules +// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. +// +// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over +// potential routes received and configured for the VPN. This rule is skipped for the default route and routes +// that are not in the main table. +// +// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. +// This table is where a default route or other specific routes received from the management server are configured, +// enabling VPN connectivity. +// +// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. +func setupRouting() (err error) { + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) } - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" + defer func() { + if err != nil { + if cleanErr := cleanupRouting(); cleanErr != nil { + log.Errorf("Error cleaning up routing: %v", cleanErr) + } + } + }() + + rules := getSetupRules() + for _, rule := range rules { + if err := addRule(rule); err != nil { + return fmt.Errorf("%s: %w", rule.description, err) + } } - ip, _, err := net.ParseCIDR(addr + addrMask) - if err != nil { - return err + return nil +} + +// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. +// It systematically removes the three rules and any associated routing table entries to ensure a clean state. +// The function uses error aggregation to report any errors encountered during the cleanup process. +func cleanupRouting() error { + var result *multierror.Error + + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err)) + } + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err)) } + rules := getSetupRules() + for _, rule := range rules { + if err := removeAllRules(rule); err != nil { + result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) + } + } + + return result.ErrorOrNil() +} + +func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error { + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 + + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + return fmt.Errorf("add blackhole: %w", err) + } + } + if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + return fmt.Errorf("add route: %w", err) + } + return nil +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error { + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + return fmt.Errorf("remove unreachable route: %w", err) + } + } + if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil +} + +func getRoutesFromTable() ([]netip.Prefix, error) { + return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4) +} + +// addRoute adds a route to a specific routing table identified by tableID. +func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: family, } - err = netlink.RouteAdd(route) - if err != nil { - return err + if prefix != nil { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + route.Dst = ipNet + } + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("netlink add route: %w", err) } return nil } -func removeFromRouteTable(prefix netip.Prefix, addr string) error { +// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. +// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. +// tableID specifies the routing table to which the unreachable route will be added. +func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" + route := &netlink.Route{ + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: ipFamily, + Dst: ipNet, } - ip, _, err := net.ParseCIDR(addr + addrMask) + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("netlink add unreachable route: %w", err) + } + + return nil +} + +func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: ipFamily, + Dst: ipNet, } - err = netlink.RouteDel(route) - if err != nil { - return err + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + return fmt.Errorf("netlink remove unreachable route: %w", err) } return nil + } -func getRoutesFromTable() ([]netip.Prefix, error) { - tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) +// removeRoute removes a route from a specific routing table identified by tableID. +func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return nil, err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } - msgs, err := syscall.ParseNetlinkMessage(tab) + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: family, + Dst: ipNet, + } + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + return fmt.Errorf("netlink remove route: %w", err) + } + + return nil +} + +func flushRoutes(tableID, family int) error { + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) if err != nil { - return nil, err + return fmt.Errorf("list routes from table %d: %w", tableID, err) } - var prefixList []netip.Prefix -loop: - for _, m := range msgs { - switch m.Header.Type { - case syscall.NLMSG_DONE: - break loop - case syscall.RTM_NEWROUTE: - rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) - msg := m - attrs, err := syscall.ParseNetlinkRouteAttr(&msg) - if err != nil { - return nil, err + + var result *multierror.Error + for i := range routes { + route := routes[i] + // unreachable default routes don't come back with Dst set + if route.Gw == nil && route.Src == nil && route.Dst == nil { + if family == netlink.FAMILY_V4 { + routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} + } else { + routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} } - if rt.Family != syscall.AF_INET { - continue loop + } + if err := netlink.RouteDel(&routes[i]); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) + } + } + + return result.ErrorOrNil() +} + +// getRoutes fetches routes from a specific routing table identified by tableID. +func getRoutes(tableID, family int) ([]netip.Prefix, error) { + var prefixList []netip.Prefix + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) + } + + for _, route := range routes { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) } - for _, attr := range attrs { - if attr.Attr.Type == syscall.RTA_DST { - addr, ok := netip.AddrFromSlice(attr.Value) - if !ok { - continue - } - mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) - cidr, _ := mask.Size() - routePrefix := netip.PrefixFrom(addr, cidr) - if routePrefix.IsValid() && routePrefix.Addr().Is4() { - prefixList = append(prefixList, routePrefix) - } - } + ones, _ := route.Dst.Mask.Size() + + prefix := netip.PrefixFrom(addr, ones) + if prefix.IsValid() { + prefixList = append(prefixList, prefix) } } } + return prefixList, nil } func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { - return err + return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) } // check if it is already enabled @@ -147,5 +303,142 @@ func enableIPForwarding() error { return nil } - return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec + //nolint:gosec + if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { + return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) + } + return nil +} + +// entryExists checks if the specified ID or name already exists in the rt_tables file +// and verifies if existing names start with "netbird_". +func entryExists(file *os.File, id int) (bool, error) { + if _, err := file.Seek(0, 0); err != nil { + return false, fmt.Errorf("seek rt_tables: %w", err) + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + var existingID int + var existingName string + if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil { + if existingID == id { + if existingName != NetbirdVPNTableName { + return true, ErrTableIDExists + } + return true, nil + } + } + } + if err := scanner.Err(); err != nil { + return false, fmt.Errorf("scan rt_tables: %w", err) + } + return false, nil +} + +// addRoutingTableName adds human-readable names for custom routing tables. +func addRoutingTableName() error { + file, err := os.Open(rtTablesPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("open rt_tables: %w", err) + } + defer func() { + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables: %v", err) + } + }() + + exists, err := entryExists(file, NetbirdVPNTableID) + if err != nil { + return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err) + } + if exists { + return nil + } + + // Reopen the file in append mode to add new entries + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables before appending: %v", err) + } + file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) + if err != nil { + return fmt.Errorf("open rt_tables for appending: %w", err) + } + + if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil { + return fmt.Errorf("append entry to rt_tables: %w", err) + } + + return nil +} + +// addRule adds a routing rule to a specific routing table identified by tableID. +func addRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Priority = params.priority + rule.Invert = params.invert + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleAdd(rule); err != nil { + return fmt.Errorf("add routing rule: %w", err) + } + + return nil +} + +// removeRule removes a routing rule from a specific routing table identified by tableID. +func removeRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Invert = params.invert + rule.Priority = params.priority + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleDel(rule); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + + return nil +} + +func removeAllRules(params ruleParams) error { + for { + if err := removeRule(params); err != nil { + if errors.Is(err, syscall.ENOENT) { + break + } + return err + } + } + return nil +} + +// addNextHop adds the gateway and device to the route. +func addNextHop(addr *string, intf *string, route *netlink.Route) error { + if addr != nil { + ip := net.ParseIP(*addr) + if ip == nil { + return fmt.Errorf("parsing address %s failed", *addr) + } + + route.Gw = ip + } + + if intf != nil { + link, err := netlink.LinkByName(*intf) + if err != nil { + return fmt.Errorf("set interface %s: %w", *intf, err) + } + route.LinkIndex = link.Attrs().Index + } + + return nil } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go new file mode 100644 index 00000000000..96e43d20f0b --- /dev/null +++ b/client/internal/routemanager/systemops_linux_test.go @@ -0,0 +1,469 @@ +//go:build !android + +package routemanager + +import ( + "errors" + "fmt" + "net" + "net/netip" + "os" + "strings" + "syscall" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +func TestEntryExists(t *testing.T) { + tempDir := t.TempDir() + tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) + + content := []string{ + "1000 reserved", + fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName), + "9999 other_table", + } + require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644)) + + file, err := os.Open(tempFilePath) + require.NoError(t, err) + defer func() { + assert.NoError(t, file.Close()) + }() + + tests := []struct { + name string + id int + shouldExist bool + err error + }{ + { + name: "ExistsWithNetbirdPrefix", + id: 7120, + shouldExist: true, + err: nil, + }, + { + name: "ExistsWithDifferentName", + id: 1000, + shouldExist: true, + err: ErrTableIDExists, + }, + { + name: "DoesNotExist", + id: 1234, + shouldExist: false, + err: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + exists, err := entryExists(file, tc.id) + if tc.err != nil { + assert.ErrorIs(t, err, tc.err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.shouldExist, exists) + }) + } +} + +func TestRoutingWithTables(t *testing.T) { + testCases := []struct { + name string + destination string + captureInterface string + dialer *net.Dialer + packetExpectation PacketExpectation + }{ + { + name: "To external host without fwmark via vpn", + destination: "192.0.2.1:53", + captureInterface: "wgtest0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with fwmark via physical interface", + destination: "192.0.2.1:53", + captureInterface: "dummyext0", + dialer: nbnet.NewDialer(), + packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with fwmark via physical interface", + destination: "10.0.0.1:53", + captureInterface: "dummyint0", + dialer: nbnet.NewDialer(), + packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), + }, + { + name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence + destination: "10.0.0.1:53", + captureInterface: "dummyint0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), + }, + + { + name: "To unique vpn route with fwmark via physical interface", + destination: "172.16.0.1:53", + captureInterface: "dummyext0", + dialer: nbnet.NewDialer(), + packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53), + }, + { + name: "To unique vpn route without fwmark via vpn", + destination: "172.16.0.1:53", + captureInterface: "wgtest0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53), + }, + + { + name: "To more specific route without fwmark via vpn interface", + destination: "10.10.0.1:53", + captureInterface: "dummyint0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53), + }, + + { + name: "To more specific route (local) without fwmark via physical interface", + destination: "127.0.10.1:53", + captureInterface: "lo", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wgIface, _, _ := setupTestEnv(t) + + // default route exists in main table and vpn table + err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 10.0.0.0/8 route exists in main table and vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 10.10.0.0/24 more specific route exists in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 127.0.10.0/24 more specific route exists in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // unique route in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.captureInterface, filter) + + sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.packetExpectation) + }) + } +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } + +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy { + t.Helper() + + dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} + err := netlink.LinkDel(dummy) + if err != nil && !errors.Is(err, syscall.EINVAL) { + t.Logf("Failed to delete dummy interface: %v", err) + } + + err = netlink.LinkAdd(dummy) + require.NoError(t, err) + + err = netlink.LinkSetUp(dummy) + require.NoError(t, err) + + if ipAddressCIDR != "" { + addr, err := netlink.ParseAddr(ipAddressCIDR) + require.NoError(t, err) + err = netlink.AddrAdd(dummy, addr) + require.NoError(t, err) + } + + return dummy +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { + t.Helper() + + _, dstIPNet, err := net.ParseCIDR(dstCIDR) + require.NoError(t, err) + + if dstIPNet.String() == "0.0.0.0/0" { + gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4) + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + // Handle existing routes with metric 0 + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + if err == nil { + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + } else if !errors.Is(err, syscall.ESRCH) { + t.Logf("Failed to delete route: %v", err) + } + } + + route := &netlink.Route{ + Dst: dstIPNet, + Gw: gw, + LinkIndex: linkIndex, + } + err = netlink.RouteDel(route) + if err != nil && !errors.Is(err, syscall.ESRCH) { + t.Logf("Failed to delete route: %v", err) + } + + err = netlink.RouteAdd(route) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } +} + +// fetchOriginalGateway returns the original gateway IP address and the interface index. +func fetchOriginalGateway(family int) (net.IP, int, error) { + routes, err := netlink.RouteList(nil, family) + if err != nil { + return nil, 0, err + } + + for _, route := range routes { + if route.Dst == nil { + return route.Gw, route.LinkIndex, nil + } + } + + return nil, 0, fmt.Errorf("default route not found") +} + +func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index) + + otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index) + + t.Cleanup(func() { + err := netlink.LinkDel(defaultDummy) + assert.NoError(t, err) + err = netlink.LinkDel(otherDummy) + assert.NoError(t, err) + }) + + return defaultDummy.Name, otherDummy.Name +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet(nil) + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) { + t.Helper() + + defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + err := setupRouting() + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + return wgIface, defaultDummy, otherDummy +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + dialer.LocalAddr = localUDPAddr + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go index 11247c7dccd..65f670ace17 100644 --- a/client/internal/routemanager/systemops_nonandroid.go +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -1,11 +1,15 @@ -//go:build !android && !ios +//go:build !android +//nolint:unused package routemanager import ( + "errors" "fmt" "net" "net/netip" + "os/exec" + "runtime" "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" @@ -13,41 +17,16 @@ import ( var errRouteNotFound = fmt.Errorf("route not found") -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return err - } - if ok { - log.Warnf("skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return err - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, addr) -} - -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil && err != errRouteNotFound { - return err +func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + defaultGateway, err := getExistingRIBRouteGateway(defaultv4) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) } addr := netip.MustParseAddr(defaultGateway.String()) if !prefix.Contains(addr) { - log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) return nil } @@ -59,56 +38,79 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { } if ok { - log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) return nil } gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) - if err != nil && err != errRouteNotFound { + if err != nil && !errors.Is(err, errRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } - log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop.String()) + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "") } -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() +func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { + ok, err := existsInRouteTable(prefix) if err != nil { - return false, err + return fmt.Errorf("exists in route table: %w", err) } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil } - return false, nil -} -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() + ok, err = isSubRange(prefix) if err != nil { - return false, err + return fmt.Errorf("sub range: %w", err) } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil + + if ok { + err := genericAddRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) } } - return false, nil + + return genericAddToRouteTable(prefix, addr, intf) +} + +func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { + return genericRemoveFromRouteTable(prefix, addr, intf) +} + +func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error { + cmd := exec.Command("route", "add", prefix.String(), addr) + out, err := cmd.Output() + if err != nil { + return fmt.Errorf("add route: %w", err) + } + log.Debugf(string(out)) + return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { - return removeFromRouteTable(prefix, addr) +func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error { + args := []string{"delete", prefix.String()} + if runtime.GOOS == "darwin" { + args = append(args, addr) + } + cmd := exec.Command("route", args...) + out, err := cmd.Output() + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + log.Debugf(string(out)) + return nil } func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { r, err := netroute.New() if err != nil { - return nil, err + return nil, fmt.Errorf("new netroute: %w", err) } _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) if err != nil { - log.Errorf("getting routes returned an error: %v", err) + log.Errorf("Getting routes returned an error: %v", err) return nil, errRouteNotFound } @@ -118,3 +120,29 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { return gateway, nil } + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go index 6f32d9634bc..aae5e5faa16 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -8,17 +8,63 @@ import ( "net" "net/netip" "os" + "os/exec" + "runtime" "strings" "testing" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + + if runtime.GOOS == "linux" { + outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String()) + require.NoError(t, err, "getOutgoingInterfaceLinux should not return error") + if invert { + require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface") + } else { + require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface") + } + return + } + + prefixGateway, err := getExistingRIBRouteGateway(prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} + +func getOutgoingInterfaceLinux(destination string) (string, error) { + cmd := exec.Command("ip", "route", "get", destination) + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("executing ip route get: %w", err) + } + + return parseOutgoingInterface(string(output)), nil +} + +func parseOutgoingInterface(routeGetOutput string) string { + fields := strings.Fields(routeGetOutput) + for i, field := range fields { + if field == "dev" && i+1 < len(fields) { + return fields[i+1] + } + } + return "" +} + func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string @@ -54,23 +100,26 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, setupRouting()) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { - require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + assertWGOutInterface(t, testCase.prefix, wgInterface, false) } else { - require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") + assertWGOutInterface(t, testCase.prefix, wgInterface, true) } exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) + prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) require.NoError(t, err, "getExistingRIBRouteGateway should not return err") internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) @@ -189,16 +238,21 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + require.NoError(t, setupRouting()) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + MockAddr := wgInterface.Address().IP.String() // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) + err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) + err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -208,7 +262,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) + err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -217,72 +271,12 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { ok, err := existsInRouteTable(testCase.prefix) t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") - if !strings.Contains(buf.String(), "because it already exists") { + + // Linux uses a separate routing table, so the route can exist in both tables. + // The main routing table takes precedence over the wireguard routing table. + if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { require.False(t, ok, "route should not exist") } }) } } - -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { - addressPrefixes = append(addressPrefixes, p.Masked()) - } - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} - -func TestIsSubRange(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var subRangeAddressPrefixes []netip.Prefix - var nonSubRangeAddressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { - p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) - subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) - nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) - } - } - - for _, prefix := range subRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if !isSubRangePrefix { - t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) - } - } - - for _, prefix := range nonSubRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if isSubRangePrefix { - t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) - } - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 47bd60eb02b..d793f0fbde0 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,41 +1,22 @@ -//go:build !linux -// +build !linux +//go:build !linux || android package routemanager import ( - "net/netip" - "os/exec" "runtime" log "github.com/sirupsen/logrus" ) -func addToRouteTable(prefix netip.Prefix, addr string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) +func setupRouting() error { return nil } -func removeFromRouteTable(prefix netip.Prefix, addr string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) - } - cmd := exec.Command("route", args...) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) +func cleanupRouting() error { return nil } func enableIPForwarding() error { - log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go new file mode 100644 index 00000000000..afaf5ba7724 --- /dev/null +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -0,0 +1,80 @@ +//go:build !linux || android + +package routemanager + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsSubRange(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var subRangeAddressPrefixes []netip.Prefix + var nonSubRangeAddressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { + p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) + subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) + nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) + } + } + + for _, prefix := range subRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if !isSubRangePrefix { + t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) + } + } + + for _, prefix := range nonSubRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if isSubRangePrefix { + t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) + } + } +} + +func TestExistsInRouteTable(t *testing.T) { + require.NoError(t, setupRouting()) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is4() { + addressPrefixes = append(addressPrefixes, p.Masked()) + } + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 309c184b9ca..c009ce66b9d 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -1,12 +1,13 @@ //go:build windows -// +build windows package routemanager import ( + "fmt" "net" "net/netip" + log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" ) @@ -21,17 +22,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) { err := wmi.Query(query, &routes) if err != nil { - return nil, err + return nil, fmt.Errorf("get routes: %w", err) } var prefixList []netip.Prefix for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { + log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) continue } maskSlice := net.ParseIP(route.Mask).To4() if maskSlice == nil { + log.Warnf("Unable to parse route mask %s", route.Mask) continue } mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) @@ -44,3 +47,11 @@ func getRoutesFromTable() ([]netip.Prefix, error) { } return prefixList, nil } + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { + return genericAddToRouteTableIfNoExists(prefix, addr, intf) +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { + return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) +} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go new file mode 100644 index 00000000000..e80adb42b20 --- /dev/null +++ b/client/internal/stdnet/dialer.go @@ -0,0 +1,24 @@ +package stdnet + +import ( + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// Dial connects to the address on the named network. +func (n *Net) Dial(network, address string) (net.Conn, error) { + return nbnet.NewDialer().Dial(network, address) +} + +// DialUDP connects to the address on the named UDP network. +func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.DialUDP(network, laddr, raddr) +} + +// DialTCP connects to the address on the named TCP network. +func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + return nbnet.DialTCP(network, laddr, raddr) +} diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go new file mode 100644 index 00000000000..9ce0a555610 --- /dev/null +++ b/client/internal/stdnet/listener.go @@ -0,0 +1,20 @@ +package stdnet + +import ( + "context" + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// ListenPacket listens for incoming packets on the given network and address. +func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { + return nbnet.NewListener().ListenPacket(context.Background(), network, address) +} + +// ListenUDP acts like ListenPacket for UDP networks. +func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.ListenUDP(network, locAddr) +} diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/portlookup.go index 6f3d33487ea..6ede4b83f1d 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/portlookup.go @@ -1,8 +1,10 @@ package wgproxy import ( + "context" "fmt" - "net" + + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -23,7 +25,7 @@ func (pl portLookup) searchFreePort() (int, error) { } func (pl portLookup) tryToBind(port int) error { - l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) + l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port)) if err != nil { return err } diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 6ca19c9737e..b91cd7b439d 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) // WGEBPFProxy definition for proxy with EBPF support @@ -66,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - p.conn, err = net.ListenUDP("udp", &addr) + p.conn, err = nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -208,20 +209,41 @@ generatePort: } func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) if err != nil { - return nil, err + return nil, fmt.Errorf("creating raw socket failed: %w", err) } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) if err != nil { - return nil, err + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) } + + // Bind the socket to the "lo" interface. err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") if err != nil { - return nil, err + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) } - return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))) + return packetConn, nil } func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index b692ea70842..17ebfbc499b 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -6,6 +6,8 @@ import ( "net" log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -33,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/go.mod b/go.mod index ce3da619e5a..67ec9c42ee0 100644 --- a/go.mod +++ b/go.mod @@ -47,8 +47,9 @@ require ( github.com/google/go-cmp v0.5.9 github.com/google/gopacket v1.1.19 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 - github.com/hashicorp/go-multierror v1.1.0 + github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 github.com/libp2p/go-netroute v0.2.0 @@ -123,7 +124,6 @@ require ( github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect - github.com/gopacket/gopacket v1.1.1 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index e304e3191dd..c36b8aff31d 100644 --- a/go.sum +++ b/go.sum @@ -291,8 +291,8 @@ github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f2 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= -github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= diff --git a/iface/address.go b/iface/address.go index 5ff4fbc0645..2920d009fa1 100644 --- a/iface/address.go +++ b/iface/address.go @@ -23,6 +23,24 @@ func parseWGAddress(address string) (WGAddress, error) { }, nil } +// Masked returns the WGAddress with the IP address part masked according to its network mask. +func (addr WGAddress) Masked() WGAddress { + ip := addr.IP.To4() + if ip == nil { + ip = addr.IP.To16() + } + + maskedIP := make(net.IP, len(ip)) + for i := range ip { + maskedIP[i] = ip[i] & addr.Network.Mask[i] + } + + return WGAddress{ + IP: maskedIP, + Network: addr.Network, + } +} + func (addr WGAddress) String() string { maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 36fd13cc262..9fe987cee21 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,6 +10,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -29,7 +31,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := 0 + fwmark := nbnet.NetbirdFwmark config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 200bfbc9614..24dfadf1408 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -13,6 +13,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgUSPConfigurer struct { @@ -37,7 +39,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error if err != nil { return err } - fwmark := 0 + fwmark := getFwmark() config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, @@ -345,3 +347,10 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } return sb.String() } + +func getFwmark() int { + if runtime.GOOS == "linux" { + return nbnet.NetbirdFwmark + } + return 0 +} diff --git a/management/client/grpc.go b/management/client/grpc.go index 0234f866cb8..0b1804906c2 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -57,6 +58,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE mgmCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 656fdc8ca24..74ac6c163ad 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -21,6 +21,8 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" + + nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -55,8 +57,7 @@ var writeSerializerOptions = gopacket.SerializeOptions{ } // Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines -func Listen(port int, filter BPFFilter) (net.PacketConn, error) { - var err error +func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { ctx, cancel := context.WithCancel(context.Background()) rawSock := &SharedSocket{ ctx: ctx, @@ -65,37 +66,51 @@ func Listen(port int, filter BPFFilter) (net.PacketConn, error) { packetDemux: make(chan rcvdPacket), } + defer func() { + if err != nil { + if closeErr := rawSock.Close(); closeErr != nil { + log.Errorf("Failed to close raw socket: %v", closeErr) + } + } + }() + rawSock.router, err = netroute.New() if err != nil { - return nil, fmt.Errorf("failed to create raw socket router: %v", err) + return nil, fmt.Errorf("failed to create raw socket router: %w", err) } rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) if err != nil { - return nil, fmt.Errorf("failed to create ipv4 raw socket: %v", err) + return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } - rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) - if err != nil { - log.Errorf("failed to create ipv6 raw socket: %v", err) + if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + } + + var sockErr error + rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) + if sockErr != nil { + log.Errorf("Failed to create ipv6 raw socket: %v", err) + } else { + if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) if err != nil { - _ = rawSock.Close() - return nil, fmt.Errorf("getBPFInstructions failed with: %rawSock", err) + return nil, fmt.Errorf("getBPFInstructions failed with: %w", err) } err = rawSock.conn4.SetBPF(ipv4Instructions) if err != nil { - _ = rawSock.Close() - return nil, fmt.Errorf("socket4.SetBPF failed with: %rawSock", err) + return nil, fmt.Errorf("socket4.SetBPF failed with: %w", err) } if rawSock.conn6 != nil { err = rawSock.conn6.SetBPF(ipv6Instructions) if err != nil { - _ = rawSock.Close() - return nil, fmt.Errorf("socket6.SetBPF failed with: %rawSock", err) + return nil, fmt.Errorf("socket6.SetBPF failed with: %w", err) } } @@ -121,7 +136,7 @@ func (s *SharedSocket) updateRouter() { case <-ticker.C: router, err := netroute.New() if err != nil { - log.Errorf("failed to create and update packet router for stunListener: %s", err) + log.Errorf("Failed to create and update packet router for stunListener: %s", err) continue } s.routerMux.Lock() @@ -144,7 +159,7 @@ func (s *SharedSocket) LocalAddr() net.Addr { func (s *SharedSocket) SetDeadline(t time.Time) error { err := s.conn4.SetDeadline(t) if err != nil { - return fmt.Errorf("s.conn4.SetDeadline error: %s", err) + return fmt.Errorf("s.conn4.SetDeadline error: %w", err) } if s.conn6 == nil { return nil @@ -152,7 +167,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error { err = s.conn6.SetDeadline(t) if err != nil { - return fmt.Errorf("s.conn6.SetDeadline error: %s", err) + return fmt.Errorf("s.conn6.SetDeadline error: %w", err) } return nil } @@ -161,7 +176,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error { func (s *SharedSocket) SetReadDeadline(t time.Time) error { err := s.conn4.SetReadDeadline(t) if err != nil { - return fmt.Errorf("s.conn4.SetReadDeadline error: %s", err) + return fmt.Errorf("s.conn4.SetReadDeadline error: %w", err) } if s.conn6 == nil { return nil @@ -169,7 +184,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error { err = s.conn6.SetReadDeadline(t) if err != nil { - return fmt.Errorf("s.conn6.SetReadDeadline error: %s", err) + return fmt.Errorf("s.conn6.SetReadDeadline error: %w", err) } return nil } @@ -178,7 +193,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error { func (s *SharedSocket) SetWriteDeadline(t time.Time) error { err := s.conn4.SetWriteDeadline(t) if err != nil { - return fmt.Errorf("s.conn4.SetWriteDeadline error: %s", err) + return fmt.Errorf("s.conn4.SetWriteDeadline error: %w", err) } if s.conn6 == nil { return nil @@ -186,7 +201,7 @@ func (s *SharedSocket) SetWriteDeadline(t time.Time) error { err = s.conn6.SetWriteDeadline(t) if err != nil { - return fmt.Errorf("s.conn6.SetWriteDeadline error: %s", err) + return fmt.Errorf("s.conn6.SetWriteDeadline error: %w", err) } return nil } @@ -282,7 +297,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { _, _, src, err := s.router.Route(rUDPAddr.IP) if err != nil { - return 0, fmt.Errorf("got an error while checking route, err: %s", err) + return 0, fmt.Errorf("got an error while checking route, err: %w", err) } rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) @@ -292,7 +307,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { } if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil { - return -1, fmt.Errorf("failed serialize rcvdPacket: %s", err) + return -1, fmt.Errorf("failed serialize rcvdPacket: %w", err) } bufser := buffer.Bytes() diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7531608c3bb..7c4535e2896 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -76,6 +77,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo sigCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/util/grpc/dialer_generic.go b/util/grpc/dialer_generic.go new file mode 100644 index 00000000000..1c2285b14bf --- /dev/null +++ b/util/grpc/dialer_generic.go @@ -0,0 +1,9 @@ +//go:build !linux || android + +package grpc + +import "google.golang.org/grpc" + +func WithCustomDialer() grpc.DialOption { + return grpc.EmptyDialOption{} +} diff --git a/util/grpc/dialer_linux.go b/util/grpc/dialer_linux.go new file mode 100644 index 00000000000..b29ee4b2936 --- /dev/null +++ b/util/grpc/dialer_linux.go @@ -0,0 +1,18 @@ +//go:build !android + +package grpc + +import ( + "context" + "net" + + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func WithCustomDialer() grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return nbnet.NewDialer().DialContext(ctx, "tcp", addr) + }) +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go new file mode 100644 index 00000000000..a3c3ad67c74 --- /dev/null +++ b/util/net/dialer_generic.go @@ -0,0 +1,19 @@ +//go:build !linux || android + +package net + +import ( + "net" +) + +func NewDialer() *net.Dialer { + return &net.Dialer{} +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + return net.DialUDP(network, laddr, raddr) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + return net.DialTCP(network, laddr, raddr) +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go new file mode 100644 index 00000000000..d559490c517 --- /dev/null +++ b/util/net/dialer_linux.go @@ -0,0 +1,60 @@ +//go:build !android + +package net + +import ( + "context" + "fmt" + "net" + "syscall" + + log "github.com/sirupsen/logrus" +) + +func NewDialer() *net.Dialer { + return &net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + return SetRawSocketMark(c) + }, + } +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.DialContext(context.Background(), network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type") + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.DialContext(context.Background(), network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type") + } + + return tcpConn, nil +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go new file mode 100644 index 00000000000..241c744e528 --- /dev/null +++ b/util/net/listener_generic.go @@ -0,0 +1,13 @@ +//go:build !linux || android + +package net + +import "net" + +func NewListener() *net.ListenConfig { + return &net.ListenConfig{} +} + +func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, locAddr) +} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go new file mode 100644 index 00000000000..7b9bda97c7d --- /dev/null +++ b/util/net/listener_linux.go @@ -0,0 +1,30 @@ +//go:build !android + +package net + +import ( + "context" + "fmt" + "net" + "syscall" +) + +func NewListener() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + return SetRawSocketMark(c) + }, + } +} + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err) + } + udpConn, ok := pc.(*net.UDPConn) + if !ok { + return nil, fmt.Errorf("packetConn is not a *net.UDPConn") + } + return udpConn, nil +} diff --git a/util/net/net.go b/util/net/net.go new file mode 100644 index 00000000000..5714e52294e --- /dev/null +++ b/util/net/net.go @@ -0,0 +1,6 @@ +package net + +const ( + // NetbirdFwmark is the fwmark value used by Netbird via wireguard + NetbirdFwmark = 0x1BD00 +) diff --git a/util/net/net_linux.go b/util/net/net_linux.go new file mode 100644 index 00000000000..82141750029 --- /dev/null +++ b/util/net/net_linux.go @@ -0,0 +1,35 @@ +//go:build !android + +package net + +import ( + "fmt" + "syscall" +) + +// SetSocketMark sets the SO_MARK option on the given socket connection +func SetSocketMark(conn syscall.Conn) error { + sysconn, err := conn.SyscallConn() + if err != nil { + return fmt.Errorf("get raw conn: %w", err) + } + + return SetRawSocketMark(sysconn) +} + +func SetRawSocketMark(conn syscall.RawConn) error { + var setErr error + + err := conn.Control(func(fd uintptr) { + setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + }) + if err != nil { + return fmt.Errorf("control: %w", err) + } + + if setErr != nil { + return fmt.Errorf("set SO_MARK: %w", setErr) + } + + return nil +} From af50eb350f56827461171a35079b9197f83437da Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 25 Mar 2024 14:25:26 +0100 Subject: [PATCH 15/89] Change log level for JWT override message of single account mode (#1747) --- management/server/account.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/account.go b/management/server/account.go index 8b326d93a60..8588cf343f2 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1588,7 +1588,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain claims.DomainCategory = PrivateCategory - log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + log.Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } newAcc, err := am.getAccountWithAuthorizationClaims(claims) From 68b377a28caf09e77613fa1b1a4d23bc6b3c7f03 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 26 Mar 2024 15:33:01 +0100 Subject: [PATCH 16/89] Collect chassis.serial (#1748) --- client/system/info_linux.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/system/info_linux.go b/client/system/info_linux.go index ca3be9d1c5b..652bc111518 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -120,5 +120,5 @@ func _getReleaseInfo() string { func sysInfo() (serialNumber string, productName string, manufacturer string) { var si sysinfo.SysInfo si.GetSysInfo() - return si.Product.Version, si.Product.Name, si.Product.Vendor + return si.Chassis.Serial, si.Product.Name, si.Product.Vendor } From ea2d060f93275695423bed1ab88cfd89e62d8c63 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:11:45 +0100 Subject: [PATCH 17/89] Add limited dashboard view (#1738) --- management/server/account.go | 13 +- management/server/group.go | 38 ++++- management/server/http/accounts_handler.go | 2 + .../server/http/accounts_handler_test.go | 9 +- management/server/http/api/openapi.yml | 17 +- management/server/http/api/types.gen.go | 50 +++++- management/server/http/groups_handler.go | 22 ++- management/server/http/groups_handler_test.go | 8 +- management/server/http/users_handler.go | 3 + management/server/http/users_handler_test.go | 2 +- management/server/mock_server/account_mock.go | 27 ++- management/server/peer.go | 9 + management/server/peer_test.go | 155 +++++++++++++++++- management/server/setupkey.go | 8 + management/server/setupkey_test.go | 31 ++++ management/server/user.go | 36 +++- management/server/user_test.go | 77 +++++++++ 17 files changed, 466 insertions(+), 41 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 8588cf343f2..d9030007db7 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -85,7 +85,8 @@ type AccountManager interface { GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID string) (*Group, error) + GetGroup(accountId, groupID, userID string) (*Group, error) + GetAllGroups(accountID, userID string) ([]*Group, error) GetGroupByName(groupName, accountID string) (*Group, error) SaveGroup(accountID, userID string, group *Group) error DeleteGroup(accountId, userId, groupID string) error @@ -162,6 +163,9 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements + RegularUsersViewBlocked bool + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer GroupsPropagationEnabled bool @@ -188,6 +192,7 @@ func (s *Settings) Copy() *Settings { JWTGroupsClaimName: s.JWTGroupsClaimName, GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, + RegularUsersViewBlocked: s.RegularUsersViewBlocked, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -226,6 +231,10 @@ type Account struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +type UserPermissions struct { + DashboardView string `json:"dashboard_view"` +} + type UserInfo struct { ID string `json:"id"` Email string `json:"email"` @@ -239,6 +248,7 @@ type UserInfo struct { LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` IntegrationReference IntegrationReference `json:"-"` + Permissions UserPermissions `json:"permissions"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes @@ -1885,6 +1895,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { PeerLoginExpirationEnabled: true, PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, }, } diff --git a/management/server/group.go b/management/server/group.go index 43d48e6227f..59f05a354ac 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -63,7 +63,7 @@ func (g *Group) Copy() *Group { } // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -72,6 +72,15 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er return nil, err } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + group, ok := account.Groups[groupID] if ok { return group, nil @@ -80,6 +89,33 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) } +// GetAllGroups returns all groups in an account +func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*Group, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + + groups := make([]*Group, 0, len(account.Groups)) + for _, item := range account.Groups { + groups = append(groups, item) + } + + return groups, nil +} + // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*Group, error) { unlock := am.Store.AcquireAccountLock(accountID) diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 71088cfaf3f..d3c9954d364 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -76,6 +76,7 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings := &server.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), + RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, } if req.Settings.Extra != nil { @@ -143,6 +144,7 @@ func toAccountResponse(account *server.Account) *api.Account { JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, + RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, } if account.Settings.Extra != nil { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index fd2c4bfcd33..9d174d0be90 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -69,6 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { Settings: &server.Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour, + RegularUsersViewBlocked: true, }, }, adminUser) @@ -96,6 +97,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, }, expectedArray: true, expectedID: accountID, @@ -114,6 +116,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, }, expectedArray: false, expectedID: accountID, @@ -123,7 +126,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -132,6 +135,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr("roles"), JwtGroupsEnabled: br(true), JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, }, expectedArray: false, expectedID: accountID, @@ -141,7 +145,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 554400, @@ -150,6 +154,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr("groups"), JwtGroupsEnabled: br(true), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 7ec2310afe6..2810893962b 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,10 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + regular_users_view_blocked: + description: Allows blocking regular users from viewing parts of the system. + type: boolean + example: true groups_propagation_enabled: description: Allows propagate the new user auto groups to peers that belongs to the user type: boolean @@ -77,6 +81,7 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - regular_users_view_blocked AccountExtraSettings: type: object properties: @@ -144,6 +149,8 @@ components: description: How user was issued by API or Integration type: string example: api + permissions: + $ref: '#/components/schemas/UserPermissions' required: - id - email @@ -152,6 +159,14 @@ components: - auto_groups - status - is_blocked + UserPermissions: + type: object + properties: + dashboard_view: + description: User's permission to view the dashboard + type: string + enum: [ "limited", "blocked", "full" ] + example: limited UserRequest: type: object properties: @@ -589,8 +604,6 @@ components: type: string enum: ["api", "integration", "jwt"] example: api - type: string - example: api required: - id - name diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index a4c492bb870..78cd83a2769 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -69,6 +69,20 @@ const ( GeoLocationCheckActionDeny GeoLocationCheckAction = "deny" ) +// Defines values for GroupIssued. +const ( + GroupIssuedApi GroupIssued = "api" + GroupIssuedIntegration GroupIssued = "integration" + GroupIssuedJwt GroupIssued = "jwt" +) + +// Defines values for GroupMinimumIssued. +const ( + GroupMinimumIssuedApi GroupMinimumIssued = "api" + GroupMinimumIssuedIntegration GroupMinimumIssued = "integration" + GroupMinimumIssuedJwt GroupMinimumIssued = "jwt" +) + // Defines values for NameserverNsType. const ( NameserverNsTypeUdp NameserverNsType = "udp" @@ -129,6 +143,13 @@ const ( UserStatusInvited UserStatus = "invited" ) +// Defines values for UserPermissionsDashboardView. +const ( + UserPermissionsDashboardViewBlocked UserPermissionsDashboardView = "blocked" + UserPermissionsDashboardViewFull UserPermissionsDashboardView = "full" + UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited" +) + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud @@ -186,6 +207,9 @@ type AccountSettings struct { // PeerLoginExpirationEnabled Enables or disables peer login expiration globally. After peer's login has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). PeerLoginExpirationEnabled bool `json:"peer_login_expiration_enabled"` + + // RegularUsersViewBlocked Allows blocking regular users from viewing parts of the system. + RegularUsersViewBlocked bool `json:"regular_users_view_blocked"` } // Checks List of objects that perform the actual checks @@ -283,8 +307,8 @@ type Group struct { // Id Group ID Id string `json:"id"` - // Issued How group was issued by API or from JWT token - Issued *string `json:"issued,omitempty"` + // Issued How the group was issued (api, integration, jwt) + Issued *GroupIssued `json:"issued,omitempty"` // Name Group Name identifier Name string `json:"name"` @@ -296,13 +320,16 @@ type Group struct { PeersCount int `json:"peers_count"` } +// GroupIssued How the group was issued (api, integration, jwt) +type GroupIssued string + // GroupMinimum defines model for GroupMinimum. type GroupMinimum struct { // Id Group ID Id string `json:"id"` - // Issued How group was issued by API or from JWT token - Issued *string `json:"issued,omitempty"` + // Issued How the group was issued (api, integration, jwt) + Issued *GroupMinimumIssued `json:"issued,omitempty"` // Name Group Name identifier Name string `json:"name"` @@ -311,6 +338,9 @@ type GroupMinimum struct { PeersCount int `json:"peers_count"` } +// GroupMinimumIssued How the group was issued (api, integration, jwt) +type GroupMinimumIssued string + // GroupRequest defines model for GroupRequest. type GroupRequest struct { // Name Group name identifier @@ -1072,7 +1102,8 @@ type User struct { LastLogin *time.Time `json:"last_login,omitempty"` // Name User's name from idp provider - Name string `json:"name"` + Name string `json:"name"` + Permissions *UserPermissions `json:"permissions,omitempty"` // Role User's NetBird account role Role string `json:"role"` @@ -1102,6 +1133,15 @@ type UserCreateRequest struct { Role string `json:"role"` } +// UserPermissions defines model for UserPermissions. +type UserPermissions struct { + // DashboardView User's permission to view the dashboard + DashboardView *UserPermissionsDashboardView `json:"dashboard_view,omitempty"` +} + +// UserPermissionsDashboardView User's permission to view the dashboard +type UserPermissionsDashboardView string + // UserRequest defines model for UserRequest. type UserRequest struct { // AutoGroups Group IDs to auto-assign to peers registered by this user diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index b37f4fd2f46..56d06595fe6 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -35,19 +35,25 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - var groups []*api.Group - for _, g := range account.Groups { - groups = append(groups, toGroupResponse(account, g)) + groups, err := h.accountManager.GetAllGroups(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return } - util.WriteJSONObject(w, groups) + groupsResponse := make([]*api.Group, 0, len(groups)) + for _, group := range groups { + groupsResponse = append(groupsResponse, toGroupResponse(account, group)) + } + + util.WriteJSONObject(w, groupsResponse) } // UpdateGroup handles update to a group identified by a given ID @@ -207,7 +213,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) return @@ -221,7 +227,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := h.accountManager.GetGroup(account.Id, groupID) + group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id) if err != nil { util.WriteError(err, w) return @@ -239,7 +245,7 @@ func toGroupResponse(account *server.Account, group *server.Group) *api.Group { gr := api.Group{ Id: group.ID, Name: group.Name, - Issued: &group.Issued, + Issued: (*api.GroupIssued)(&group.Issued), } for _, pid := range group.Peers { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 5b47b120861..303efc9d704 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -37,7 +37,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle } return nil }, - GetGroupFunc: func(_, groupID string) (*server.Group, error) { + GetGroupFunc: func(_, groupID, _ string) (*server.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } @@ -187,7 +187,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-was-set", Name: "Default POSTed Group", - Issued: &groupIssuedAPI, + Issued: (*api.GroupIssued)(&groupIssuedAPI), }, }, { @@ -209,7 +209,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-existed", Name: "Default POSTed Group", - Issued: &groupIssuedAPI, + Issued: (*api.GroupIssued)(&groupIssuedAPI), }, }, { @@ -240,7 +240,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-jwt-group", Name: "changed", - Issued: &groupIssuedJWT, + Issued: (*api.GroupIssued)(&groupIssuedJWT), }, }, } diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 5d92b65e5d8..ed8a3f5438c 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -288,5 +288,8 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { IsBlocked: user.IsBlocked, LastLogin: &user.LastLogin, Issued: &user.Issued, + Permissions: &api.UserPermissions{ + DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView), + }, } } diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index ff886ca9fda..91f19d8d89e 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -105,7 +105,7 @@ func initUsersTestData() *UsersHandler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil) + info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false}) if err != nil { return nil, err } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index f518372ed95..9463498cf7b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -31,7 +31,8 @@ type MockAccountManager struct { GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID string) (*server.Group, error) + GetGroupFunc func(accountID, groupID, userID string) (*server.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*server.Group, error) GetGroupByNameFunc func(accountID, groupName string) (*server.Group, error) SaveGroupFunc func(accountID, userID string, group *server.Group) error DeleteGroupFunc func(accountID, userId, groupID string) error @@ -92,6 +93,22 @@ type MockAccountManager struct { GetIdpManagerFunc func() idp.Manager } +// GetGroup mock implementation of GetGroup from server.AccountManager interface +func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*server.Group, error) { + if am.GetGroupFunc != nil { + return am.GetGroupFunc(accountId, groupID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") +} + +// GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface +func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*server.Group, error) { + if am.GetAllGroupsFunc != nil { + return am.GetAllGroupsFunc(accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllGroups is not implemented") +} + // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { @@ -243,14 +260,6 @@ func (am *MockAccountManager) AddPeer( return nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } -// GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountID, groupID string) (*server.Group, error) { - if am.GetGroupFunc != nil { - return am.GetGroupFunc(accountID, groupID) - } - return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") -} - // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*server.Group, error) { if am.GetGroupFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index 53b86e9b35a..7de1b654254 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -54,6 +54,11 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return peers, nil + } + for _, peer := range account.Peers { if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin @@ -738,6 +743,10 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) + } + peer := account.GetPeer(peerID) if peer == nil { return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index ee84ea47dab..7f6d440bb10 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -4,9 +4,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/rs/xid" + "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -392,6 +391,8 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { Id: someUser, Role: UserRoleUser, } + account.Settings.RegularUsersViewBlocked = false + err = manager.Store.SaveAccount(account) if err != nil { t.Fatal(err) @@ -480,3 +481,153 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } assert.NotNil(t, peer) } + +func TestDefaultAccountManager_GetPeers(t *testing.T) { + testCases := []struct { + name string + role UserRole + limitedViewSettings bool + isServiceUser bool + expectedPeerCount int + }{ + { + name: "Regular user, no limited view settings, not a service user", + role: UserRoleUser, + limitedViewSettings: false, + isServiceUser: false, + expectedPeerCount: 1, + }, + { + name: "Service user, no limited view settings", + role: UserRoleUser, + limitedViewSettings: false, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Regular user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 0, + }, + { + name: "Service user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Admin, no limited view settings, not a service user", + role: UserRoleAdmin, + limitedViewSettings: false, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Admin service user, no limited view settings", + role: UserRoleAdmin, + limitedViewSettings: false, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Admin, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Admin Service user, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Owner, no limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Owner, limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + // account with an admin and a regular user + accountID := "test_account" + adminUser := "account_creator" + someUser := "some_user" + account := newAccountWithId(accountID, adminUser, "") + account.Users[someUser] = &User{ + Id: someUser, + Role: testCase.role, + IsServiceUser: testCase.isServiceUser, + } + account.Policies = []*Policy{} + account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + + err = manager.Store.SaveAccount(account) + if err != nil { + t.Fatal(err) + return + } + + peerKey1, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + peerKey2, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{ + Key: peerKey1.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, + }) + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{ + Key: peerKey2.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, + }) + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + peers, err := manager.GetPeers(accountID, someUser) + if err != nil { + t.Fatal(err) + return + } + assert.NotNil(t, peers) + + assert.Len(t, peers, testCase.expectedPeerCount) + + }) + } + +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 97266552782..ff6fb320409 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -339,6 +339,10 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + } + keys := make([]*SetupKey, 0, len(account.SetupKeys)) for _, key := range account.SetupKeys { var k *SetupKey @@ -368,6 +372,10 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (* return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + } + var foundKey *SetupKey for _, key := range account.SetupKeys { if key.Id == keyID { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index c22df2094de..b714652f137 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -166,6 +166,37 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } +func TestGetSetupKeys(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(userID, "") + if err != nil { + t.Fatal(err) + } + + err = manager.SaveGroup(account.Id, userID, &Group{ + ID: "group_1", + Name: "group_name_1", + Peers: []string{}, + }) + if err != nil { + t.Fatal(err) + } + + err = manager.SaveGroup(account.Id, userID, &Group{ + ID: "group_2", + Name: "group_name_2", + Peers: []string{}, + }) + if err != nil { + t.Fatal(err) + } +} + func TestGenerateDefaultSetupKey(t *testing.T) { expectedName := "Default key" expectedRevoke := false diff --git a/management/server/user.go b/management/server/user.go index f1516139b63..15517db41b4 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -113,12 +113,20 @@ func (u *User) HasAdminPower() bool { } // ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { +func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups if autoGroups == nil { autoGroups = []string{} } + dashboardViewPermissions := "full" + if !u.HasAdminPower() { + dashboardViewPermissions = "limited" + if settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + if userData == nil { return &UserInfo{ ID: u.Id, @@ -131,6 +139,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsBlocked: u.Blocked, LastLogin: u.LastLogin, Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, }, nil } if userData.ID != u.Id { @@ -153,6 +164,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsBlocked: u.Blocked, LastLogin: u.LastLogin, Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, }, nil } @@ -358,7 +372,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite am.StoreEvent(userID, newUser.Id, accountID, activity.UserInvited, nil) - return newUser.ToUserInfo(idpUser) + return newUser.ToUserInfo(idpUser, account.Settings) } // GetUser looks up a user by provided authorization claims. @@ -905,9 +919,9 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string if err != nil { return nil, err } - return newUser.ToUserInfo(userData) + return newUser.ToUserInfo(userData, account.Settings) } - return newUser.ToUserInfo(nil) + return newUser.ToUserInfo(nil, account.Settings) } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist @@ -998,7 +1012,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( // if user is not an admin then show only current user and do not show other users continue } - info, err := accountUser.ToUserInfo(nil) + info, err := accountUser.ToUserInfo(nil, account.Settings) if err != nil { return nil, err } @@ -1015,7 +1029,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( var info *UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.ToUserInfo(queriedUser) + info, err = localUser.ToUserInfo(queriedUser, account.Settings) if err != nil { return nil, err } @@ -1024,6 +1038,15 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( if localUser.IsServiceUser { name = localUser.ServiceUserName } + + dashboardViewPermissions := "full" + if !localUser.HasAdminPower() { + dashboardViewPermissions = "limited" + if account.Settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + info = &UserInfo{ ID: localUser.Id, Email: "", @@ -1033,6 +1056,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( Status: string(UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, + Permissions: UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfos = append(userInfos, info) diff --git a/management/server/user_test.go b/management/server/user_test.go index 50cd726ef20..e34aa406d2e 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -709,6 +709,83 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { assert.Equal(t, 2, regular) } +func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { + testCases := []struct { + name string + role UserRole + limitedViewSettings bool + expectedDashboardPermissions string + }{ + { + name: "Regular user, no limited view settings", + role: UserRoleUser, + limitedViewSettings: false, + expectedDashboardPermissions: "limited", + }, + { + name: "Admin user, no limited view settings", + role: UserRoleAdmin, + limitedViewSettings: false, + expectedDashboardPermissions: "full", + }, + { + name: "Owner, no limited view settings", + role: UserRoleOwner, + limitedViewSettings: false, + expectedDashboardPermissions: "full", + }, + { + name: "Regular user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + expectedDashboardPermissions: "blocked", + }, + { + name: "Admin user, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + expectedDashboardPermissions: "full", + }, + { + name: "Owner, limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + expectedDashboardPermissions: "full", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + delete(account.Users, mockUserID) + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + users, err := am.ListUsers(mockAccountID) + if err != nil { + t.Fatalf("Error when checking user role: %s", err) + } + + assert.Equal(t, 1, len(users)) + + userInfo, _ := users[0].ToUserInfo(nil, account.Settings) + assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) + }) + } + +} + func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") From 2d76b058fcee1bebc54745a91e2b447da2a2c9db Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 27 Mar 2024 18:48:48 +0100 Subject: [PATCH 18/89] Feature/peer validator (#1553) Follow up management-integrations changes move groups to separated packages to avoid circle dependencies save location information in Login action --- client/cmd/testutil.go | 4 +- client/internal/engine_test.go | 4 +- client/server/server_test.go | 4 +- go.mod | 4 +- go.sum | 7 +- management/client/client_test.go | 11 +- management/cmd/management.go | 8 +- management/server/account.go | 88 ++++++++---- management/server/account/account.go | 8 +- management/server/account_test.go | 99 ++++++++++---- management/server/dns_test.go | 7 +- management/server/ephemeral.go | 2 +- management/server/file_store.go | 3 +- management/server/file_store_test.go | 8 +- management/server/group.go | 77 +++-------- management/server/group/group.go | 46 +++++++ management/server/group_test.go | 35 ++--- management/server/grpcserver.go | 1 + management/server/http/api/openapi.yml | 1 + management/server/http/api/types.gen.go | 6 +- management/server/http/groups_handler.go | 20 +-- management/server/http/groups_handler_test.go | 27 ++-- management/server/http/handler.go | 1 - management/server/http/peers_handler.go | 125 ++++++++++++------ .../server/http/policies_handler_test.go | 3 +- .../server/http/setupkeys_handler_test.go | 9 +- management/server/http/util/util.go | 2 + management/server/integrated_validator.go | 80 +++++++++++ .../server/integrated_validator/interface.go | 19 +++ .../integration_reference.go | 23 ++++ management/server/management_proto_test.go | 5 +- management/server/management_test.go | 57 ++++++-- management/server/metrics/selfhosted_test.go | 5 +- management/server/mock_server/account_mock.go | 47 +++++-- management/server/nameserver.go | 3 +- management/server/nameserver_test.go | 7 +- management/server/peer.go | 119 ++++++++++++++--- management/server/peer_test.go | 5 +- management/server/policy.go | 19 +-- management/server/policy_test.go | 51 ++++--- management/server/route_test.go | 9 +- management/server/setupkey_test.go | 11 +- management/server/sqlite_store.go | 15 ++- management/server/user.go | 20 +-- management/server/user_test.go | 10 +- 45 files changed, 777 insertions(+), 338 deletions(-) create mode 100644 management/server/group/group.go create mode 100644 management/server/integrated_validator.go create mode 100644 management/server/integrated_validator/interface.go create mode 100644 management/server/integration_reference/integration_reference.go diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 2cfc934159e..2f92e1c03dc 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc" + "github.com/netbirdio/management-integrations/integrations" clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" mgmtProto "github.com/netbirdio/netbird/management/proto" @@ -78,7 +79,8 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { return nil, nil } - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + iv, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 952b3c90cfb..309b2e7c6f9 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -1050,7 +1051,8 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 7f8310c903b..4e4a091453f 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "github.com/netbirdio/management-integrations/integrations" "net" "testing" "time" @@ -114,7 +115,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 67ec9c42ee0..5566f85599b 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.9 github.com/google/gopacket v1.1.19 + github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 @@ -59,8 +60,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index c36b8aff31d..6da405341d5 100644 --- a/go.sum +++ b/go.sum @@ -255,6 +255,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= @@ -382,10 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= diff --git a/management/client/client_test.go b/management/client/client_test.go index f30ae0cfd66..30f91c73b8a 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -3,6 +3,7 @@ package client import ( "context" "net" + "os" "path/filepath" "sync" "testing" @@ -15,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" mgmt "github.com/netbirdio/netbird/management/server" @@ -30,6 +32,12 @@ import ( const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" +func TestMain(m *testing.M) { + _ = util.InitLog("debug", "console") + code := m.Run() + os.Exit(code) +} + func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Helper() level, _ := log.ParseLevel("debug") @@ -60,7 +68,8 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index e8bcdc97d0a..23d9c195cd3 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -172,8 +173,12 @@ var ( log.Infof("geo location service has been initialized from %s", config.Datadir) } + integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore) + if err != nil { + return fmt.Errorf("failed to initialize integrated peer validator: %v", err) + } accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore, geo, userDeleteFromIDPEnabled) + dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } @@ -323,6 +328,7 @@ var ( SetupCloseHandler() <-stopCh + integratedPeerValidator.Stop() if geo != nil { _ = geo.Stop() } diff --git a/management/server/account.go b/management/server/account.go index d9030007db7..c145c1bd789 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -21,14 +21,15 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/management-integrations/additions" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integrated_validator" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -85,12 +86,12 @@ type AccountManager interface { GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID, userID string) (*Group, error) - GetAllGroups(accountID, userID string) ([]*Group, error) - GetGroupByName(groupName, accountID string) (*Group, error) - SaveGroup(accountID, userID string, group *Group) error + GetGroup(accountId, groupID, userID string) (*nbgroup.Group, error) + GetAllGroups(accountID, userID string) ([]*nbgroup.Group, error) + GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) + SaveGroup(accountID, userID string, group *nbgroup.Group) error DeleteGroup(accountId, userId, groupID string) error - ListGroups(accountId string) ([]*Group, error) + ListGroups(accountId string) ([]*nbgroup.Group, error) GroupAddPeer(accountId, groupID, peerID string) error GroupDeletePeer(accountId, groupID, peerID string) error GetPolicy(accountID, policyID, userID string) (*Policy, error) @@ -124,6 +125,9 @@ type AccountManager interface { DeletePostureChecks(accountID, postureChecksID, userID string) error ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager + UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error + GroupValidation(accountId string, groups []string) (bool, error) + GetValidatedPeers(account *Account) (map[string]struct{}, error) } type DefaultAccountManager struct { @@ -152,6 +156,8 @@ type DefaultAccountManager struct { // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool + + integratedPeerValidator integrated_validator.IntegratedValidator } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -218,8 +224,8 @@ type Account struct { PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*nbgroup.Group `gorm:"-"` + GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[string]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` @@ -247,7 +253,7 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` - IntegrationReference IntegrationReference `json:"-"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` Permissions UserPermissions `json:"permissions"` } @@ -372,25 +378,26 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { } // GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *Group { +func (a *Account) GetGroup(groupID string) *nbgroup.Group { return a.Groups[groupID] } // GetPeerNetworkMap returns a group by ID if exists, nil otherwise -func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { +func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ Network: a.Network.Copy(), } } - validatedPeers := additions.ValidatePeers([]*nbpeer.Peer{peer}) - if len(validatedPeers) == 0 { + + if _, ok := validatedPeersMap[peerID]; !ok { return &NetworkMap{ Network: a.Network.Copy(), } } - aclPeers, firewallRules := a.getPeerConnectionResources(peerID) + + aclPeers, firewallRules := a.getPeerConnectionResources(peerID, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -564,7 +571,7 @@ func (a *Account) FindUser(userID string) (*User, error) { } // FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*Group, error) { +func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { for _, group := range a.Groups { if group.Name == groupName { return group, nil @@ -583,6 +590,20 @@ func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { return key, nil } +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + func (a *Account) getUserGroups(userID string) ([]string, error) { user, err := a.FindUser(userID) if err != nil { @@ -660,7 +681,7 @@ func (a *Account) Copy() *Account { setupKeys[id] = key.Copy() } - groups := map[string]*Group{} + groups := map[string]*nbgroup.Group{} for id, group := range a.Groups { groups[id] = group.Copy() } @@ -713,7 +734,7 @@ func (a *Account) Copy() *Account { } } -func (a *Account) GetGroupAll() (*Group, error) { +func (a *Account) GetGroupAll() (*nbgroup.Group, error) { for _, g := range a.Groups { if g.Name == "All" { return g, nil @@ -734,7 +755,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { return false } - existedGroupsByName := make(map[string]*Group) + existedGroupsByName := make(map[string]*nbgroup.Group) for _, group := range a.Groups { existedGroupsByName[group.Name] = group } @@ -743,7 +764,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { removed := 0 jwtAutoGroups := make(map[string]struct{}) for i, id := range user.AutoGroups { - if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT { + if group, ok := a.Groups[id]; ok && group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = struct{}{} user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) removed++ @@ -756,15 +777,15 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { for _, name := range groupsNames { group, ok := existedGroupsByName[name] if !ok { - group = &Group{ + group = &nbgroup.Group{ ID: xid.New().String(), Name: name, - Issued: GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, } a.Groups[group.ID] = group } // only JWT groups will be synced - if group.Issued == GroupIssuedJWT { + if group.Issued == nbgroup.GroupIssuedJWT { user.AutoGroups = append(user.AutoGroups, group.ID) if _, ok := jwtAutoGroups[name]; !ok { modified = true @@ -837,6 +858,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, + integratedPeerValidator integrated_validator.IntegratedValidator, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, @@ -850,6 +872,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, + integratedPeerValidator: integratedPeerValidator, } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -906,6 +929,8 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage }() } + am.integratedPeerValidator.SetPeerInvalidationListener(am.onPeersInvalidated) + return am, nil } @@ -948,7 +973,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } - err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore) + err = am.integratedPeerValidator.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) if err != nil { return nil, err } @@ -1823,18 +1848,27 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut return nil } +func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { + updatedAccount, err := am.Store.GetAccount(accountID) + if err != nil { + log.Errorf("failed to get account %s: %v", accountID, err) + return + } + am.updateAccountPeers(updatedAccount) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { - allGroup := &Group{ + allGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "All", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*Group{allGroup.ID: allGroup} + account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} id := xid.New().String() diff --git a/management/server/account/account.go b/management/server/account/account.go index b8b71a6de9e..40f032fbed4 100644 --- a/management/server/account/account.go +++ b/management/server/account/account.go @@ -3,11 +3,17 @@ package account type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + + // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations + IntegratedValidatorGroups []string `gorm:"serializer:json"` } // Copy copies the ExtraSettings struct func (e *ExtraSettings) Copy() *ExtraSettings { + var cpGroup []string + return &ExtraSettings{ - PeerApprovalEnabled: e.PeerApprovalEnabled, + PeerApprovalEnabled: e.PeerApprovalEnabled, + IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...), } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 2b0c4419671..a0eff239b54 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -12,19 +12,56 @@ import ( "time" "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" +) - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +type MocIntegratedValidator struct { +} - "github.com/netbirdio/netbird/management/server/jwtclaims" -) +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() { +} func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { t.Helper() @@ -367,7 +404,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { account.Groups[all.ID].Peers = append(account.Groups[all.ID].Peers, peer.ID) } - networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io") + validatedPeers := map[string]struct{}{} + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -667,7 +709,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*Group{} + groupsByNames := map[string]*group.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -675,12 +717,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") }) } @@ -800,7 +842,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) } @@ -815,7 +857,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to get an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("updating domain. expected %s got %s", domain, account.Domain) } } @@ -835,13 +877,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { } if account == nil { t.Fatalf("expected to create an account for a user %s", userId) + return } - accountId := account.Id - - _, err = manager.GetAccountByUserOrAccountID("", accountId, "") + _, err = manager.GetAccountByUserOrAccountID("", account.Id, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) } _, err = manager.GetAccountByUserOrAccountID("", "", "") @@ -1124,7 +1165,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) defer manager.peersUpdateManager.CloseChannel(peer1.ID) - group := Group{ + group := group.Group{ ID: "group-id", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1417,7 +1458,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[string]*route.Route{ "route-1": { ID: "route-1", @@ -1518,7 +1559,7 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*group.Group{ "group1": { ID: "group1", Peers: []string{"peer1"}, @@ -2112,8 +2153,8 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Settings: &Settings{GroupsPropagationEnabled: true}, Users: map[string]*User{ @@ -2160,10 +2201,10 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2196,10 +2237,10 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2223,7 +2264,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index aac35308c93..18f942e68d6 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -193,7 +194,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) } func createDNSStore(t *testing.T) (Store, error) { @@ -278,13 +279,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &Group{ + newGroup1 := &group.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &Group{ + newGroup2 := &group.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 9d70a05d148..4fffa024d02 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -165,7 +165,7 @@ func (e *EphemeralManager) cleanup() { log.Debugf("delete ephemeral peer: %s", id) err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) if err != nil { - log.Tracef("failed to delete ephemeral peer: %s", err) + log.Errorf("failed to delete ephemeral peer: %s", err) } } } diff --git a/management/server/file_store.go b/management/server/file_store.go index 0228285cbe9..2de852beed7 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,6 +10,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" @@ -170,7 +171,7 @@ func restore(file string) (*FileStore, error) { // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI } } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index d8575a3bfed..d53298d8fa3 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -188,7 +189,7 @@ func TestStore(t *testing.T) { Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - account.Groups["all"] = &Group{ + account.Groups["all"] = &group.Group{ ID: "all", Name: "all", Peers: []string{"testpeer"}, @@ -320,7 +321,7 @@ func TestRestoreGroups_Migration(t *testing.T) { // create default group account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*Group{ + account.Groups = map[string]*group.Group{ "cfefqs706sqkneg59g3g": { ID: "cfefqs706sqkneg59g3g", Name: "All", @@ -336,7 +337,7 @@ func TestRestoreGroups_Migration(t *testing.T) { account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") + require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") } func TestGetAccountByPrivateDomain(t *testing.T) { @@ -384,6 +385,7 @@ func TestFileStore_GetAccount(t *testing.T) { expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] if expected == nil { t.Fatalf("expected account doesn't exist") + return } account, err := store.GetAccount(expected.Id) diff --git a/management/server/group.go b/management/server/group.go index 59f05a354ac..0fc952cdbc9 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -19,51 +20,8 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -const ( - GroupIssuedAPI = "api" - GroupIssuedJWT = "jwt" - GroupIssuedIntegration = "integration" -) - -// Group of the peers for ACL -type Group struct { - // ID of the group - ID string - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name visible in the UI - Name string - - // Issued defines how this group was created (enum of "api", "integration" or "jwt") - Issued string - - // Peers list of the group - Peers []string `gorm:"serializer:json"` - - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// EventMeta returns activity event meta related to the group -func (g *Group) EventMeta() map[string]any { - return map[string]any{"name": g.Name} -} - -func (g *Group) Copy() *Group { - group := &Group{ - ID: g.ID, - Name: g.Name, - Issued: g.Issued, - Peers: make([]string, len(g.Peers)), - IntegrationReference: g.IntegrationReference, - } - copy(group.Peers, g.Peers) - return group -} - // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -90,7 +48,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*G } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*Group, error) { +func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -108,7 +66,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") } - groups := make([]*Group, 0, len(account.Groups)) + groups := make([]*nbgroup.Group, 0, len(account.Groups)) for _, item := range account.Groups { groups = append(groups, item) } @@ -117,7 +75,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -126,7 +84,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G return nil, err } - matchingGroups := make([]*Group, 0) + matchingGroups := make([]*nbgroup.Group, 0) for _, group := range account.Groups { if group.Name == groupName { matchingGroups = append(matchingGroups, group) @@ -138,7 +96,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } maxPeers := -1 - var groupWithMostPeers *Group + var groupWithMostPeers *nbgroup.Group for i, group := range matchingGroups { if len(group.Peers) > maxPeers { maxPeers = len(group.Peers) @@ -150,7 +108,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error { +func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -159,11 +117,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G return err } - if newGroup.ID == "" && newGroup.Issued != GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { existingGroup, err := account.FindGroupByName(newGroup.Name) if err != nil { @@ -270,7 +228,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // disable a deleting integration group if the initiator is not an admin service user - if g.Issued == GroupIssuedIntegration { + if g.Issued == nbgroup.GroupIssuedIntegration { executingUser := account.Users[userId] if executingUser == nil { return status.Errorf(status.NotFound, "user not found") @@ -340,6 +298,15 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } } + // check integrated peer validator groups + if account.Settings.Extra != nil { + for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups { + if groupID == integratedPeerValidatorGroups { + return &GroupLinkError{"integrated validator", g.Name} + } + } + } + delete(account.Groups, groupID) account.Network.IncSerial() @@ -355,7 +322,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { +func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -364,7 +331,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) return nil, err } - groups := make([]*Group, 0, len(account.Groups)) + groups := make([]*nbgroup.Group, 0, len(account.Groups)) for _, item := range account.Groups { groups = append(groups, item) } diff --git a/management/server/group/group.go b/management/server/group/group.go new file mode 100644 index 00000000000..79dfd995ce0 --- /dev/null +++ b/management/server/group/group.go @@ -0,0 +1,46 @@ +package group + +import "github.com/netbirdio/netbird/management/server/integration_reference" + +const ( + GroupIssuedAPI = "api" + GroupIssuedJWT = "jwt" + GroupIssuedIntegration = "integration" +) + +// Group of the peers for ACL +type Group struct { + // ID of the group + ID string + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name visible in the UI + Name string + + // Issued defines how this group was created (enum of "api", "integration" or "jwt") + Issued string + + // Peers list of the group + Peers []string `gorm:"serializer:json"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// EventMeta returns activity event meta related to the group +func (g *Group) EventMeta() map[string]any { + return map[string]any{"name": g.Name} +} + +func (g *Group) Copy() *Group { + group := &Group{ + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: make([]string, len(g.Peers)), + IntegrationReference: g.IntegrationReference, + } + copy(group.Peers, g.Peers) + return group +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 3a2195c889d..35e9b2170eb 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -5,6 +5,7 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" ) @@ -24,22 +25,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = GroupIssuedIntegration + group.Issued = nbgroup.GroupIssuedIntegration err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = GroupIssuedJWT + group.Issued = nbgroup.GroupIssuedJWT err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedJWT) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI group.ID = "" err = am.SaveGroup(account.Id, groupAdminUserID, group) if err == nil { @@ -129,51 +130,51 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &Group{ + groupForRoute := &nbgroup.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &Group{ + groupForNameServerGroups := &nbgroup.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &Group{ + groupForPolicies := &nbgroup.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &Group{ + groupForSetupKeys := &nbgroup.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &Group{ + groupForUsers := &nbgroup.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &Group{ + groupForIntegration := &nbgroup.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: GroupIssuedIntegration, + Issued: nbgroup.GroupIssuedIntegration, Peers: make([]string, 0), } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 341d202b6e8..340adcfc61e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -361,6 +361,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p Meta: extractPeerMeta(loginReq), UserID: userID, SetupKey: loginReq.GetSetupKey(), + ConnectionIP: realIP, }) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2810893962b..f8f581bd41c 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -355,6 +355,7 @@ components: - user_id - version - ui_version + - approval_required AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 78cd83a2769..0bed93b3c8e 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -470,7 +470,7 @@ type Peer struct { AccessiblePeers []AccessiblePeer `json:"accessible_peers"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -539,7 +539,7 @@ type Peer struct { // PeerBase defines model for PeerBase. type PeerBase struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -611,7 +611,7 @@ type PeerBatch struct { AccessiblePeersCount int `json:"accessible_peers_count"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 56d06595fe6..47bcf2f320f 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -4,15 +4,15 @@ import ( "encoding/json" "net/http" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" + "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" - - "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/status" ) // GroupsHandler is a handler that returns groups of the account @@ -110,7 +110,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ ID: groupID, Name: req.Name, Peers: peers, @@ -154,10 +154,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ Name: req.Name, Peers: peers, - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } err = h.accountManager.SaveGroup(account.Id, user.Id, &group) @@ -240,7 +240,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { } } -func toGroupResponse(account *server.Account, group *server.Group) *api.Group { +func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 303efc9d704..3d74b848c7d 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -15,6 +15,7 @@ import ( "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -28,30 +29,30 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandler { +func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(accountID, userID string, group *server.Group) error { + SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_, groupID, _ string) (*server.Group, error) { + GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } if groupID == "id-jwt-group" { - return &server.Group{ + return &nbgroup.Group{ ID: "id-jwt-group", Name: "Default Group", - Issued: server.GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, }, nil } - return &server.Group{ + return &nbgroup.Group{ ID: "idofthegroup", Name: "Group", - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, }, nil }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { @@ -62,10 +63,10 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle Users: map[string]*server.User{ user.Id: user, }, - Groups: map[string]*server.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI}, + Groups: map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, }, }, user, nil }, @@ -118,7 +119,7 @@ func TestGetGroup(t *testing.T) { }, } - group := &server.Group{ + group := &nbgroup.Group{ ID: "idofthegroup", Name: "Group", } @@ -153,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &server.Group{} + got := &nbgroup.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index d035ae0b750..bdbeba3464f 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,7 +9,6 @@ import ( "github.com/rs/cors" "github.com/netbirdio/management-integrations/integrations" - s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index d4d2558e88a..77b4578f8d6 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -6,8 +6,10 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -61,10 +63,18 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w groupsInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { @@ -75,11 +85,18 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe return } - update := &nbpeer.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name, - LoginExpirationEnabled: req.LoginExpirationEnabled} + update := &nbpeer.Peer{ + ID: peerID, + SSHEnabled: req.SshEnabled, + Name: req.Name, + LoginExpirationEnabled: req.LoginExpirationEnabled, + } if req.ApprovalRequired != nil { - update.Status = &nbpeer.PeerStatus{RequiresApproval: *req.ApprovalRequired} + // todo: looks like that we reset all status property, is it right? + update.Status = &nbpeer.PeerStatus{ + RequiresApproval: *req.ApprovalRequired, + } } peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) @@ -91,15 +108,24 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + + util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(accountID, peerID, userID) if err != nil { + log.Errorf("failed to delete peer: %v", err) util.WriteError(err, w) return } @@ -138,46 +164,68 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } + if r.Method != http.MethodGet { + util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + return + } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + peers, err := h.accountManager.GetPeers(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return + } + + dnsDomain := h.accountManager.GetDNSDomain() - peers, err := h.accountManager.GetPeers(account.Id, user.Id) + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { + peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(err, w) return } + groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - dnsDomain := h.accountManager.GetDNSDomain() + accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID) - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(err, w) - return - } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - - accessiblePeerNumbers := h.accessiblePeersNumber(account, peer.ID) + respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) + } - respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) - } - util.WriteJSONObject(w, respBody) + validPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) return - default: - util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) } + h.setApprovalRequiredFlag(respBody, validPeersMap) + + util.WriteJSONObject(w, respBody) +} + +func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) { + validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + return 0, err + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) + return len(netMap.Peers) + len(netMap.OfflinePeers), nil } -func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) int { - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) - return len(netMap.Peers) + len(netMap.OfflinePeers) +func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { + for _, peer := range respBody { + _, ok := approvedPeersMap[peer.Id] + if !ok { + peer.ApprovalRequired = true + } + } } func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { @@ -206,7 +254,7 @@ func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.Access return accessiblePeers } -func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMinimum { +func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { var groupsInfo []api.GroupMinimum groupsChecked := make(map[string]struct{}) for _, group := range groups { @@ -230,7 +278,7 @@ func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMin return groupsInfo } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core @@ -257,7 +305,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeers: accessiblePeer, - ApprovalRequired: &peer.Status.RequiresApproval, + ApprovalRequired: !approved, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } @@ -290,7 +338,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeersCount: accessiblePeersCount, - ApprovalRequired: &peer.Status.RequiresApproval, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index e6b858036b7..74e682854f5 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" @@ -51,7 +52,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Policies: []*server.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 7b68479eddd..ebbd5954fdd 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -13,13 +13,12 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/jwtclaims" - - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/status" ) const ( @@ -44,7 +43,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup SetupKeys: map[string]*server.SetupKey{ defaultKey.Key: defaultKey, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "id-all": {ID: "id-all", Name: "All"}, }, diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 4e2c3d0b328..2bb279c7671 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -99,6 +99,8 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusUnprocessableEntity case status.Unauthorized: httpStatus = http.StatusUnauthorized + case status.BadRequest: + httpStatus = http.StatusBadRequest default: } msg = strings.ToLower(err.Error()) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go new file mode 100644 index 00000000000..cd770a80146 --- /dev/null +++ b/management/server/integrated_validator.go @@ -0,0 +1,80 @@ +package server + +import ( + "errors" + + "github.com/google/martian/v3/log" + + "github.com/netbirdio/netbird/management/server/account" +) + +// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. +// It retrieves the account associated with the provided userID, then updates the integrated validator groups +// with the provided list of group ids. The updated account is then saved. +// +// Parameters: +// - accountID: The ID of the account for which integrated validator groups are to be updated. +// - userID: The ID of the user whose account is being updated. +// - groups: A slice of strings representing the ids of integrated validator groups to be updated. +// +// Returns: +// - error: An error if any occurred during the process, otherwise returns nil +func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + ok, err := am.GroupValidation(accountID, groups) + if err != nil { + log.Debugf("error validating groups: %s", err.Error()) + return err + } + + if !ok { + log.Debugf("invalid groups") + return errors.New("invalid groups") + } + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + a, err := am.Store.GetAccountByUser(userID) + if err != nil { + return err + } + + var extra *account.ExtraSettings + + if a.Settings.Extra != nil { + extra = a.Settings.Extra + } else { + extra = &account.ExtraSettings{} + a.Settings.Extra = extra + } + extra.IntegratedValidatorGroups = groups + return am.Store.SaveAccount(a) +} + +func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if len(groups) == 0 { + return true, nil + } + accountsGroups, err := am.ListGroups(accountId) + if err != nil { + return false, err + } + for _, group := range groups { + var found bool + for _, accountGroup := range accountsGroups { + if accountGroup.ID == group { + found = true + break + } + } + if !found { + return false, nil + } + } + + return true, nil +} + +func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { + return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +} diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go new file mode 100644 index 00000000000..e87755b874c --- /dev/null +++ b/management/server/integrated_validator/interface.go @@ -0,0 +1,19 @@ +package integrated_validator + +import ( + "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +// IntegratedValidator interface exists to avoid the circle dependencies +type IntegratedValidator interface { + ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) + PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer + IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) + GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + PeerDeleted(accountID, peerID string) error + SetPeerInvalidationListener(fn func(accountID string)) + Stop() +} diff --git a/management/server/integration_reference/integration_reference.go b/management/server/integration_reference/integration_reference.go new file mode 100644 index 00000000000..254b4e62f44 --- /dev/null +++ b/management/server/integration_reference/integration_reference.go @@ -0,0 +1,23 @@ +package integration_reference + +import ( + "fmt" + "strings" +) + +// IntegrationReference holds the reference to a particular integration +type IntegrationReference struct { + ID int + IntegrationType string +} + +func (ir IntegrationReference) String() string { + return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) +} + +func (ir IntegrationReference) CacheKey(path ...string) string { + if len(path) == 0 { + return ir.String() + } + return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 6ea902003de..98ad0de0ca0 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -19,6 +17,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/util" ) @@ -413,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index fb3f74cb9fa..13db5ae9589 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -10,24 +10,22 @@ import ( sync2 "sync" "time" - "github.com/netbirdio/netbird/management/server/activity" - - "google.golang.org/grpc/credentials/insecure" - - "github.com/netbirdio/netbird/management/server" - pb "github.com/golang/protobuf/proto" //nolint - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/encryption" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -448,6 +446,43 @@ var _ = Describe("Management service", func() { }) }) +type MocIntegratedValidator struct { +} + +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} + +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for p := range peers { + validatedPeers[p] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() {} + func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -504,7 +539,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 7be3c818d6a..c479867d294 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,6 +5,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" @@ -32,7 +33,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, @@ -117,7 +118,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 9463498cf7b..8e7c47a280a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -31,12 +32,12 @@ type MockAccountManager struct { GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID, userID string) (*server.Group, error) - GetAllGroupsFunc func(accountID, userID string) ([]*server.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*server.Group, error) - SaveGroupFunc func(accountID, userID string, group *server.Group) error + GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(accountID, userID string, group *group.Group) error DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*server.Group, error) + ListGroupsFunc func(accountID string) ([]*group.Group, error) GroupAddPeerFunc func(accountID, groupID, peerID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error DeleteRuleFunc func(accountID, ruleID, userID string) error @@ -91,10 +92,20 @@ type MockAccountManager struct { DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager + UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error + GroupValidationFunc func(accountId string, groups []string) (bool, error) +} + +func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { + approvedPeers := make(map[string]struct{}) + for id := range account.Peers { + approvedPeers[id] = struct{}{} + } + return approvedPeers, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*server.Group, error) { +func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(accountId, groupID, userID) } @@ -102,7 +113,7 @@ func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*serv } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*server.Group, error) { +func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(accountID, userID) } @@ -261,7 +272,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*server.Group, error) { +func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(accountID, groupName) } @@ -269,7 +280,7 @@ func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*serv } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.Group) error { +func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(accountID, userID, group) } @@ -285,7 +296,7 @@ func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) err } // ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, error) { +func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) { if am.ListGroupsFunc != nil { return am.ListGroupsFunc(accountID) } @@ -694,3 +705,19 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { } return nil } + +// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface +func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + if am.UpdateIntegratedValidatorGroupsFunc != nil { + return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) + } + return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented") +} + +// GroupValidation mocks GroupValidation of the AccountManager interface +func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if am.GroupValidationFunc != nil { + return am.GroupValidationFunc(accountId, groups) + } + return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e521805c810..fa77936024b 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -261,7 +262,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*Group) error { +func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d04ac1a20a1..b10f9387a62 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -8,6 +8,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -759,7 +760,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createNSStore(t *testing.T) (Store, error) { @@ -831,12 +832,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &Group{ + newGroup1 := &nbgroup.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &Group{ + newGroup2 := &nbgroup.Group{ ID: group2ID, Name: group2ID, } diff --git a/management/server/peer.go b/management/server/peer.go index 7de1b654254..fda8e49e9cc 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -7,16 +7,12 @@ import ( "time" "github.com/rs/xid" + log "github.com/sirupsen/logrus" - "github.com/netbirdio/management-integrations/additions" - + "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/management/proto" ) // PeerSync used as a data object between the gRPC API and AccountManager on Sync request. @@ -37,6 +33,8 @@ type PeerLogin struct { UserID string // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. SetupKey string + // ConnectionIP is the real IP of the peer + ConnectionIP net.IP } // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if @@ -52,6 +50,10 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) @@ -71,7 +73,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(peer.ID) + aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -167,7 +169,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) } - update, err = additions.ValidatePeersUpdateRequest(update, peer, userID, accountID, am.eventStore, am.GetDNSDomain()) + update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, err } @@ -244,6 +246,12 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, // the 2nd loop performs the actual modification for _, peer := range peers { + + err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID) + if err != nil { + return err + } + account.DeletePeer(peer.ID) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{ @@ -304,7 +312,17 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro if peer == nil { return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) } - return account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + if err != nil { + return nil, err + } + return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil } // GetPeerNetwork returns the Network for a given peer @@ -433,10 +451,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P CreatedAt: registrationTime, LoginExpirationEnabled: addedByUser, Ephemeral: ephemeral, - } - - if account.Settings.Extra != nil { - newPeer = additions.PreparePeer(newPeer, account.Settings.Extra) + Location: peer.Location, } // add peer to 'All' group @@ -467,6 +482,8 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } } + newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) + if addedByUser { user, err := account.FindUser(userID) if err != nil { @@ -492,7 +509,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P am.updateAccountPeers(account) - networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain) + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap) return newPeer, networkMap, nil } @@ -529,23 +550,53 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network if peerLoginExpired(peer, account) { return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + if requiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + if isStatusChanged { + am.updateAccountPeers(account) + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey) - if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. - return am.AddPeer(login.SetupKey, login.UserID, &nbpeer.Peer{ + newPeer := &nbpeer.Peer{ Key: login.WireGuardPubKey, Meta: login.Meta, SSHKey: login.SSHKey, - }) + } + if am.geo != nil && login.ConnectionIP != nil { + location, err := am.geo.Lookup(login.ConnectionIP) + if err != nil { + log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) + } else { + newPeer.Location.ConnectionIP = login.ConnectionIP + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + + } + } + + return am.AddPeer(login.SetupKey, login.UserID, newPeer) } log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") @@ -595,6 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) } + isRequiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peer, updated := updatePeerMeta(peer, login.Meta, account) if updated { shouldStoreAccount = true @@ -612,10 +664,23 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } } - if updateRemotePeers { + if updateRemotePeers || isStatusChanged { am.updateAccountPeers(account) } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + if isRequiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { @@ -764,8 +829,13 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } + for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(p.ID) + aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -789,8 +859,13 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco func (am *DefaultAccountManager) updateAccountPeers(account *Account) { peers := account.GetPeers() + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed send out updates to peers, failed to validate peer: %v", err) + return + } for _, peer := range peers { - remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain) + remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 7f6d440bb10..6063cc2a742 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -199,8 +200,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 Group - group2 Group + group1 nbgroup.Group + group2 nbgroup.Group policy Policy ) diff --git a/management/server/policy.go b/management/server/policy.go index 8265dabb51c..e162d2b3bc8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,11 +5,11 @@ import ( "strconv" "strings" - "github.com/netbirdio/management-integrations/additions" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -211,7 +211,8 @@ type FirewallRule struct { // getPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator() for _, policy := range a.Policies { if !policy.Enabled { @@ -223,10 +224,8 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []* continue } - sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks) - destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil) - sourcePeers = additions.ValidatePeers(sourcePeers) - destinationPeers = additions.ValidatePeers(destinationPeers) + sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -264,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in all, err := a.GetGroupAll() if err != nil { log.Errorf("failed to get group all: %v", err) - all = &Group{} + all = &nbgroup.Group{} } return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { @@ -491,7 +490,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { +func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { @@ -512,6 +511,10 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou continue } + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + if peer.ID == peerID { peerInGroups = true continue diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 681bab1dac9..1ea3bb379a9 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" ) @@ -56,7 +57,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -135,16 +136,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, } + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(p.ID) + peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -299,7 +305,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -374,8 +380,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, } + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } + t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -403,7 +414,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -433,7 +444,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -454,7 +465,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -569,7 +580,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -644,10 +655,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }) + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -657,7 +672,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*FirewallRule{ @@ -673,7 +688,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -683,7 +698,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -698,19 +713,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -725,14 +740,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources("peerA") + peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5a56eaa8bd9..9f8ea08c932 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" ) @@ -858,7 +859,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { groups, err := am.ListGroups(account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *Group + var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -967,7 +968,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &Group{ + newGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1014,7 +1015,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createRouterStore(t *testing.T) (Store, error) { @@ -1195,7 +1196,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*Group{ + newGroup := []*nbgroup.Group{ { ID: routeGroup1, Name: routeGroup1, diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index b714652f137..43edabbd6e2 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -82,7 +83,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -91,7 +92,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -178,7 +179,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -187,7 +188,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index f6a6f92a726..e6a9c846726 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -17,6 +17,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -64,7 +65,7 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, sql.SetMaxOpenConns(conns) // TODO: make it configurable err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, + &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, ) @@ -99,17 +100,17 @@ func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics t // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { - log.Debugf("acquiring global lock") + log.Tracef("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Debugf("released global lock in %v", time.Since(start)) + log.Tracef("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Debugf("took %v to acquire global lock", took) + log.Tracef("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -118,7 +119,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { } func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { - log.Debugf("acquiring lock for account %s", accountID) + log.Tracef("acquiring lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) @@ -127,7 +128,7 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + log.Tracef("released lock for account %s in %v", accountID, time.Since(start)) } return unlock @@ -434,7 +435,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { } account.UsersG = nil - account.Groups = make(map[string]*Group, len(account.GroupsG)) + account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } diff --git a/management/server/user.go b/management/server/user.go index 15517db41b4..b955c405895 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -10,6 +10,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -49,23 +50,6 @@ type UserStatus string // UserRole is the role of a User type UserRole string -// IntegrationReference holds the reference to a particular integration -type IntegrationReference struct { - ID int - IntegrationType string -} - -func (ir IntegrationReference) String() string { - return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) -} - -func (ir IntegrationReference) CacheKey(path ...string) string { - if len(path) == 0 { - return ir.String() - } - return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) -} - // User represents a user of the system type User struct { Id string `gorm:"primaryKey"` @@ -91,7 +75,7 @@ type User struct { // Issued of the user Issued string `gorm:"default:api"` - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } // IsBlocked returns true if the user is blocked, false otherwise diff --git a/management/server/user_test.go b/management/server/user_test.go index e34aa406d2e..c92f87e6c7f 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -276,7 +277,7 @@ func TestUser_Copy(t *testing.T) { LastLogin: time.Now().UTC(), CreatedAt: time.Now().UTC(), Issued: "test", - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 0, IntegrationType: "test", }, @@ -603,8 +604,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + integratedPeerValidator: MocIntegratedValidator{}, } testCases := []struct { @@ -793,7 +795,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { Id: "externalUser", Role: UserRoleUser, Issued: UserIssuedIntegration, - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", }, From bd7a65d7984ad802e105a6451ffd31560ee8312e Mon Sep 17 00:00:00 2001 From: Jeremy Wu Date: Thu, 28 Mar 2024 16:56:41 +0800 Subject: [PATCH 19/89] support to configure extra blacklist of iface in "up" command (#1734) Support to configure extra blacklist of iface in "up" command --- client/cmd/root.go | 2 + client/cmd/up.go | 14 +- client/internal/config.go | 18 +- client/internal/config_test.go | 22 +- client/proto/daemon.pb.go | 401 +++++++++++++++++---------------- client/proto/daemon.proto | 2 + client/server/server.go | 8 +- 7 files changed, 259 insertions(+), 208 deletions(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index c3ff0a3c876..9c4ad99dec0 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -34,6 +34,7 @@ const ( wireguardPortFlag = "wireguard-port" disableAutoConnectFlag = "disable-auto-connect" serverSSHAllowedFlag = "allow-server-ssh" + extraIFaceBlackListFlag = "extra-iface-blacklist" ) var ( @@ -63,6 +64,7 @@ var ( wireguardPort uint16 serviceName string autoConnectDisabled bool + extraIFaceBlackList []string rootCmd = &cobra.Command{ Use: "netbird", Short: "", diff --git a/client/cmd/up.go b/client/cmd/up.go index f44f29a4702..c2c3c7c90ba 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -40,6 +40,7 @@ func init() { upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") + upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") } func upFunc(cmd *cobra.Command, args []string) error { @@ -83,11 +84,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { } ic := internal.ConfigInput{ - ManagementURL: managementURL, - AdminURL: adminURL, - ConfigPath: configPath, - NATExternalIPs: natExternalIPs, - CustomDNSAddress: customDNSAddressConverted, + ManagementURL: managementURL, + AdminURL: adminURL, + ConfigPath: configPath, + NATExternalIPs: natExternalIPs, + CustomDNSAddress: customDNSAddressConverted, + ExtraIFaceBlackList: extraIFaceBlackList, } if cmd.Flag(enableRosenpassFlag).Changed { @@ -149,7 +151,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { - customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed) if err != nil { return err @@ -190,6 +191,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { CustomDNSAddress: customDNSAddressConverted, IsLinuxDesktopClient: isLinuxRunningDesktop(), Hostname: hostName, + ExtraIFaceBlacklist: extraIFaceBlackList, } if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { diff --git a/client/internal/config.go b/client/internal/config.go index 2f69582350e..5b3c61cbd66 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -30,8 +30,10 @@ const ( DefaultAdminURL = "https://app.netbird.io:443" ) -var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", - "Tailscale", "tailscale", "docker", "veth", "br-", "lo"} +var defaultInterfaceBlacklist = []string{ + iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", + "Tailscale", "tailscale", "docker", "veth", "br-", "lo", +} // ConfigInput carries configuration changes to the client type ConfigInput struct { @@ -47,6 +49,7 @@ type ConfigInput struct { InterfaceName *string WireguardPort *int DisableAutoConnect *bool + ExtraIFaceBlackList []string } // Config Configuration type @@ -220,7 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) { config.AdminURL = newURL } - config.IFaceBlackList = defaultInterfaceBlacklist + // nolint:gocritic + config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...) return config, nil } @@ -320,6 +324,13 @@ func update(input ConfigInput) (*Config, error) { refresh = true } + if len(input.ExtraIFaceBlackList) > 0 { + for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) { + config.IFaceBlackList = append(config.IFaceBlackList, iFace) + refresh = true + } + } + if refresh { // since we have new management URL, we need to update config file if err := util.WriteJson(input.ConfigPath, config); err != nil { @@ -384,7 +395,6 @@ func configFileIsExists(path string) bool { // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // The check is performed only for the NetBird's managed version. func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) { - defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL) if err != nil { return nil, err diff --git a/client/internal/config_test.go b/client/internal/config_test.go index 7453c8fdf86..978d0b3df54 100644 --- a/client/internal/config_test.go +++ b/client/internal/config_test.go @@ -18,7 +18,6 @@ func TestGetConfig(t *testing.T) { config, err := UpdateOrCreateConfig(ConfigInput{ ConfigPath: filepath.Join(t.TempDir(), "config.json"), }) - if err != nil { return } @@ -86,6 +85,26 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL) } +func TestExtraIFaceBlackList(t *testing.T) { + extraIFaceBlackList := []string{"eth1"} + path := filepath.Join(t.TempDir(), "config.json") + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: path, + ExtraIFaceBlackList: extraIFaceBlackList, + }) + if err != nil { + return + } + + assert.Contains(t, config.IFaceBlackList, "eth1") + readConf, err := util.ReadJson(path, config) + if err != nil { + return + } + + assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1") +} + func TestHiddenPreSharedKey(t *testing.T) { hidden := "**********" samplePreSharedKey := "mysecretpresharedkey" @@ -111,7 +130,6 @@ func TestHiddenPreSharedKey(t *testing.T) { ConfigPath: cfgFile, PreSharedKey: tt.preSharedKey, }) - if err != nil { t.Fatalf("failed to get cfg: %s", err) } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 81998b115d3..4b850226893 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -44,17 +44,18 @@ type LoginRequest struct { // cleanNATExternalIPs clean map list of external IPs. // This is needed because the generated code // omits initialized empty slices due to omitempty tags - CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` - CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` - IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` - Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` - RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` - InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` - WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` - OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` - DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` - ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` - RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` + CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` + IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` + Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` + RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` + InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` + WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` + OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` + DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` + RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` } func (x *LoginRequest) Reset() { @@ -202,6 +203,13 @@ func (x *LoginRequest) GetRosenpassPermissive() bool { return false } +func (x *LoginRequest) GetExtraIFaceBlacklist() []string { + if x != nil { + return x.ExtraIFaceBlacklist + } + return nil +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1385,7 +1393,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xdd, 0x06, 0x0a, 0x0c, 0x4c, 0x6f, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x8f, 0x07, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -1430,192 +1438,195 @@ var file_daemon_proto_rawDesc = []byte{ 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x10, 0x20, 0x01, 0x28, 0x08, 0x48, 0x06, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, - 0x88, 0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, - 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, - 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, - 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, - 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, - 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6e, - 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x28, 0x0a, - 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, - 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, - 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, - 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, - 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, - 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, - 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, - 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x32, 0x0a, - 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, 0x6e, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, 0x11, 0x47, - 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, - 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, 0x65, - 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, - 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, - 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, - 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, - 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, - 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, - 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, + 0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, + 0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63, + 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, + 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, + 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, + 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, + 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, + 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, + 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, + 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, + 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, + 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, + 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, + 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, + 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, + 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, + 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, + 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, + 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, + 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, + 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, + 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, + 0x52, 0x4c, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, + 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, + 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, + 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, + 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, - 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, - 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, - 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, - 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, - 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0b, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, + 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, - 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, 0x61, - 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, - 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x18, - 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, 0x18, - 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, - 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, - 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, - 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, - 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, - 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, - 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, + 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, + 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, + 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, + 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, + 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, - 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, - 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, - 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, - 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, - 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, - 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, - 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, - 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, - 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, - 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, - 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, - 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, - 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, - 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, - 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, + 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, + 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, + 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, + 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, + 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, + 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, + 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, + 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, + 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, + 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, + 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, + 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, + 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, + 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, + 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, + 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 8f9148d68af..5f8878a11b9 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -70,6 +70,8 @@ message LoginRequest { optional bool serverSSHAllowed = 15; optional bool rosenpassPermissive = 16; + + repeated string extraIFaceBlacklist = 17; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index 481ef0f7cc6..d1d9dbda451 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -152,7 +152,8 @@ func (s *Server) Start() error { // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, - mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe) { + mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, +) { backOff := getConnectWithBackoff(ctx) retryStarted := false @@ -351,6 +352,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.latestConfigInput.WireguardPort = &port } + if len(msg.ExtraIFaceBlacklist) > 0 { + inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + } + s.mutex.Unlock() if msg.OptionalPreSharedKey != nil { From 22beac1b1b510ac117d84f95af4ddb08d7a048b9 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 28 Mar 2024 12:33:56 +0100 Subject: [PATCH 20/89] Fix invalid token due to the cache race (#1763) --- management/server/grpcserver.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 340adcfc61e..4df24711ead 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -343,10 +343,18 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p userID := "" // JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered, // or it uses a setup key to register. + if loginReq.GetJwtToken() != "" { - userID, err = s.validateToken(loginReq.GetJwtToken()) + for i := 0; i < 3; i++ { + userID, err = s.validateToken(loginReq.GetJwtToken()) + if err == nil { + break + } + log.Warnf("failed validating JWT token sent from peer %s with error %v. "+ + "Trying again as it may be due to the IdP cache issue", peerKey, err) + time.Sleep(200 * time.Millisecond) + } if err != nil { - log.Warnf("failed validating JWT token sent from peer %s", peerKey) return nil, err } } From 4fff93a1f228fe15a802b5a0862aa1180f912c5d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 28 Mar 2024 13:06:54 +0100 Subject: [PATCH 21/89] Ignore unsupported address families (#1766) --- client/internal/routemanager/systemops_linux.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 192509992c7..3510f95531a 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -162,7 +162,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink add route: %w", err) } @@ -185,7 +185,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { Dst: ipNet, } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink add unreachable route: %w", err) } @@ -205,7 +205,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { Dst: ipNet, } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink remove unreachable route: %w", err) } @@ -231,7 +231,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink remove route: %w", err) } @@ -255,7 +255,7 @@ func flushRoutes(tableID, family int) error { routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} } } - if err := netlink.RouteDel(&routes[i]); err != nil { + if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) } } @@ -385,7 +385,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("add routing rule: %w", err) } @@ -402,7 +402,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("remove routing rule: %w", err) } From fd23d0c28ff069a5e60ed9c10139f4babf15a1b4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 28 Mar 2024 18:12:25 +0100 Subject: [PATCH 22/89] Don't block on failed routing setup (#1768) --- client/internal/engine.go | 3 +-- client/internal/routemanager/systemops_linux.go | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 7f7b5ef55ba..046a6c94450 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -261,8 +261,7 @@ func (e *Engine) Start() error { e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) if err := e.routeManager.Init(); err != nil { - e.close() - return fmt.Errorf("init route manager: %w", err) + log.Errorf("Failed to initialize route manager: %s", err) } e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 3510f95531a..83af5008ae0 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -45,10 +45,10 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, } } From 40d56e5d29608d55aa628a5172fceac1116a3aa9 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 28 Mar 2024 18:43:32 +0100 Subject: [PATCH 23/89] Update network security image (#1765) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9b52f5b5f12..d0b07feffa9 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,8 @@ ### Open-Source Network Security in a Single Platform -![download (2)](https://github.com/netbirdio/netbird/assets/700848/16210ac2-7265-44c1-8d4e-8fae85534dac) +![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f) + ### Key features From 9c2dc05df1a735adc4ab8ec19fc6371c528817c2 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sun, 31 Mar 2024 19:39:52 +0200 Subject: [PATCH 24/89] Eval/higher timeouts (#1776) --- client/internal/peer/conn.go | 28 ++++++++++++++++------------ client/internal/peer/env_config.go | 27 ++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index c180e8f032b..ce8cc4b9779 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -26,6 +26,8 @@ import ( const ( iceKeepAliveDefault = 4 * time.Second iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second defaultWgKeepAlive = 25 * time.Second ) @@ -196,20 +198,22 @@ func (conn *Conn) reCreateAgent() error { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: conn.config.StunTurn, - CandidateTypes: conn.candidateTypes(), - FailedTimeout: &failedTimeout, - InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList), - UDPMux: conn.config.UDPMux, - UDPMuxSrflx: conn.config.UDPMuxSrflx, - NAT1To1IPs: conn.config.NATExternalIPs, - Net: transportNet, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: conn.config.StunTurn, + CandidateTypes: conn.candidateTypes(), + FailedTimeout: &failedTimeout, + InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList), + UDPMux: conn.config.UDPMux, + UDPMuxSrflx: conn.config.UDPMuxSrflx, + NAT1To1IPs: conn.config.NATExternalIPs, + Net: transportNet, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, } if conn.config.DisableIPv6Discovery { diff --git a/client/internal/peer/env_config.go b/client/internal/peer/env_config.go index 540bc413ea7..87b626df763 100644 --- a/client/internal/peer/env_config.go +++ b/client/internal/peer/env_config.go @@ -10,9 +10,10 @@ import ( ) const ( - envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" - envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" - envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" + envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" + envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" + envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" + envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" ) func iceKeepAlive() time.Duration { @@ -21,7 +22,7 @@ func iceKeepAlive() time.Duration { return iceKeepAliveDefault } - log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv) + log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) if err != nil { log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) @@ -37,7 +38,7 @@ func iceDisconnectedTimeout() time.Duration { return iceDisconnectedTimeoutDefault } - log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) + log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) if err != nil { log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) @@ -47,6 +48,22 @@ func iceDisconnectedTimeout() time.Duration { return time.Duration(disconnectedTimeoutSec) * time.Second } +func iceRelayAcceptanceMinWait() time.Duration { + iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec) + if iceRelayAcceptanceMinWaitEnv == "" { + return iceRelayAcceptanceMinWaitDefault + } + + log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) + disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) + if err != nil { + log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) + return iceRelayAcceptanceMinWaitDefault + } + + return time.Duration(disconnectedTimeoutSec) * time.Second +} + func hasICEForceRelayConn() bool { disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) return strings.ToLower(disconnectedTimeoutEnv) == "true" From 23a14737974e3849fa86408d136cc46db8a885d0 Mon Sep 17 00:00:00 2001 From: Vilian Gerdzhikov Date: Tue, 2 Apr 2024 11:08:58 +0300 Subject: [PATCH 25/89] Fix grammar in readme (#1778) --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d0b07feffa9..d2a2bd6b9af 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. -**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. +**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. ### Open-Source Network Security in a Single Platform @@ -77,7 +77,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird - **Public domain** name pointing to the VM. **Software requirements:** -- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher. +- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher. - [jq](https://jqlang.github.io/jq/) installed. In most distributions Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq` - [curl](https://curl.se/) installed. @@ -94,9 +94,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird - Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard. - Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers). - NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines. -- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. +- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. - Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates. -- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server. +- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server. [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups. @@ -120,7 +120,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu ![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png) ### Testimonials -We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution). +We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution). ### Legal _WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld. From 9af532fe719e3851db756087730f278ea5559751 Mon Sep 17 00:00:00 2001 From: rqi14 Date: Tue, 2 Apr 2024 19:43:57 +0800 Subject: [PATCH 26/89] Get scope from endpoint url instead of hardcoding (#1770) --- management/server/idp/azure.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 706e4d33014..2f21b3b5417 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -115,7 +115,15 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { data.Set("client_id", ac.clientConfig.ClientID) data.Set("client_secret", ac.clientConfig.ClientSecret) data.Set("grant_type", ac.clientConfig.GrantType) - data.Set("scope", "https://graph.microsoft.com/.default") + parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint) + if err != nil { + return nil, err + } + + // get base url and add "/.default" as scope + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + scopeURL := baseURL + "/.default" + data.Set("scope", scopeURL) payload := strings.NewReader(data.Encode()) req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload) From 79382951905d9fe399ed5b07fff1ddcda8de7ee2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 11:11:46 +0200 Subject: [PATCH 27/89] Feature/exit nodes - Windows and macOS support (#1726) --- .github/workflows/golang-test-darwin.yml | 3 + .github/workflows/golang-test-windows.yml | 2 +- client/internal/engine.go | 26 +- client/internal/peer/conn.go | 33 ++ client/internal/routemanager/client.go | 4 +- client/internal/routemanager/manager.go | 41 +- client/internal/routemanager/manager_test.go | 30 +- client/internal/routemanager/mock.go | 5 +- client/internal/routemanager/routemanager.go | 119 +++++ .../routemanager/server_nonandroid.go | 8 +- .../routemanager/systemops_android.go | 24 +- .../routemanager/systemops_bsd_nonios.go | 13 - .../internal/routemanager/systemops_darwin.go | 61 +++ .../routemanager/systemops_darwin_test.go | 100 +++++ client/internal/routemanager/systemops_ios.go | 24 +- .../internal/routemanager/systemops_linux.go | 44 +- .../routemanager/systemops_linux_test.go | 386 +++-------------- .../routemanager/systemops_nonandroid.go | 148 ------- .../routemanager/systemops_nonandroid_test.go | 282 ------------ .../routemanager/systemops_nonlinux.go | 406 +++++++++++++++++- .../routemanager/systemops_nonlinux_test.go | 242 ++++++++++- .../routemanager/systemops_unix_test.go | 234 ++++++++++ .../routemanager/systemops_windows.go | 81 +++- .../routemanager/systemops_windows_test.go | 289 +++++++++++++ client/internal/routemanager/sytemops_test.go | 101 +++++ client/internal/wgproxy/portlookup.go | 6 +- client/internal/wgproxy/proxy_ebpf.go | 6 +- go.mod | 2 +- go.sum | 6 +- util/grpc/{dialer_linux.go => dialer.go} | 10 +- util/grpc/dialer_generic.go | 9 - util/net/dialer.go | 64 +++ util/net/dialer_generic.go | 118 ++++- util/net/dialer_linux.go | 58 +-- util/net/dialer_nonlinux.go | 6 + util/net/listener.go | 21 + util/net/listener_generic.go | 153 ++++++- util/net/listener_linux.go | 24 +- util/net/listener_mobile.go | 11 + util/net/listener_nonlinux.go | 6 + util/net/net.go | 11 + 41 files changed, 2253 insertions(+), 964 deletions(-) create mode 100644 client/internal/routemanager/routemanager.go delete mode 100644 client/internal/routemanager/systemops_bsd_nonios.go create mode 100644 client/internal/routemanager/systemops_darwin.go create mode 100644 client/internal/routemanager/systemops_darwin_test.go delete mode 100644 client/internal/routemanager/systemops_nonandroid.go delete mode 100644 client/internal/routemanager/systemops_nonandroid_test.go create mode 100644 client/internal/routemanager/systemops_unix_test.go create mode 100644 client/internal/routemanager/systemops_windows_test.go create mode 100644 client/internal/routemanager/sytemops_test.go rename util/grpc/{dialer_linux.go => dialer.go} (56%) delete mode 100644 util/grpc/dialer_generic.go create mode 100644 util/net/dialer.go create mode 100644 util/net/dialer_nonlinux.go create mode 100644 util/net/listener.go create mode 100644 util/net/listener_mobile.go create mode 100644 util/net/listener_nonlinux.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index f8afd3d6eab..d7007c86080 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,6 +32,9 @@ jobs: restore-keys: | macos-go- + - name: Install libpcap + run: brew install libpcap + - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 6027d36269f..2d63acbcd5a 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/client/internal/engine.go b/client/internal/engine.go index 046a6c94450..d6238c4b3ca 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -93,6 +93,10 @@ type Engine struct { mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn + + beforePeerHook peer.BeforeAddPeerHookFunc + afterPeerHook peer.AfterRemovePeerHookFunc + // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -260,9 +264,14 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) - if err := e.routeManager.Init(); err != nil { + beforePeerHook, afterPeerHook, err := e.routeManager.Init() + if err != nil { log.Errorf("Failed to initialize route manager: %s", err) + } else { + e.beforePeerHook = beforePeerHook + e.afterPeerHook = afterPeerHook } + e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -808,10 +817,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { if _, ok := e.peerConns[peerKey]; !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { - return err + return fmt.Errorf("create peer connection: %w", err) } e.peerConns[peerKey] = conn + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) @@ -1105,6 +1119,10 @@ func (e *Engine) close() { e.dnsServer.Stop() } + if e.routeManager != nil { + e.routeManager.Stop() + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { @@ -1119,10 +1137,6 @@ func (e *Engine) close() { } } - if e.routeManager != nil { - e.routeManager.Stop() - } - if e.firewall != nil { err := e.firewall.Reset() if err != nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ce8cc4b9779..f3d07dcad1f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -100,6 +101,9 @@ type IceCredentials struct { Pwd string } +type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error +type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error + type Conn struct { config ConnConfig mu sync.Mutex @@ -138,6 +142,10 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn + + connID nbnet.ConnectionID + beforeAddPeerHooks []BeforeAddPeerHookFunc + afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -393,6 +401,14 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } +func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { + conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) +} + +func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { + conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) +} + // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -419,6 +435,14 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.remoteEndpoint = endpointUdpAddr + log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + + conn.connID = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + log.Errorf("Before add peer hook failed: %v", err) + } + } err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { @@ -510,6 +534,15 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if conn.connID != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connID); err != nil { + log.Errorf("After remove peer hook failed: %v", err) + } + } + } + conn.connID = "" + if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index b2dff7f08cf..38cf4bf6550 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -193,7 +193,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } @@ -234,7 +234,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // otherwise add the route to the system - if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6a0d954da09..36a37f02c50 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -3,7 +3,9 @@ package routemanager import ( "context" "fmt" + "net" "net/netip" + "net/url" "runtime" "sync" @@ -24,7 +26,7 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) // Manager is a route manager interface type Manager interface { - Init() error + Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -65,16 +67,21 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, } // Init sets up the routing -func (m *DefaultManager) Init() error { +func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { if err := cleanupRouting(); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } - if err := setupRouting(); err != nil { - return fmt.Errorf("setup routing: %w", err) + mgmtAddress := m.statusRecorder.GetManagementState().URL + signalAddress := m.statusRecorder.GetSignalState().URL + ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) + + beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) + if err != nil { + return nil, nil, fmt.Errorf("setup routing: %w", err) } log.Info("Routing setup complete") - return nil + return beforePeerHook, afterPeerHook, nil } func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { @@ -203,16 +210,36 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } func isPrefixSupported(prefix netip.Prefix) bool { - if runtime.GOOS == "linux" { + switch runtime.GOOS { + case "linux", "windows", "darwin": return true } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported // we skip this prefix management - if prefix.Bits() < minRangeBits { + if prefix.Bits() <= minRangeBits { log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", version.NetbirdVersion(), prefix) return false } return true } + +// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. +func resolveURLsToIPs(urls []string) []net.IP { + var ips []net.IP + for _, rawurl := range urls { + u, err := url.Parse(rawurl) + if err != nil { + log.Errorf("Failed to parse url %s: %v", rawurl, err) + continue + } + ipAddrs, err := net.LookupIP(u.Hostname()) + if err != nil { + log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) + continue + } + ips = append(ips, ipAddrs...) + } + return ips +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 9d92bf90d2f..03e77e09bcb 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,14 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int - clientNetworkWatchersExpectedLinux int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedAllowed int }{ { name: "Should create 2 client networks", @@ -201,9 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, - clientNetworkWatchersExpectedLinux: 1, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedAllowed: 1, }, { name: "Remove 1 Client Route", @@ -417,7 +417,9 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) - err = routeManager.Init() + + _, _, err = routeManager.Init() + require.NoError(t, err, "should init route manager") defer routeManager.Stop() @@ -434,8 +436,8 @@ func TestManagerUpdateRoutes(t *testing.T) { require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected - if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 { - expectedWatchers = testCase.clientNetworkWatchersExpectedLinux + if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed } require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index e812b3a85b6..dd2c28e5927 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,6 +6,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -16,8 +17,8 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() error { - return nil +func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil } // InitialRouteRange mock implementation of InitialRouteRange from Manager interface diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go new file mode 100644 index 00000000000..fe8d7b4ef19 --- /dev/null +++ b/client/internal/routemanager/routemanager.go @@ -0,0 +1,119 @@ +//go:build !android + +package routemanager + +import ( + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type ref struct { + count int + nexthop netip.Addr + intf string +} + +type RouteManager struct { + // refCountMap keeps track of the reference ref for prefixes + refCountMap map[netip.Prefix]ref + // prefixMap keeps track of the prefixes associated with a connection ID for removal + prefixMap map[nbnet.ConnectionID][]netip.Prefix + addRoute AddRouteFunc + removeRoute RemoveRouteFunc + mutex sync.Mutex +} + +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error + +func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { + // TODO: read initial routing table into refCountMap + return &RouteManager{ + refCountMap: map[netip.Prefix]ref{}, + prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, + addRoute: addRoute, + removeRoute: removeRoute, + } +} + +func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + ref := rm.refCountMap[prefix] + log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) + + // Add route to the system, only if it's a new prefix + if ref.count == 0 { + log.Debugf("Adding route for prefix %s", prefix) + nexthop, intf, err := rm.addRoute(prefix) + if err != nil { + return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) + } + ref.nexthop = nexthop + ref.intf = intf + } + + ref.count++ + rm.refCountMap[prefix] = ref + rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) + + return nil +} + +func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + prefixes, ok := rm.prefixMap[connID] + if !ok { + log.Debugf("No prefixes found for connection ID %s", connID) + return nil + } + + var result *multierror.Error + for _, prefix := range prefixes { + ref := rm.refCountMap[prefix] + log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) + if ref.count == 1 { + log.Debugf("Removing route for prefix %s", prefix) + // TODO: don't fail if the route is not found + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + continue + } + delete(rm.refCountMap, prefix) + } else { + ref.count-- + rm.refCountMap[prefix] = ref + } + } + delete(rm.prefixMap, connID) + + return result.ErrorOrNil() +} + +// Flush removes all references and routes from the system +func (rm *RouteManager) Flush() error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + var result *multierror.Error + for prefix := range rm.refCountMap { + log.Debugf("Removing route for prefix %s", prefix) + ref := rm.refCountMap[prefix] + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + } + } + rm.refCountMap = map[netip.Prefix]ref{} + rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 00df735fb8a..af82dc91349 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -155,11 +155,13 @@ func (m *defaultServerRouter) cleanUp() { log.Errorf("Failed to remove cleanup route: %v", err) } - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } + + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } + func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { parsed, err := netip.ParsePrefix(source) if err != nil { diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 291826780af..34d2d270fe3 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { + return nil +} + +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd_nonios.go b/client/internal/routemanager/systemops_bsd_nonios.go deleted file mode 100644 index f60c7afc3a0..00000000000 --- a/client/internal/routemanager/systemops_bsd_nonios.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios - -package routemanager - -import "net/netip" - -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - return genericAddToRouteTableIfNoExists(prefix, addr, intf) -} - -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) -} diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go new file mode 100644 index 00000000000..f34964a8343 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin.go @@ -0,0 +1,61 @@ +//go:build darwin && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + "os/exec" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" +) + +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("add", prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("delete", prefix, nexthop, intf) +} + +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { + inet := "-inet" + if prefix.Addr().Is6() { + inet = "-inet6" + // Special case for IPv6 split default route, pointing to the wg interface fails + // TODO: Remove once we have IPv6 support on the interface + if prefix.Bits() == 1 { + intf = "lo0" + } + } + + args := []string{"-n", action, inet, prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } else if intf != "" { + args = append(args, "-interface", intf) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go new file mode 100644 index 00000000000..5c5aaa24fe1 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -0,0 +1,100 @@ +//go:build !ios + +package routemanager + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var expectedVPNint = "utun100" +var expectedExternalInt = "lo0" +var expectedInternalInt = "lo0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via vpn", + destination: "10.10.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), + }, + }...) +} + +func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { + t.Helper() + + err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() + require.NoError(t, err, "Failed to create loopback alias") + + t.Cleanup(func() { + err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() + assert.NoError(t, err, "Failed to remove loopback alias") + }) + + return "lo0" +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { + t.Helper() + + var originalNexthop net.IP + if dstCIDR == "0.0.0.0/0" { + var err error + originalNexthop, err = fetchOriginalGateway() + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { + t.Logf("Failed to delete route: %v, output: %s", err, output) + } + } + + t.Cleanup(func() { + if originalNexthop != nil { + err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() + assert.NoError(t, err, "Failed to restore original route") + } + }) + + err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() + require.NoError(t, err, "Failed to add route") + + t.Cleanup(func() { + err := exec.Command("route", "delete", "-net", dstCIDR).Run() + assert.NoError(t, err, "Failed to remove route") + }) +} + +func fetchOriginalGateway() (net.IP, error) { + output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() + if err != nil { + return nil, err + } + + matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) + if len(matches) == 0 { + return nil, fmt.Errorf("gateway not found") + } + + return net.ParseIP(matches[1]), nil +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 291826780af..34d2d270fe3 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { + return nil +} + +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 83af5008ae0..d21a3bfbfea 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -15,6 +15,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -64,7 +66,7 @@ func getSetupRules() []ruleParams { // enabling VPN connectivity. // // The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting() (err error) { +func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } @@ -80,11 +82,11 @@ func setupRouting() (err error) { rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { - return fmt.Errorf("%s: %w", rule.description, err) + return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } - return nil + return nil, nil, nil } // cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. @@ -110,7 +112,7 @@ func cleanupRouting() error { return result.ErrorOrNil() } -func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error { +func addVPNRoute(prefix netip.Prefix, intf string) error { // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 // TODO remove this once we have ipv6 support @@ -125,7 +127,7 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error { +func removeVPNRoute(prefix netip.Prefix, intf string) error { // TODO remove this once we have ipv6 support if prefix == defaultv4 { if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { @@ -138,10 +140,6 @@ func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) return nil } -func getRoutesFromTable() ([]netip.Prefix, error) { - return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4) -} - // addRoute adds a route to a specific routing table identified by tableID. func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { route := &netlink.Route{ @@ -263,34 +261,6 @@ func flushRoutes(tableID, family int) error { return result.ErrorOrNil() } -// getRoutes fetches routes from a specific routing table identified by tableID. -func getRoutes(tableID, family int) ([]netip.Prefix, error) { - var prefixList []netip.Prefix - - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) - } - - for _, route := range routes { - if route.Dst != nil { - addr, ok := netip.AddrFromSlice(route.Dst.IP) - if !ok { - return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) - } - - ones, _ := route.Dst.Mask.Size() - - prefix := netip.PrefixFrom(addr, ones) - if prefix.IsValid() { - prefixList = append(prefixList, prefix) - } - } - } - - return prefixList, nil -} - func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 96e43d20f0b..50a02401a68 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -6,34 +6,40 @@ import ( "errors" "fmt" "net" - "net/netip" "os" "strings" "syscall" "testing" - "time" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/gopacket/gopacket/pcap" - "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool +var expectedVPNint = "wgtest0" +var expectedLoopbackInt = "lo" +var expectedExternalInt = "dummyext0" +var expectedInternalInt = "dummyint0" + +var errRouteNotFound = fmt.Errorf("route not found") + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via physical interface", + destination: "10.10.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), + }, + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.1:53", + expectedInterface: expectedLoopbackInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + }...) } func TestEntryExists(t *testing.T) { @@ -92,157 +98,7 @@ func TestEntryExists(t *testing.T) { } } -func TestRoutingWithTables(t *testing.T) { - testCases := []struct { - name string - destination string - captureInterface string - dialer *net.Dialer - packetExpectation PacketExpectation - }{ - { - name: "To external host without fwmark via vpn", - destination: "192.0.2.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with fwmark via physical interface", - destination: "192.0.2.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with fwmark via physical interface", - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - { - name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - - { - name: "To unique vpn route with fwmark via physical interface", - destination: "172.16.0.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53), - }, - { - name: "To unique vpn route without fwmark via vpn", - destination: "172.16.0.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53), - }, - - { - name: "To more specific route without fwmark via vpn interface", - destination: "10.10.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53), - }, - - { - name: "To more specific route (local) without fwmark via physical interface", - destination: "127.0.10.1:53", - captureInterface: "lo", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - wgIface, _, _ := setupTestEnv(t) - - // default route exists in main table and vpn table - err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.0.0.0/8 route exists in main table and vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.10.0.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 127.0.10.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // unique route in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.captureInterface, filter) - - sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.packetExpectation) - }) - } -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } - -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy { +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { t.Helper() dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} @@ -264,35 +120,52 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str require.NoError(t, err) } - return dummy + t.Cleanup(func() { + err := netlink.LinkDel(dummy) + assert.NoError(t, err) + }) + + return dummy.Name } -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { t.Helper() _, dstIPNet, err := net.ParseCIDR(dstCIDR) require.NoError(t, err) + // Handle existing routes with metric 0 + var originalNexthop net.IP + var originalLinkIndex int if dstIPNet.String() == "0.0.0.0/0" { - gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil { + var err error + originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) + if err != nil && !errors.Is(err, errRouteNotFound) { t.Logf("Failed to fetch original gateway: %v", err) } - // Handle existing routes with metric 0 - err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - if err == nil { - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - } else if !errors.Is(err, syscall.ESRCH) { - t.Logf("Failed to delete route: %v", err) + if originalNexthop != nil { + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + switch { + case err != nil && !errors.Is(err, syscall.ESRCH): + t.Logf("Failed to delete route: %v", err) + case err == nil: + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + default: + t.Logf("Failed to delete route: %v", err) + } } } + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + route := &netlink.Route{ Dst: dstIPNet, Gw: gw, @@ -307,9 +180,9 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } + require.NoError(t, err) } -// fetchOriginalGateway returns the original gateway IP address and the interface index. func fetchOriginalGateway(family int) (net.IP, int, error) { routes, err := netlink.RouteList(nil, family) if err != nil { @@ -317,153 +190,20 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } for _, route := range routes { - if route.Dst == nil { + if route.Dst == nil && route.Priority == 0 { return route.Gw, route.LinkIndex, nil } } - return nil, 0, fmt.Errorf("default route not found") + return nil, 0, errRouteNotFound } -func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) { +func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index) + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index) - - t.Cleanup(func() { - err := netlink.LinkDel(defaultDummy) - assert.NoError(t, err) - err = netlink.LinkDel(otherDummy) - assert.NoError(t, err) - }) - - return defaultDummy.Name, otherDummy.Name -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet(nil) - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) { - t.Helper() - - defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - err := setupRouting() - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - return wgIface, defaultDummy, otherDummy -} - -func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { - t.Helper() - - inactive, err := pcap.NewInactiveHandle(intf) - require.NoError(t, err, "Failed to create inactive pcap handle") - defer inactive.CleanUp() - - err = inactive.SetSnapLen(1600) - require.NoError(t, err, "Failed to set snap length on inactive handle") - - err = inactive.SetTimeout(time.Second * 10) - require.NoError(t, err, "Failed to set timeout on inactive handle") - - err = inactive.SetImmediateMode(true) - require.NoError(t, err, "Failed to set immediate mode on inactive handle") - - handle, err := inactive.Activate() - require.NoError(t, err, "Failed to activate pcap handle") - t.Cleanup(handle.Close) - - err = handle.SetBPFFilter(filter) - require.NoError(t, err, "Failed to set BPF filter") - - return handle -} - -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) { - t.Helper() - - if dialer == nil { - dialer = &net.Dialer{} - } - - if sourcePort != 0 { - localUDPAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: sourcePort, - } - dialer.LocalAddr = localUDPAddr - } - - msg := new(dns.Msg) - msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = []dns.Question{ - {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - conn, err := dialer.Dial("udp", destination) - require.NoError(t, err, "Failed to dial UDP") - defer conn.Close() - - data, err := msg.Pack() - require.NoError(t, err, "Failed to pack DNS message") - - _, err = conn.Write(data) - if err != nil { - if strings.Contains(err.Error(), "required key not available") { - t.Logf("Ignoring WireGuard key error: %v", err) - return - } - t.Fatalf("Failed to send DNS query: %v", err) - } -} - -func createBPFFilter(destination string) string { - host, port, err := net.SplitHostPort(destination) - if err != nil { - return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) - } - return "udp" -} - -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, - } + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) } diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go deleted file mode 100644 index 65f670ace17..00000000000 --- a/client/internal/routemanager/systemops_nonandroid.go +++ /dev/null @@ -1,148 +0,0 @@ -//go:build !android - -//nolint:unused -package routemanager - -import ( - "errors" - "fmt" - "net" - "net/netip" - "os/exec" - "runtime" - - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" -) - -var errRouteNotFound = fmt.Errorf("route not found") - -func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(defaultv4) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - addr := netip.MustParseAddr(defaultGateway.String()) - - if !prefix.Contains(addr) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(addr, 32) - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "") -} - -func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := genericAddRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return genericAddToRouteTable(prefix, addr, intf) -} - -func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTable(prefix, addr, intf) -} - -func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return fmt.Errorf("add route: %w", err) - } - log.Debugf(string(out)) - return nil -} - -func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) - } - cmd := exec.Command("route", args...) - out, err := cmd.Output() - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - log.Debugf(string(out)) - return nil -} - -func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { - r, err := netroute.New() - if err != nil { - return nil, fmt.Errorf("new netroute: %w", err) - } - _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) - if err != nil { - log.Errorf("Getting routes returned an error: %v", err) - return nil, errRouteNotFound - } - - if gateway == nil { - return preferredSrc, nil - } - - return gateway, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go deleted file mode 100644 index aae5e5faa16..00000000000 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ /dev/null @@ -1,282 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "bytes" - "fmt" - "net" - "net/netip" - "os" - "os/exec" - "runtime" - "strings" - "testing" - - "github.com/pion/transport/v3/stdnet" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/iface" -) - -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - - if runtime.GOOS == "linux" { - outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String()) - require.NoError(t, err, "getOutgoingInterfaceLinux should not return error") - if invert { - require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface") - } else { - require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface") - } - return - } - - prefixGateway, err := getExistingRIBRouteGateway(prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } -} - -func getOutgoingInterfaceLinux(destination string) (string, error) { - cmd := exec.Command("ip", "route", "get", destination) - output, err := cmd.Output() - if err != nil { - return "", fmt.Errorf("executing ip route get: %w", err) - } - - return parseOutgoingInterface(string(output)), nil -} - -func parseOutgoingInterface(routeGetOutput string) string { - fields := strings.Fields(routeGetOutput) - for i, field := range fields { - if field == "dev" && i+1 < len(fields) { - return fields[i+1] - } - } - return "" -} - -func TestAddRemoveRoutes(t *testing.T) { - testCases := []struct { - name string - prefix netip.Prefix - shouldRouteToWireguard bool - shouldBeRemoved bool - }{ - { - name: "Should Add And Remove Route 100.66.120.0/24", - prefix: netip.MustParsePrefix("100.66.120.0/24"), - shouldRouteToWireguard: true, - shouldBeRemoved: true, - }, - { - name: "Should Not Add Or Remove Route 127.0.0.1/32", - prefix: netip.MustParsePrefix("127.0.0.1/32"), - shouldRouteToWireguard: false, - shouldBeRemoved: false, - }, - } - - for n, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") - - require.NoError(t, setupRouting()) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - if testCase.shouldRouteToWireguard { - assertWGOutInterface(t, testCase.prefix, wgInterface, false) - } else { - assertWGOutInterface(t, testCase.prefix, wgInterface, true) - } - exists, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "existsInRouteTable should not return err") - if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - - internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - require.NoError(t, err) - - if testCase.shouldBeRemoved { - require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") - } else { - require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") - } - } - }) - } -} - -func TestGetExistingRIBRouteGateway(t *testing.T) { - gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil { - t.Fatal("shouldn't return error when fetching the gateway: ", err) - } - if gateway == nil { - t.Fatal("should return a gateway") - } - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var testingIP string - var testingPrefix netip.Prefix - for _, address := range addresses { - if address.Network() != "ip+net" { - continue - } - prefix := netip.MustParsePrefix(address.String()) - if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { - testingIP = prefix.Addr().String() - testingPrefix = prefix.Masked() - break - } - } - - localIP, err := getExistingRIBRouteGateway(testingPrefix) - if err != nil { - t.Fatal("shouldn't return error: ", err) - } - if localIP == nil { - t.Fatal("should return a gateway for local network") - } - if localIP.String() == gateway.String() { - t.Fatal("local ip should not match with gateway IP") - } - if localIP.String() != testingIP { - t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) - } -} - -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - t.Log("defaultGateway: ", defaultGateway) - if err != nil { - t.Fatal("shouldn't return error when fetching the gateway: ", err) - } - testCases := []struct { - name string - prefix netip.Prefix - preExistingPrefix netip.Prefix - shouldAddRoute bool - }{ - { - name: "Should Add And Remove random Route", - prefix: netip.MustParsePrefix("99.99.99.99/32"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if overlaps with default gateway", - prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), - shouldAddRoute: false, - }, - { - name: "Should Add Route if bigger network exists", - prefix: netip.MustParsePrefix("100.100.100.0/24"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: true, - }, - { - name: "Should Add Route if smaller network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if same network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: false, - }, - } - - for n, testCase := range testCases { - var buf bytes.Buffer - log.SetOutput(&buf) - defer func() { - log.SetOutput(os.Stderr) - }() - t.Run(testCase.name, func(t *testing.T) { - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") - - require.NoError(t, setupRouting()) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - MockAddr := wgInterface.Address().IP.String() - - // Prepare the environment - if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err when adding pre-existing route") - } - - // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err when adding route") - - if testCase.shouldAddRoute { - // test if route exists after adding - ok, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "should not return err") - require.True(t, ok, "route should exist") - - // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err") - } - - // route should either not have been added or should have been removed - // In case of already existing route, it should not have been added (but still exist) - ok, err := existsInRouteTable(testCase.prefix) - t.Log("Buffer string: ", buf.String()) - require.NoError(t, err, "should not return err") - - // Linux uses a separate routing table, so the route can exist in both tables. - // The main routing table takes precedence over the wireguard routing table. - if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { - require.False(t, ok, "route should not exist") - } - }) - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index d793f0fbde0..4bc186f215e 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,22 +1,416 @@ -//go:build !linux || android +//go:build !linux && !ios package routemanager import ( + "context" + "errors" + "fmt" + "net" + "net/netip" "runtime" + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) -func setupRouting() error { +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var errRouteNotFound = fmt.Errorf("route not found") + +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } -func cleanupRouting() error { - return nil +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) } -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Errorf("Getting routes returned an error: %v", err) + return netip.Addr{}, nil, errRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, errRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. +// If the next hop or interface is pointing to the VPN interface, it will return an error +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(): + return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) + case addr.IsLinkLocalUnicast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) + case addr.IsLinkLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) + case addr.IsInterfaceLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) + case addr.IsUnspecified(): + return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) + case addr.IsMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// addVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func addVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// removeVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func removeVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + return nil } + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index afaf5ba7724..adb83bac6d8 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -1,16 +1,250 @@ -//go:build !linux || android +//go:build !linux && !ios package routemanager import ( + "bytes" + "fmt" "net" "net/netip" + "os" + "strings" "testing" + "github.com/pion/transport/v3/stdnet" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/iface" ) +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} + +func TestAddRemoveRoutes(t *testing.T) { + testCases := []struct { + name string + prefix netip.Prefix + shouldRouteToWireguard bool + shouldBeRemoved bool + }{ + { + name: "Should Add And Remove Route 100.66.120.0/24", + prefix: netip.MustParsePrefix("100.66.120.0/24"), + shouldRouteToWireguard: true, + shouldBeRemoved: true, + }, + { + name: "Should Not Add Or Remove Route 127.0.0.1/32", + prefix: netip.MustParsePrefix("127.0.0.1/32"), + shouldRouteToWireguard: false, + shouldBeRemoved: false, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() + newNet, err := stdnet.NewNet() + if err != nil { + t.Fatal(err) + } + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + err = addVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + + if testCase.shouldRouteToWireguard { + assertWGOutInterface(t, testCase.prefix, wgInterface, false) + } else { + assertWGOutInterface(t, testCase.prefix, wgInterface, true) + } + exists, err := existsInRouteTable(testCase.prefix) + require.NoError(t, err, "existsInRouteTable should not return err") + if exists && testCase.shouldRouteToWireguard { + err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "removeVPNRoute should not return err") + + prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + + internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + require.NoError(t, err) + + if testCase.shouldBeRemoved { + require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") + } else { + require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") + } + } + }) + } +} + +func TestGetNextHop(t *testing.T) { + gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + if err != nil { + t.Fatal("shouldn't return error when fetching the gateway: ", err) + } + if !gateway.IsValid() { + t.Fatal("should return a gateway") + } + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var testingIP string + var testingPrefix netip.Prefix + for _, address := range addresses { + if address.Network() != "ip+net" { + continue + } + prefix := netip.MustParsePrefix(address.String()) + if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { + testingIP = prefix.Addr().String() + testingPrefix = prefix.Masked() + break + } + } + + localIP, _, err := getNextHop(testingPrefix.Addr()) + if err != nil { + t.Fatal("shouldn't return error: ", err) + } + if !localIP.IsValid() { + t.Fatal("should return a gateway for local network") + } + if localIP.String() == gateway.String() { + t.Fatal("local ip should not match with gateway IP") + } + if localIP.String() != testingIP { + t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) + } +} + +func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { + defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + t.Log("defaultGateway: ", defaultGateway) + if err != nil { + t.Fatal("shouldn't return error when fetching the gateway: ", err) + } + testCases := []struct { + name string + prefix netip.Prefix + preExistingPrefix netip.Prefix + shouldAddRoute bool + }{ + { + name: "Should Add And Remove random Route", + prefix: netip.MustParsePrefix("99.99.99.99/32"), + shouldAddRoute: true, + }, + { + name: "Should Not Add Route if overlaps with default gateway", + prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), + shouldAddRoute: false, + }, + { + name: "Should Add Route if bigger network exists", + prefix: netip.MustParsePrefix("100.100.100.0/24"), + preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), + shouldAddRoute: true, + }, + { + name: "Should Add Route if smaller network exists", + prefix: netip.MustParsePrefix("100.100.0.0/16"), + preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"), + shouldAddRoute: true, + }, + { + name: "Should Not Add Route if same network exists", + prefix: netip.MustParsePrefix("100.100.0.0/16"), + preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), + shouldAddRoute: false, + }, + } + + for n, testCase := range testCases { + var buf bytes.Buffer + log.SetOutput(&buf) + defer func() { + log.SetOutput(os.Stderr) + }() + t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() + newNet, err := stdnet.NewNet() + if err != nil { + t.Fatal(err) + } + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // Prepare the environment + if testCase.preExistingPrefix.IsValid() { + err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + require.NoError(t, err, "should not return err when adding pre-existing route") + } + + // Add the route + err = addVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "should not return err when adding route") + + if testCase.shouldAddRoute { + // test if route exists after adding + ok, err := existsInRouteTable(testCase.prefix) + require.NoError(t, err, "should not return err") + require.True(t, ok, "route should exist") + + // remove route again if added + err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "should not return err") + } + + // route should either not have been added or should have been removed + // In case of already existing route, it should not have been added (but still exist) + ok, err := existsInRouteTable(testCase.prefix) + t.Log("Buffer string: ", buf.String()) + require.NoError(t, err, "should not return err") + + if !strings.Contains(buf.String(), "because it already exists") { + require.False(t, ok, "route should not exist") + } + }) + } +} + func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { @@ -50,7 +284,8 @@ func TestIsSubRange(t *testing.T) { } func TestExistsInRouteTable(t *testing.T) { - require.NoError(t, setupRouting()) + _, _, err := setupRouting(nil, nil) + require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) }) @@ -63,7 +298,8 @@ func TestExistsInRouteTable(t *testing.T) { var addressPrefixes []netip.Prefix for _, address := range addresses { p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() { addressPrefixes = append(addressPrefixes, p.Masked()) } } diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go new file mode 100644 index 00000000000..561eaeea4b2 --- /dev/null +++ b/client/internal/routemanager/systemops_unix_test.go @@ -0,0 +1,234 @@ +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly + +package routemanager + +import ( + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +type testCase struct { + name string + destination string + expectedInterface string + dialer dialer + expectedPacket PacketExpectation +} + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.expectedInterface, filter) + + sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.expectedPacket) + }) + } +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + switch dialer := dialer.(type) { + case *nbnet.Dialer: + dialer.LocalAddr = localUDPAddr + case *net.Dialer: + dialer.LocalAddr = localUDPAddr + default: + t.Fatal("Unsupported dialer type") + } + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index c009ce66b9d..50fff0cd58d 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -6,9 +6,14 @@ import ( "fmt" "net" "net/netip" + "os/exec" + "strings" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -16,6 +21,16 @@ type Win32_IP4RouteTable struct { Mask string } +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" @@ -48,10 +63,68 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - return genericAddToRouteTableIfNoExists(prefix, addr, intf) +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + destinationPrefix := prefix.String() + psCmd := "New-NetRoute" + + addressFamily := "IPv4" + if prefix.Addr().Is6() { + addressFamily = "IPv6" + } + + script := fmt.Sprintf( + `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, intf, + ) + + if nexthop.IsValid() { + script = fmt.Sprintf( + `%s -NextHop "%s"`, script, nexthop, + ) + } + + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + log.Tracef("PowerShell add route: %s", string(out)) + + if err != nil { + return fmt.Errorf("PowerShell add route: %w", err) + } + + return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"add", prefix.String(), nexthop.Unmap().String()} + + out, err := exec.Command("route", args...).CombinedOutput() + + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + if err != nil { + return fmt.Errorf("route add: %w", err) + } + + return nil +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + // Powershell doesn't support adding routes without an interface but allows to add interface by name + if intf != "" { + return addRoutePowershell(prefix, nexthop, intf) + } + return addRouteCmd(prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"delete", prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil } diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go new file mode 100644 index 00000000000..a5e03b8d2ce --- /dev/null +++ b/client/internal/routemanager/systemops_windows_test.go @@ -0,0 +1,289 @@ +package routemanager + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +var expectedExtInt = "Ethernet1" + +type RouteInfo struct { + NextHop string `json:"nexthop"` + InterfaceAlias string `json:"interfacealias"` + RouteMetric int `json:"routemetric"` +} + +type FindNetRouteOutput struct { + IPAddress string `json:"IPAddress"` + InterfaceIndex int `json:"InterfaceIndex"` + InterfaceAlias string `json:"InterfaceAlias"` + AddressFamily int `json:"AddressFamily"` + NextHop string `json:"NextHop"` + DestinationPrefix string `json:"DestinationPrefix"` +} + +type testCase struct { + name string + destination string + expectedSourceIP string + expectedDestPrefix string + expectedNextHop string + expectedInterface string + dialer dialer +} + +var expectedVPNint = "wgtest0" + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "128.0.0.0/1", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedDestPrefix: "192.0.2.1/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedDestPrefix: "10.0.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "10.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedDestPrefix: "172.16.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "172.16.0.0/12", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route without custom dialer via vpn interface", + destination: "10.10.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "10.10.0.0/24", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "127.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + route, err := fetchOriginalGateway() + require.NoError(t, err, "Failed to fetch original gateway") + ip, err := fetchInterfaceIP(route.InterfaceAlias) + require.NoError(t, err, "Failed to fetch interface IP") + + output := testRoute(t, tc.destination, tc.dialer) + if tc.expectedInterface == expectedExtInt { + verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) + } else { + verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + } + }) + } +} + +// fetchInterfaceIP fetches the IPv4 address of the specified interface. +func fetchInterfaceIP(interfaceAlias string) (string, error) { + script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) + out, err := exec.Command("powershell", "-Command", script).Output() + if err != nil { + return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) + } + + ip := strings.TrimSpace(string(out)) + return ip, nil +} + +func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "udp", destination) + require.NoError(t, err, "Failed to dial destination") + defer func() { + err := conn.Close() + assert.NoError(t, err, "Failed to close connection") + }() + + host, _, err := net.SplitHostPort(destination) + require.NoError(t, err) + + script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) + + out, err := exec.Command("powershell", "-Command", script).Output() + require.NoError(t, err, "Failed to execute Find-NetRoute") + + var outputs []FindNetRouteOutput + err = json.Unmarshal(out, &outputs) + require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") + + require.Greater(t, len(outputs), 0, "No route found for destination") + combinedOutput := combineOutputs(outputs) + + return combinedOutput +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + t.Helper() + + ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) + require.NoError(t, err) + subnetMaskSize, _ := ipNet.Mask.Size() + script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to assign IP address to loopback adapter") + + // Wait for the IP address to be applied + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = waitForIPAddress(ctx, interfaceName, ip.String()) + require.NoError(t, err, "IP address not applied within timeout") + + t.Cleanup(func() { + script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to remove IP address from loopback adapter") + }) + + return interfaceName +} + +func fetchOriginalGateway() (*RouteInfo, error) { + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) + } + + var routeInfo RouteInfo + err = json.Unmarshal(output, &routeInfo) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON output: %w", err) + } + + return &routeInfo, nil +} + +func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { + t.Helper() + + assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") + assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") + assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") + assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") +} + +func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() + if err != nil { + return err + } + + ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") + for _, ip := range ipAddresses { + if strings.TrimSpace(ip) == expectedIPAddress { + return nil + } + } + } + } +} + +func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { + var combined FindNetRouteOutput + + for _, output := range outputs { + if output.IPAddress != "" { + combined.IPAddress = output.IPAddress + } + if output.InterfaceIndex != 0 { + combined.InterfaceIndex = output.InterfaceIndex + } + if output.InterfaceAlias != "" { + combined.InterfaceAlias = output.InterfaceAlias + } + if output.AddressFamily != 0 { + combined.AddressFamily = output.AddressFamily + } + if output.NextHop != "" { + combined.NextHop = output.NextHop + } + if output.DestinationPrefix != "" { + combined.DestinationPrefix = output.DestinationPrefix + } + } + + return &combined +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") +} diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go new file mode 100644 index 00000000000..28a6502d2ef --- /dev/null +++ b/client/internal/routemanager/sytemops_test.go @@ -0,0 +1,101 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/iface" +) + +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet(nil) + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/portlookup.go index 6ede4b83f1d..6f3d33487ea 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/portlookup.go @@ -1,10 +1,8 @@ package wgproxy import ( - "context" "fmt" - - nbnet "github.com/netbirdio/netbird/util/net" + "net" ) const ( @@ -25,7 +23,7 @@ func (pl portLookup) searchFreePort() (int, error) { } func (pl portLookup) tryToBind(port int) error { - l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port)) + l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) if err != nil { return err } diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index b91cd7b439d..2235c5d2bdf 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/ebpf" @@ -29,7 +30,7 @@ type WGEBPFProxy struct { turnConnMutex sync.Mutex rawConn net.PacketConn - conn *net.UDPConn + conn transport.UDPConn } // NewWGEBPFProxy create new WGEBPFProxy instance @@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - p.conn, err = nbnet.ListenUDP("udp", &addr) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -75,6 +76,7 @@ func (p *WGEBPFProxy) Listen() error { } return err } + p.conn = conn go p.proxyToRemote() log.Infof("local wg proxy listening on: %d", wgPorxyPort) diff --git a/go.mod b/go.mod index 5566f85599b..29a1570c896 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 - github.com/libp2p/go-netroute v0.2.0 + github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.5 github.com/mattn/go-sqlite3 v1.14.19 github.com/mdlayher/socket v0.4.1 diff --git a/go.sum b/go.sum index 6da405341d5..b488a42a42a 100644 --- a/go.sum +++ b/go.sum @@ -345,8 +345,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= -github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= +github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU= +github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= @@ -659,7 +659,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -746,7 +745,6 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/util/grpc/dialer_linux.go b/util/grpc/dialer.go similarity index 56% rename from util/grpc/dialer_linux.go rename to util/grpc/dialer.go index b29ee4b2936..96b2bc32be0 100644 --- a/util/grpc/dialer_linux.go +++ b/util/grpc/dialer.go @@ -1,11 +1,10 @@ -//go:build !android - package grpc import ( "context" "net" + log "github.com/sirupsen/logrus" "google.golang.org/grpc" nbnet "github.com/netbirdio/netbird/util/net" @@ -13,6 +12,11 @@ import ( func WithCustomDialer() grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return nbnet.NewDialer().DialContext(ctx, "tcp", addr) + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, err + } + return conn, nil }) } diff --git a/util/grpc/dialer_generic.go b/util/grpc/dialer_generic.go deleted file mode 100644 index 1c2285b14bf..00000000000 --- a/util/grpc/dialer_generic.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !linux || android - -package grpc - -import "google.golang.org/grpc" - -func WithCustomDialer() grpc.DialOption { - return grpc.EmptyDialOption{} -} diff --git a/util/net/dialer.go b/util/net/dialer.go new file mode 100644 index 00000000000..7b9bddbb52a --- /dev/null +++ b/util/net/dialer.go @@ -0,0 +1,64 @@ +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +// Dialer extends the standard net.Dialer with the ability to execute hooks before +// and after connections. This can be used to bypass the VPN for connections using this dialer. +type Dialer struct { + *net.Dialer +} + +// NewDialer returns a customized net.Dialer with overridden Control method +func NewDialer() *Dialer { + dialer := &Dialer{ + Dialer: &net.Dialer{}, + } + dialer.init() + + return dialer +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to closeConn connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type") + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type") + } + + return tcpConn, nil +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index a3c3ad67c74..2e102da50f8 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -1,19 +1,123 @@ -//go:build !linux || android +//go:build !android && !ios package net import ( + "context" + "fmt" "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" +) + +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc ) -func NewDialer() *net.Dialer { - return &net.Dialer{} +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) } -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - return net.DialUDP(network, laddr, raddr) +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) +} + +// RemoveDialerHook removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil } -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - return net.DialTCP(network, laddr, raddr) +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialerDialHooks != nil { + if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} + +func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() } diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go index d559490c517..aed5c59a322 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_linux.go @@ -2,59 +2,11 @@ package net -import ( - "context" - "fmt" - "net" - "syscall" +import "syscall" - log "github.com/sirupsen/logrus" -) - -func NewDialer() *net.Dialer { - return &net.Dialer{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, - } -} - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type") +// init configures the net.Dialer Control function to set the fwmark on the socket +func (d *Dialer) init() { + d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type") - } - - return tcpConn, nil } diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go new file mode 100644 index 00000000000..3254e6d066b --- /dev/null +++ b/util/net/dialer_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (d *Dialer) init() { +} diff --git a/util/net/listener.go b/util/net/listener.go new file mode 100644 index 00000000000..f4d769f587e --- /dev/null +++ b/util/net/listener.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before +// responding via the socket and after closing. This can be used to bypass the VPN for listeners. +type ListenerConfig struct { + *net.ListenConfig +} + +// NewListener creates a new ListenerConfig instance. +func NewListener() *ListenerConfig { + listener := &ListenerConfig{ + ListenConfig: &net.ListenConfig{}, + } + listener.init() + + return listener +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index 241c744e528..ae412415ff9 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -1,13 +1,154 @@ -//go:build !linux || android +//go:build !android && !ios package net -import "net" +import ( + "context" + "fmt" + "net" + "sync" -func NewListener() *net.ListenConfig { - return &net.ListenConfig{} + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) +} + +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +// RemoveListenerHooks removes all dialer hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil +} + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.UDPConn) +} + +func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + return + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + return + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(id, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } +} + +func closeConn(id ConnectionID, conn net.PacketConn) error { + err := conn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(id, conn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + return err } -func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) { - return net.ListenUDP(network, locAddr) +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { + udpConn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + connID := GenerateConnID() + return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil } diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go index 7b9bda97c7d..8d332160a04 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_linux.go @@ -3,28 +3,12 @@ package net import ( - "context" - "fmt" - "net" "syscall" ) -func NewListener() *net.ListenConfig { - return &net.ListenConfig{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } } - -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err) - } - udpConn, ok := pc.(*net.UDPConn) - if !ok { - return nil, fmt.Errorf("packetConn is not a *net.UDPConn") - } - return udpConn, nil -} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go new file mode 100644 index 00000000000..0dbbb360b53 --- /dev/null +++ b/util/net/listener_mobile.go @@ -0,0 +1,11 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, laddr) +} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go new file mode 100644 index 00000000000..fb6eadaaad8 --- /dev/null +++ b/util/net/listener_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (l *ListenerConfig) init() { +} diff --git a/util/net/net.go b/util/net/net.go index 5714e52294e..9ea7ae80340 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,6 +1,17 @@ package net +import "github.com/google/uuid" + const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard NetbirdFwmark = 0x1BD00 ) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} From bb0d5c5bafebe63653018387c8cdfd712ec65ba0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 18:04:22 +0200 Subject: [PATCH 28/89] Linux legacy routing (#1774) * Add Linux legacy routing if ip rule functionality is not available * Ignore exclusion route errors if host has no route * Exclude iOS from route manager * Also retrieve IPv6 routes * Ignore loopback addresses not being in the main table * Ignore "not supported" errors on cleanup * Fix regression in ListenUDP not using fwmarks --- client/internal/routemanager/routemanager.go | 6 +- client/internal/routemanager/systemops.go | 410 ++++++++++++++++++ .../internal/routemanager/systemops_linux.go | 138 ++++-- .../routemanager/systemops_linux_test.go | 2 - .../routemanager/systemops_nonlinux.go | 397 +---------------- ...ops_nonlinux_test.go => systemops_test.go} | 152 +++++-- client/internal/routemanager/sytemops_test.go | 101 ----- util/net/dialer.go | 2 +- util/net/listener_generic.go | 15 +- 9 files changed, 655 insertions(+), 568 deletions(-) create mode 100644 client/internal/routemanager/systemops.go rename client/internal/routemanager/{systemops_nonlinux_test.go => systemops_test.go} (70%) delete mode 100644 client/internal/routemanager/sytemops_test.go diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index fe8d7b4ef19..39b55f1052b 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -1,8 +1,9 @@ -//go:build !android +//go:build !android && !ios package routemanager import ( + "errors" "fmt" "net/netip" "sync" @@ -53,6 +54,9 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref if ref.count == 0 { log.Debugf("Adding route for prefix %s", prefix) nexthop, intf, err := rm.addRoute(prefix) + if errors.Is(err, errRouteNotFound) { + return nil + } if err != nil { return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 00000000000..c6f3376e032 --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,410 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var errRouteNotFound = fmt.Errorf("route not found") + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Warnf("Failed to get route for %s: %v", ip, err) + return netip.Addr{}, nil, errRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, errRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. +// If the next hop or interface is pointing to the VPN interface, it will return an error +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(): + return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) + case addr.IsLinkLocalUnicast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) + case addr.IsLinkLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) + case addr.IsInterfaceLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) + case addr.IsUnspecified(): + return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) + case addr.IsMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func genericAddVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil && !errors.Is(err, errRouteNotFound) { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil && !errors.Is(err, errRouteNotFound) { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index d21a3bfbfea..44691f0d65b 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -35,6 +35,9 @@ const ( var ErrTableIDExists = errors.New("ID exists with different name") +var routeManager = &RouteManager{} +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" + type ruleParams struct { fwmark int tableID int @@ -66,7 +69,12 @@ func getSetupRules() []ruleParams { // enabling VPN connectivity. // // The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { + if isLegacy { + log.Infof("Using legacy routing setup") + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } @@ -82,6 +90,11 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { + if errors.Is(err, syscall.EOPNOTSUPP) { + log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") + isLegacy = true + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } @@ -93,6 +106,10 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. func cleanupRouting() error { + if isLegacy { + return cleanupRoutingWithRouteManager(routeManager) + } + var result *multierror.Error if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { @@ -104,7 +121,7 @@ func cleanupRouting() error { rules := getSetupRules() for _, rule := range rules { - if err := removeAllRules(rule); err != nil { + if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) } } @@ -112,49 +129,104 @@ func cleanupRouting() error { return result.ErrorOrNil() } +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + func addVPNRoute(prefix netip.Prefix, intf string) error { - // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 + if isLegacy { + return genericAddVPNRoute(prefix, intf) + } + + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("add blackhole: %w", err) } } - if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("add route: %w", err) } return nil } func removeVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericRemoveVPNRoute(prefix, intf) + } + // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove unreachable route: %w", err) } } - if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove route: %w", err) } return nil } +func getRoutesFromTable() ([]netip.Prefix, error) { + v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) + if err != nil { + return nil, fmt.Errorf("get v4 routes: %w", err) + } + v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) + if err != nil { + return nil, fmt.Errorf("get v6 routes: %w", err) + + } + return append(v4Routes, v6Routes...), nil +} + +// getRoutes fetches routes from a specific routing table identified by tableID. +func getRoutes(tableID, family int) ([]netip.Prefix, error) { + var prefixList []netip.Prefix + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) + } + + for _, route := range routes { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) + } + + ones, _ := route.Dst.Mask.Size() + + prefix := netip.PrefixFrom(addr, ones) + if prefix.IsValid() { + prefixList = append(prefixList, prefix) + } + } + } + + return prefixList, nil +} + // addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { +func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, - Family: family, + Family: getAddressFamily(prefix), } - if prefix != nil { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - route.Dst = ipNet + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) } + route.Dst = ipNet if err := addNextHop(addr, intf, route); err != nil { return fmt.Errorf("add gateway and device: %w", err) @@ -170,7 +242,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err // addUnreachableRoute adds an unreachable route for the specified IP family and routing table. // ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. // tableID specifies the routing table to which the unreachable route will be added. -func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { +func addUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -179,7 +251,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { route := &netlink.Route{ Type: syscall.RTN_UNREACHABLE, Table: tableID, - Family: ipFamily, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -190,7 +262,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { return nil } -func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { +func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -199,7 +271,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { route := &netlink.Route{ Type: syscall.RTN_UNREACHABLE, Table: tableID, - Family: ipFamily, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -212,7 +284,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { } // removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -221,7 +293,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, - Family: family, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -392,23 +464,25 @@ func removeAllRules(params ruleParams) error { } // addNextHop adds the gateway and device to the route. -func addNextHop(addr *string, intf *string, route *netlink.Route) error { - if addr != nil { - ip := net.ParseIP(*addr) - if ip == nil { - return fmt.Errorf("parsing address %s failed", *addr) - } - - route.Gw = ip +func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { + if addr.IsValid() { + route.Gw = addr.AsSlice() } - if intf != nil { - link, err := netlink.LinkByName(*intf) + if intf != "" { + link, err := netlink.LinkByName(intf) if err != nil { - return fmt.Errorf("set interface %s: %w", *intf, err) + return fmt.Errorf("set interface %s: %w", intf, err) } route.LinkIndex = link.Attrs().Index } return nil } + +func getAddressFamily(prefix netip.Prefix) int { + if prefix.Addr().Is4() { + return netlink.FAMILY_V4 + } + return netlink.FAMILY_V6 +} diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 50a02401a68..d77c7cc7dcf 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -21,8 +21,6 @@ var expectedLoopbackInt = "lo" var expectedExternalInt = "dummyext0" var expectedInternalInt = "dummyint0" -var errRouteNotFound = fmt.Errorf("route not found") - func init() { testCases = append(testCases, []testCase{ { diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 4bc186f215e..38026107ec7 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -3,414 +3,21 @@ package routemanager import ( - "context" - "errors" - "fmt" - "net" "net/netip" "runtime" - "github.com/hashicorp/go-multierror" - "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - -var errRouteNotFound = fmt.Errorf("route not found") - func enableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } -// TODO: fix: for default our wg address now appears as the default gw -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(defaultGateway) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) - if defaultGateway.Is6() { - gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - var exitIntf string - gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - if intf != nil { - exitIntf = intf.Name - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) -} - -func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Errorf("Getting routes returned an error: %v", err) - return netip.Addr{}, nil, errRouteNotFound - } - - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - if preferredSrc == nil { - return netip.Addr{}, nil, errRouteNotFound - } - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) - } - return addr.Unmap(), intf, nil - } - - addr, ok := netip.AddrFromSlice(gateway) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) - } - - return addr.Unmap(), intf, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. -// If the next hop or interface is pointing to the VPN interface, it will return an error -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(): - return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) - case addr.IsLinkLocalUnicast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) - case addr.IsLinkLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) - case addr.IsInterfaceLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) - case addr.IsUnspecified(): - return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) - case addr.IsMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) - } - - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) - if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err) - } - - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } - - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") - } - - // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) - exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } - } - - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) - } - - return exitNextHop, exitIntf, nil -} - -// addVPNRoute adds a new route to the vpn interface, it splits the default prefix -// in two /1 prefixes to avoid replacing the existing default route func addVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - return err - } - if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } - - // TODO: remove once IPv6 is supported on the interface - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } else if prefix == defaultv6 { - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } - - return addNonExistingRoute(prefix, intf) + return genericAddVPNRoute(prefix, intf) } -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, netip.Addr{}, intf) -} - -// removeVPNRoute removes the route from the vpn interface. If a default prefix is given, -// it will remove the split /1 prefixes func removeVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - // TODO: remove once IPv6 is supported on the interface - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } else if prefix == defaultv6 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } - - return removeFromRouteTable(prefix, netip.Addr{}, intf) -} - -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil { - log.Errorf("Unable to get initial v4 default next hop: %v", err) - } - initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil { - log.Errorf("Unable to get initial v6 default next hop: %v", err) - } - - *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { - addr := prefix.Addr() - nexthop, intf := initialNextHopV4, initialIntfV4 - if addr.Is6() { - nexthop, intf = initialNextHopV6, initialIntfV6 - } - return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) - }, - removeFromRouteTable, - ) - - return setupHooks(*routeManager, initAddresses) -} - -func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil -} - -func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil + return genericRemoveVPNRoute(prefix, intf) } diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_test.go similarity index 70% rename from client/internal/routemanager/systemops_nonlinux_test.go rename to client/internal/routemanager/systemops_test.go index adb83bac6d8..97386f19a1a 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -1,13 +1,15 @@ -//go:build !linux && !ios +//go:build !android && !ios package routemanager import ( "bytes" + "context" "fmt" "net" "net/netip" "os" + "runtime" "strings" "testing" @@ -20,16 +22,9 @@ import ( "github.com/netbirdio/netbird/iface" ) -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - - prefixGateway, _, err := getNextHop(prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) } func TestAddRemoveRoutes(t *testing.T) { @@ -72,8 +67,8 @@ func TestAddRemoveRoutes(t *testing.T) { assert.NoError(t, cleanupRouting()) }) - err = addVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.shouldRouteToWireguard { assertWGOutInterface(t, testCase.prefix, wgInterface, false) @@ -83,8 +78,8 @@ func TestAddRemoveRoutes(t *testing.T) { exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "removeVPNRoute should not return err") + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericRemoveVPNRoute should not return err") prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) require.NoError(t, err, "getNextHop should not return err") @@ -144,7 +139,7 @@ func TestGetNextHop(t *testing.T) { } } -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { +func TestAddExistAndRemoveRoute(t *testing.T) { defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { @@ -205,20 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addVPNRoute(testCase.prefix, wgInterface.Name()) + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -228,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -284,12 +273,6 @@ func TestIsSubRange(t *testing.T) { } func TestExistsInRouteTable(t *testing.T) { - _, _, err := setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - addresses, err := net.InterfaceAddrs() if err != nil { t.Fatal("shouldn't return error when fetching interface addresses: ", err) @@ -298,10 +281,19 @@ func TestExistsInRouteTable(t *testing.T) { var addressPrefixes []netip.Prefix for _, address := range addresses { p := netip.MustParsePrefix(address.String()) + if p.Addr().Is6() { + continue + } // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() { - addressPrefixes = append(addressPrefixes, p.Masked()) + if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { + continue + } + // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence + if runtime.GOOS == "linux" && p.Addr().IsLoopback() { + continue } + + addressPrefixes = append(addressPrefixes, p.Masked()) } for _, prefix := range addressPrefixes { @@ -314,3 +306,97 @@ func TestExistsInRouteTable(t *testing.T) { } } } + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet() + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} + +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { + return + } + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go deleted file mode 100644 index 28a6502d2ef..00000000000 --- a/client/internal/routemanager/sytemops_test.go +++ /dev/null @@ -1,101 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "context" - "net" - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" -) - -type dialer interface { - Dial(network, address string) (net.Conn, error) - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet(nil) - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) { - t.Helper() - - setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - _, _, err := setupRouting(nil, wgIface) - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) -} diff --git a/util/net/dialer.go b/util/net/dialer.go index 7b9bddbb52a..d3adef363a0 100644 --- a/util/net/dialer.go +++ b/util/net/dialer.go @@ -35,7 +35,7 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { udpConn, ok := conn.(*net.UDPConn) if !ok { if err := conn.Close(); err != nil { - log.Errorf("Failed to closeConn connection: %v", err) + log.Errorf("Failed to close connection: %v", err) } return nil, fmt.Errorf("expected UDP connection, got different type") } diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index ae412415ff9..a195bdeb917 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -145,10 +145,19 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { // ListenUDP listens on the network address and returns a transport.UDPConn // which includes support for write and close hooks. func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { - udpConn, err := net.ListenUDP(network, laddr) + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) if err != nil { return nil, fmt.Errorf("listen UDP: %w", err) } - connID := GenerateConnID() - return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type") + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil } From 25f5f26527d777c1064d2825209da5183ea3dcae Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 18:57:50 +0200 Subject: [PATCH 29/89] Timeout rule removing loop and catch IPv6 unsupported error in loop (#1791) --- .../internal/routemanager/systemops_linux.go | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 44691f0d65b..ef464372737 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -4,12 +4,14 @@ package routemanager import ( "bufio" + "context" "errors" "fmt" "net" "net/netip" "os" "syscall" + "time" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -444,7 +446,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil { return fmt.Errorf("remove routing rule: %w", err) } @@ -452,15 +454,33 @@ func removeRule(params ruleParams) error { } func removeAllRules(params ruleParams) error { - for { - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) { - break + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + for { + if ctx.Err() != nil { + done <- ctx.Err() + return + } + if err := removeRule(params); err != nil { + if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { + done <- nil + return + } + done <- err + return } - return err } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return err } - return nil } // addNextHop adds the gateway and device to the route. From 3d2a2377c685d397c876587f9a55e04057f341db Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 3 Apr 2024 19:06:04 +0200 Subject: [PATCH 30/89] Don't return errors on disallowed routes (#1792) --- client/internal/routemanager/routemanager.go | 5 ++- client/internal/routemanager/systemops.go | 39 +++++++++---------- .../routemanager/systemops_linux_test.go | 4 +- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index 39b55f1052b..8f9ff9f4bd0 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -54,9 +54,12 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref if ref.count == 0 { log.Debugf("Adding route for prefix %s", prefix) nexthop, intf, err := rm.addRoute(prefix) - if errors.Is(err, errRouteNotFound) { + if errors.Is(err, ErrRouteNotFound) { return nil } + if errors.Is(err, ErrRouteNotAllowed) { + log.Debugf("Adding route for prefix %s: %s", prefix, err) + } if err != nil { return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index c6f3376e032..a91f53636da 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -23,7 +23,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) -var errRouteNotFound = fmt.Errorf("route not found") +var ErrRouteNotFound = errors.New("route not found") +var ErrRouteNotAllowed = errors.New("route not allowed") // TODO: fix: for default our wg address now appears as the default gw func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { @@ -33,7 +34,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { } defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("get existing route gateway: %s", err) } @@ -59,7 +60,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { var exitIntf string gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } if intf != nil { @@ -78,13 +79,13 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) if err != nil { log.Warnf("Failed to get route for %s: %v", ip, err) - return netip.Addr{}, nil, errRouteNotFound + return netip.Addr{}, nil, ErrRouteNotFound } log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { if preferredSrc == nil { - return netip.Addr{}, nil, errRouteNotFound + return netip.Addr{}, nil, ErrRouteNotFound } log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) @@ -129,8 +130,8 @@ func isSubRange(prefix netip.Prefix) (bool, error) { return false, nil } -// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. -// If the next hop or interface is pointing to the VPN interface, it will return an error +// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. +// If the next hop or interface is pointing to the VPN interface, it will return the initial values. func addRouteToNonVPNIntf( prefix netip.Prefix, vpnIntf *iface.WGIface, @@ -139,18 +140,14 @@ func addRouteToNonVPNIntf( ) (netip.Addr, string, error) { addr := prefix.Addr() switch { - case addr.IsLoopback(): - return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) - case addr.IsLinkLocalUnicast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) - case addr.IsLinkLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) - case addr.IsInterfaceLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) - case addr.IsUnspecified(): - return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) - case addr.IsMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + case addr.IsLoopback(), + addr.IsLinkLocalUnicast(), + addr.IsLinkLocalMulticast(), + addr.IsInterfaceLocalMulticast(), + addr.IsUnspecified(), + addr.IsMulticast(): + + return netip.Addr{}, "", ErrRouteNotAllowed } // Determine the exit interface and next hop for the prefix, so we can add a specific route @@ -316,11 +313,11 @@ func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) } initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { log.Errorf("Unable to get initial v6 default next hop: %v", err) } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index d77c7cc7dcf..0043c3f4e94 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -138,7 +138,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { if dstIPNet.String() == "0.0.0.0/0" { var err error originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil && !errors.Is(err, errRouteNotFound) { + if err != nil && !errors.Is(err, ErrRouteNotFound) { t.Logf("Failed to fetch original gateway: %v", err) } @@ -193,7 +193,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } } - return nil, 0, errRouteNotFound + return nil, 0, ErrRouteNotFound } func setupDummyInterfacesAndRoutes(t *testing.T) { From 3461b1bb90e71152084b57e4feec1583a3cf60b9 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 5 Apr 2024 00:10:32 +0200 Subject: [PATCH 31/89] Expect correct conn type (#1801) --- util/net/dialer.go | 43 ------------------------------------ util/net/dialer_generic.go | 40 +++++++++++++++++++++++++++++++++ util/net/dialer_mobile.go | 15 +++++++++++++ util/net/listener_generic.go | 2 +- 4 files changed, 56 insertions(+), 44 deletions(-) create mode 100644 util/net/dialer_mobile.go diff --git a/util/net/dialer.go b/util/net/dialer.go index d3adef363a0..0786c667e53 100644 --- a/util/net/dialer.go +++ b/util/net/dialer.go @@ -1,10 +1,7 @@ package net import ( - "fmt" "net" - - log "github.com/sirupsen/logrus" ) // Dialer extends the standard net.Dialer with the ability to execute hooks before @@ -22,43 +19,3 @@ func NewDialer() *Dialer { return dialer } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type") - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type") - } - - return tcpConn, nil -} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 2e102da50f8..06fac3bbf85 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -121,3 +121,43 @@ func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, return result.ErrorOrNil() } + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_mobile.go b/util/net/dialer_mobile.go new file mode 100644 index 00000000000..b95aaa973e9 --- /dev/null +++ b/util/net/dialer_mobile.go @@ -0,0 +1,15 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + return net.DialUDP(network, laddr, raddr) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + return net.DialTCP(network, laddr, raddr) +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index a195bdeb917..451279e9d25 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -156,7 +156,7 @@ func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { if err := packetConn.Close(); err != nil { log.Errorf("Failed to close connection: %v", err) } - return nil, fmt.Errorf("expected UDPConn, got different type") + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) } return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil From 1d1d057e7d01f8e3e34bb509738a7dbb5f21a62e Mon Sep 17 00:00:00 2001 From: trax <38944599+4nx@users.noreply.github.com> Date: Fri, 5 Apr 2024 13:51:28 +0200 Subject: [PATCH 32/89] Change the dashboard image pull from wiretrustee to netbirdio (#1804) --- infrastructure_files/docker-compose.yml.tmpl.traefik | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index fd194a042bf..d3ae6529a23 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -2,7 +2,7 @@ version: "3" services: #UI dashboard dashboard: - image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG + image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG restart: unless-stopped #ports: # - 80:80 From 9f32ccd4533d5301bcb901677af9816cb3408f92 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 5 Apr 2024 20:38:49 +0200 Subject: [PATCH 33/89] Rollback new routing functionality (#1805) --- .github/workflows/golang-test-darwin.yml | 3 - .github/workflows/golang-test-linux.yml | 10 +- .github/workflows/golang-test-windows.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- client/internal/engine.go | 16 - client/internal/peer/conn.go | 32 -- client/internal/relay/relay.go | 7 +- client/internal/routemanager/client.go | 63 +-- client/internal/routemanager/manager.go | 79 +-- client/internal/routemanager/manager_test.go | 30 +- client/internal/routemanager/mock.go | 5 - client/internal/routemanager/routemanager.go | 126 ----- .../routemanager/server_nonandroid.go | 57 +- client/internal/routemanager/systemops.go | 407 -------------- .../routemanager/systemops_android.go | 24 +- client/internal/routemanager/systemops_bsd.go | 1 + .../internal/routemanager/systemops_darwin.go | 61 --- .../routemanager/systemops_darwin_test.go | 100 ---- client/internal/routemanager/systemops_ios.go | 26 +- .../internal/routemanager/systemops_linux.go | 511 +++--------------- .../routemanager/systemops_linux_test.go | 207 ------- .../routemanager/systemops_nonandroid.go | 120 ++++ ...s_test.go => systemops_nonandroid_test.go} | 212 ++------ .../routemanager/systemops_nonlinux.go | 32 +- .../routemanager/systemops_unix_test.go | 234 -------- .../routemanager/systemops_windows.go | 88 +-- .../routemanager/systemops_windows_test.go | 289 ---------- client/internal/stdnet/dialer.go | 24 - client/internal/stdnet/listener.go | 20 - client/internal/wgproxy/proxy_ebpf.go | 9 +- client/internal/wgproxy/proxy_userspace.go | 4 +- go.mod | 2 +- iface/wg_configurer_kernel.go | 4 +- iface/wg_configurer_usp.go | 11 +- management/client/grpc.go | 2 - sharedsock/sock_linux.go | 10 - signal/client/grpc.go | 2 - util/grpc/dialer.go | 22 - util/net/dialer.go | 21 - util/net/dialer_generic.go | 163 ------ util/net/dialer_linux.go | 12 - util/net/dialer_nonlinux.go | 6 - util/net/listener.go | 21 - util/net/listener_generic.go | 163 ------ util/net/listener_linux.go | 14 - util/net/listener_mobile.go | 11 - util/net/listener_nonlinux.go | 6 - util/net/net.go | 17 - util/net/net_linux.go | 35 -- 49 files changed, 354 insertions(+), 2969 deletions(-) delete mode 100644 client/internal/routemanager/routemanager.go delete mode 100644 client/internal/routemanager/systemops.go delete mode 100644 client/internal/routemanager/systemops_darwin.go delete mode 100644 client/internal/routemanager/systemops_darwin_test.go delete mode 100644 client/internal/routemanager/systemops_linux_test.go create mode 100644 client/internal/routemanager/systemops_nonandroid.go rename client/internal/routemanager/{systemops_test.go => systemops_nonandroid_test.go} (59%) delete mode 100644 client/internal/routemanager/systemops_unix_test.go delete mode 100644 client/internal/routemanager/systemops_windows_test.go delete mode 100644 client/internal/stdnet/dialer.go delete mode 100644 client/internal/stdnet/listener.go delete mode 100644 util/grpc/dialer.go delete mode 100644 util/net/dialer.go delete mode 100644 util/net/dialer_generic.go delete mode 100644 util/net/dialer_linux.go delete mode 100644 util/net/dialer_nonlinux.go delete mode 100644 util/net/listener.go delete mode 100644 util/net/listener_generic.go delete mode 100644 util/net/listener_linux.go delete mode 100644 util/net/listener_mobile.go delete mode 100644 util/net/listener_nonlinux.go delete mode 100644 util/net/net.go delete mode 100644 util/net/net_linux.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index d7007c86080..f8afd3d6eab 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,9 +32,6 @@ jobs: restore-keys: | macos-go- - - name: Install libpcap - run: brew install libpcap - - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 42f740e9b54..74e6d1203ab 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -36,11 +36,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib - name: Install modules run: go mod tidy @@ -71,7 +67,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib - name: Install modules run: go mod tidy @@ -86,7 +82,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... + run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2d63acbcd5a..6027d36269f 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 13228250d59..9f543c74c45 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -40,7 +40,7 @@ jobs: cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index d6238c4b3ca..13ef8ce1563 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -94,9 +94,6 @@ type Engine struct { // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn - beforePeerHook peer.BeforeAddPeerHookFunc - afterPeerHook peer.AfterRemovePeerHookFunc - // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -264,14 +261,6 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() - if err != nil { - log.Errorf("Failed to initialize route manager: %s", err) - } else { - e.beforePeerHook = beforePeerHook - e.afterPeerHook = afterPeerHook - } - e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -821,11 +810,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { } e.peerConns[peerKey] = conn - if e.beforePeerHook != nil && e.afterPeerHook != nil { - conn.AddBeforeAddPeerHook(e.beforePeerHook) - conn.AddAfterRemovePeerHook(e.afterPeerHook) - } - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index f3d07dcad1f..17ef7e87fd2 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,7 +20,6 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" - nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -101,9 +100,6 @@ type IceCredentials struct { Pwd string } -type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error -type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error - type Conn struct { config ConnConfig mu sync.Mutex @@ -142,10 +138,6 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn - - connID nbnet.ConnectionID - beforeAddPeerHooks []BeforeAddPeerHookFunc - afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -401,14 +393,6 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } -func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { - conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) -} - -func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { - conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) -} - // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -437,13 +421,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem conn.remoteEndpoint = endpointUdpAddr log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) - conn.connID = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { - log.Errorf("Before add peer hook failed: %v", err) - } - } - err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { if conn.wgProxy != nil { @@ -534,15 +511,6 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if conn.connID != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connID); err != nil { - log.Errorf("After remove peer hook failed: %v", err) - } - } - } - conn.connID = "" - if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 84fd72e49c9..ad3b94f2a5f 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,7 +12,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request @@ -96,13 +95,15 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) switch uri.Proto { case stun.ProtoTypeUDP: var err error - conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "") + listener := &net.ListenConfig{} + conn, err = listener.ListenPacket(ctx, "udp", "") if err != nil { probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: - tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr) + dialer := &net.Dialer{} + tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 38cf4bf6550..f7ead582720 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -41,7 +41,6 @@ type clientNetwork struct { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { ctx, cancel := context.WithCancel(ctx) - client := &clientNetwork{ ctx: ctx, stop: cancel, @@ -73,18 +72,6 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { return routePeerStatuses } -// getBestRouteFromStatuses determines the most optimal route from the available routes -// within a clientNetwork, taking into account peer connection status, route metrics, and -// preference for non-relayed and direct connections. -// -// It follows these prioritization rules: -// * Connected peers: Only routes with connected peers are considered. -// * Metric: Routes with lower metrics (better) are prioritized. -// * Non-relayed: Routes without relays are preferred. -// * Direct connections: Routes with direct peer connections are favored. -// * Stability: In case of equal scores, the currently active route (if any) is maintained. -// -// It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" chosenScore := 0 @@ -171,7 +158,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) if err != nil { - return fmt.Errorf("get peer state: %v", err) + return err } delete(state.Routes, c.network.String()) @@ -185,7 +172,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) if err != nil { - return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", + return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } return nil @@ -193,26 +180,30 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { - return fmt.Errorf("remove route %s from system, err: %v", c.network, err) + err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err } - - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route: %v", err) + err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) + if err != nil { + return fmt.Errorf("couldn't remove route %s from system, err: %v", + c.network, err) } } return nil } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { + + var err error + routerPeerStatuses := c.getRouterPeerStatuses() chosen := c.getBestRouteFromStatuses(routerPeerStatuses) - - // If no route is chosen, remove the route from the peer and system if chosen == "" { - if err := c.removeRouteFromPeerAndSystem(); err != nil { - return fmt.Errorf("remove route from peer and system: %v", err) + err = c.removeRouteFromPeerAndSystem() + if err != nil { + return err } c.chosenRoute = nil @@ -220,7 +211,6 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } - // If the chosen route is the same as the current route, do nothing if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute.IsEqual(c.routes[chosen]) { return nil @@ -228,13 +218,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - // If a previous route exists, remove it from the peer - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route from peer: %v", err) + err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err } } else { - // otherwise add the route to the system - if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { + err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) + if err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -255,7 +245,8 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } - if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { + err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) + if err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } @@ -296,21 +287,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { - log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) + log.Error(err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system: %v", err) + log.Error(err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("Received a routes update with smaller serial number, ignoring it") + log.Warnf("received a routes update with smaller serial number, ignoring it") continue } - log.Debugf("Received a new client network route update for %s", c.network) + log.Debugf("received a new client network route update for %s", c.network) c.handleUpdate(update) @@ -318,7 +309,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) + log.Error(err) } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 36a37f02c50..b624d8c34ce 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,10 +2,6 @@ package routemanager import ( "context" - "fmt" - "net" - "net/netip" - "net/url" "runtime" "sync" @@ -19,14 +15,8 @@ import ( "github.com/netbirdio/netbird/version" ) -var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - -// nolint:unused -var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - // Manager is a route manager interface type Manager interface { - Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -66,24 +56,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } -// Init sets up the routing -func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - if err := cleanupRouting(); err != nil { - log.Warnf("Failed cleaning up routing: %v", err) - } - - mgmtAddress := m.statusRecorder.GetManagementState().URL - signalAddress := m.statusRecorder.GetSignalState().URL - ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) - - beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) - if err != nil { - return nil, nil, fmt.Errorf("setup routing: %w", err) - } - log.Info("Routing setup complete") - return beforePeerHook, afterPeerHook, nil -} - func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -99,15 +71,9 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } - if err := cleanupRouting(); err != nil { - log.Errorf("Error cleaning up routing: %v", err) - } else { - log.Info("Routing cleanup complete") - } - m.ctx = nil } -// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps +// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -125,7 +91,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return fmt.Errorf("update routes: %w", err) + return err } } @@ -190,7 +156,11 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] for _, newRoute := range newRoutes { networkID := route.GetHAUniqueID(newRoute) if !ownNetworkIDs[networkID] { - if !isPrefixSupported(newRoute.Network) { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < minRangeBits { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", + version.NetbirdVersion(), newRoute.Network) continue } newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) @@ -208,38 +178,3 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } - -func isPrefixSupported(prefix netip.Prefix) bool { - switch runtime.GOOS { - case "linux", "windows", "darwin": - return true - } - - // If prefix is too small, lets assume it is a possible default prefix which is not yet supported - // we skip this prefix management - if prefix.Bits() <= minRangeBits { - log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", - version.NetbirdVersion(), prefix) - return false - } - return true -} - -// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. -func resolveURLsToIPs(urls []string) []net.IP { - var ips []net.IP - for _, rawurl := range urls { - u, err := url.Parse(rawurl) - if err != nil { - log.Errorf("Failed to parse url %s: %v", rawurl, err) - continue - } - ipAddrs, err := net.LookupIP(u.Hostname()) - if err != nil { - log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) - continue - } - ips = append(ips, ipAddrs...) - } - return ips -} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 03e77e09bcb..2e5cf6649d8 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,14 +28,13 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int - clientNetworkWatchersExpectedAllowed int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int }{ { name: "Should create 2 client networks", @@ -201,9 +200,8 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, - clientNetworkWatchersExpectedAllowed: 1, + inputSerial: 1, + clientNetworkWatchersExpected: 0, }, { name: "Remove 1 Client Route", @@ -417,10 +415,6 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) - - _, _, err = routeManager.Init() - - require.NoError(t, err, "should init route manager") defer routeManager.Stop() if testCase.removeSrvRouter { @@ -435,11 +429,7 @@ func TestManagerUpdateRoutes(t *testing.T) { err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") - expectedWatchers := testCase.clientNetworkWatchersExpected - if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { - expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed - } - require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") + require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index dd2c28e5927..a1214cbb9ec 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,7 +6,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -17,10 +16,6 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - // InitialRouteRange mock implementation of InitialRouteRange from Manager interface func (m *MockManager) InitialRouteRange() []string { return nil diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go deleted file mode 100644 index 8f9ff9f4bd0..00000000000 --- a/client/internal/routemanager/routemanager.go +++ /dev/null @@ -1,126 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "errors" - "fmt" - "net/netip" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -type ref struct { - count int - nexthop netip.Addr - intf string -} - -type RouteManager struct { - // refCountMap keeps track of the reference ref for prefixes - refCountMap map[netip.Prefix]ref - // prefixMap keeps track of the prefixes associated with a connection ID for removal - prefixMap map[nbnet.ConnectionID][]netip.Prefix - addRoute AddRouteFunc - removeRoute RemoveRouteFunc - mutex sync.Mutex -} - -type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) -type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error - -func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { - // TODO: read initial routing table into refCountMap - return &RouteManager{ - refCountMap: map[netip.Prefix]ref{}, - prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, - addRoute: addRoute, - removeRoute: removeRoute, - } -} - -func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - ref := rm.refCountMap[prefix] - log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) - - // Add route to the system, only if it's a new prefix - if ref.count == 0 { - log.Debugf("Adding route for prefix %s", prefix) - nexthop, intf, err := rm.addRoute(prefix) - if errors.Is(err, ErrRouteNotFound) { - return nil - } - if errors.Is(err, ErrRouteNotAllowed) { - log.Debugf("Adding route for prefix %s: %s", prefix, err) - } - if err != nil { - return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) - } - ref.nexthop = nexthop - ref.intf = intf - } - - ref.count++ - rm.refCountMap[prefix] = ref - rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) - - return nil -} - -func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - prefixes, ok := rm.prefixMap[connID] - if !ok { - log.Debugf("No prefixes found for connection ID %s", connID) - return nil - } - - var result *multierror.Error - for _, prefix := range prefixes { - ref := rm.refCountMap[prefix] - log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) - if ref.count == 1 { - log.Debugf("Removing route for prefix %s", prefix) - // TODO: don't fail if the route is not found - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - continue - } - delete(rm.refCountMap, prefix) - } else { - ref.count-- - rm.refCountMap[prefix] = ref - } - } - delete(rm.prefixMap, connID) - - return result.ErrorOrNil() -} - -// Flush removes all references and routes from the system -func (rm *RouteManager) Flush() error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - var result *multierror.Error - for prefix := range rm.refCountMap { - log.Debugf("Removing route for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - } - } - rm.refCountMap = map[netip.Prefix]ref{} - rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} - - return result.ErrorOrNil() -} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index af82dc91349..19236787772 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,7 +4,6 @@ package routemanager import ( "context" - "fmt" "net/netip" "sync" @@ -49,7 +48,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er oldRoute := m.routes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } delete(m.routes, routeID) @@ -63,7 +62,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er err := m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) continue } m.routes[id] = newRoute @@ -82,22 +81,15 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("Not removing from server network because context is done") + log.Infof("not removing from server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.RemoveRoutingRules(routerPair) + err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { - return fmt.Errorf("remove routing rules: %w", err) + return err } - delete(m.routes, route.ID) state := m.statusRecorder.GetLocalPeerState() @@ -111,22 +103,15 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("Not adding to server network because context is done") + log.Infof("not adding to server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) + err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.InsertRoutingRules(routerPair) - if err != nil { - return fmt.Errorf("insert routing rules: %w", err) + return err } - m.routes[route.ID] = route state := m.statusRecorder.GetLocalPeerState() @@ -144,33 +129,23 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) - if err != nil { - log.Errorf("Failed to convert route to router pair: %v", err) - continue - } - - err = m.firewall.RemoveRoutingRules(routerPair) + err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) if err != nil { - log.Errorf("Failed to remove cleanup route: %v", err) + log.Warnf("failed to remove clean up route: %s", r.ID) } + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } - - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } -func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { - parsed, err := netip.ParsePrefix(source) - if err != nil { - return firewall.RouterPair{}, err - } +func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { + parsed := netip.MustParsePrefix(source).Masked() return firewall.RouterPair{ ID: route.ID, Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, - }, nil + } } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go deleted file mode 100644 index a91f53636da..00000000000 --- a/client/internal/routemanager/systemops.go +++ /dev/null @@ -1,407 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "context" - "errors" - "fmt" - "net" - "net/netip" - - "github.com/hashicorp/go-multierror" - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" -) - -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - -var ErrRouteNotFound = errors.New("route not found") -var ErrRouteNotAllowed = errors.New("route not allowed") - -// TODO: fix: for default our wg address now appears as the default gw -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(defaultGateway) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) - if defaultGateway.Is6() { - gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - var exitIntf string - gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - if intf != nil { - exitIntf = intf.Name - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) -} - -func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Warnf("Failed to get route for %s: %v", ip, err) - return netip.Addr{}, nil, ErrRouteNotFound - } - - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - if preferredSrc == nil { - return netip.Addr{}, nil, ErrRouteNotFound - } - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) - } - return addr.Unmap(), intf, nil - } - - addr, ok := netip.AddrFromSlice(gateway) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) - } - - return addr.Unmap(), intf, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. -// If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(), - addr.IsLinkLocalUnicast(), - addr.IsLinkLocalMulticast(), - addr.IsInterfaceLocalMulticast(), - addr.IsUnspecified(), - addr.IsMulticast(): - - return netip.Addr{}, "", ErrRouteNotAllowed - } - - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) - if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) - } - - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } - - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") - } - - // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) - exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } - } - - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) - } - - return exitNextHop, exitIntf, nil -} - -// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix -// in two /1 prefixes to avoid replacing the existing default route -func genericAddVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - return err - } - if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } - - // TODO: remove once IPv6 is supported on the interface - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } else if prefix == defaultv6 { - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } - - return addNonExistingRoute(prefix, intf) -} - -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, netip.Addr{}, intf) -} - -// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, -// it will remove the split /1 prefixes -func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - // TODO: remove once IPv6 is supported on the interface - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } else if prefix == defaultv6 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } - - return removeFromRouteTable(prefix, netip.Addr{}, intf) -} - -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v4 default next hop: %v", err) - } - initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v6 default next hop: %v", err) - } - - *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { - addr := prefix.Addr() - nexthop, intf := initialNextHopV4, initialIntfV4 - if addr.Is6() { - nexthop, intf = initialNextHopV6, initialIntfV6 - } - return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) - }, - removeFromRouteTable, - ) - - return setupHooks(*routeManager, initAddresses) -} - -func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil -} - -func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil -} diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 34d2d270fe3..950a268434c 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,33 +1,13 @@ package routemanager import ( - "net" "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { - return nil -} - -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, string) error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { return nil } -func removeVPNRoute(netip.Prefix, string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 173e7c0e847..b2da8075cfa 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -1,4 +1,5 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd +// +build darwin dragonfly freebsd netbsd openbsd package routemanager diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go deleted file mode 100644 index f34964a8343..00000000000 --- a/client/internal/routemanager/systemops_darwin.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build darwin && !ios - -package routemanager - -import ( - "fmt" - "net" - "net/netip" - "os/exec" - "strings" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" -) - -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) -} - -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return routeCmd("add", prefix, nexthop, intf) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return routeCmd("delete", prefix, nexthop, intf) -} - -func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { - inet := "-inet" - if prefix.Addr().Is6() { - inet = "-inet6" - // Special case for IPv6 split default route, pointing to the wg interface fails - // TODO: Remove once we have IPv6 support on the interface - if prefix.Bits() == 1 { - intf = "lo0" - } - } - - args := []string{"-n", action, inet, prefix.String()} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) - } else if intf != "" { - args = append(args, "-interface", intf) - } - - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) - - if err != nil { - return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) - } - return nil -} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go deleted file mode 100644 index 5c5aaa24fe1..00000000000 --- a/client/internal/routemanager/systemops_darwin_test.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build !ios - -package routemanager - -import ( - "fmt" - "net" - "os/exec" - "regexp" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var expectedVPNint = "utun100" -var expectedExternalInt = "lo0" -var expectedInternalInt = "lo0" - -func init() { - testCases = append(testCases, []testCase{ - { - name: "To more specific route without custom dialer via vpn", - destination: "10.10.0.2:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), - }, - }...) -} - -func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { - t.Helper() - - err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() - require.NoError(t, err, "Failed to create loopback alias") - - t.Cleanup(func() { - err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() - assert.NoError(t, err, "Failed to remove loopback alias") - }) - - return "lo0" -} - -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { - t.Helper() - - var originalNexthop net.IP - if dstCIDR == "0.0.0.0/0" { - var err error - originalNexthop, err = fetchOriginalGateway() - if err != nil { - t.Logf("Failed to fetch original gateway: %v", err) - } - - if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { - t.Logf("Failed to delete route: %v, output: %s", err, output) - } - } - - t.Cleanup(func() { - if originalNexthop != nil { - err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() - assert.NoError(t, err, "Failed to restore original route") - } - }) - - err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() - require.NoError(t, err, "Failed to add route") - - t.Cleanup(func() { - err := exec.Command("route", "delete", "-net", dstCIDR).Run() - assert.NoError(t, err, "Failed to remove route") - }) -} - -func fetchOriginalGateway() (net.IP, error) { - output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() - if err != nil { - return nil, err - } - - matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) - if len(matches) == 0 { - return nil, fmt.Errorf("gateway not found") - } - - return net.ParseIP(matches[1]), nil -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) - - otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) -} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 34d2d270fe3..aae0f8dc8f2 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,33 +1,15 @@ +//go:build ios + package routemanager import ( - "net" "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { - return nil -} - -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, string) error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { return nil } -func removeVPNRoute(netip.Prefix, string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index ef464372737..0562826a55d 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -3,342 +3,142 @@ package routemanager import ( - "bufio" - "context" - "errors" - "fmt" "net" "net/netip" "os" "syscall" - "time" + "unsafe" - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" -) - -const ( - // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. - NetbirdVPNTableID = 0x1BD0 - // NetbirdVPNTableName is the name of the custom routing table used by Netbird. - NetbirdVPNTableName = "netbird" - - // rtTablesPath is the path to the file containing the routing table names. - rtTablesPath = "/etc/iproute2/rt_tables" - - // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. - ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" ) -var ErrTableIDExists = errors.New("ID exists with different name") - -var routeManager = &RouteManager{} -var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" - -type ruleParams struct { - fwmark int - tableID int - family int - priority int - invert bool - suppressPrefix int - description string -} - -func getSetupRules() []ruleParams { - return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, - } -} - -// setupRouting establishes the routing configuration for the VPN, including essential rules -// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. -// -// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over -// potential routes received and configured for the VPN. This rule is skipped for the default route and routes -// that are not in the main table. -// -// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. -// This table is where a default route or other specific routes received from the management server are configured, -// enabling VPN connectivity. -// -// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { - if isLegacy { - log.Infof("Using legacy routing setup") - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) - } - - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) - } +// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html +// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. +type routeInfoInMemory struct { + Family byte + DstLen byte + SrcLen byte + TOS byte - defer func() { - if err != nil { - if cleanErr := cleanupRouting(); cleanErr != nil { - log.Errorf("Error cleaning up routing: %v", cleanErr) - } - } - }() + Table byte + Protocol byte + Scope byte + Type byte - rules := getSetupRules() - for _, rule := range rules { - if err := addRule(rule); err != nil { - if errors.Is(err, syscall.EOPNOTSUPP) { - log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") - isLegacy = true - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) - } - return nil, nil, fmt.Errorf("%s: %w", rule.description, err) - } - } - - return nil, nil, nil -} - -// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. -// It systematically removes the three rules and any associated routing table entries to ensure a clean state. -// The function uses error aggregation to report any errors encountered during the cleanup process. -func cleanupRouting() error { - if isLegacy { - return cleanupRoutingWithRouteManager(routeManager) - } - - var result *multierror.Error - - if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { - result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err)) - } - if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { - result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err)) - } - - rules := getSetupRules() - for _, rule := range rules { - if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { - result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) - } - } - - return result.ErrorOrNil() -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) -} - -func addVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { - return genericAddVPNRoute(prefix, intf) - } - - // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 - - // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add blackhole: %w", err) - } - } - if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add route: %w", err) - } - return nil + Flags uint32 } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { - return genericRemoveVPNRoute(prefix, intf) - } +const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" - // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove unreachable route: %w", err) - } - } - if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove route: %w", err) - } - return nil -} - -func getRoutesFromTable() ([]netip.Prefix, error) { - v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) +func addToRouteTable(prefix netip.Prefix, addr string) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return nil, fmt.Errorf("get v4 routes: %w", err) + return err } - v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) - if err != nil { - return nil, fmt.Errorf("get v6 routes: %w", err) + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" } - return append(v4Routes, v6Routes...), nil -} -// getRoutes fetches routes from a specific routing table identified by tableID. -func getRoutes(tableID, family int) ([]netip.Prefix, error) { - var prefixList []netip.Prefix - - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + ip, _, err := net.ParseCIDR(addr + addrMask) if err != nil { - return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) - } - - for _, route := range routes { - if route.Dst != nil { - addr, ok := netip.AddrFromSlice(route.Dst.IP) - if !ok { - return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) - } - - ones, _ := route.Dst.Mask.Size() - - prefix := netip.PrefixFrom(addr, ones) - if prefix.IsValid() { - prefixList = append(prefixList, prefix) - } - } + return err } - return prefixList, nil -} - -// addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: getAddressFamily(prefix), + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, } - _, ipNet, err := net.ParseCIDR(prefix.String()) + err = netlink.RouteAdd(route) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - route.Dst = ipNet - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink add route: %w", err) + return err } return nil } -// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. -// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. -// tableID specifies the routing table to which the unreachable route will be added. -func addUnreachableRoute(prefix netip.Prefix, tableID int) error { +func removeFromRouteTable(prefix netip.Prefix, addr string) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, + return err } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink add unreachable route: %w", err) + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" } - return nil -} - -func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) + ip, _, err := net.ParseCIDR(addr + addrMask) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) + return err } route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink remove unreachable route: %w", err) + err = netlink.RouteDel(route) + if err != nil { + return err } return nil - } -// removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) +func getRoutesFromTable() ([]netip.Prefix, error) { + tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, - } - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink remove route: %w", err) + return nil, err } - - return nil -} - -func flushRoutes(tableID, family int) error { - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { - return fmt.Errorf("list routes from table %d: %w", tableID, err) + return nil, err } + var prefixList []netip.Prefix +loop: + for _, m := range msgs { + switch m.Header.Type { + case syscall.NLMSG_DONE: + break loop + case syscall.RTM_NEWROUTE: + rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) + msg := m + attrs, err := syscall.ParseNetlinkRouteAttr(&msg) + if err != nil { + return nil, err + } + if rt.Family != syscall.AF_INET { + continue loop + } - var result *multierror.Error - for i := range routes { - route := routes[i] - // unreachable default routes don't come back with Dst set - if route.Gw == nil && route.Src == nil && route.Dst == nil { - if family == netlink.FAMILY_V4 { - routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} - } else { - routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} + for _, attr := range attrs { + if attr.Attr.Type == syscall.RTA_DST { + addr, ok := netip.AddrFromSlice(attr.Value) + if !ok { + continue + } + mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) + cidr, _ := mask.Size() + routePrefix := netip.PrefixFrom(addr, cidr) + if routePrefix.IsValid() && routePrefix.Addr().Is4() { + prefixList = append(prefixList, routePrefix) + } + } } } - if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) - } } - - return result.ErrorOrNil() + return prefixList, nil } func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { - return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) + return err } // check if it is already enabled @@ -347,162 +147,5 @@ func enableIPForwarding() error { return nil } - //nolint:gosec - if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { - return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) - } - return nil -} - -// entryExists checks if the specified ID or name already exists in the rt_tables file -// and verifies if existing names start with "netbird_". -func entryExists(file *os.File, id int) (bool, error) { - if _, err := file.Seek(0, 0); err != nil { - return false, fmt.Errorf("seek rt_tables: %w", err) - } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - var existingID int - var existingName string - if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil { - if existingID == id { - if existingName != NetbirdVPNTableName { - return true, ErrTableIDExists - } - return true, nil - } - } - } - if err := scanner.Err(); err != nil { - return false, fmt.Errorf("scan rt_tables: %w", err) - } - return false, nil -} - -// addRoutingTableName adds human-readable names for custom routing tables. -func addRoutingTableName() error { - file, err := os.Open(rtTablesPath) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return fmt.Errorf("open rt_tables: %w", err) - } - defer func() { - if err := file.Close(); err != nil { - log.Errorf("Error closing rt_tables: %v", err) - } - }() - - exists, err := entryExists(file, NetbirdVPNTableID) - if err != nil { - return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err) - } - if exists { - return nil - } - - // Reopen the file in append mode to add new entries - if err := file.Close(); err != nil { - log.Errorf("Error closing rt_tables before appending: %v", err) - } - file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) - if err != nil { - return fmt.Errorf("open rt_tables for appending: %w", err) - } - - if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil { - return fmt.Errorf("append entry to rt_tables: %w", err) - } - - return nil -} - -// addRule adds a routing rule to a specific routing table identified by tableID. -func addRule(params ruleParams) error { - rule := netlink.NewRule() - rule.Table = params.tableID - rule.Mark = params.fwmark - rule.Family = params.family - rule.Priority = params.priority - rule.Invert = params.invert - rule.SuppressPrefixlen = params.suppressPrefix - - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("add routing rule: %w", err) - } - - return nil -} - -// removeRule removes a routing rule from a specific routing table identified by tableID. -func removeRule(params ruleParams) error { - rule := netlink.NewRule() - rule.Table = params.tableID - rule.Mark = params.fwmark - rule.Family = params.family - rule.Invert = params.invert - rule.Priority = params.priority - rule.SuppressPrefixlen = params.suppressPrefix - - if err := netlink.RuleDel(rule); err != nil { - return fmt.Errorf("remove routing rule: %w", err) - } - - return nil -} - -func removeAllRules(params ruleParams) error { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - done := make(chan error, 1) - go func() { - for { - if ctx.Err() != nil { - done <- ctx.Err() - return - } - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { - done <- nil - return - } - done <- err - return - } - } - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - return err - } -} - -// addNextHop adds the gateway and device to the route. -func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { - if addr.IsValid() { - route.Gw = addr.AsSlice() - } - - if intf != "" { - link, err := netlink.LinkByName(intf) - if err != nil { - return fmt.Errorf("set interface %s: %w", intf, err) - } - route.LinkIndex = link.Attrs().Index - } - - return nil -} - -func getAddressFamily(prefix netip.Prefix) int { - if prefix.Addr().Is4() { - return netlink.FAMILY_V4 - } - return netlink.FAMILY_V6 + return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go deleted file mode 100644 index 0043c3f4e94..00000000000 --- a/client/internal/routemanager/systemops_linux_test.go +++ /dev/null @@ -1,207 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "errors" - "fmt" - "net" - "os" - "strings" - "syscall" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" -) - -var expectedVPNint = "wgtest0" -var expectedLoopbackInt = "lo" -var expectedExternalInt = "dummyext0" -var expectedInternalInt = "dummyint0" - -func init() { - testCases = append(testCases, []testCase{ - { - name: "To more specific route without custom dialer via physical interface", - destination: "10.10.0.2:53", - expectedInterface: expectedInternalInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), - }, - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.1:53", - expectedInterface: expectedLoopbackInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, - }...) -} - -func TestEntryExists(t *testing.T) { - tempDir := t.TempDir() - tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) - - content := []string{ - "1000 reserved", - fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName), - "9999 other_table", - } - require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644)) - - file, err := os.Open(tempFilePath) - require.NoError(t, err) - defer func() { - assert.NoError(t, file.Close()) - }() - - tests := []struct { - name string - id int - shouldExist bool - err error - }{ - { - name: "ExistsWithNetbirdPrefix", - id: 7120, - shouldExist: true, - err: nil, - }, - { - name: "ExistsWithDifferentName", - id: 1000, - shouldExist: true, - err: ErrTableIDExists, - }, - { - name: "DoesNotExist", - id: 1234, - shouldExist: false, - err: nil, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - exists, err := entryExists(file, tc.id) - if tc.err != nil { - assert.ErrorIs(t, err, tc.err) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.shouldExist, exists) - }) - } -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} - err := netlink.LinkDel(dummy) - if err != nil && !errors.Is(err, syscall.EINVAL) { - t.Logf("Failed to delete dummy interface: %v", err) - } - - err = netlink.LinkAdd(dummy) - require.NoError(t, err) - - err = netlink.LinkSetUp(dummy) - require.NoError(t, err) - - if ipAddressCIDR != "" { - addr, err := netlink.ParseAddr(ipAddressCIDR) - require.NoError(t, err) - err = netlink.AddrAdd(dummy, addr) - require.NoError(t, err) - } - - t.Cleanup(func() { - err := netlink.LinkDel(dummy) - assert.NoError(t, err) - }) - - return dummy.Name -} - -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { - t.Helper() - - _, dstIPNet, err := net.ParseCIDR(dstCIDR) - require.NoError(t, err) - - // Handle existing routes with metric 0 - var originalNexthop net.IP - var originalLinkIndex int - if dstIPNet.String() == "0.0.0.0/0" { - var err error - originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - t.Logf("Failed to fetch original gateway: %v", err) - } - - if originalNexthop != nil { - err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - switch { - case err != nil && !errors.Is(err, syscall.ESRCH): - t.Logf("Failed to delete route: %v", err) - case err == nil: - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - default: - t.Logf("Failed to delete route: %v", err) - } - } - } - - link, err := netlink.LinkByName(intf) - require.NoError(t, err) - linkIndex := link.Attrs().Index - - route := &netlink.Route{ - Dst: dstIPNet, - Gw: gw, - LinkIndex: linkIndex, - } - err = netlink.RouteDel(route) - if err != nil && !errors.Is(err, syscall.ESRCH) { - t.Logf("Failed to delete route: %v", err) - } - - err = netlink.RouteAdd(route) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - require.NoError(t, err) -} - -func fetchOriginalGateway(family int) (net.IP, int, error) { - routes, err := netlink.RouteList(nil, family) - if err != nil { - return nil, 0, err - } - - for _, route := range routes { - if route.Dst == nil && route.Priority == 0 { - return route.Gw, route.LinkIndex, nil - } - } - - return nil, 0, ErrRouteNotFound -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) - - otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) -} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go new file mode 100644 index 00000000000..11247c7dccd --- /dev/null +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -0,0 +1,120 @@ +//go:build !android && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" +) + +var errRouteNotFound = fmt.Errorf("route not found") + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return err + } + if ok { + log.Warnf("skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return err + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, addr) +} + +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + if err != nil && err != errRouteNotFound { + return err + } + + addr := netip.MustParseAddr(defaultGateway.String()) + + if !prefix.Contains(addr) { + log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(addr, 32) + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) + if err != nil && err != errRouteNotFound { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop.String()) +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { + return removeFromRouteTable(prefix, addr) +} + +func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { + r, err := netroute.New() + if err != nil { + return nil, err + } + _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) + if err != nil { + log.Errorf("getting routes returned an error: %v", err) + return nil, errRouteNotFound + } + + if gateway == nil { + return preferredSrc, nil + } + + return gateway, nil +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_nonandroid_test.go similarity index 59% rename from client/internal/routemanager/systemops_test.go rename to client/internal/routemanager/systemops_nonandroid_test.go index 97386f19a1a..6f32d9634bc 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -1,32 +1,24 @@ -//go:build !android && !ios +//go:build !android package routemanager import ( "bytes" - "context" "fmt" "net" "net/netip" "os" - "runtime" "strings" "testing" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) -type dialer interface { - Dial(network, address string) (net.Conn, error) - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string @@ -61,30 +53,27 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "genericAddVPNRoute should not return err") + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { - assertWGOutInterface(t, testCase.prefix, wgInterface, false) + require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") } else { - assertWGOutInterface(t, testCase.prefix, wgInterface, true) + require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") } exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "genericRemoveVPNRoute should not return err") + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") + prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) require.NoError(t, err) if testCase.shouldBeRemoved { @@ -97,12 +86,12 @@ func TestAddRemoveRoutes(t *testing.T) { } } -func TestGetNextHop(t *testing.T) { - gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) +func TestGetExistingRIBRouteGateway(t *testing.T) { + gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if !gateway.IsValid() { + if gateway == nil { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -124,11 +113,11 @@ func TestGetNextHop(t *testing.T) { } } - localIP, _, err := getNextHop(testingPrefix.Addr()) + localIP, err := getExistingRIBRouteGateway(testingPrefix) if err != nil { t.Fatal("shouldn't return error: ", err) } - if !localIP.IsValid() { + if localIP == nil { t.Fatal("should return a gateway for local network") } if localIP.String() == gateway.String() { @@ -139,8 +128,8 @@ func TestGetNextHop(t *testing.T) { } } -func TestAddExistAndRemoveRoute(t *testing.T) { - defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) +func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { + defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) @@ -200,14 +189,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + MockAddr := wgInterface.Address().IP.String() + // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -217,7 +208,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) require.NoError(t, err, "should not return err") } @@ -226,7 +217,6 @@ func TestAddExistAndRemoveRoute(t *testing.T) { ok, err := existsInRouteTable(testCase.prefix) t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") - if !strings.Contains(buf.String(), "because it already exists") { require.False(t, ok, "route should not exist") } @@ -234,6 +224,31 @@ func TestAddExistAndRemoveRoute(t *testing.T) { } } +func TestExistsInRouteTable(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is4() { + addressPrefixes = append(addressPrefixes, p.Masked()) + } + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} + func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { @@ -271,132 +286,3 @@ func TestIsSubRange(t *testing.T) { } } } - -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is6() { - continue - } - // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { - continue - } - // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence - if runtime.GOOS == "linux" && p.Addr().IsLoopback() { - continue - } - - addressPrefixes = append(addressPrefixes, p.Masked()) - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet() - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) { - t.Helper() - - setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - _, _, err := setupRouting(nil, wgIface) - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) -} - -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { - return - } - - prefixGateway, _, err := getNextHop(prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 38026107ec7..47bd60eb02b 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,23 +1,41 @@ -//go:build !linux && !ios +//go:build !linux +// +build !linux package routemanager import ( "net/netip" + "os/exec" "runtime" log "github.com/sirupsen/logrus" ) -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) +func addToRouteTable(prefix netip.Prefix, addr string) error { + cmd := exec.Command("route", "add", prefix.String(), addr) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) return nil } -func addVPNRoute(prefix netip.Prefix, intf string) error { - return genericAddVPNRoute(prefix, intf) +func removeFromRouteTable(prefix netip.Prefix, addr string) error { + args := []string{"delete", prefix.String()} + if runtime.GOOS == "darwin" { + args = append(args, addr) + } + cmd := exec.Command("route", args...) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) + return nil } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - return genericRemoveVPNRoute(prefix, intf) +func enableIPForwarding() error { + log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil } diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go deleted file mode 100644 index 561eaeea4b2..00000000000 --- a/client/internal/routemanager/systemops_unix_test.go +++ /dev/null @@ -1,234 +0,0 @@ -//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly - -package routemanager - -import ( - "fmt" - "net" - "strings" - "testing" - "time" - - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/gopacket/gopacket/pcap" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool -} - -type testCase struct { - name string - destination string - expectedInterface string - dialer dialer - expectedPacket PacketExpectation -} - -var testCases = []testCase{ - { - name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", - expectedInterface: expectedExternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", - expectedInterface: expectedInternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedInterface: expectedInternalInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - - { - name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", - expectedInterface: expectedExternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), - }, - { - name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), - }, -} - -func TestRouting(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setupTestEnv(t) - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.expectedInterface, filter) - - sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.expectedPacket) - }) - } -} - -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, - } -} - -func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { - t.Helper() - - inactive, err := pcap.NewInactiveHandle(intf) - require.NoError(t, err, "Failed to create inactive pcap handle") - defer inactive.CleanUp() - - err = inactive.SetSnapLen(1600) - require.NoError(t, err, "Failed to set snap length on inactive handle") - - err = inactive.SetTimeout(time.Second * 10) - require.NoError(t, err, "Failed to set timeout on inactive handle") - - err = inactive.SetImmediateMode(true) - require.NoError(t, err, "Failed to set immediate mode on inactive handle") - - handle, err := inactive.Activate() - require.NoError(t, err, "Failed to activate pcap handle") - t.Cleanup(handle.Close) - - err = handle.SetBPFFilter(filter) - require.NoError(t, err, "Failed to set BPF filter") - - return handle -} - -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { - t.Helper() - - if dialer == nil { - dialer = &net.Dialer{} - } - - if sourcePort != 0 { - localUDPAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: sourcePort, - } - switch dialer := dialer.(type) { - case *nbnet.Dialer: - dialer.LocalAddr = localUDPAddr - case *net.Dialer: - dialer.LocalAddr = localUDPAddr - default: - t.Fatal("Unsupported dialer type") - } - } - - msg := new(dns.Msg) - msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = []dns.Question{ - {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - conn, err := dialer.Dial("udp", destination) - require.NoError(t, err, "Failed to dial UDP") - defer conn.Close() - - data, err := msg.Pack() - require.NoError(t, err, "Failed to pack DNS message") - - _, err = conn.Write(data) - if err != nil { - if strings.Contains(err.Error(), "required key not available") { - t.Logf("Ignoring WireGuard key error: %v", err) - return - } - t.Fatalf("Failed to send DNS query: %v", err) - } -} - -func createBPFFilter(destination string) string { - host, port, err := net.SplitHostPort(destination) - if err != nil { - return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) - } - return "udp" -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } -} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 50fff0cd58d..309c184b9ca 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -1,19 +1,13 @@ //go:build windows +// +build windows package routemanager import ( - "fmt" "net" "net/netip" - "os/exec" - "strings" - log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -21,35 +15,23 @@ type Win32_IP4RouteTable struct { Mask string } -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) -} - -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) -} - func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" err := wmi.Query(query, &routes) if err != nil { - return nil, fmt.Errorf("get routes: %w", err) + return nil, err } var prefixList []netip.Prefix for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { - log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) continue } maskSlice := net.ParseIP(route.Mask).To4() if maskSlice == nil { - log.Warnf("Unable to parse route mask %s", route.Mask) continue } mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) @@ -62,69 +44,3 @@ func getRoutesFromTable() ([]netip.Prefix, error) { } return prefixList, nil } - -func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - destinationPrefix := prefix.String() - psCmd := "New-NetRoute" - - addressFamily := "IPv4" - if prefix.Addr().Is6() { - addressFamily = "IPv6" - } - - script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, - psCmd, addressFamily, destinationPrefix, intf, - ) - - if nexthop.IsValid() { - script = fmt.Sprintf( - `%s -NextHop "%s"`, script, nexthop, - ) - } - - out, err := exec.Command("powershell", "-Command", script).CombinedOutput() - log.Tracef("PowerShell add route: %s", string(out)) - - if err != nil { - return fmt.Errorf("PowerShell add route: %w", err) - } - - return nil -} - -func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"add", prefix.String(), nexthop.Unmap().String()} - - out, err := exec.Command("route", args...).CombinedOutput() - - log.Tracef("route %s output: %s", strings.Join(args, " "), out) - if err != nil { - return fmt.Errorf("route add: %w", err) - } - - return nil -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - // Powershell doesn't support adding routes without an interface but allows to add interface by name - if intf != "" { - return addRoutePowershell(prefix, nexthop, intf) - } - return addRouteCmd(prefix, nexthop, intf) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"delete", prefix.String()} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) - } - - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s output: %s", strings.Join(args, " "), out) - - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - return nil -} diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go deleted file mode 100644 index a5e03b8d2ce..00000000000 --- a/client/internal/routemanager/systemops_windows_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package routemanager - -import ( - "context" - "encoding/json" - "fmt" - "net" - "os/exec" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -var expectedExtInt = "Ethernet1" - -type RouteInfo struct { - NextHop string `json:"nexthop"` - InterfaceAlias string `json:"interfacealias"` - RouteMetric int `json:"routemetric"` -} - -type FindNetRouteOutput struct { - IPAddress string `json:"IPAddress"` - InterfaceIndex int `json:"InterfaceIndex"` - InterfaceAlias string `json:"InterfaceAlias"` - AddressFamily int `json:"AddressFamily"` - NextHop string `json:"NextHop"` - DestinationPrefix string `json:"DestinationPrefix"` -} - -type testCase struct { - name string - destination string - expectedSourceIP string - expectedDestPrefix string - expectedNextHop string - expectedInterface string - dialer dialer -} - -var expectedVPNint = "wgtest0" - -var testCases = []testCase{ - { - name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "128.0.0.0/1", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - { - name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", - expectedDestPrefix: "192.0.2.1/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - - { - name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", - expectedDestPrefix: "10.0.0.2/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedSourceIP: "10.0.0.1", - expectedDestPrefix: "10.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", - dialer: &net.Dialer{}, - }, - - { - name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", - expectedDestPrefix: "172.16.0.2/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - { - name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "172.16.0.0/12", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - - { - name: "To more specific route without custom dialer via vpn interface", - destination: "10.10.0.2:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "10.10.0.0/24", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.2:53", - expectedSourceIP: "10.0.0.1", - expectedDestPrefix: "127.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", - dialer: &net.Dialer{}, - }, -} - -func TestRouting(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setupTestEnv(t) - - route, err := fetchOriginalGateway() - require.NoError(t, err, "Failed to fetch original gateway") - ip, err := fetchInterfaceIP(route.InterfaceAlias) - require.NoError(t, err, "Failed to fetch interface IP") - - output := testRoute(t, tc.destination, tc.dialer) - if tc.expectedInterface == expectedExtInt { - verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) - } else { - verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) - } - }) - } -} - -// fetchInterfaceIP fetches the IPv4 address of the specified interface. -func fetchInterfaceIP(interfaceAlias string) (string, error) { - script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) - out, err := exec.Command("powershell", "-Command", script).Output() - if err != nil { - return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) - } - - ip := strings.TrimSpace(string(out)) - return ip, nil -} - -func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - conn, err := dialer.DialContext(ctx, "udp", destination) - require.NoError(t, err, "Failed to dial destination") - defer func() { - err := conn.Close() - assert.NoError(t, err, "Failed to close connection") - }() - - host, _, err := net.SplitHostPort(destination) - require.NoError(t, err) - - script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) - - out, err := exec.Command("powershell", "-Command", script).Output() - require.NoError(t, err, "Failed to execute Find-NetRoute") - - var outputs []FindNetRouteOutput - err = json.Unmarshal(out, &outputs) - require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") - - require.Greater(t, len(outputs), 0, "No route found for destination") - combinedOutput := combineOutputs(outputs) - - return combinedOutput -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) - require.NoError(t, err) - subnetMaskSize, _ := ipNet.Mask.Size() - script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to assign IP address to loopback adapter") - - // Wait for the IP address to be applied - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - err = waitForIPAddress(ctx, interfaceName, ip.String()) - require.NoError(t, err, "IP address not applied within timeout") - - t.Cleanup(func() { - script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to remove IP address from loopback adapter") - }) - - return interfaceName -} - -func fetchOriginalGateway() (*RouteInfo, error) { - cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") - output, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) - } - - var routeInfo RouteInfo - err = json.Unmarshal(output, &routeInfo) - if err != nil { - return nil, fmt.Errorf("failed to parse JSON output: %w", err) - } - - return &routeInfo, nil -} - -func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { - t.Helper() - - assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") - assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") - assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") - assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") -} - -func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() - if err != nil { - return err - } - - ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") - for _, ip := range ipAddresses { - if strings.TrimSpace(ip) == expectedIPAddress { - return nil - } - } - } - } -} - -func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { - var combined FindNetRouteOutput - - for _, output := range outputs { - if output.IPAddress != "" { - combined.IPAddress = output.IPAddress - } - if output.InterfaceIndex != 0 { - combined.InterfaceIndex = output.InterfaceIndex - } - if output.InterfaceAlias != "" { - combined.InterfaceAlias = output.InterfaceAlias - } - if output.AddressFamily != 0 { - combined.AddressFamily = output.AddressFamily - } - if output.NextHop != "" { - combined.NextHop = output.NextHop - } - if output.DestinationPrefix != "" { - combined.DestinationPrefix = output.DestinationPrefix - } - } - - return &combined -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") -} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go deleted file mode 100644 index e80adb42b20..00000000000 --- a/client/internal/stdnet/dialer.go +++ /dev/null @@ -1,24 +0,0 @@ -package stdnet - -import ( - "net" - - "github.com/pion/transport/v3" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -// Dial connects to the address on the named network. -func (n *Net) Dial(network, address string) (net.Conn, error) { - return nbnet.NewDialer().Dial(network, address) -} - -// DialUDP connects to the address on the named UDP network. -func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { - return nbnet.DialUDP(network, laddr, raddr) -} - -// DialTCP connects to the address on the named TCP network. -func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { - return nbnet.DialTCP(network, laddr, raddr) -} diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go deleted file mode 100644 index 9ce0a555610..00000000000 --- a/client/internal/stdnet/listener.go +++ /dev/null @@ -1,20 +0,0 @@ -package stdnet - -import ( - "context" - "net" - - "github.com/pion/transport/v3" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -// ListenPacket listens for incoming packets on the given network and address. -func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { - return nbnet.NewListener().ListenPacket(context.Background(), network, address) -} - -// ListenUDP acts like ListenPacket for UDP networks. -func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { - return nbnet.ListenUDP(network, locAddr) -} diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 2235c5d2bdf..f02b4943bc6 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -17,7 +17,6 @@ import ( "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" ) // WGEBPFProxy definition for proxy with EBPF support @@ -68,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - conn, err := nbnet.ListenUDP("udp", &addr) + conn, err := net.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -229,12 +228,6 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return nil, fmt.Errorf("binding to lo interface failed: %w", err) } - // Set the fwmark on the socket. - err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index 17ebfbc499b..b692ea70842 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -6,8 +6,6 @@ import ( "net" log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -35,7 +33,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/go.mod b/go.mod index 29a1570c896..e4e36b96685 100644 --- a/go.mod +++ b/go.mod @@ -48,7 +48,6 @@ require ( github.com/google/gopacket v1.1.19 github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 - github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 @@ -124,6 +123,7 @@ require ( github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect + github.com/gopacket/gopacket v1.1.1 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 9fe987cee21..36fd13cc262 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,8 +10,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := nbnet.NetbirdFwmark + fwmark := 0 config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 24dfadf1408..200bfbc9614 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -13,8 +13,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgUSPConfigurer struct { @@ -39,7 +37,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error if err != nil { return err } - fwmark := getFwmark() + fwmark := 0 config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, @@ -347,10 +345,3 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } return sb.String() } - -func getFwmark() int { - if runtime.GOOS == "linux" { - return nbnet.NetbirdFwmark - } - return 0 -} diff --git a/management/client/grpc.go b/management/client/grpc.go index 0b1804906c2..0234f866cb8 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -58,7 +57,6 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE mgmCtx, addr, transportOption, - nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 74ac6c163ad..02b4e174dab 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -21,8 +21,6 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - - nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -84,18 +82,10 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } - if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) - } - var sockErr error rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if sockErr != nil { log.Errorf("Failed to create ipv6 raw socket: %v", err) - } else { - if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) - } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7c4535e2896..7531608c3bb 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -23,7 +23,6 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -77,7 +76,6 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo sigCtx, addr, transportOption, - nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go deleted file mode 100644 index 96b2bc32be0..00000000000 --- a/util/grpc/dialer.go +++ /dev/null @@ -1,22 +0,0 @@ -package grpc - -import ( - "context" - "net" - - log "github.com/sirupsen/logrus" - "google.golang.org/grpc" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -func WithCustomDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) - if err != nil { - log.Errorf("Failed to dial: %s", err) - return nil, err - } - return conn, nil - }) -} diff --git a/util/net/dialer.go b/util/net/dialer.go deleted file mode 100644 index 0786c667e53..00000000000 --- a/util/net/dialer.go +++ /dev/null @@ -1,21 +0,0 @@ -package net - -import ( - "net" -) - -// Dialer extends the standard net.Dialer with the ability to execute hooks before -// and after connections. This can be used to bypass the VPN for connections using this dialer. -type Dialer struct { - *net.Dialer -} - -// NewDialer returns a customized net.Dialer with overridden Control method -func NewDialer() *Dialer { - dialer := &Dialer{ - Dialer: &net.Dialer{}, - } - dialer.init() - - return dialer -} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go deleted file mode 100644 index 06fac3bbf85..00000000000 --- a/util/net/dialer_generic.go +++ /dev/null @@ -1,163 +0,0 @@ -//go:build !android && !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHook removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - -func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go deleted file mode 100644 index aed5c59a322..00000000000 --- a/util/net/dialer_linux.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !android - -package net - -import "syscall" - -// init configures the net.Dialer Control function to set the fwmark on the socket -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) - } -} diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go deleted file mode 100644 index 3254e6d066b..00000000000 --- a/util/net/dialer_nonlinux.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build !linux || android - -package net - -func (d *Dialer) init() { -} diff --git a/util/net/listener.go b/util/net/listener.go deleted file mode 100644 index f4d769f587e..00000000000 --- a/util/net/listener.go +++ /dev/null @@ -1,21 +0,0 @@ -package net - -import ( - "net" -) - -// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before -// responding via the socket and after closing. This can be used to bypass the VPN for listeners. -type ListenerConfig struct { - *net.ListenConfig -} - -// NewListener creates a new ListenerConfig instance. -func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } - listener.init() - - return listener -} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go deleted file mode 100644 index 451279e9d25..00000000000 --- a/util/net/listener_generic.go +++ /dev/null @@ -1,163 +0,0 @@ -//go:build !android && !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// RemoveListenerHooks removes all dialer hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -} - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go deleted file mode 100644 index 8d332160a04..00000000000 --- a/util/net/listener_linux.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !android - -package net - -import ( - "syscall" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) - } -} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go deleted file mode 100644 index 0dbbb360b53..00000000000 --- a/util/net/listener_mobile.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build android || ios - -package net - -import ( - "net" -) - -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - return net.ListenUDP(network, laddr) -} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go deleted file mode 100644 index fb6eadaaad8..00000000000 --- a/util/net/listener_nonlinux.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build !linux || android - -package net - -func (l *ListenerConfig) init() { -} diff --git a/util/net/net.go b/util/net/net.go deleted file mode 100644 index 9ea7ae80340..00000000000 --- a/util/net/net.go +++ /dev/null @@ -1,17 +0,0 @@ -package net - -import "github.com/google/uuid" - -const ( - // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 -) - -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go deleted file mode 100644 index 82141750029..00000000000 --- a/util/net/net_linux.go +++ /dev/null @@ -1,35 +0,0 @@ -//go:build !android - -package net - -import ( - "fmt" - "syscall" -) - -// SetSocketMark sets the SO_MARK option on the given socket connection -func SetSocketMark(conn syscall.Conn) error { - sysconn, err := conn.SyscallConn() - if err != nil { - return fmt.Errorf("get raw conn: %w", err) - } - - return SetRawSocketMark(sysconn) -} - -func SetRawSocketMark(conn syscall.RawConn) error { - var setErr error - - err := conn.Control(func(fd uintptr) { - setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) - }) - if err != nil { - return fmt.Errorf("control: %w", err) - } - - if setErr != nil { - return fmt.Errorf("set SO_MARK: %w", setErr) - } - - return nil -} From 3875c29f6b1d92ef2fc00029e74036001c55829b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 9 Apr 2024 01:56:52 +0900 Subject: [PATCH 34/89] Revert "Rollback new routing functionality (#1805)" (#1813) This reverts commit 9f32ccd4533d5301bcb901677af9816cb3408f92. --- .github/workflows/golang-test-darwin.yml | 3 + .github/workflows/golang-test-linux.yml | 10 +- .github/workflows/golang-test-windows.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- client/internal/engine.go | 16 + client/internal/peer/conn.go | 32 ++ client/internal/relay/relay.go | 7 +- client/internal/routemanager/client.go | 63 ++- client/internal/routemanager/manager.go | 79 ++- client/internal/routemanager/manager_test.go | 30 +- client/internal/routemanager/mock.go | 5 + client/internal/routemanager/routemanager.go | 126 +++++ .../routemanager/server_nonandroid.go | 57 +- client/internal/routemanager/systemops.go | 407 ++++++++++++++ .../routemanager/systemops_android.go | 24 +- client/internal/routemanager/systemops_bsd.go | 1 - .../internal/routemanager/systemops_darwin.go | 61 +++ .../routemanager/systemops_darwin_test.go | 100 ++++ client/internal/routemanager/systemops_ios.go | 26 +- .../internal/routemanager/systemops_linux.go | 511 +++++++++++++++--- .../routemanager/systemops_linux_test.go | 207 +++++++ .../routemanager/systemops_nonandroid.go | 120 ---- .../routemanager/systemops_nonlinux.go | 32 +- ...s_nonandroid_test.go => systemops_test.go} | 212 ++++++-- .../routemanager/systemops_unix_test.go | 234 ++++++++ .../routemanager/systemops_windows.go | 88 ++- .../routemanager/systemops_windows_test.go | 289 ++++++++++ client/internal/stdnet/dialer.go | 24 + client/internal/stdnet/listener.go | 20 + client/internal/wgproxy/proxy_ebpf.go | 9 +- client/internal/wgproxy/proxy_userspace.go | 4 +- go.mod | 2 +- iface/wg_configurer_kernel.go | 4 +- iface/wg_configurer_usp.go | 11 +- management/client/grpc.go | 2 + sharedsock/sock_linux.go | 10 + signal/client/grpc.go | 2 + util/grpc/dialer.go | 22 + util/net/dialer.go | 21 + util/net/dialer_generic.go | 163 ++++++ util/net/dialer_linux.go | 12 + util/net/dialer_nonlinux.go | 6 + util/net/listener.go | 21 + util/net/listener_generic.go | 163 ++++++ util/net/listener_linux.go | 14 + util/net/listener_mobile.go | 11 + util/net/listener_nonlinux.go | 6 + util/net/net.go | 17 + util/net/net_linux.go | 35 ++ 49 files changed, 2969 insertions(+), 354 deletions(-) create mode 100644 client/internal/routemanager/routemanager.go create mode 100644 client/internal/routemanager/systemops.go create mode 100644 client/internal/routemanager/systemops_darwin.go create mode 100644 client/internal/routemanager/systemops_darwin_test.go create mode 100644 client/internal/routemanager/systemops_linux_test.go delete mode 100644 client/internal/routemanager/systemops_nonandroid.go rename client/internal/routemanager/{systemops_nonandroid_test.go => systemops_test.go} (59%) create mode 100644 client/internal/routemanager/systemops_unix_test.go create mode 100644 client/internal/routemanager/systemops_windows_test.go create mode 100644 client/internal/stdnet/dialer.go create mode 100644 client/internal/stdnet/listener.go create mode 100644 util/grpc/dialer.go create mode 100644 util/net/dialer.go create mode 100644 util/net/dialer_generic.go create mode 100644 util/net/dialer_linux.go create mode 100644 util/net/dialer_nonlinux.go create mode 100644 util/net/listener.go create mode 100644 util/net/listener_generic.go create mode 100644 util/net/listener_linux.go create mode 100644 util/net/listener_mobile.go create mode 100644 util/net/listener_nonlinux.go create mode 100644 util/net/net.go create mode 100644 util/net/net_linux.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index f8afd3d6eab..d7007c86080 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,6 +32,9 @@ jobs: restore-keys: | macos-go- + - name: Install libpcap + run: brew install libpcap + - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 74e6d1203ab..42f740e9b54 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -36,7 +36,11 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - name: Install modules run: go mod tidy @@ -67,7 +71,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - name: Install modules run: go mod tidy @@ -82,7 +86,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... + run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 6027d36269f..2d63acbcd5a 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9f543c74c45..13228250d59 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -40,7 +40,7 @@ jobs: cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index 13ef8ce1563..d6238c4b3ca 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -94,6 +94,9 @@ type Engine struct { // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn + beforePeerHook peer.BeforeAddPeerHookFunc + afterPeerHook peer.AfterRemovePeerHookFunc + // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -261,6 +264,14 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) + beforePeerHook, afterPeerHook, err := e.routeManager.Init() + if err != nil { + log.Errorf("Failed to initialize route manager: %s", err) + } else { + e.beforePeerHook = beforePeerHook + e.afterPeerHook = afterPeerHook + } + e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -810,6 +821,11 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { } e.peerConns[peerKey] = conn + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 17ef7e87fd2..f3d07dcad1f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -100,6 +101,9 @@ type IceCredentials struct { Pwd string } +type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error +type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error + type Conn struct { config ConnConfig mu sync.Mutex @@ -138,6 +142,10 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn + + connID nbnet.ConnectionID + beforeAddPeerHooks []BeforeAddPeerHookFunc + afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -393,6 +401,14 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } +func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { + conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) +} + +func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { + conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) +} + // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -421,6 +437,13 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem conn.remoteEndpoint = endpointUdpAddr log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + conn.connID = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + log.Errorf("Before add peer hook failed: %v", err) + } + } + err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { if conn.wgProxy != nil { @@ -511,6 +534,15 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if conn.connID != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connID); err != nil { + log.Errorf("After remove peer hook failed: %v", err) + } + } + } + conn.connID = "" + if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index ad3b94f2a5f..84fd72e49c9 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" + nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request @@ -95,15 +96,13 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) switch uri.Proto { case stun.ProtoTypeUDP: var err error - listener := &net.ListenConfig{} - conn, err = listener.ListenPacket(ctx, "udp", "") + conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "") if err != nil { probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: - dialer := &net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) + tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index f7ead582720..38cf4bf6550 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -41,6 +41,7 @@ type clientNetwork struct { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { ctx, cancel := context.WithCancel(ctx) + client := &clientNetwork{ ctx: ctx, stop: cancel, @@ -72,6 +73,18 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { return routePeerStatuses } +// getBestRouteFromStatuses determines the most optimal route from the available routes +// within a clientNetwork, taking into account peer connection status, route metrics, and +// preference for non-relayed and direct connections. +// +// It follows these prioritization rules: +// * Connected peers: Only routes with connected peers are considered. +// * Metric: Routes with lower metrics (better) are prioritized. +// * Non-relayed: Routes without relays are preferred. +// * Direct connections: Routes with direct peer connections are favored. +// * Stability: In case of equal scores, the currently active route (if any) is maintained. +// +// It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" chosenScore := 0 @@ -158,7 +171,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) if err != nil { - return err + return fmt.Errorf("get peer state: %v", err) } delete(state.Routes, c.network.String()) @@ -172,7 +185,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) if err != nil { - return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", + return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } return nil @@ -180,30 +193,26 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { + return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } - err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) - if err != nil { - return fmt.Errorf("couldn't remove route %s from system, err: %v", - c.network, err) + + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route: %v", err) } } return nil } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { - - var err error - routerPeerStatuses := c.getRouterPeerStatuses() chosen := c.getBestRouteFromStatuses(routerPeerStatuses) + + // If no route is chosen, remove the route from the peer and system if chosen == "" { - err = c.removeRouteFromPeerAndSystem() - if err != nil { - return err + if err := c.removeRouteFromPeerAndSystem(); err != nil { + return fmt.Errorf("remove route from peer and system: %v", err) } c.chosenRoute = nil @@ -211,6 +220,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } + // If the chosen route is the same as the current route, do nothing if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute.IsEqual(c.routes[chosen]) { return nil @@ -218,13 +228,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + // If a previous route exists, remove it from the peer + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route from peer: %v", err) } } else { - err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) - if err != nil { + // otherwise add the route to the system + if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -245,8 +255,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } - err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) - if err != nil { + if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } @@ -287,21 +296,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system: %v", err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("received a routes update with smaller serial number, ignoring it") + log.Warnf("Received a routes update with smaller serial number, ignoring it") continue } - log.Debugf("received a new client network route update for %s", c.network) + log.Debugf("Received a new client network route update for %s", c.network) c.handleUpdate(update) @@ -309,7 +318,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index b624d8c34ce..36a37f02c50 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,6 +2,10 @@ package routemanager import ( "context" + "fmt" + "net" + "net/netip" + "net/url" "runtime" "sync" @@ -15,8 +19,14 @@ import ( "github.com/netbirdio/netbird/version" ) +var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + +// nolint:unused +var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + // Manager is a route manager interface type Manager interface { + Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -56,6 +66,24 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } +// Init sets up the routing +func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + if err := cleanupRouting(); err != nil { + log.Warnf("Failed cleaning up routing: %v", err) + } + + mgmtAddress := m.statusRecorder.GetManagementState().URL + signalAddress := m.statusRecorder.GetSignalState().URL + ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) + + beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) + if err != nil { + return nil, nil, fmt.Errorf("setup routing: %w", err) + } + log.Info("Routing setup complete") + return beforePeerHook, afterPeerHook, nil +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -71,9 +99,15 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } + if err := cleanupRouting(); err != nil { + log.Errorf("Error cleaning up routing: %v", err) + } else { + log.Info("Routing cleanup complete") + } + m.ctx = nil } -// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps +// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -91,7 +125,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return err + return fmt.Errorf("update routes: %w", err) } } @@ -156,11 +190,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] for _, newRoute := range newRoutes { networkID := route.GetHAUniqueID(newRoute) if !ownNetworkIDs[networkID] { - // if prefix is too small, lets assume is a possible default route which is not yet supported - // we skip this route management - if newRoute.Network.Bits() < minRangeBits { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", - version.NetbirdVersion(), newRoute.Network) + if !isPrefixSupported(newRoute.Network) { continue } newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) @@ -178,3 +208,38 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } + +func isPrefixSupported(prefix netip.Prefix) bool { + switch runtime.GOOS { + case "linux", "windows", "darwin": + return true + } + + // If prefix is too small, lets assume it is a possible default prefix which is not yet supported + // we skip this prefix management + if prefix.Bits() <= minRangeBits { + log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", + version.NetbirdVersion(), prefix) + return false + } + return true +} + +// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. +func resolveURLsToIPs(urls []string) []net.IP { + var ips []net.IP + for _, rawurl := range urls { + u, err := url.Parse(rawurl) + if err != nil { + log.Errorf("Failed to parse url %s: %v", rawurl, err) + continue + } + ipAddrs, err := net.LookupIP(u.Hostname()) + if err != nil { + log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) + continue + } + ips = append(ips, ipAddrs...) + } + return ips +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2e5cf6649d8..03e77e09bcb 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedAllowed int }{ { name: "Should create 2 client networks", @@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedAllowed: 1, }, { name: "Remove 1 Client Route", @@ -415,6 +417,10 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) + + _, _, err = routeManager.Init() + + require.NoError(t, err, "should init route manager") defer routeManager.Stop() if testCase.removeSrvRouter { @@ -429,7 +435,11 @@ func TestManagerUpdateRoutes(t *testing.T) { err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") - require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") + expectedWatchers := testCase.clientNetworkWatchersExpected + if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed + } + require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index a1214cbb9ec..dd2c28e5927 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,6 +6,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -16,6 +17,10 @@ type MockManager struct { StopFunc func() } +func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + // InitialRouteRange mock implementation of InitialRouteRange from Manager interface func (m *MockManager) InitialRouteRange() []string { return nil diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go new file mode 100644 index 00000000000..8f9ff9f4bd0 --- /dev/null +++ b/client/internal/routemanager/routemanager.go @@ -0,0 +1,126 @@ +//go:build !android && !ios + +package routemanager + +import ( + "errors" + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type ref struct { + count int + nexthop netip.Addr + intf string +} + +type RouteManager struct { + // refCountMap keeps track of the reference ref for prefixes + refCountMap map[netip.Prefix]ref + // prefixMap keeps track of the prefixes associated with a connection ID for removal + prefixMap map[nbnet.ConnectionID][]netip.Prefix + addRoute AddRouteFunc + removeRoute RemoveRouteFunc + mutex sync.Mutex +} + +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error + +func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { + // TODO: read initial routing table into refCountMap + return &RouteManager{ + refCountMap: map[netip.Prefix]ref{}, + prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, + addRoute: addRoute, + removeRoute: removeRoute, + } +} + +func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + ref := rm.refCountMap[prefix] + log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) + + // Add route to the system, only if it's a new prefix + if ref.count == 0 { + log.Debugf("Adding route for prefix %s", prefix) + nexthop, intf, err := rm.addRoute(prefix) + if errors.Is(err, ErrRouteNotFound) { + return nil + } + if errors.Is(err, ErrRouteNotAllowed) { + log.Debugf("Adding route for prefix %s: %s", prefix, err) + } + if err != nil { + return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) + } + ref.nexthop = nexthop + ref.intf = intf + } + + ref.count++ + rm.refCountMap[prefix] = ref + rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) + + return nil +} + +func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + prefixes, ok := rm.prefixMap[connID] + if !ok { + log.Debugf("No prefixes found for connection ID %s", connID) + return nil + } + + var result *multierror.Error + for _, prefix := range prefixes { + ref := rm.refCountMap[prefix] + log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) + if ref.count == 1 { + log.Debugf("Removing route for prefix %s", prefix) + // TODO: don't fail if the route is not found + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + continue + } + delete(rm.refCountMap, prefix) + } else { + ref.count-- + rm.refCountMap[prefix] = ref + } + } + delete(rm.prefixMap, connID) + + return result.ErrorOrNil() +} + +// Flush removes all references and routes from the system +func (rm *RouteManager) Flush() error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + var result *multierror.Error + for prefix := range rm.refCountMap { + log.Debugf("Removing route for prefix %s", prefix) + ref := rm.refCountMap[prefix] + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + } + } + rm.refCountMap = map[netip.Prefix]ref{} + rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 19236787772..af82dc91349 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,6 +4,7 @@ package routemanager import ( "context" + "fmt" "net/netip" "sync" @@ -48,7 +49,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er oldRoute := m.routes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } delete(m.routes, routeID) @@ -62,7 +63,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er err := m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue } m.routes[id] = newRoute @@ -81,15 +82,22 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not removing from server network because context is done") + log.Infof("Not removing from server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.RemoveRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("remove routing rules: %w", err) } + delete(m.routes, route.ID) state := m.statusRecorder.GetLocalPeerState() @@ -103,15 +111,22 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not adding to server network because context is done") + log.Infof("Not adding to server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.InsertRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("insert routing rules: %w", err) } + m.routes[route.ID] = route state := m.statusRecorder.GetLocalPeerState() @@ -129,23 +144,33 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) if err != nil { - log.Warnf("failed to remove clean up route: %s", r.ID) + log.Errorf("Failed to convert route to router pair: %v", err) + continue + } + + err = m.firewall.RemoveRoutingRules(routerPair) + if err != nil { + log.Errorf("Failed to remove cleanup route: %v", err) } - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } + + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } -func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { - parsed := netip.MustParsePrefix(source).Masked() +func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { + parsed, err := netip.ParsePrefix(source) + if err != nil { + return firewall.RouterPair{}, err + } return firewall.RouterPair{ ID: route.ID, Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, - } + }, nil } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 00000000000..a91f53636da --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,407 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var ErrRouteNotFound = errors.New("route not found") +var ErrRouteNotAllowed = errors.New("route not allowed") + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Warnf("Failed to get route for %s: %v", ip, err) + return netip.Addr{}, nil, ErrRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, ErrRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. +// If the next hop or interface is pointing to the VPN interface, it will return the initial values. +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(), + addr.IsLinkLocalUnicast(), + addr.IsLinkLocalMulticast(), + addr.IsInterfaceLocalMulticast(), + addr.IsUnspecified(), + addr.IsMulticast(): + + return netip.Addr{}, "", ErrRouteNotAllowed + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func genericAddVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 950a268434c..34d2d270fe3 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { + return nil +} + +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index b2da8075cfa..173e7c0e847 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -1,5 +1,4 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package routemanager diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go new file mode 100644 index 00000000000..f34964a8343 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin.go @@ -0,0 +1,61 @@ +//go:build darwin && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + "os/exec" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" +) + +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("add", prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("delete", prefix, nexthop, intf) +} + +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { + inet := "-inet" + if prefix.Addr().Is6() { + inet = "-inet6" + // Special case for IPv6 split default route, pointing to the wg interface fails + // TODO: Remove once we have IPv6 support on the interface + if prefix.Bits() == 1 { + intf = "lo0" + } + } + + args := []string{"-n", action, inet, prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } else if intf != "" { + args = append(args, "-interface", intf) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go new file mode 100644 index 00000000000..5c5aaa24fe1 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -0,0 +1,100 @@ +//go:build !ios + +package routemanager + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var expectedVPNint = "utun100" +var expectedExternalInt = "lo0" +var expectedInternalInt = "lo0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via vpn", + destination: "10.10.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), + }, + }...) +} + +func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { + t.Helper() + + err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() + require.NoError(t, err, "Failed to create loopback alias") + + t.Cleanup(func() { + err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() + assert.NoError(t, err, "Failed to remove loopback alias") + }) + + return "lo0" +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { + t.Helper() + + var originalNexthop net.IP + if dstCIDR == "0.0.0.0/0" { + var err error + originalNexthop, err = fetchOriginalGateway() + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { + t.Logf("Failed to delete route: %v, output: %s", err, output) + } + } + + t.Cleanup(func() { + if originalNexthop != nil { + err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() + assert.NoError(t, err, "Failed to restore original route") + } + }) + + err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() + require.NoError(t, err, "Failed to add route") + + t.Cleanup(func() { + err := exec.Command("route", "delete", "-net", dstCIDR).Run() + assert.NoError(t, err, "Failed to remove route") + }) +} + +func fetchOriginalGateway() (net.IP, error) { + output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() + if err != nil { + return nil, err + } + + matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) + if len(matches) == 0 { + return nil, fmt.Errorf("gateway not found") + } + + return net.ParseIP(matches[1]), nil +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index aae0f8dc8f2..34d2d270fe3 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,15 +1,33 @@ -//go:build ios - package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { + return nil +} + +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 0562826a55d..ef464372737 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -3,142 +3,342 @@ package routemanager import ( + "bufio" + "context" + "errors" + "fmt" "net" "net/netip" "os" "syscall" - "unsafe" + "time" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +const ( + // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. + NetbirdVPNTableID = 0x1BD0 + // NetbirdVPNTableName is the name of the custom routing table used by Netbird. + NetbirdVPNTableName = "netbird" + + // rtTablesPath is the path to the file containing the routing table names. + rtTablesPath = "/etc/iproute2/rt_tables" + + // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. + ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" ) -// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html -// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. -type routeInfoInMemory struct { - Family byte - DstLen byte - SrcLen byte - TOS byte +var ErrTableIDExists = errors.New("ID exists with different name") - Table byte - Protocol byte - Scope byte - Type byte +var routeManager = &RouteManager{} +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" - Flags uint32 +type ruleParams struct { + fwmark int + tableID int + family int + priority int + invert bool + suppressPrefix int + description string } -const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +func getSetupRules() []ruleParams { + return []ruleParams{ + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, + } +} -func addToRouteTable(prefix netip.Prefix, addr string) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) +// setupRouting establishes the routing configuration for the VPN, including essential rules +// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. +// +// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over +// potential routes received and configured for the VPN. This rule is skipped for the default route and routes +// that are not in the main table. +// +// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. +// This table is where a default route or other specific routes received from the management server are configured, +// enabling VPN connectivity. +// +// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { + if isLegacy { + log.Infof("Using legacy routing setup") + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) + } + + defer func() { + if err != nil { + if cleanErr := cleanupRouting(); cleanErr != nil { + log.Errorf("Error cleaning up routing: %v", cleanErr) + } + } + }() + + rules := getSetupRules() + for _, rule := range rules { + if err := addRule(rule); err != nil { + if errors.Is(err, syscall.EOPNOTSUPP) { + log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") + isLegacy = true + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + return nil, nil, fmt.Errorf("%s: %w", rule.description, err) + } + } + + return nil, nil, nil +} + +// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. +// It systematically removes the three rules and any associated routing table entries to ensure a clean state. +// The function uses error aggregation to report any errors encountered during the cleanup process. +func cleanupRouting() error { + if isLegacy { + return cleanupRoutingWithRouteManager(routeManager) + } + + var result *multierror.Error + + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err)) + } + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err)) + } + + rules := getSetupRules() + for _, rule := range rules { + if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { + result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) + } + } + + return result.ErrorOrNil() +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func addVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericAddVPNRoute(prefix, intf) + } + + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 + + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { + return fmt.Errorf("add blackhole: %w", err) + } + } + if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { + return fmt.Errorf("add route: %w", err) + } + return nil +} + +func removeVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericRemoveVPNRoute(prefix, intf) + } + + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { + return fmt.Errorf("remove unreachable route: %w", err) + } + } + if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil +} + +func getRoutesFromTable() ([]netip.Prefix, error) { + v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) if err != nil { - return err + return nil, fmt.Errorf("get v4 routes: %w", err) } + v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) + if err != nil { + return nil, fmt.Errorf("get v6 routes: %w", err) - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" } + return append(v4Routes, v6Routes...), nil +} - ip, _, err := net.ParseCIDR(addr + addrMask) +// getRoutes fetches routes from a specific routing table identified by tableID. +func getRoutes(tableID, family int) ([]netip.Prefix, error) { + var prefixList []netip.Prefix + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) if err != nil { - return err + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) } + for _, route := range routes { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) + } + + ones, _ := route.Dst.Mask.Size() + + prefix := netip.PrefixFrom(addr, ones) + if prefix.IsValid() { + prefixList = append(prefixList, prefix) + } + } + } + + return prefixList, nil +} + +// addRoute adds a route to a specific routing table identified by tableID. +func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: getAddressFamily(prefix), } - err = netlink.RouteAdd(route) + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return err + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + route.Dst = ipNet + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink add route: %w", err) } return nil } -func removeFromRouteTable(prefix netip.Prefix, addr string) error { +// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. +// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. +// tableID specifies the routing table to which the unreachable route will be added. +func addUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" + route := &netlink.Route{ + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, } - ip, _, err := net.ParseCIDR(addr + addrMask) + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink add unreachable route: %w", err) + } + + return nil +} + +func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, } - err = netlink.RouteDel(route) - if err != nil { - return err + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink remove unreachable route: %w", err) } return nil + } -func getRoutesFromTable() ([]netip.Prefix, error) { - tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) +// removeRoute removes a route from a specific routing table identified by tableID. +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return nil, err + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, + } + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink remove route: %w", err) } - msgs, err := syscall.ParseNetlinkMessage(tab) + + return nil +} + +func flushRoutes(tableID, family int) error { + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) if err != nil { - return nil, err + return fmt.Errorf("list routes from table %d: %w", tableID, err) } - var prefixList []netip.Prefix -loop: - for _, m := range msgs { - switch m.Header.Type { - case syscall.NLMSG_DONE: - break loop - case syscall.RTM_NEWROUTE: - rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) - msg := m - attrs, err := syscall.ParseNetlinkRouteAttr(&msg) - if err != nil { - return nil, err - } - if rt.Family != syscall.AF_INET { - continue loop - } - for _, attr := range attrs { - if attr.Attr.Type == syscall.RTA_DST { - addr, ok := netip.AddrFromSlice(attr.Value) - if !ok { - continue - } - mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) - cidr, _ := mask.Size() - routePrefix := netip.PrefixFrom(addr, cidr) - if routePrefix.IsValid() && routePrefix.Addr().Is4() { - prefixList = append(prefixList, routePrefix) - } - } + var result *multierror.Error + for i := range routes { + route := routes[i] + // unreachable default routes don't come back with Dst set + if route.Gw == nil && route.Src == nil && route.Dst == nil { + if family == netlink.FAMILY_V4 { + routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} + } else { + routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} } } + if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) + } } - return prefixList, nil + + return result.ErrorOrNil() } func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { - return err + return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) } // check if it is already enabled @@ -147,5 +347,162 @@ func enableIPForwarding() error { return nil } - return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec + //nolint:gosec + if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { + return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) + } + return nil +} + +// entryExists checks if the specified ID or name already exists in the rt_tables file +// and verifies if existing names start with "netbird_". +func entryExists(file *os.File, id int) (bool, error) { + if _, err := file.Seek(0, 0); err != nil { + return false, fmt.Errorf("seek rt_tables: %w", err) + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + var existingID int + var existingName string + if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil { + if existingID == id { + if existingName != NetbirdVPNTableName { + return true, ErrTableIDExists + } + return true, nil + } + } + } + if err := scanner.Err(); err != nil { + return false, fmt.Errorf("scan rt_tables: %w", err) + } + return false, nil +} + +// addRoutingTableName adds human-readable names for custom routing tables. +func addRoutingTableName() error { + file, err := os.Open(rtTablesPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("open rt_tables: %w", err) + } + defer func() { + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables: %v", err) + } + }() + + exists, err := entryExists(file, NetbirdVPNTableID) + if err != nil { + return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err) + } + if exists { + return nil + } + + // Reopen the file in append mode to add new entries + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables before appending: %v", err) + } + file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) + if err != nil { + return fmt.Errorf("open rt_tables for appending: %w", err) + } + + if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil { + return fmt.Errorf("append entry to rt_tables: %w", err) + } + + return nil +} + +// addRule adds a routing rule to a specific routing table identified by tableID. +func addRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Priority = params.priority + rule.Invert = params.invert + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("add routing rule: %w", err) + } + + return nil +} + +// removeRule removes a routing rule from a specific routing table identified by tableID. +func removeRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Invert = params.invert + rule.Priority = params.priority + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleDel(rule); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + + return nil +} + +func removeAllRules(params ruleParams) error { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + for { + if ctx.Err() != nil { + done <- ctx.Err() + return + } + if err := removeRule(params); err != nil { + if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { + done <- nil + return + } + done <- err + return + } + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return err + } +} + +// addNextHop adds the gateway and device to the route. +func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { + if addr.IsValid() { + route.Gw = addr.AsSlice() + } + + if intf != "" { + link, err := netlink.LinkByName(intf) + if err != nil { + return fmt.Errorf("set interface %s: %w", intf, err) + } + route.LinkIndex = link.Attrs().Index + } + + return nil +} + +func getAddressFamily(prefix netip.Prefix) int { + if prefix.Addr().Is4() { + return netlink.FAMILY_V4 + } + return netlink.FAMILY_V6 } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go new file mode 100644 index 00000000000..0043c3f4e94 --- /dev/null +++ b/client/internal/routemanager/systemops_linux_test.go @@ -0,0 +1,207 @@ +//go:build !android + +package routemanager + +import ( + "errors" + "fmt" + "net" + "os" + "strings" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" +) + +var expectedVPNint = "wgtest0" +var expectedLoopbackInt = "lo" +var expectedExternalInt = "dummyext0" +var expectedInternalInt = "dummyint0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via physical interface", + destination: "10.10.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), + }, + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.1:53", + expectedInterface: expectedLoopbackInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + }...) +} + +func TestEntryExists(t *testing.T) { + tempDir := t.TempDir() + tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) + + content := []string{ + "1000 reserved", + fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName), + "9999 other_table", + } + require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644)) + + file, err := os.Open(tempFilePath) + require.NoError(t, err) + defer func() { + assert.NoError(t, file.Close()) + }() + + tests := []struct { + name string + id int + shouldExist bool + err error + }{ + { + name: "ExistsWithNetbirdPrefix", + id: 7120, + shouldExist: true, + err: nil, + }, + { + name: "ExistsWithDifferentName", + id: 1000, + shouldExist: true, + err: ErrTableIDExists, + }, + { + name: "DoesNotExist", + id: 1234, + shouldExist: false, + err: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + exists, err := entryExists(file, tc.id) + if tc.err != nil { + assert.ErrorIs(t, err, tc.err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.shouldExist, exists) + }) + } +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + t.Helper() + + dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} + err := netlink.LinkDel(dummy) + if err != nil && !errors.Is(err, syscall.EINVAL) { + t.Logf("Failed to delete dummy interface: %v", err) + } + + err = netlink.LinkAdd(dummy) + require.NoError(t, err) + + err = netlink.LinkSetUp(dummy) + require.NoError(t, err) + + if ipAddressCIDR != "" { + addr, err := netlink.ParseAddr(ipAddressCIDR) + require.NoError(t, err) + err = netlink.AddrAdd(dummy, addr) + require.NoError(t, err) + } + + t.Cleanup(func() { + err := netlink.LinkDel(dummy) + assert.NoError(t, err) + }) + + return dummy.Name +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { + t.Helper() + + _, dstIPNet, err := net.ParseCIDR(dstCIDR) + require.NoError(t, err) + + // Handle existing routes with metric 0 + var originalNexthop net.IP + var originalLinkIndex int + if dstIPNet.String() == "0.0.0.0/0" { + var err error + originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if originalNexthop != nil { + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + switch { + case err != nil && !errors.Is(err, syscall.ESRCH): + t.Logf("Failed to delete route: %v", err) + case err == nil: + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + default: + t.Logf("Failed to delete route: %v", err) + } + } + } + + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + + route := &netlink.Route{ + Dst: dstIPNet, + Gw: gw, + LinkIndex: linkIndex, + } + err = netlink.RouteDel(route) + if err != nil && !errors.Is(err, syscall.ESRCH) { + t.Logf("Failed to delete route: %v", err) + } + + err = netlink.RouteAdd(route) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + require.NoError(t, err) +} + +func fetchOriginalGateway(family int) (net.IP, int, error) { + routes, err := netlink.RouteList(nil, family) + if err != nil { + return nil, 0, err + } + + for _, route := range routes { + if route.Dst == nil && route.Priority == 0 { + return route.Gw, route.LinkIndex, nil + } + } + + return nil, 0, ErrRouteNotFound +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go deleted file mode 100644 index 11247c7dccd..00000000000 --- a/client/internal/routemanager/systemops_nonandroid.go +++ /dev/null @@ -1,120 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "fmt" - "net" - "net/netip" - - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" -) - -var errRouteNotFound = fmt.Errorf("route not found") - -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return err - } - if ok { - log.Warnf("skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return err - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, addr) -} - -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil && err != errRouteNotFound { - return err - } - - addr := netip.MustParseAddr(defaultGateway.String()) - - if !prefix.Contains(addr) { - log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(addr, 32) - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) - if err != nil && err != errRouteNotFound { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop.String()) -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, err - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, err - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { - return removeFromRouteTable(prefix, addr) -} - -func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { - r, err := netroute.New() - if err != nil { - return nil, err - } - _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) - if err != nil { - log.Errorf("getting routes returned an error: %v", err) - return nil, errRouteNotFound - } - - if gateway == nil { - return preferredSrc, nil - } - - return gateway, nil -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 47bd60eb02b..38026107ec7 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,41 +1,23 @@ -//go:build !linux -// +build !linux +//go:build !linux && !ios package routemanager import ( "net/netip" - "os/exec" "runtime" log "github.com/sirupsen/logrus" ) -func addToRouteTable(prefix netip.Prefix, addr string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } -func removeFromRouteTable(prefix netip.Prefix, addr string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) - } - cmd := exec.Command("route", args...) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) - return nil +func addVPNRoute(prefix netip.Prefix, intf string) error { + return genericAddVPNRoute(prefix, intf) } -func enableIPForwarding() error { - log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil +func removeVPNRoute(prefix netip.Prefix, intf string) error { + return genericRemoveVPNRoute(prefix, intf) } diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_test.go similarity index 59% rename from client/internal/routemanager/systemops_nonandroid_test.go rename to client/internal/routemanager/systemops_test.go index 6f32d9634bc..97386f19a1a 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -1,24 +1,32 @@ -//go:build !android +//go:build !android && !ios package routemanager import ( "bytes" + "context" "fmt" "net" "net/netip" "os" + "runtime" "strings" "testing" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string @@ -53,27 +61,30 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericAddVPNRoute should not return err") - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { - require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + assertWGOutInterface(t, testCase.prefix, wgInterface, false) } else { - require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") + assertWGOutInterface(t, testCase.prefix, wgInterface, true) } exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) - require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericRemoveVPNRoute should not return err") - prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") + prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") - internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) require.NoError(t, err) if testCase.shouldBeRemoved { @@ -86,12 +97,12 @@ func TestAddRemoveRoutes(t *testing.T) { } } -func TestGetExistingRIBRouteGateway(t *testing.T) { - gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) +func TestGetNextHop(t *testing.T) { + gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if gateway == nil { + if !gateway.IsValid() { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -113,11 +124,11 @@ func TestGetExistingRIBRouteGateway(t *testing.T) { } } - localIP, err := getExistingRIBRouteGateway(testingPrefix) + localIP, _, err := getNextHop(testingPrefix.Addr()) if err != nil { t.Fatal("shouldn't return error: ", err) } - if localIP == nil { + if !localIP.IsValid() { t.Fatal("should return a gateway for local network") } if localIP.String() == gateway.String() { @@ -128,8 +139,8 @@ func TestGetExistingRIBRouteGateway(t *testing.T) { } } -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) +func TestAddExistAndRemoveRoute(t *testing.T) { + defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) @@ -189,16 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - MockAddr := wgInterface.Address().IP.String() - // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) + err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -208,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -217,6 +226,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { ok, err := existsInRouteTable(testCase.prefix) t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") + if !strings.Contains(buf.String(), "because it already exists") { require.False(t, ok, "route should not exist") } @@ -224,31 +234,6 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { } } -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { - addressPrefixes = append(addressPrefixes, p.Masked()) - } - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} - func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { @@ -286,3 +271,132 @@ func TestIsSubRange(t *testing.T) { } } } + +func TestExistsInRouteTable(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is6() { + continue + } + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { + continue + } + // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence + if runtime.GOOS == "linux" && p.Addr().IsLoopback() { + continue + } + + addressPrefixes = append(addressPrefixes, p.Masked()) + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet() + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} + +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { + return + } + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go new file mode 100644 index 00000000000..561eaeea4b2 --- /dev/null +++ b/client/internal/routemanager/systemops_unix_test.go @@ -0,0 +1,234 @@ +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly + +package routemanager + +import ( + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +type testCase struct { + name string + destination string + expectedInterface string + dialer dialer + expectedPacket PacketExpectation +} + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.expectedInterface, filter) + + sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.expectedPacket) + }) + } +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + switch dialer := dialer.(type) { + case *nbnet.Dialer: + dialer.LocalAddr = localUDPAddr + case *net.Dialer: + dialer.LocalAddr = localUDPAddr + default: + t.Fatal("Unsupported dialer type") + } + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 309c184b9ca..50fff0cd58d 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -1,13 +1,19 @@ //go:build windows -// +build windows package routemanager import ( + "fmt" "net" "net/netip" + "os/exec" + "strings" + log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -15,23 +21,35 @@ type Win32_IP4RouteTable struct { Mask string } +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" err := wmi.Query(query, &routes) if err != nil { - return nil, err + return nil, fmt.Errorf("get routes: %w", err) } var prefixList []netip.Prefix for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { + log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) continue } maskSlice := net.ParseIP(route.Mask).To4() if maskSlice == nil { + log.Warnf("Unable to parse route mask %s", route.Mask) continue } mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) @@ -44,3 +62,69 @@ func getRoutesFromTable() ([]netip.Prefix, error) { } return prefixList, nil } + +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + destinationPrefix := prefix.String() + psCmd := "New-NetRoute" + + addressFamily := "IPv4" + if prefix.Addr().Is6() { + addressFamily = "IPv6" + } + + script := fmt.Sprintf( + `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, intf, + ) + + if nexthop.IsValid() { + script = fmt.Sprintf( + `%s -NextHop "%s"`, script, nexthop, + ) + } + + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + log.Tracef("PowerShell add route: %s", string(out)) + + if err != nil { + return fmt.Errorf("PowerShell add route: %w", err) + } + + return nil +} + +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"add", prefix.String(), nexthop.Unmap().String()} + + out, err := exec.Command("route", args...).CombinedOutput() + + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + if err != nil { + return fmt.Errorf("route add: %w", err) + } + + return nil +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + // Powershell doesn't support adding routes without an interface but allows to add interface by name + if intf != "" { + return addRoutePowershell(prefix, nexthop, intf) + } + return addRouteCmd(prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"delete", prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go new file mode 100644 index 00000000000..a5e03b8d2ce --- /dev/null +++ b/client/internal/routemanager/systemops_windows_test.go @@ -0,0 +1,289 @@ +package routemanager + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +var expectedExtInt = "Ethernet1" + +type RouteInfo struct { + NextHop string `json:"nexthop"` + InterfaceAlias string `json:"interfacealias"` + RouteMetric int `json:"routemetric"` +} + +type FindNetRouteOutput struct { + IPAddress string `json:"IPAddress"` + InterfaceIndex int `json:"InterfaceIndex"` + InterfaceAlias string `json:"InterfaceAlias"` + AddressFamily int `json:"AddressFamily"` + NextHop string `json:"NextHop"` + DestinationPrefix string `json:"DestinationPrefix"` +} + +type testCase struct { + name string + destination string + expectedSourceIP string + expectedDestPrefix string + expectedNextHop string + expectedInterface string + dialer dialer +} + +var expectedVPNint = "wgtest0" + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "128.0.0.0/1", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedDestPrefix: "192.0.2.1/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedDestPrefix: "10.0.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "10.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedDestPrefix: "172.16.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "172.16.0.0/12", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route without custom dialer via vpn interface", + destination: "10.10.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "10.10.0.0/24", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "127.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + route, err := fetchOriginalGateway() + require.NoError(t, err, "Failed to fetch original gateway") + ip, err := fetchInterfaceIP(route.InterfaceAlias) + require.NoError(t, err, "Failed to fetch interface IP") + + output := testRoute(t, tc.destination, tc.dialer) + if tc.expectedInterface == expectedExtInt { + verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) + } else { + verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + } + }) + } +} + +// fetchInterfaceIP fetches the IPv4 address of the specified interface. +func fetchInterfaceIP(interfaceAlias string) (string, error) { + script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) + out, err := exec.Command("powershell", "-Command", script).Output() + if err != nil { + return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) + } + + ip := strings.TrimSpace(string(out)) + return ip, nil +} + +func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "udp", destination) + require.NoError(t, err, "Failed to dial destination") + defer func() { + err := conn.Close() + assert.NoError(t, err, "Failed to close connection") + }() + + host, _, err := net.SplitHostPort(destination) + require.NoError(t, err) + + script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) + + out, err := exec.Command("powershell", "-Command", script).Output() + require.NoError(t, err, "Failed to execute Find-NetRoute") + + var outputs []FindNetRouteOutput + err = json.Unmarshal(out, &outputs) + require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") + + require.Greater(t, len(outputs), 0, "No route found for destination") + combinedOutput := combineOutputs(outputs) + + return combinedOutput +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + t.Helper() + + ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) + require.NoError(t, err) + subnetMaskSize, _ := ipNet.Mask.Size() + script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to assign IP address to loopback adapter") + + // Wait for the IP address to be applied + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = waitForIPAddress(ctx, interfaceName, ip.String()) + require.NoError(t, err, "IP address not applied within timeout") + + t.Cleanup(func() { + script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to remove IP address from loopback adapter") + }) + + return interfaceName +} + +func fetchOriginalGateway() (*RouteInfo, error) { + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) + } + + var routeInfo RouteInfo + err = json.Unmarshal(output, &routeInfo) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON output: %w", err) + } + + return &routeInfo, nil +} + +func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { + t.Helper() + + assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") + assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") + assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") + assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") +} + +func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() + if err != nil { + return err + } + + ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") + for _, ip := range ipAddresses { + if strings.TrimSpace(ip) == expectedIPAddress { + return nil + } + } + } + } +} + +func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { + var combined FindNetRouteOutput + + for _, output := range outputs { + if output.IPAddress != "" { + combined.IPAddress = output.IPAddress + } + if output.InterfaceIndex != 0 { + combined.InterfaceIndex = output.InterfaceIndex + } + if output.InterfaceAlias != "" { + combined.InterfaceAlias = output.InterfaceAlias + } + if output.AddressFamily != 0 { + combined.AddressFamily = output.AddressFamily + } + if output.NextHop != "" { + combined.NextHop = output.NextHop + } + if output.DestinationPrefix != "" { + combined.DestinationPrefix = output.DestinationPrefix + } + } + + return &combined +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") +} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go new file mode 100644 index 00000000000..e80adb42b20 --- /dev/null +++ b/client/internal/stdnet/dialer.go @@ -0,0 +1,24 @@ +package stdnet + +import ( + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// Dial connects to the address on the named network. +func (n *Net) Dial(network, address string) (net.Conn, error) { + return nbnet.NewDialer().Dial(network, address) +} + +// DialUDP connects to the address on the named UDP network. +func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.DialUDP(network, laddr, raddr) +} + +// DialTCP connects to the address on the named TCP network. +func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + return nbnet.DialTCP(network, laddr, raddr) +} diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go new file mode 100644 index 00000000000..9ce0a555610 --- /dev/null +++ b/client/internal/stdnet/listener.go @@ -0,0 +1,20 @@ +package stdnet + +import ( + "context" + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// ListenPacket listens for incoming packets on the given network and address. +func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { + return nbnet.NewListener().ListenPacket(context.Background(), network, address) +} + +// ListenUDP acts like ListenPacket for UDP networks. +func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.ListenUDP(network, locAddr) +} diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index f02b4943bc6..2235c5d2bdf 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -17,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) // WGEBPFProxy definition for proxy with EBPF support @@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - conn, err := net.ListenUDP("udp", &addr) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -228,6 +229,12 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return nil, fmt.Errorf("binding to lo interface failed: %w", err) } + // Set the fwmark on the socket. + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index b692ea70842..17ebfbc499b 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -6,6 +6,8 @@ import ( "net" log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -33,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/go.mod b/go.mod index e4e36b96685..29a1570c896 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/google/gopacket v1.1.19 github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 @@ -123,7 +124,6 @@ require ( github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect - github.com/gopacket/gopacket v1.1.1 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 36fd13cc262..9fe987cee21 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,6 +10,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -29,7 +31,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := 0 + fwmark := nbnet.NetbirdFwmark config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 200bfbc9614..24dfadf1408 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -13,6 +13,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgUSPConfigurer struct { @@ -37,7 +39,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error if err != nil { return err } - fwmark := 0 + fwmark := getFwmark() config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, @@ -345,3 +347,10 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } return sb.String() } + +func getFwmark() int { + if runtime.GOOS == "linux" { + return nbnet.NetbirdFwmark + } + return 0 +} diff --git a/management/client/grpc.go b/management/client/grpc.go index 0234f866cb8..0b1804906c2 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -57,6 +58,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE mgmCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 02b4e174dab..74ac6c163ad 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -21,6 +21,8 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" + + nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -82,10 +84,18 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } + if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + } + var sockErr error rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if sockErr != nil { log.Errorf("Failed to create ipv6 raw socket: %v", err) + } else { + if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7531608c3bb..7c4535e2896 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -76,6 +77,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo sigCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go new file mode 100644 index 00000000000..96b2bc32be0 --- /dev/null +++ b/util/grpc/dialer.go @@ -0,0 +1,22 @@ +package grpc + +import ( + "context" + "net" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func WithCustomDialer() grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, err + } + return conn, nil + }) +} diff --git a/util/net/dialer.go b/util/net/dialer.go new file mode 100644 index 00000000000..0786c667e53 --- /dev/null +++ b/util/net/dialer.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// Dialer extends the standard net.Dialer with the ability to execute hooks before +// and after connections. This can be used to bypass the VPN for connections using this dialer. +type Dialer struct { + *net.Dialer +} + +// NewDialer returns a customized net.Dialer with overridden Control method +func NewDialer() *Dialer { + dialer := &Dialer{ + Dialer: &net.Dialer{}, + } + dialer.init() + + return dialer +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go new file mode 100644 index 00000000000..06fac3bbf85 --- /dev/null +++ b/util/net/dialer_generic.go @@ -0,0 +1,163 @@ +//go:build !android && !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" +) + +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc +) + +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) +} + +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) +} + +// RemoveDialerHook removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil +} + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialerDialHooks != nil { + if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} + +func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go new file mode 100644 index 00000000000..aed5c59a322 --- /dev/null +++ b/util/net/dialer_linux.go @@ -0,0 +1,12 @@ +//go:build !android + +package net + +import "syscall" + +// init configures the net.Dialer Control function to set the fwmark on the socket +func (d *Dialer) init() { + d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) + } +} diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go new file mode 100644 index 00000000000..3254e6d066b --- /dev/null +++ b/util/net/dialer_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (d *Dialer) init() { +} diff --git a/util/net/listener.go b/util/net/listener.go new file mode 100644 index 00000000000..f4d769f587e --- /dev/null +++ b/util/net/listener.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before +// responding via the socket and after closing. This can be used to bypass the VPN for listeners. +type ListenerConfig struct { + *net.ListenConfig +} + +// NewListener creates a new ListenerConfig instance. +func NewListener() *ListenerConfig { + listener := &ListenerConfig{ + ListenConfig: &net.ListenConfig{}, + } + listener.init() + + return listener +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go new file mode 100644 index 00000000000..451279e9d25 --- /dev/null +++ b/util/net/listener_generic.go @@ -0,0 +1,163 @@ +//go:build !android && !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) +} + +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +// RemoveListenerHooks removes all dialer hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil +} + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.UDPConn) +} + +func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + return + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + return + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(id, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } +} + +func closeConn(id ConnectionID, conn net.PacketConn) error { + err := conn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(id, conn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + return err +} + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go new file mode 100644 index 00000000000..8d332160a04 --- /dev/null +++ b/util/net/listener_linux.go @@ -0,0 +1,14 @@ +//go:build !android + +package net + +import ( + "syscall" +) + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) + } +} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go new file mode 100644 index 00000000000..0dbbb360b53 --- /dev/null +++ b/util/net/listener_mobile.go @@ -0,0 +1,11 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, laddr) +} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go new file mode 100644 index 00000000000..fb6eadaaad8 --- /dev/null +++ b/util/net/listener_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (l *ListenerConfig) init() { +} diff --git a/util/net/net.go b/util/net/net.go new file mode 100644 index 00000000000..9ea7ae80340 --- /dev/null +++ b/util/net/net.go @@ -0,0 +1,17 @@ +package net + +import "github.com/google/uuid" + +const ( + // NetbirdFwmark is the fwmark value used by Netbird via wireguard + NetbirdFwmark = 0x1BD00 +) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} diff --git a/util/net/net_linux.go b/util/net/net_linux.go new file mode 100644 index 00000000000..82141750029 --- /dev/null +++ b/util/net/net_linux.go @@ -0,0 +1,35 @@ +//go:build !android + +package net + +import ( + "fmt" + "syscall" +) + +// SetSocketMark sets the SO_MARK option on the given socket connection +func SetSocketMark(conn syscall.Conn) error { + sysconn, err := conn.SyscallConn() + if err != nil { + return fmt.Errorf("get raw conn: %w", err) + } + + return SetRawSocketMark(sysconn) +} + +func SetRawSocketMark(conn syscall.RawConn) error { + var setErr error + + err := conn.Control(func(fd uintptr) { + setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + }) + if err != nil { + return fmt.Errorf("control: %w", err) + } + + if setErr != nil { + return fmt.Errorf("set SO_MARK: %w", setErr) + } + + return nil +} From c28657710a7184a43f14b0a0d341de819ebc3af7 Mon Sep 17 00:00:00 2001 From: verytrap <166317454+verytrap@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:18:38 +0800 Subject: [PATCH 35/89] Fix function names in comments (#1816) Signed-off-by: verytrap --- management/server/account.go | 2 +- management/server/idp/okta.go | 2 +- management/server/mock_server/account_mock.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c145c1bd789..20bd15ad698 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -278,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou return routes } -// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership +// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { var filteredRoutes []*route.Route for _, r := range routes { diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index d20ee7e4839..c8d33a207e3 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error { return nil } -// parseOktaUserToUserData parse okta user to UserData. +// parseOktaUser parse okta user to UserData. func parseOktaUser(user *okta.User) (*UserData, error) { var oktaUser struct { Email string `json:"email"` diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8e7c47a280a..8687937dc49 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -706,7 +706,7 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { return nil } -// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface +// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { if am.UpdateIntegratedValidatorGroupsFunc != nil { return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) From ac0fe6025b0ea565262b1fdd961f1eba439f59e5 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 9 Apr 2024 13:25:14 +0200 Subject: [PATCH 36/89] Fix routing issues with MacOS (#1815) * Handle zones properly * Use host routes for single IPs * Add GOOS and GOARCH to startup log * Log powershell command --- client/internal/connect.go | 3 +- client/internal/routemanager/systemops.go | 35 +++++++++++++++---- client/internal/routemanager/systemops_bsd.go | 28 +++++++++------ .../internal/routemanager/systemops_darwin.go | 6 +++- .../internal/routemanager/systemops_linux.go | 3 ++ .../routemanager/systemops_windows.go | 33 ++++++++++++----- util/net/dialer_generic.go | 4 +-- 7 files changed, 83 insertions(+), 29 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 682a1efedb8..b50b3a62910 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "runtime" "strings" "time" @@ -93,7 +94,7 @@ func runClient( relayProbe *Probe, wgProbe *Probe, ) error { - log.Infof("starting NetBird client version %s", version.NetbirdVersion()) + log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) // Check if client was not shut down in a clean way and restore DNS config if required. // Otherwise, we might not be able to connect to the management server to retrieve new config. diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index a91f53636da..1ee54b746d8 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -8,6 +8,8 @@ import ( "fmt" "net" "net/netip" + "runtime" + "strconv" "github.com/hashicorp/go-multierror" "github.com/libp2p/go-netroute" @@ -85,23 +87,42 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { if preferredSrc == nil { - return netip.Addr{}, nil, ErrRouteNotFound + return netip.Addr{}, nil, ErrRouteNotFound } log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + addr, err := ipToAddr(preferredSrc, intf) + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err) } return addr.Unmap(), intf, nil } - addr, ok := netip.AddrFromSlice(gateway) + addr, err := ipToAddr(gateway, intf) + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err) + } + + return addr, intf, nil +} + +// converts a net.IP to a netip.Addr including the zone based on the passed interface +func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { + addr, ok := netip.AddrFromSlice(ip) if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip) + } + + if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) { + log.Tracef("Adding zone %s to address %s", intf.Name, addr) + if runtime.GOOS == "windows" { + addr = addr.WithZone(strconv.Itoa(intf.Index)) + } else { + addr = addr.WithZone(intf.Name) + } } - return addr.Unmap(), intf, nil + return addr.Unmap(), nil } func existsInRouteTable(prefix netip.Prefix) (bool, error) { diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 173e7c0e847..b6a2006e776 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -8,6 +8,7 @@ import ( "net/netip" "syscall" + log "github.com/sirupsen/logrus" "golang.org/x/net/route" ) @@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) { continue } - addr, ok := toNetIPAddr(m.Addrs[0]) - if !ok { + if len(m.Addrs) < 3 { + log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs) continue } - mask, ok := toNetIPMASK(m.Addrs[2]) + addr, ok := toNetIPAddr(m.Addrs[0]) if !ok { continue } - cidr, _ := mask.Size() + + cidr := 32 + if mask := m.Addrs[2]; mask != nil { + cidr, ok = toCIDR(mask) + if !ok { + log.Debugf("Unexpected RIB message Addrs[2]: %v", mask) + continue + } + } routePrefix := netip.PrefixFrom(addr, cidr) if routePrefix.IsValid() { @@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) { func toNetIPAddr(a route.Addr) (netip.Addr, bool) { switch t := a.(type) { case *route.Inet4Addr: - ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - addr := netip.MustParseAddr(ip.String()) - return addr, true + return netip.AddrFrom4(t.IP), true default: return netip.Addr{}, false } } -func toNetIPMASK(a route.Addr) (net.IPMask, bool) { +func toCIDR(a route.Addr) (int, bool) { switch t := a.(type) { case *route.Inet4Addr: mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - return mask, true + cidr, _ := mask.Size() + return cidr, true default: - return nil, false + return 0, false } } diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go index f34964a8343..33b8287b6b8 100644 --- a/client/internal/routemanager/systemops_darwin.go +++ b/client/internal/routemanager/systemops_darwin.go @@ -35,6 +35,10 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { inet := "-inet" + network := prefix.String() + if prefix.IsSingleIP() { + network = prefix.Addr().String() + } if prefix.Addr().Is6() { inet = "-inet6" // Special case for IPv6 split default route, pointing to the wg interface fails @@ -44,7 +48,7 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin } } - args := []string{"-n", action, inet, prefix.String()} + args := []string{"-n", action, inet, network} if nexthop.IsValid() { args = append(args, nexthop.Unmap().String()) } else if intf != "" { diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index ef464372737..dd00626e125 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -487,6 +487,9 @@ func removeAllRules(params ruleParams) error { func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { if addr.IsValid() { route.Gw = addr.AsSlice() + if intf == "" { + intf = addr.Zone() + } } if intf != "" { diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 50fff0cd58d..334ace45324 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -63,7 +63,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } -func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error { destinationPrefix := prefix.String() psCmd := "New-NetRoute" @@ -73,10 +73,20 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er } script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, - psCmd, addressFamily, destinationPrefix, intf, + `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, ) + if intfIdx != "" { + script = fmt.Sprintf( + `%s -InterfaceIndex %s`, script, intfIdx, + ) + } else { + script = fmt.Sprintf( + `%s -InterfaceAlias "%s"`, script, intf, + ) + } + if nexthop.IsValid() { script = fmt.Sprintf( `%s -NextHop "%s"`, script, nexthop, @@ -84,7 +94,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) er } out, err := exec.Command("powershell", "-Command", script).CombinedOutput() - log.Tracef("PowerShell add route: %s", string(out)) + log.Tracef("PowerShell %s: %s", script, string(out)) if err != nil { return fmt.Errorf("PowerShell add route: %w", err) @@ -98,7 +108,7 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s output: %s", strings.Join(args, " "), out) + log.Tracef("route %s: %s", strings.Join(args, " "), out) if err != nil { return fmt.Errorf("route add: %w", err) } @@ -107,9 +117,15 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { } func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + var intfIdx string + if nexthop.Zone() != "" { + intfIdx = nexthop.Zone() + nexthop.WithZone("") + } + // Powershell doesn't support adding routes without an interface but allows to add interface by name - if intf != "" { - return addRoutePowershell(prefix, nexthop, intf) + if intf != "" || intfIdx != "" { + return addRoutePowershell(prefix, nexthop, intf, intfIdx) } return addRouteCmd(prefix, nexthop, intf) } @@ -117,11 +133,12 @@ func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { args := []string{"delete", prefix.String()} if nexthop.IsValid() { + nexthop.WithZone("") args = append(args, nexthop.Unmap().String()) } out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s output: %s", strings.Join(args, " "), out) + log.Tracef("route %s: %s", strings.Join(args, " "), out) if err != nil { return fmt.Errorf("remove route: %w", err) diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 06fac3bbf85..4eda710ac40 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -56,7 +56,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. connID := GenerateConnID() if dialerDialHooks != nil { - if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + if err := callDialerHooks(ctx, connID, address, resolver); err != nil { log.Errorf("Failed to call dialer hooks: %v", err) } } @@ -97,7 +97,7 @@ func (c *Conn) Close() error { return err } -func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { +func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { host, _, err := net.SplitHostPort(address) if err != nil { return fmt.Errorf("split host and port: %w", err) From c1f66d135487b6093fe899cb0f078ded3d465ed8 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 9 Apr 2024 15:27:19 +0200 Subject: [PATCH 37/89] Retry macOS route command (#1817) --- .../internal/routemanager/systemops_darwin.go | 30 +++++++++++++-- .../routemanager/systemops_darwin_test.go | 38 +++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go index 33b8287b6b8..f7ce72a4e89 100644 --- a/client/internal/routemanager/systemops_darwin.go +++ b/client/internal/routemanager/systemops_darwin.go @@ -8,7 +8,9 @@ import ( "net/netip" "os/exec" "strings" + "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/peer" @@ -55,11 +57,33 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin args = append(args, "-interface", intf) } - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) + if err := retryRouteCmd(args); err != nil { + return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + } + return nil +} +func retryRouteCmd(args []string) error { + operation := func() error { + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + // https://github.com/golang/go/issues/45736 + if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") { + return err + } else if err != nil { + return backoff.Permanent(err) + } + return nil + } + + expBackOff := backoff.NewExponentialBackOff() + expBackOff.InitialInterval = 50 * time.Millisecond + expBackOff.MaxInterval = 500 * time.Millisecond + expBackOff.MaxElapsedTime = 1 * time.Second + + err := backoff.Retry(operation, expBackOff) if err != nil { - return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + return fmt.Errorf("route cmd retry failed: %w", err) } return nil } diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go index 5c5aaa24fe1..cc9bb9db598 100644 --- a/client/internal/routemanager/systemops_darwin_test.go +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -5,8 +5,10 @@ package routemanager import ( "fmt" "net" + "net/netip" "os/exec" "regexp" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -29,6 +31,42 @@ func init() { }...) } +func TestConcurrentRoutes(t *testing.T) { + baseIP := netip.MustParseAddr("192.0.2.0") + intf := "lo0" + + var wg sync.WaitGroup + for i := 0; i < 1024; i++ { + wg.Add(1) + go func(ip netip.Addr) { + defer wg.Done() + prefix := netip.PrefixFrom(ip, 32) + if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil { + t.Errorf("Failed to add route for %s: %v", prefix, err) + } + }(baseIP) + baseIP = baseIP.Next() + } + + wg.Wait() + + baseIP = netip.MustParseAddr("192.0.2.0") + + for i := 0; i < 1024; i++ { + wg.Add(1) + go func(ip netip.Addr) { + defer wg.Done() + prefix := netip.PrefixFrom(ip, 32) + if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil { + t.Errorf("Failed to remove route for %s: %v", prefix, err) + } + }(baseIP) + baseIP = baseIP.Next() + } + + wg.Wait() +} + func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { t.Helper() From 22b2caffc60ead6828d59d495b5670475cf95f3c Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:01:31 +0200 Subject: [PATCH 38/89] Remove dns based cloud detection (#1812) * remove dns based cloud checks * remove dns based cloud checks --- client/system/detect_cloud/detect.go | 2 - client/system/detect_cloud/gcp.go | 2 +- client/system/detect_cloud/ibmcloud.go | 54 ------------------------- client/system/detect_cloud/softlayer.go | 25 ------------ 4 files changed, 1 insertion(+), 82 deletions(-) delete mode 100644 client/system/detect_cloud/ibmcloud.go delete mode 100644 client/system/detect_cloud/softlayer.go diff --git a/client/system/detect_cloud/detect.go b/client/system/detect_cloud/detect.go index 3bbff434580..8a8de763eb2 100644 --- a/client/system/detect_cloud/detect.go +++ b/client/system/detect_cloud/detect.go @@ -25,8 +25,6 @@ func Detect(ctx context.Context) string { detectDigitalOcean, detectGCP, detectOracle, - detectIBMCloud, - detectSoftlayer, detectVultr, } diff --git a/client/system/detect_cloud/gcp.go b/client/system/detect_cloud/gcp.go index c673f893739..a24c38c0c6b 100644 --- a/client/system/detect_cloud/gcp.go +++ b/client/system/detect_cloud/gcp.go @@ -6,7 +6,7 @@ import ( ) func detectGCP(ctx context.Context) string { - req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil) + req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil) if err != nil { return "" } diff --git a/client/system/detect_cloud/ibmcloud.go b/client/system/detect_cloud/ibmcloud.go deleted file mode 100644 index 07de6a2ee11..00000000000 --- a/client/system/detect_cloud/ibmcloud.go +++ /dev/null @@ -1,54 +0,0 @@ -package detect_cloud - -import ( - "context" - "net/http" -) - -func detectIBMCloud(ctx context.Context) string { - v1ResultChan := make(chan bool, 1) - v2ResultChan := make(chan bool, 1) - - go func() { - v1ResultChan <- detectIBMSecure(ctx) - }() - - go func() { - v2ResultChan <- detectIBM(ctx) - }() - - v1Result, v2Result := <-v1ResultChan, <-v2ResultChan - - if v1Result || v2Result { - return "IBM Cloud" - } - return "" -} - -func detectIBMSecure(ctx context.Context) bool { - req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil) - if err != nil { - return false - } - - resp, err := hc.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == http.StatusOK -} - -func detectIBM(ctx context.Context) bool { - req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil) - if err != nil { - return false - } - - resp, err := hc.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == http.StatusOK -} diff --git a/client/system/detect_cloud/softlayer.go b/client/system/detect_cloud/softlayer.go deleted file mode 100644 index a09b522c454..00000000000 --- a/client/system/detect_cloud/softlayer.go +++ /dev/null @@ -1,25 +0,0 @@ -package detect_cloud - -import ( - "context" - "net/http" -) - -func detectSoftlayer(ctx context.Context) string { - req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil) - if err != nil { - return "" - } - - resp, err := hc.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - // Since SoftLayer was acquired by IBM, we should return "IBM Cloud" - return "IBM Cloud" - } - return "" -} From dd0cf4114713069c11e2a829e0ee90769e399fcf Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 10 Apr 2024 03:10:59 +0900 Subject: [PATCH 39/89] Auto restart Windows agent daemon service (#1819) This enables auto restart of the windows agent daemon service on event of failure --- client/cmd/service_installer.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index da6beef4ff9..5e147262bb4 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -64,6 +64,10 @@ var installCmd = &cobra.Command{ } } + if runtime.GOOS == "windows" { + svcConfig.Option["OnFailure"] = "restart" + } + ctx, cancel := context.WithCancel(cmd.Context()) s, err := newSVC(newProgram(ctx, cancel), svcConfig) From 90bd39c74017cb4974e7b5d9de56180eb2e2b66b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 9 Apr 2024 20:27:27 +0200 Subject: [PATCH 40/89] Log panics (#1818) --- client/internal/connect.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/client/internal/connect.go b/client/internal/connect.go index b50b3a62910..6b888c9cca8 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "runtime" + "runtime/debug" "strings" "time" @@ -94,6 +95,12 @@ func runClient( relayProbe *Probe, wgProbe *Probe, ) error { + defer func() { + if r := recover(); r != nil { + log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + } + }() + log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) // Check if client was not shut down in a clean way and restore DNS config if required. From 4c83408f27259b7fecfc0fa933a2ad68cbcfffaa Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 10 Apr 2024 04:00:43 +0900 Subject: [PATCH 41/89] Add log-level to the management's docker service command (#1820) --- infrastructure_files/docker-compose.yml.tmpl | 1 + 1 file changed, 1 insertion(+) diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 47cadb93c83..747eebd539d 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -58,6 +58,7 @@ services: command: [ "--port", "443", "--log-file", "console", + "--log-level", "info", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" From 3ed2f08f3c5dd930a598a26f24cf028807816486 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 9 Apr 2024 21:20:02 +0200 Subject: [PATCH 42/89] Add latency based routing (#1732) Now that we have the latency between peers available we can use this data to consider when choosing the best route. This way the route with the routing peer with the lower latency will be preferred over others with the same target network. --- client/internal/routemanager/client.go | 38 ++++-- client/internal/routemanager/client_test.go | 142 ++++++++++++++++++-- 2 files changed, 163 insertions(+), 17 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 38cf4bf6550..370ad5cf44b 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "time" log "github.com/sirupsen/logrus" @@ -18,6 +19,7 @@ type routerPeerStatus struct { connected bool relayed bool direct bool + latency time.Duration } type routesUpdate struct { @@ -68,6 +70,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { connected: peerStatus.ConnStatus == peer.StatusConnected, relayed: peerStatus.Relayed, direct: peerStatus.Direct, + latency: peerStatus.Latency, } } return routePeerStatuses @@ -83,11 +86,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { // * Non-relayed: Routes without relays are preferred. // * Direct connections: Routes with direct peer connections are favored. // * Stability: In case of equal scores, the currently active route (if any) is maintained. +// * Latency: Routes with lower latency are prioritized. // // It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" - chosenScore := 0 + chosenScore := float64(0) + currScore := float64(0) currID := "" if c.chosenRoute != nil { @@ -95,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro } for _, r := range c.routes { - tempScore := 0 + tempScore := float64(0) peerStatus, found := routePeerStatuses[r.ID] if !found || !peerStatus.connected { continue @@ -103,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro if r.Metric < route.MaxMetric { metricDiff := route.MaxMetric - r.Metric - tempScore = metricDiff * 10 + tempScore = float64(metricDiff) * 10 } + // in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route + latency := time.Second + if peerStatus.latency != 0 { + latency = peerStatus.latency + } else { + log.Warnf("peer %s has 0 latency", r.Peer) + } + tempScore += 1 - latency.Seconds() + if !peerStatus.relayed { tempScore++ } @@ -114,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro tempScore++ } - if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) { + if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") { chosen = r.ID chosenScore = tempScore } @@ -123,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro chosen = r.ID chosenScore = tempScore } + + if r.ID == currID { + currScore = tempScore + } } - if chosen == "" { + switch { + case chosen == "": var peers []string for _, r := range c.routes { peers = append(peers, r.Peer) } log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers) - - } else if chosen != currID { - log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) + case chosen != currID: + if currScore != 0 && currScore < chosenScore+0.1 { + return currID + } else { + log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) + } } return chosen diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 3700d72ecd7..d24d42b8eaf 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -3,6 +3,7 @@ package routemanager import ( "net/netip" "testing" + "time" "github.com/netbirdio/netbird/route" ) @@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name string statuses map[string]routerPeerStatus expectedRouteID string - currentRoute *route.Route + currentRoute string existingRoutes map[string]*route.Route }{ { @@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "", }, { @@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, + { + name: "multiple connected peers with different latencies", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + latency: 300 * time.Millisecond, + }, + "route2": { + connected: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "should ignore routes with latency 0", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "current route with similar score and similar but slightly worse latency should not change", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + direct: true, + latency: 12 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + direct: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "current chosen route doesn't exist anymore", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + direct: true, + latency: 20 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + direct: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "routeDoesntExistAnymore", + expectedRouteID: "route2", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + currentRoute := &route.Route{ + ID: "routeDoesntExistAnymore", + } + if tc.currentRoute != "" { + currentRoute = tc.existingRoutes[tc.currentRoute] + } + // create new clientNetwork client := &clientNetwork{ network: netip.MustParsePrefix("192.168.0.0/24"), routes: tc.existingRoutes, - chosenRoute: tc.currentRoute, + chosenRoute: currentRoute, } chosenRoute := client.getBestRouteFromStatuses(tc.statuses) From 704c67dec8a1e62da43523938dc0e5322d306411 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 11 Apr 2024 10:02:51 +0200 Subject: [PATCH 43/89] Allow owners that did not create the account to delete it (#1825) Sometimes the Owner role will be passed to new users, and they need to be able to delete the account --- management/server/account.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 20bd15ad698..099369fc2af 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -242,19 +242,19 @@ type UserPermissions struct { } type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` - NonDeletable bool `json:"non_deletable"` - LastLogin time.Time `json:"last_login"` - Issued string `json:"issued"` + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + NonDeletable bool `json:"non_deletable"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` + Permissions UserPermissions `json:"permissions"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes @@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") } - if user.Id != account.CreatedBy { + if user.Role != UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } for _, otherUser := range account.Users { From 9505805313137c5018839d633508c2b47e590525 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 11 Apr 2024 14:08:03 +0200 Subject: [PATCH 44/89] Rename variable (#1829) --- management/server/peer.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index fda8e49e9cc..1448e301197 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -551,8 +551,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } - requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) - if requiresApproval { + peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + if peerNotValid { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } @@ -563,11 +563,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network am.updateAccountPeers(account) } - approvedPeersMap, err := am.GetValidatedPeers(account) + validPeersMap, err := am.GetValidatedPeers(account) if err != nil { return nil, nil, err } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil } // LoginPeer logs in or registers a peer. From 061f673a4f4c7ec40c288d39e920764aa72c7888 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 11 Apr 2024 15:29:03 +0200 Subject: [PATCH 45/89] Don't use the custom dialer as non-root (#1823) --- util/grpc/dialer.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 96b2bc32be0..63c56de17d6 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -3,6 +3,8 @@ package grpc import ( "context" "net" + "os/user" + "runtime" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -12,6 +14,20 @@ import ( func WithCustomDialer() grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + if runtime.GOOS == "linux" { + currentUser, err := user.Current() + if err != nil { + log.Fatalf("failed to get current user: %v", err) + } + + // the custom dialer requires root permissions which are not required for use cases run as non-root + if currentUser.Uid != "0" { + dialer := &net.Dialer{} + return dialer.DialContext(ctx, "tcp", addr) + } + } + + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) From 76702c8a09f050214bb9fb1b079c72869bc861c9 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Thu, 11 Apr 2024 14:12:23 -0600 Subject: [PATCH 46/89] Add safe read/write to route map (#1760) --- client/internal/engine.go | 1 + client/internal/peer/conn.go | 5 +++- client/internal/peer/status.go | 40 +++++++++++++++++++++++--- client/internal/peer/status_test.go | 5 ++++ client/internal/routemanager/client.go | 7 ++--- client/server/server.go | 2 +- 6 files changed, 49 insertions(+), 11 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index d6238c4b3ca..ba7074672c3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -794,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { FQDN: offlinePeer.GetFqdn(), ConnStatus: peer.StatusDisconnected, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } } e.statusRecorder.ReplaceOfflinePeers(replacement) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index f3d07dcad1f..9e7ee695932 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error { } conn.agent, err = ice.NewAgent(agentConfig) - if err != nil { return err } @@ -285,6 +284,7 @@ func (conn *Conn) Open() error { IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], ConnStatusUpdate: time.Now(), ConnStatus: conn.status, + Mux: new(sync.RWMutex), } err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { @@ -344,6 +344,7 @@ func (conn *Conn) Open() error { PubKey: conn.config.Key, ConnStatus: conn.status, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } err = conn.statusRecorder.UpdatePeerState(peerState) if err != nil { @@ -468,6 +469,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()), Direct: !isRelayCandidate(pair.Local), RosenpassEnabled: rosenpassEnabled, + Mux: new(sync.RWMutex), } if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { peerState.Relayed = true @@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error { PubKey: conn.config.Key, ConnStatus: conn.status, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index ca97c3ea497..ddea7d04e16 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -14,6 +14,7 @@ import ( // State contains the latest state of a peer type State struct { + Mux *sync.RWMutex IP string PubKey string FQDN string @@ -30,7 +31,38 @@ type State struct { BytesRx int64 Latency time.Duration RosenpassEnabled bool - Routes map[string]struct{} + routes map[string]struct{} +} + +// AddRoute add a single route to routes map +func (s *State) AddRoute(network string) { + s.Mux.Lock() + if s.routes == nil { + s.routes = make(map[string]struct{}) + } + s.routes[network] = struct{}{} + s.Mux.Unlock() +} + +// SetRoutes set state routes +func (s *State) SetRoutes(routes map[string]struct{}) { + s.Mux.Lock() + s.routes = routes + s.Mux.Unlock() +} + +// DeleteRoute removes a route from the network amp +func (s *State) DeleteRoute(network string) { + s.Mux.Lock() + delete(s.routes, network) + s.Mux.Unlock() +} + +// GetRoutes return routes map +func (s *State) GetRoutes() map[string]struct{} { + s.Mux.RLock() + defer s.Mux.RUnlock() + return s.routes } // LocalPeerState contains the latest state of the local peer @@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error { PubKey: peerPubKey, ConnStatus: StatusDisconnected, FQDN: fqdn, + Mux: new(sync.RWMutex), } d.peerListChangedForNotification = true return nil @@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.Routes != nil { - peerState.Routes = receivedState.Routes + if receivedState.GetRoutes() != nil { + peerState.SetRoutes(receivedState.GetRoutes()) } skipNotification := shouldSkipNotify(receivedState, peerState) @@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool { s, ok := gstatus.FromError(d.managementError) if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return true - } return false } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 9038371bd1c..a4a6e608132 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -3,6 +3,7 @@ package peer import ( "errors" "testing" + "sync" "github.com/stretchr/testify/assert" ) @@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 370ad5cf44b..d41ed422b81 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -196,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { return fmt.Errorf("get peer state: %v", err) } - delete(state.Routes, c.network.String()) + state.DeleteRoute(c.network.String()) if err := c.statusRecorder.UpdatePeerState(state); err != nil { log.Warnf("Failed to update peer state: %v", err) } @@ -268,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { if err != nil { log.Errorf("Failed to get peer state: %v", err) } else { - if state.Routes == nil { - state.Routes = map[string]struct{}{} - } - state.Routes[c.network.String()] = struct{}{} + state.AddRoute(c.network.String()) if err := c.statusRecorder.UpdatePeerState(state); err != nil { log.Warnf("Failed to update peer state: %v", err) } diff --git a/client/server/server.go b/client/server/server.go index d1d9dbda451..d33bb515582 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, - Routes: maps.Keys(peerState.Routes), + Routes: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) From 91b2f9fc51b38d8d05d8906ba2899a5f7ffc58ca Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 12 Apr 2024 15:22:40 +0200 Subject: [PATCH 47/89] Use route active store (#1834) --- client/internal/routemanager/systemops_windows.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 334ace45324..ba211082f1f 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -73,7 +73,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx s } script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`, + `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`, psCmd, addressFamily, destinationPrefix, ) From 15a2feb7237e7aa7d3b16a251e8cb35a87a2a7d7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 12 Apr 2024 16:07:03 +0200 Subject: [PATCH 48/89] Use fixed preference for rules (#1836) --- .../internal/routemanager/systemops_linux.go | 50 +++---------------- 1 file changed, 8 insertions(+), 42 deletions(-) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index dd00626e125..2dfde31a222 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -4,14 +4,12 @@ package routemanager import ( "bufio" - "context" "errors" "fmt" "net" "net/netip" "os" "syscall" - "time" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -41,10 +39,10 @@ var routeManager = &RouteManager{} var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" type ruleParams struct { + priority int fwmark int tableID int family int - priority int invert bool suppressPrefix int description string @@ -52,10 +50,10 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, + {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, + {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, + {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, + {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, } } @@ -69,8 +67,6 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -// -// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { if isLegacy { log.Infof("Using legacy routing setup") @@ -123,7 +119,7 @@ func cleanupRouting() error { rules := getSetupRules() for _, rule := range rules { - if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { + if err := removeRule(rule); err != nil { result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) } } @@ -429,7 +425,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("add routing rule: %w", err) } @@ -446,43 +442,13 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("remove routing rule: %w", err) } return nil } -func removeAllRules(params ruleParams) error { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - done := make(chan error, 1) - go func() { - for { - if ctx.Err() != nil { - done <- ctx.Err() - return - } - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { - done <- nil - return - } - done <- err - return - } - } - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - return err - } -} - // addNextHop adds the gateway and device to the route. func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { if addr.IsValid() { From d30cf8706ae197eed2b65bfd287c73c9395db0c2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 12 Apr 2024 16:53:11 +0200 Subject: [PATCH 49/89] Allow disabling custom routing (#1840) --- client/internal/routemanager/manager.go | 25 +++++++++++++------ .../internal/routemanager/systemops_linux.go | 2 +- client/internal/wgproxy/proxy_ebpf.go | 2 +- iface/wg_configurer_kernel.go | 4 +-- iface/wg_configurer_usp.go | 2 +- util/net/dialer_generic.go | 12 +++++++++ util/net/listener_generic.go | 11 +++++++- util/net/net.go | 12 ++++++++- util/net/net_linux.go | 10 +++++++- 9 files changed, 64 insertions(+), 16 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 36a37f02c50..0dfc0f7e008 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, // Init sets up the routing func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + if nbnet.CustomRoutingDisabled() { + return nil, nil, nil + } + if err := cleanupRouting(); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } - if err := cleanupRouting(); err != nil { - log.Errorf("Error cleaning up routing: %v", err) - } else { - log.Info("Routing cleanup complete") + + if !nbnet.CustomRoutingDisabled() { + if err := cleanupRouting(); err != nil { + log.Errorf("Error cleaning up routing: %v", err) + } else { + log.Info("Routing cleanup complete") + } } + m.ctx = nil } @@ -210,9 +219,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } func isPrefixSupported(prefix netip.Prefix) bool { - switch runtime.GOOS { - case "linux", "windows", "darwin": - return true + if !nbnet.CustomRoutingDisabled() { + switch runtime.GOOS { + case "linux", "windows", "darwin": + return true + } } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 2dfde31a222..a0f55131df4 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -36,7 +36,7 @@ const ( var ErrTableIDExists = errors.New("ID exists with different name") var routeManager = &RouteManager{} -var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() type ruleParams struct { priority int diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 2235c5d2bdf..22d3273762b 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { } // Set the fwmark on the socket. - err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) + err = nbnet.SetSocketOpt(fd) if err != nil { return nil, fmt.Errorf("setting fwmark failed: %w", err) } diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 9fe987cee21..67bfb716d0f 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,8 +10,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := nbnet.NetbirdFwmark + fwmark := getFwmark() config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 24dfadf1408..c15bc1448fc 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } func getFwmark() int { - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() { return nbnet.NetbirdFwmark } return 0 diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 4eda710ac40..1e217da1369 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -49,6 +49,10 @@ func RemoveDialerHooks() { // DialContext wraps the net.Dialer's DialContext method to use the custom connection func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if CustomRoutingDisabled() { + return d.Dialer.DialContext(ctx, network, address) + } + var resolver *net.Resolver if d.Resolver != nil { resolver = d.Resolver @@ -123,6 +127,10 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r } func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + dialer := NewDialer() dialer.LocalAddr = laddr @@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { } func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + dialer := NewDialer() dialer.LocalAddr = laddr diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index 451279e9d25..7847a29c737 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -8,6 +8,7 @@ import ( "net" "sync" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" ) @@ -52,6 +53,10 @@ func RemoveListenerHooks() { // ListenPacket listens on the network address and returns a PacketConn // which includes support for write hooks. func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) if err != nil { return nil, fmt.Errorf("listen packet: %w", err) @@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { // ListenUDP listens on the network address and returns a transport.UDPConn // which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) if err != nil { return nil, fmt.Errorf("listen UDP: %w", err) diff --git a/util/net/net.go b/util/net/net.go index 9ea7ae80340..3856911b1b7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,10 +1,16 @@ package net -import "github.com/google/uuid" +import ( + "os" + + "github.com/google/uuid" +) const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard NetbirdFwmark = 0x1BD00 + + envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) // ConnectionID provides a globally unique identifier for network connections. @@ -15,3 +21,7 @@ type ConnectionID string func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } + +func CustomRoutingDisabled() bool { + return os.Getenv(envDisableCustomRouting) == "true" +} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 82141750029..954545eb556 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + setErr = SetSocketOpt(int(fd)) }) if err != nil { return fmt.Errorf("control: %w", err) @@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error { return nil } + +func SetSocketOpt(fd int) error { + if CustomRoutingDisabled() { + return nil + } + + return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) +} From 5ea24ba56e219d08fa9a5e444e0c7dfb9dd4ec50 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 12 Apr 2024 17:53:07 +0200 Subject: [PATCH 50/89] Add sysctl opts to prevent reverse path filtering from dropping fwmark packets (#1839) --- .../internal/routemanager/systemops_linux.go | 131 +++++++++++++++--- .../internal/routemanager/systemops_test.go | 2 +- 2 files changed, 115 insertions(+), 18 deletions(-) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index a0f55131df4..d1302b39cc6 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -9,6 +9,8 @@ import ( "net" "net/netip" "os" + "strconv" + "strings" "syscall" "github.com/hashicorp/go-multierror" @@ -30,14 +32,26 @@ const ( rtTablesPath = "/etc/iproute2/rt_tables" // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. - ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" + ipv4ForwardingPath = "net.ipv4.ip_forward" + + rpFilterPath = "net.ipv4.conf.all.rp_filter" + rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter" + srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark" ) var ErrTableIDExists = errors.New("ID exists with different name") var routeManager = &RouteManager{} + +// originalSysctl stores the original sysctl values before they are modified +var originalSysctl map[string]int + +// determines whether to use the legacy routing setup var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() +// sysctlFailed is used as an indicator to emit a warning when default routes are configured +var sysctlFailed bool + type ruleParams struct { priority int fwmark int @@ -77,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before log.Errorf("Error adding routing table name: %v", err) } + originalValues, err := setupSysctl(wgIface) + if err != nil { + log.Errorf("Error setting up sysctl: %v", err) + sysctlFailed = true + } + originalSysctl = originalValues + defer func() { if err != nil { if cleanErr := cleanupRouting(); cleanErr != nil { @@ -124,6 +145,12 @@ func cleanupRouting() error { } } + if err := cleanupSysctl(originalSysctl); err != nil { + result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err)) + } + originalSysctl = nil + sysctlFailed = false + return result.ErrorOrNil() } @@ -140,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error { return genericAddVPNRoute(prefix, intf) } + if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) { + log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)") + } + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // TODO remove this once we have ipv6 support @@ -332,22 +363,8 @@ func flushRoutes(tableID, family int) error { } func enableIPForwarding() error { - bytes, err := os.ReadFile(ipv4ForwardingPath) - if err != nil { - return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) - } - - // check if it is already enabled - // see more: https://github.com/netbirdio/netbird/issues/872 - if len(bytes) > 0 && bytes[0] == 49 { - return nil - } - - //nolint:gosec - if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { - return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) - } - return nil + _, err := setSysctl(ipv4ForwardingPath, 1, false) + return err } // entryExists checks if the specified ID or name already exists in the rt_tables file @@ -475,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int { } return netlink.FAMILY_V6 } + +// setupSysctl configures sysctl settings for RP filtering and source validation. +func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) { + keys := map[string]int{} + var result *multierror.Error + + oldVal, err := setSysctl(srcValidMarkPath, 1, false) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[srcValidMarkPath] = oldVal + } + + oldVal, err = setSysctl(rpFilterPath, 2, true) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[rpFilterPath] = oldVal + } + + interfaces, err := net.Interfaces() + if err != nil { + result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err)) + } + + for _, intf := range interfaces { + if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() { + continue + } + + i := fmt.Sprintf(rpFilterInterfacePath, intf.Name) + oldVal, err := setSysctl(i, 2, true) + if err != nil { + result = multierror.Append(result, err) + } else { + keys[i] = oldVal + } + } + + return keys, result.ErrorOrNil() +} + +// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1 +func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) { + path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/")) + currentValue, err := os.ReadFile(path) + if err != nil { + return -1, fmt.Errorf("read sysctl %s: %w", key, err) + } + + currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue))) + if err != nil && len(currentValue) > 0 { + return -1, fmt.Errorf("convert current desiredValue to int: %w", err) + } + + if currentV == desiredValue || onlyIfOne && currentV != 1 { + return currentV, nil + } + + //nolint:gosec + if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil { + return currentV, fmt.Errorf("write sysctl %s: %w", key, err) + } + log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue) + + return currentV, nil +} + +func cleanupSysctl(originalSettings map[string]int) error { + var result *multierror.Error + + for key, value := range originalSettings { + _, err := setSysctl(key, value, false) + if err != nil { + result = multierror.Append(result, err) + } + } + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 97386f19a1a..9f906c06fbe 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) + _, _, err = setupRouting(nil, wgInterface) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) From 5204d0781103a45669b5e29aed0e1289412f69cf Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 15 Apr 2024 12:08:38 +0200 Subject: [PATCH 51/89] Pass integrated validator for API (#1814) Pass integrated validator for API handler --- go.mod | 2 +- go.sum | 4 ++-- management/cmd/management.go | 2 +- management/server/account.go | 1 + management/server/http/handler.go | 5 +++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 29a1570c896..1d5ec92532e 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index b488a42a42a..97b935d9cf0 100644 --- a/go.sum +++ b/go.sum @@ -383,8 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= diff --git a/management/cmd/management.go b/management/cmd/management.go index 23d9c195cd3..3669358029a 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -251,7 +251,7 @@ var ( ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg) + httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 099369fc2af..c3ba0c86c0d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1849,6 +1849,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut } func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { + log.Debugf("validated peers has been invalidated for account %s", accountID) updatedAccount, err := am.Store.GetAccount(accountID) if err != nil { log.Errorf("failed to get account %s: %v", accountID, err) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index bdbeba3464f..4405d295c5c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,6 +12,7 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -38,7 +39,7 @@ type emptyObject struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa AuthCfg: authCfg, } - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } From e0de86d6c9e7082919010bd60df959c9e872c486 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 15 Apr 2024 14:15:46 +0200 Subject: [PATCH 52/89] Use fixed activity codes (#1846) * Add duplicate constants check --- .github/workflows/golangci-lint.yml | 4 + management/server/activity/codes.go | 127 ++++++++++++++-------------- 2 files changed, 68 insertions(+), 63 deletions(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 13228250d59..50cb4e2afaf 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -33,6 +33,10 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v3 + - name: Check for duplicate constants + if: matrix.os == 'ubuntu-latest' + run: | + ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . - name: Install Go uses: actions/setup-go@v4 with: diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index e179fd14d38..4ee57f1817c 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -11,133 +11,134 @@ type Code struct { Code string } +// Existing consts must not be changed, as this will break the compatibility with the existing data const ( // PeerAddedByUser indicates that a user added a new peer to the system - PeerAddedByUser Activity = iota + PeerAddedByUser Activity = 0 // PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key - PeerAddedWithSetupKey + PeerAddedWithSetupKey Activity = 1 // UserJoined indicates that a new user joined the account - UserJoined + UserJoined Activity = 2 // UserInvited indicates that a new user was invited to join the account - UserInvited + UserInvited Activity = 3 // AccountCreated indicates that a new account has been created - AccountCreated + AccountCreated Activity = 4 // PeerRemovedByUser indicates that a user removed a peer from the system - PeerRemovedByUser + PeerRemovedByUser Activity = 5 // RuleAdded indicates that a user added a new rule - RuleAdded + RuleAdded Activity = 6 // RuleUpdated indicates that a user updated a rule - RuleUpdated + RuleUpdated Activity = 7 // RuleRemoved indicates that a user removed a rule - RuleRemoved + RuleRemoved Activity = 8 // PolicyAdded indicates that a user added a new policy - PolicyAdded + PolicyAdded Activity = 9 // PolicyUpdated indicates that a user updated a policy - PolicyUpdated + PolicyUpdated Activity = 10 // PolicyRemoved indicates that a user removed a policy - PolicyRemoved + PolicyRemoved Activity = 11 // SetupKeyCreated indicates that a user created a new setup key - SetupKeyCreated + SetupKeyCreated Activity = 12 // SetupKeyUpdated indicates that a user updated a setup key - SetupKeyUpdated + SetupKeyUpdated Activity = 13 // SetupKeyRevoked indicates that a user revoked a setup key - SetupKeyRevoked + SetupKeyRevoked Activity = 14 // SetupKeyOverused indicates that setup key usage exhausted - SetupKeyOverused + SetupKeyOverused Activity = 15 // GroupCreated indicates that a user created a group - GroupCreated + GroupCreated Activity = 16 // GroupUpdated indicates that a user updated a group - GroupUpdated + GroupUpdated Activity = 17 // GroupAddedToPeer indicates that a user added group to a peer - GroupAddedToPeer + GroupAddedToPeer Activity = 18 // GroupRemovedFromPeer indicates that a user removed peer group - GroupRemovedFromPeer + GroupRemovedFromPeer Activity = 19 // GroupAddedToUser indicates that a user added group to a user - GroupAddedToUser + GroupAddedToUser Activity = 20 // GroupRemovedFromUser indicates that a user removed a group from a user - GroupRemovedFromUser + GroupRemovedFromUser Activity = 21 // UserRoleUpdated indicates that a user changed the role of a user - UserRoleUpdated + UserRoleUpdated Activity = 22 // GroupAddedToSetupKey indicates that a user added group to a setup key - GroupAddedToSetupKey + GroupAddedToSetupKey Activity = 23 // GroupRemovedFromSetupKey indicates that a user removed a group from a setup key - GroupRemovedFromSetupKey + GroupRemovedFromSetupKey Activity = 24 // GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups - GroupAddedToDisabledManagementGroups + GroupAddedToDisabledManagementGroups Activity = 25 // GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups - GroupRemovedFromDisabledManagementGroups + GroupRemovedFromDisabledManagementGroups Activity = 26 // RouteCreated indicates that a user created a route - RouteCreated + RouteCreated Activity = 27 // RouteRemoved indicates that a user deleted a route - RouteRemoved + RouteRemoved Activity = 28 // RouteUpdated indicates that a user updated a route - RouteUpdated + RouteUpdated Activity = 29 // PeerSSHEnabled indicates that a user enabled SSH server on a peer - PeerSSHEnabled + PeerSSHEnabled Activity = 30 // PeerSSHDisabled indicates that a user disabled SSH server on a peer - PeerSSHDisabled + PeerSSHDisabled Activity = 31 // PeerRenamed indicates that a user renamed a peer - PeerRenamed + PeerRenamed Activity = 32 // PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer - PeerLoginExpirationEnabled + PeerLoginExpirationEnabled Activity = 33 // PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer - PeerLoginExpirationDisabled + PeerLoginExpirationDisabled Activity = 34 // NameserverGroupCreated indicates that a user created a nameservers group - NameserverGroupCreated + NameserverGroupCreated Activity = 35 // NameserverGroupDeleted indicates that a user deleted a nameservers group - NameserverGroupDeleted + NameserverGroupDeleted Activity = 36 // NameserverGroupUpdated indicates that a user updated a nameservers group - NameserverGroupUpdated + NameserverGroupUpdated Activity = 37 // AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account - AccountPeerLoginExpirationEnabled + AccountPeerLoginExpirationEnabled Activity = 38 // AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account - AccountPeerLoginExpirationDisabled + AccountPeerLoginExpirationDisabled Activity = 39 // AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account - AccountPeerLoginExpirationDurationUpdated + AccountPeerLoginExpirationDurationUpdated Activity = 40 // PersonalAccessTokenCreated indicates that a user created a personal access token - PersonalAccessTokenCreated + PersonalAccessTokenCreated Activity = 41 // PersonalAccessTokenDeleted indicates that a user deleted a personal access token - PersonalAccessTokenDeleted + PersonalAccessTokenDeleted Activity = 42 // ServiceUserCreated indicates that a user created a service user - ServiceUserCreated + ServiceUserCreated Activity = 43 // ServiceUserDeleted indicates that a user deleted a service user - ServiceUserDeleted + ServiceUserDeleted Activity = 44 // UserBlocked indicates that a user blocked another user - UserBlocked + UserBlocked Activity = 45 // UserUnblocked indicates that a user unblocked another user - UserUnblocked + UserUnblocked Activity = 46 // UserDeleted indicates that a user deleted another user - UserDeleted + UserDeleted Activity = 47 // GroupDeleted indicates that a user deleted group - GroupDeleted + GroupDeleted Activity = 48 // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login - UserLoggedInPeer + UserLoggedInPeer Activity = 49 // PeerLoginExpired indicates that the user peer login has been expired and peer disconnected - PeerLoginExpired + PeerLoginExpired Activity = 50 // DashboardLogin indicates that the user logged in to the dashboard - DashboardLogin + DashboardLogin Activity = 51 // IntegrationCreated indicates that the user created an integration - IntegrationCreated + IntegrationCreated Activity = 52 // IntegrationUpdated indicates that the user updated an integration - IntegrationUpdated + IntegrationUpdated Activity = 53 // IntegrationDeleted indicates that the user deleted an integration - IntegrationDeleted + IntegrationDeleted Activity = 54 // AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account - AccountPeerApprovalEnabled + AccountPeerApprovalEnabled Activity = 55 // AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account - AccountPeerApprovalDisabled + AccountPeerApprovalDisabled Activity = 56 // PeerApproved indicates that the peer has been approved - PeerApproved + PeerApproved Activity = 57 // PeerApprovalRevoked indicates that the peer approval has been revoked - PeerApprovalRevoked + PeerApprovalRevoked Activity = 58 // TransferredOwnerRole indicates that the user transferred the owner role of the account - TransferredOwnerRole + TransferredOwnerRole Activity = 59 // PostureCheckCreated indicates that the user created a posture check - PostureCheckCreated + PostureCheckCreated Activity = 60 // PostureCheckUpdated indicates that the user updated a posture check - PostureCheckUpdated + PostureCheckUpdated Activity = 61 // PostureCheckDeleted indicates that the user deleted a posture check - PostureCheckDeleted + PostureCheckDeleted Activity = 62 ) var activityMap = map[Activity]Code{ From e3b76448f3dfc3ad597423cd5e034e76307ffbe8 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 16 Apr 2024 14:01:59 +0200 Subject: [PATCH 53/89] Fix ICE endpoint remote port in status command (#1851) --- client/internal/peer/conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 9e7ee695932..a0da82b8d21 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -466,7 +466,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem LocalIceCandidateType: pair.Local.Type().String(), RemoteIceCandidateType: pair.Remote.Type().String(), LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()), - RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()), + RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), Direct: !isRelayCandidate(pair.Local), RosenpassEnabled: rosenpassEnabled, Mux: new(sync.RWMutex), From 77488ad11a8394741b43a5a22346459484e300fa Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 18 Apr 2024 18:14:21 +0200 Subject: [PATCH 54/89] Migrate serializer:gob fields to serializer:json (#1855) --- management/server/migration/migration.go | 101 ++++++++++++++++++ management/server/migration/migration_test.go | 91 ++++++++++++++++ management/server/network.go | 2 +- management/server/sqlite_store.go | 39 ++++++- management/server/sqlite_store_test.go | 57 ++++++++++ route/route.go | 4 +- 6 files changed, 290 insertions(+), 4 deletions(-) create mode 100644 management/server/migration/migration.go create mode 100644 management/server/migration/migration_test.go diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go new file mode 100644 index 00000000000..ba31ee45990 --- /dev/null +++ b/management/server/migration/migration.go @@ -0,0 +1,101 @@ +package migration + +import ( + "encoding/gob" + "encoding/json" + "errors" + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding. +// T is the type of the model that contains the field to be migrated. +// S is the type of the field to be migrated. +func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error { + + oldColumnName := fieldName + newColumnName := fieldName + "_tmp" + + var model T + + if !db.Migrator().HasTable(&model) { + log.Debugf("Table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + var item string + if err := db.Model(model).Select(oldColumnName).First(&item).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Debugf("No records in table %s, no migration needed", tableName) + return nil + } + return fmt.Errorf("fetch first record: %w", err) + } + + var js json.RawMessage + var syntaxError *json.SyntaxError + err = json.Unmarshal([]byte(item), &js) + if err == nil || !errors.As(err, &syntaxError) { + log.Debugf("No migration needed for %s, %s", tableName, fieldName) + return nil + } + + if err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil { + return fmt.Errorf("add column %s: %w", newColumnName, err) + } + + var rows []map[string]any + if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil { + return fmt.Errorf("find rows: %w", err) + } + + for _, row := range rows { + var field S + + str, ok := row[oldColumnName].(string) + if !ok { + return fmt.Errorf("type assertion failed") + } + reader := strings.NewReader(str) + + if err := gob.NewDecoder(reader).Decode(&field); err != nil { + return fmt.Errorf("gob decode error: %w", err) + } + + jsonValue, err := json.Marshal(field) + if err != nil { + return fmt.Errorf("re-encode to JSON: %w", err) + } + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil { + return fmt.Errorf("update row: %w", err) + } + } + + if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil { + return fmt.Errorf("drop column %s: %w", oldColumnName, err) + } + if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil { + return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err) + } + + return nil + }); err != nil { + return err + } + + log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName) + + return nil +} diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go new file mode 100644 index 00000000000..4bef41b86a4 --- /dev/null +++ b/management/server/migration/migration_test.go @@ -0,0 +1,91 @@ +package migration_test + +import ( + "encoding/gob" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/migration" + "github.com/netbirdio/netbird/route" +) + +func setupDatabase(t *testing.T) *gorm.DB { + t.Helper() + + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{ + PrepareStmt: true, + }) + + require.NoError(t, err, "Failed to open database") + return db +} + +func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { + db := setupDatabase(t) + err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + require.NoError(t, err, "Migration should not fail for an empty database") +} + +func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.Account{}, &route.Route{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + _, ipnet, err := net.ParseCIDR("10.0.0.0/24") + require.NoError(t, err, "Failed to parse CIDR") + + type network struct { + server.Network + Net net.IPNet `gorm:"serializer:gob"` + } + + type account struct { + server.Account + Network *network `gorm:"embedded;embeddedPrefix:network_"` + } + + err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error + require.NoError(t, err, "Failed to insert Gob data") + + var gobStr string + err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error + assert.NoError(t, err, "Failed to fetch Gob data") + + err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet) + require.NoError(t, err, "Failed to decode Gob data") + + err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + require.NoError(t, err, "Migration should not fail with Gob data") + + var jsonStr string + db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated") +} + +func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.Account{}, &route.Route{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + _, ipnet, err := net.ParseCIDR("10.0.0.0/24") + require.NoError(t, err, "Failed to parse CIDR") + + err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error + require.NoError(t, err, "Failed to insert JSON data") + + err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + require.NoError(t, err, "Migration should not fail with JSON data") + + var jsonStr string + db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged") +} diff --git a/management/server/network.go b/management/server/network.go index ffe098c964c..0e7d753a73d 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -36,7 +36,7 @@ type NetworkMap struct { type Network struct { Identifier string `json:"id"` - Net net.IPNet `gorm:"serializer:gob"` + Net net.IPNet `gorm:"serializer:json"` Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index e6a9c846726..06369696091 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -3,6 +3,8 @@ package server import ( "errors" "fmt" + "net" + "net/netip" "path/filepath" "runtime" "strings" @@ -18,6 +20,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -40,6 +43,8 @@ type installation struct { InstallationIDValue string } +type migrationFunc func(*gorm.DB) error + // NewSqliteStore restores a store from the file located in the datadir func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) { storeStr := "store.db?cache=shared" @@ -64,13 +69,16 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, conns := runtime.NumCPU() sql.SetMaxOpenConns(conns) // TODO: make it configurable + if err := migrate(db); err != nil { + return nil, fmt.Errorf("migrate: %w", err) + } err = db.AutoMigrate( &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, ) if err != nil { - return nil, err + return nil, fmt.Errorf("auto migrate: %w", err) } return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil @@ -542,3 +550,32 @@ func (s *SqliteStore) Close() error { func (s *SqliteStore) GetStoreEngine() StoreEngine { return SqliteStoreEngine } + +// migrate migrates the SQLite database to the latest schema +func migrate(db *gorm.DB) error { + migrations := getMigrations() + + for _, m := range migrations { + if err := m(db); err != nil { + return err + } + } + + return nil +} + +func getMigrations() []migrationFunc { + return []migrationFunc{ + func(db *gorm.DB) error { + return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net") + }, + + func(db *gorm.DB) error { + return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network") + }, + + func(db *gorm.DB) error { + return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups") + }, + } +} diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go index e43a0cd9a86..31f9b8a5b32 100644 --- a/management/server/sqlite_store_test.go +++ b/management/server/sqlite_store_test.go @@ -3,6 +3,7 @@ package server import ( "fmt" "net" + "net/netip" "path/filepath" "runtime" "testing" @@ -12,6 +13,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + route2 "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -349,6 +352,60 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } +func TestMigrate(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStore(t) + + err := migrate(store.db) + require.NoError(t, err, "Migration should not fail on empty db") + + _, ipnet, err := net.ParseCIDR("10.0.0.0/24") + require.NoError(t, err, "Failed to parse CIDR") + + type network struct { + Network + Net net.IPNet `gorm:"serializer:gob"` + } + + type account struct { + Account + Network *network `gorm:"embedded;embeddedPrefix:network_"` + } + + act := &account{ + Network: &network{ + Net: *ipnet, + }, + } + + err = store.db.Save(act).Error + require.NoError(t, err, "Failed to insert Gob data") + + type route struct { + route2.Route + Network netip.Prefix `gorm:"serializer:gob"` + PeerGroups []string `gorm:"serializer:gob"` + } + + prefix := netip.MustParsePrefix("11.0.0.0/24") + rt := &route{ + Network: prefix, + PeerGroups: []string{"group1", "group2"}, + } + + err = store.db.Save(rt).Error + require.NoError(t, err, "Failed to insert Gob data") + + err = migrate(store.db) + require.NoError(t, err, "Migration should not fail on gob populated db") + + err = migrate(store.db) + require.NoError(t, err, "Migration should not fail on migrated db") +} + func newSqliteStore(t *testing.T) *SqliteStore { t.Helper() diff --git a/route/route.go b/route/route.go index 194e0c80d0f..7e8a8377c41 100644 --- a/route/route.go +++ b/route/route.go @@ -68,11 +68,11 @@ type Route struct { ID string `gorm:"primaryKey"` // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` - Network netip.Prefix `gorm:"serializer:gob"` + Network netip.Prefix `gorm:"serializer:json"` NetID string Description string Peer string - PeerGroups []string `gorm:"serializer:gob"` + PeerGroups []string `gorm:"serializer:json"` NetworkType NetworkType Masquerade bool Metric int From b74078fd95ae9e0b7749f6907addd6fc73445d9c Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sat, 20 Apr 2024 22:04:20 +0200 Subject: [PATCH 55/89] Use a better way to insert data in batches (#1874) --- management/server/sqlite_store.go | 8 +- management/server/sqlite_store_test.go | 148 +++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 3 deletions(-) diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index 06369696091..bfde82a6de7 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -55,8 +55,9 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, file := filepath.Join(dataDir, storeStr) db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), - PrepareStmt: true, + Logger: logger.Default.LogMode(logger.Silent), + CreateBatchSize: 400, + PrepareStmt: true, }) if err != nil { return nil, err @@ -196,7 +197,8 @@ func (s *SqliteStore) SaveAccount(account *Account) error { result = tx. Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.OnConflict{UpdateAll: true}).Create(account) + Clauses(clause.OnConflict{UpdateAll: true}). + Create(account) if result.Error != nil { return result.Error } diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go index 31f9b8a5b32..8a1bcd10aeb 100644 --- a/management/server/sqlite_store_test.go +++ b/management/server/sqlite_store_test.go @@ -2,6 +2,9 @@ package server import ( "fmt" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "math/rand" "net" "net/netip" "path/filepath" @@ -33,6 +36,151 @@ func TestSqlite_NewStore(t *testing.T) { } } +func TestSqlite_SaveAccount_Large(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStore(t) + + account := newAccountWithId("account_id", "testuser", "") + groupALL, err := account.GetGroupAll() + if err != nil { + t.Fatal(err) + } + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + const numPerAccount = 2000 + for n := 0; n < numPerAccount; n++ { + netIP := randomIPv4() + peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + + peer := &nbpeer.Peer{ + ID: peerID, + Key: peerID, + SetupKey: "", + IP: netIP, + Name: peerID, + DNSLabel: peerID, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + account.Peers[peerID] = peer + group, _ := account.GetGroupAll() + group.Peers = append(group.Peers, peerID) + user := &User{ + Id: fmt.Sprintf("%s-user-%d", account.Id, n), + AccountID: account.Id, + } + account.Users[user.Id] = user + route := &route2.Route{ + ID: fmt.Sprintf("network-id-%d", n), + Description: "base route", + NetID: fmt.Sprintf("network-id-%d", n), + Network: netip.MustParsePrefix(netIP.String() + "/24"), + NetworkType: route2.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + Groups: []string{groupALL.ID}, + } + account.Routes[route.ID] = route + + group = &nbgroup.Group{ + ID: fmt.Sprintf("group-id-%d", n), + AccountID: account.Id, + Name: fmt.Sprintf("group-id-%d", n), + Issued: "api", + Peers: nil, + } + account.Groups[group.ID] = group + + nameserver := &nbdns.NameServerGroup{ + ID: fmt.Sprintf("nameserver-id-%d", n), + AccountID: account.Id, + Name: fmt.Sprintf("nameserver-id-%d", n), + Description: "", + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, + Groups: []string{group.ID}, + Primary: false, + Domains: nil, + Enabled: false, + SearchDomainsEnabled: false, + } + account.NameServerGroups[nameserver.ID] = nameserver + + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + } + + err = store.SaveAccount(account) + require.NoError(t, err) + + if len(store.GetAllAccounts()) != 1 { + t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") + } + + a, err := store.GetAccount(account.Id) + if a == nil { + t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) + } + + if a != nil && len(a.Policies) != 1 { + t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) + } + + if a != nil && len(a.Policies[0].Rules) != 1 { + t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) + return + } + + if a != nil && len(a.Peers) != numPerAccount { + t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d", + numPerAccount, len(a.Peers)) + return + } + + if a != nil && len(a.Users) != numPerAccount+1 { + t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d", + numPerAccount+1, len(a.Users)) + return + } + + if a != nil && len(a.Routes) != numPerAccount { + t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d", + numPerAccount, len(a.Routes)) + return + } + + if a != nil && len(a.NameServerGroups) != numPerAccount { + t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d", + numPerAccount, len(a.NameServerGroups)) + return + } + + if a != nil && len(a.NameServerGroups) != numPerAccount { + t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d", + numPerAccount, len(a.NameServerGroups)) + return + } + + if a != nil && len(a.SetupKeys) != numPerAccount+1 { + t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d", + numPerAccount+1, len(a.SetupKeys)) + return + } +} + +func randomIPv4() net.IP { + rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]byte, 4) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return net.IP(b) +} + func TestSqlite_SaveAccount(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") From 3c3111ad01bf14dd32209afaf8881bee982d1d9a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 22 Apr 2024 10:14:07 +0200 Subject: [PATCH 56/89] Copy client binary to a directory in path (#1842) --- client/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/Dockerfile b/client/Dockerfile index 327d39f94a1..7f4060f3d16 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,5 +1,5 @@ FROM alpine:3.18.5 RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true -ENTRYPOINT [ "/go/bin/netbird","up"] -COPY netbird /go/bin/netbird \ No newline at end of file +ENTRYPOINT [ "/usr/local/bin/netbird","up"] +COPY netbird /usr/local/bin/netbird \ No newline at end of file From 9e01155d2e03bf62f677866ab2b0a9a3b5ac0269 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Mon, 22 Apr 2024 11:00:52 +0200 Subject: [PATCH 57/89] Add new intro image --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d2a2bd6b9af..5be1826b475 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,8 @@ ### Open-Source Network Security in a Single Platform -![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f) + +![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) ### Key features From a80c8b017689f363ffbadf5db843267e3392455e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 22 Apr 2024 11:10:27 +0200 Subject: [PATCH 58/89] Redeem invite only when incoming user was invited (#1861) checks for users with pending invite status in the cache that already logged in and refresh the cache --- management/server/account.go | 36 +++++++++++++------ management/server/http/users_handler.go | 2 -- management/server/jwtclaims/claims.go | 1 + management/server/jwtclaims/extractor.go | 6 ++++ management/server/jwtclaims/extractor_test.go | 6 ++++ management/server/user.go | 6 ++-- 6 files changed, 41 insertions(+), 16 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c3ba0c86c0d..23f03015ecb 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -46,6 +46,8 @@ const ( DefaultPeerLoginExpiration = 24 * time.Hour ) +type userLoggedInOnce bool + type ExternalCacheManager cache.CacheInterface[*idp.UserData] func cacheEntryExpiration() time.Duration { @@ -1092,13 +1094,15 @@ func (am *DefaultAccountManager) warmupIDPCache() error { } delete(userData, idp.UnsetAccountID) + rcvdUsers := 0 for accountID, users := range userData { + rcvdUsers += len(users) err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) if err != nil { return err } } - log.Infof("warmed up IDP cache with %d entries", len(userData)) + log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData)) return nil } @@ -1263,7 +1267,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { - users := make(map[string]struct{}, len(account.Users)) + users := make(map[string]userLoggedInOnce, len(account.Users)) // ignore service users and users provisioned by integrations than are never logged in for _, user := range account.Users { if user.IsServiceUser { @@ -1272,7 +1276,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou if user.Issued == UserIssuedIntegration { continue } - users[user.Id] = struct{}{} + users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } log.Debugf("looking up user %s of account %s in cache", userID, account.Id) userData, err := am.lookupCache(users, account.Id) @@ -1345,22 +1349,30 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo } } -func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) { +func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) { data, err := am.getAccountFromCache(accountID, false) if err != nil { return nil, err } - userDataMap := make(map[string]struct{}) + userDataMap := make(map[string]*idp.UserData, len(data)) for _, datum := range data { - userDataMap[datum.ID] = struct{}{} + userDataMap[datum.ID] = datum } + mustRefreshInviteStatus := false + // the accountUsers ID list of non integration users from store, we check if cache has all of them // as result of for loop knownUsersCount will have number of users are not presented in the cashed knownUsersCount := len(accountUsers) - for user := range accountUsers { - if _, ok := userDataMap[user]; ok { + for user, loggedInOnce := range accountUsers { + if datum, ok := userDataMap[user]; ok { + // check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed + if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple + mustRefreshInviteStatus = true + log.Infof("user %s has a pending invite and has logged in once, forcing cache refresh", user) + break + } knownUsersCount-- continue } @@ -1368,8 +1380,10 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a } // if we know users that are not yet in cache more likely cache is outdated - if knownUsersCount > 0 { - log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount) + if knownUsersCount > 0 || mustRefreshInviteStatus { + if !mustRefreshInviteStatus { + log.Infof("reloading cache with IDP manager. Users unknown to the cache: %d", knownUsersCount) + } // reload cache once avoiding loops data, err = am.refreshCache(accountID) if err != nil { @@ -1649,7 +1663,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) } - if !user.IsServiceUser { + if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(account, claims.UserId) if err != nil { return nil, nil, err diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index ed8a3f5438c..531822668e6 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -198,8 +198,6 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { serviceUser := r.URL.Query().Get("service_user") - log.Debugf("UserCount: %v", len(data)) - users := make([]*api.User, 0) for _, r := range data { if r.NonDeletable { diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go index 1fa00b2fecb..2527acbe329 100644 --- a/management/server/jwtclaims/claims.go +++ b/management/server/jwtclaims/claims.go @@ -13,6 +13,7 @@ type AuthorizationClaims struct { Domain string DomainCategory string LastLogin time.Time + Invited bool Raw jwt.MapClaims } diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 42a41f14020..c441650e97f 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -20,6 +20,8 @@ const ( UserIDClaim = "sub" // LastLoginSuffix claim for the last login LastLoginSuffix = "nb_last_login" + // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation + Invited = "nb_invited" ) // ExtractClaims Extract function type @@ -100,6 +102,10 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { if ok { jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string)) } + invitedBool, ok := claims[c.authAudience+Invited] + if ok { + jwtClaims.Invited = invitedBool.(bool) + } return jwtClaims } diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index e9316b194c0..eccd7c9e7c9 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -30,6 +30,10 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st if claims.LastLogin != (time.Time{}) { claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout) } + + if claims.Invited { + claimMaps[audience+Invited] = true + } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) require.NoError(t, err, "creating testing request failed") @@ -59,12 +63,14 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { AccountId: "testAcc", LastLogin: lastLogin, DomainCategory: "public", + Invited: true, Raw: jwt.MapClaims{ "https://login/wt_account_domain": "test.com", "https://login/wt_account_domain_category": "public", "https://login/wt_account_id": "testAcc", "https://login/nb_last_login": lastLogin.Format(layout), "sub": "test", + "https://login/" + Invited: true, }, }, testingFunc: require.EqualValues, diff --git a/management/server/user.go b/management/server/user.go index b955c405895..4ae13d1012f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -960,7 +960,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( queriedUsers := make([]*idp.UserData, 0) if !isNil(am.idpManager) { - users := make(map[string]struct{}, len(account.Users)) + users := make(map[string]userLoggedInOnce, len(account.Users)) usersFromIntegration := make([]*idp.UserData, 0) for _, user := range account.Users { if user.Issued == UserIssuedIntegration { @@ -968,14 +968,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( info, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { log.Infof("Get ExternalCache for key: %s, error: %s", key, err) - users[user.Id] = struct{}{} + users[user.Id] = true continue } usersFromIntegration = append(usersFromIntegration, info) continue } if !user.IsServiceUser { - users[user.Id] = struct{}{} + users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } } queriedUsers, err = am.lookupCache(users, accountID) From 4c5e987e027518dcbe4693c104cc3865801e90c9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 22 Apr 2024 11:57:38 +0200 Subject: [PATCH 59/89] Add support for GUI app to display error (#1844) --- client/ui/client_ui.go | 63 ++++++++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index e242a26db66..aec2c8fac81 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -58,12 +58,19 @@ func main() { var showSettings bool flag.BoolVar(&showSettings, "settings", false, "run settings windows") + var errorMSG string + flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") flag.Parse() a := app.NewWithID("NetBird") a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG)) + if errorMSG != "" { + showErrorMSG(errorMSG) + return + } + client := newServiceClient(daemonAddr, a, showSettings) if showSettings { a.Run() @@ -209,6 +216,18 @@ func (s *serviceClient) showUIElements() { s.wSettings.Show() } +// showErrorMSG opens a fyne app window to display the supplied message +func showErrorMSG(msg string) { + app := app.New() + w := app.NewWindow("NetBird Error") + content := widget.NewLabel(msg) + content.Wrapping = fyne.TextWrapWord + w.SetContent(content) + w.Resize(fyne.NewSize(400, 100)) + w.Show() + app.Run() +} + // getSettingsForm to embed it into settings window. func (s *serviceClient) getSettingsForm() *widget.Form { return &widget.Form{ @@ -504,16 +523,22 @@ func (s *serviceClient) onTrayReady() { case <-s.mAdminPanel.ClickedCh: err = open.Run(s.adminURL) case <-s.mUp.ClickedCh: + s.mUp.Disabled() go func() { + defer s.mUp.Enable() err := s.menuUpClick() if err != nil { + s.runSelfCommand("error-msg", err.Error()) return } }() case <-s.mDown.ClickedCh: + s.mDown.Disable() go func() { + defer s.mDown.Enable() err := s.menuDownClick() if err != nil { + s.runSelfCommand("error-msg", err.Error()) return } }() @@ -521,24 +546,8 @@ func (s *serviceClient) onTrayReady() { s.mSettings.Disable() go func() { defer s.mSettings.Enable() - proc, err := os.Executable() - if err != nil { - log.Errorf("show settings: %v", err) - return - } - - cmd := exec.Command(proc, "--settings=true") - out, err := cmd.CombinedOutput() - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { - log.Errorf("start settings UI: %v, %s", err, string(out)) - return - } - if len(out) != 0 { - log.Info("settings change:", string(out)) - } - - // update config in systray when settings windows closed - s.getSrvConfig() + defer s.getSrvConfig() + s.runSelfCommand("settings", "true") }() case <-s.mQuit.ClickedCh: systray.Quit() @@ -556,6 +565,24 @@ func (s *serviceClient) onTrayReady() { }() } +func (s *serviceClient) runSelfCommand(command, arg string) { + proc, err := os.Executable() + if err != nil { + log.Errorf("show %s failed with error: %v", command, err) + return + } + + cmd := exec.Command(proc, fmt.Sprintf("--%s=%s", command, arg)) + out, err := cmd.CombinedOutput() + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { + log.Errorf("start %s UI: %v, %s", command, err, string(out)) + return + } + if len(out) != 0 { + log.Infof("command %s executed: %s", command, string(out)) + } +} + func normalizedVersion(version string) string { versionString := version if unicode.IsDigit(rune(versionString[0])) { From 012e624296dcff7a1a011321a4bfcd0075af67bf Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 23 Apr 2024 10:20:09 +0200 Subject: [PATCH 60/89] Fix DNS not found query response (#1877) for local queries, we should return NXDOMAIN instead of NOERROR Also, updated gomobile for Android and iOS builds --- .github/workflows/mobile-build-validation.yml | 6 +++--- client/internal/dns/local.go | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index 85296484229..e5a5ff485e4 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -38,7 +38,7 @@ jobs: - name: Setup NDK run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed - name: gomobile init run: gomobile init - name: build android netbird lib @@ -56,10 +56,10 @@ jobs: with: go-version: "1.21.x" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed - name: gomobile init run: gomobile init - name: build iOS netbird lib - run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK + run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK env: CGO_ENABLED: 0 \ No newline at end of file diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index ba4ae42d930..6a459794b96 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -31,6 +31,8 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { response := d.lookupRecord(r) if response != nil { replyMessage.Answer = append(replyMessage.Answer, response) + } else { + replyMessage.Rcode = dns.RcodeNameError } err := w.WriteMsg(replyMessage) From 3477108ce7f71626db6efd7d4638b1704f7615c1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Apr 2024 12:48:25 +0200 Subject: [PATCH 61/89] Bump golang.org/x/net from 0.20.0 to 0.23.0 (#1867) Bumps [golang.org/x/net](https://github.com/golang/net) from 0.20.0 to 0.23.0. - [Commits](https://github.com/golang/net/compare/v0.20.0...v0.23.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 12 ++++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 1d5ec92532e..24af3c28de5 100644 --- a/go.mod +++ b/go.mod @@ -21,8 +21,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 - golang.org/x/crypto v0.18.0 - golang.org/x/sys v0.16.0 + golang.org/x/crypto v0.21.0 + golang.org/x/sys v0.18.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -82,10 +82,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 - golang.org/x/net v0.20.0 + golang.org/x/net v0.23.0 golang.org/x/oauth2 v0.8.0 golang.org/x/sync v0.3.0 - golang.org/x/term v0.16.0 + golang.org/x/term v0.18.0 google.golang.org/api v0.126.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/sqlite v1.5.3 diff --git a/go.sum b/go.sum index 97b935d9cf0..9b7b8952cbc 100644 --- a/go.sum +++ b/go.sum @@ -581,8 +581,9 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -669,8 +670,9 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -761,16 +763,18 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= From f51dc13f8c1073fea815bfa9356e4e6403a000ae Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 23 Apr 2024 14:42:53 +0200 Subject: [PATCH 62/89] Add route selection functionality for CLI and GUI (#1865) --- client/cmd/root.go | 7 + client/cmd/route.go | 144 +++++ client/cmd/ssh.go | 4 +- client/internal/connect.go | 17 +- client/internal/engine.go | 34 +- client/internal/engine_test.go | 13 +- client/internal/routemanager/manager.go | 71 ++- client/internal/routemanager/manager_test.go | 4 +- client/internal/routemanager/mock.go | 25 +- .../internal/routemanager/systemops_linux.go | 5 +- .../internal/routeselector/routeselector.go | 132 +++++ .../routeselector/routeselector_test.go | 275 +++++++++ client/proto/daemon.pb.go | 539 +++++++++++++++--- client/proto/daemon.proto | 31 + client/proto/daemon_grpc.pb.go | 114 ++++ client/server/route.go | 100 ++++ client/server/server.go | 40 +- client/server/server_test.go | 5 +- client/ui/client_ui.go | 33 +- client/ui/route.go | 203 +++++++ 20 files changed, 1650 insertions(+), 146 deletions(-) create mode 100644 client/cmd/route.go create mode 100644 client/internal/routeselector/routeselector.go create mode 100644 client/internal/routeselector/routeselector_test.go create mode 100644 client/server/route.go create mode 100644 client/ui/route.go diff --git a/client/cmd/root.go b/client/cmd/root.go index 9c4ad99dec0..ca143ffc2bb 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -119,6 +119,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") + rootCmd.AddCommand(serviceCmd) rootCmd.AddCommand(upCmd) rootCmd.AddCommand(downCmd) @@ -126,8 +127,14 @@ func init() { rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(sshCmd) + rootCmd.AddCommand(routesCmd) + serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service + + routesCmd.AddCommand(routesListCmd) + routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd) + upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, `Sets external IPs maps between local addresses and interfaces.`+ `You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+ diff --git a/client/cmd/route.go b/client/cmd/route.go new file mode 100644 index 00000000000..3d5d4b24722 --- /dev/null +++ b/client/cmd/route.go @@ -0,0 +1,144 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + "google.golang.org/grpc" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var appendFlag bool + +var routesCmd = &cobra.Command{ + Use: "routes", + Short: "Manage network routes", + Long: `Commands to list, select, or deselect network routes.`, +} + +var routesListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List routes", + Example: " netbird routes list", + Long: "List all available network routes.", + RunE: routesList, +} + +var routesSelectCmd = &cobra.Command{ + Use: "select route...|all", + Short: "Select routes", + Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.", + Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3", + Args: cobra.MinimumNArgs(1), + RunE: routesSelect, +} + +var routesDeselectCmd = &cobra.Command{ + Use: "deselect route...|all", + Short: "Deselect routes", + Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.", + Example: " netbird routes deselect all\n netbird routes deselect route1 route2", + Args: cobra.MinimumNArgs(1), + RunE: routesDeselect, +} + +func init() { + routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing") +} + +func routesList(cmd *cobra.Command, _ []string) error { + conn, err := getClient(cmd.Context()) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{}) + if err != nil { + return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message()) + } + + if len(resp.Routes) == 0 { + cmd.Println("No routes available.") + return nil + } + + cmd.Println("Available Routes:") + for _, route := range resp.Routes { + selectedStatus := "Not Selected" + if route.GetSelected() { + selectedStatus = "Selected" + } + cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus) + } + + return nil +} + +func routesSelect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd.Context()) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectRoutesRequest{ + RouteIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } else if appendFlag { + req.Append = true + } + + if _, err := client.SelectRoutes(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message()) + } + + cmd.Println("Routes selected successfully.") + + return nil +} + +func routesDeselect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd.Context()) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectRoutesRequest{ + RouteIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } + + if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message()) + } + + cmd.Println("Routes deselected successfully.") + + return nil +} + +func getClient(ctx context.Context) (*grpc.ClientConn, error) { + conn, err := DialClientGRPCServer(ctx, daemonAddr) + if err != nil { + return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ + "If the daemon is not running please run: "+ + "\nnetbird service install \nnetbird service start\n", err) + } + + return conn, nil +} diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index dd9407738ba..81e6c255a17 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -24,7 +24,7 @@ var ( ) var sshCmd = &cobra.Command{ - Use: "ssh", + Use: "ssh [user@]host", Args: func(cmd *cobra.Command, args []string) error { if len(args) < 1 { return errors.New("requires a host argument") @@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) if err != nil { cmd.Printf("Error: %v\n", err) cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + - "You can verify the connection by running:\n\n" + + "\nYou can verify the connection by running:\n\n" + " netbird status\n\n") return err } diff --git a/client/internal/connect.go b/client/internal/connect.go index 6b888c9cca8..c238fa31ccf 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -31,7 +31,7 @@ import ( // RunClient with main logic. func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error { - return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil) + return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil) } // RunClientWithProbes runs the client's main logic with probes attached @@ -43,8 +43,9 @@ func RunClientWithProbes( signalProbe *Probe, relayProbe *Probe, wgProbe *Probe, + engineChan chan<- *Engine, ) error { - return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe) + return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan) } // RunClientMobile with main logic on mobile system @@ -66,7 +67,7 @@ func RunClientMobile( HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, } - return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil) + return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil) } func RunClientiOS( @@ -82,7 +83,7 @@ func RunClientiOS( NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, } - return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil) + return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil) } func runClient( @@ -94,6 +95,7 @@ func runClient( signalProbe *Probe, relayProbe *Probe, wgProbe *Probe, + engineChan chan<- *Engine, ) error { defer func() { if r := recover(); r != nil { @@ -243,6 +245,9 @@ func runClient( log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } + if engineChan != nil { + engineChan <- engine + } log.Print("Netbird engine started, my IP is: ", peerConfig.Address) state.Set(StatusConnected) @@ -252,6 +257,10 @@ func runClient( backOff.Reset() + if engineChan != nil { + engineChan <- nil + } + err = engine.Stop() if err != nil { log.Errorf("failed stopping engine %v", err) diff --git a/client/internal/engine.go b/client/internal/engine.go index ba7074672c3..28e1f1b55d0 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -111,6 +111,9 @@ type Engine struct { // TURNs is a list of STUN servers used by ICE TURNs []*stun.URI + // clientRoutes is the most recent list of clientRoutes received from the Management Service + clientRoutes map[string][]*route.Route + cancel context.CancelFunc ctx context.Context @@ -216,6 +219,8 @@ func (e *Engine) Stop() error { return err } + e.clientRoutes = nil + // very ugly but we want to remove peers from the WireGuard interface first before removing interface. // Removing peers happens in the conn.CLose() asynchronously time.Sleep(500 * time.Millisecond) @@ -695,11 +700,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} } - err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) + + _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) if err != nil { - log.Errorf("failed to update routes, err: %v", err) + log.Errorf("failed to update clientRoutes, err: %v", err) } + e.clientRoutes = clientRoutes + protoDNSConfig := networkMap.GetDNSConfig() if protoDNSConfig == nil { protoDNSConfig = &mgmProto.DNSConfig{} @@ -1229,6 +1237,28 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { } } +// GetClientRoutes returns the current routes from the route map +func (e *Engine) GetClientRoutes() map[string][]*route.Route { + return e.clientRoutes +} + +// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only +func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route { + routes := make(map[string][]*route.Route, len(e.clientRoutes)) + for id, v := range e.clientRoutes { + if i := strings.LastIndex(id, "-"); i != -1 { + id = id[:i] + } + routes[id] = v + } + return routes +} + +// GetRouteManager returns the route manager +func (e *Engine) GetRouteManager() routemanager.Manager { + return e.routeManager +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 309b2e7c6f9..f487cc71e72 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -577,10 +578,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { }{} mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { input.inputSerial = updateSerial input.inputRoutes = newRoutes - return testCase.inputErr + return nil, nil, testCase.inputErr }, } @@ -597,8 +598,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { err = engine.updateNetworkMap(testCase.networkMap) assert.NoError(t, err, "shouldn't return error") assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") - assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match") - assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match") + assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match") + assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match") }) } } @@ -742,8 +743,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { - return nil + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { + return nil, nil, nil }, } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0dfc0f7e008..57007c4a3a5 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -28,7 +29,9 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) // Manager is a route manager interface type Manager interface { Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) + TriggerSelection(map[string][]*route.Route) + GetRouteSelector() *routeselector.RouteSelector SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error @@ -41,6 +44,7 @@ type DefaultManager struct { stop context.CancelFunc mux sync.Mutex clientNetworks map[string]*clientNetwork + routeSelector *routeselector.RouteSelector serverRouter serverRouter statusRecorder *peer.Status wgInterface *iface.WGIface @@ -54,6 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, ctx: mCTX, stop: cancel, clientNetworks: make(map[string]*clientNetwork), + routeSelector: routeselector.NewRouteSelector(), statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, @@ -117,28 +122,29 @@ func (m *DefaultManager) Stop() { } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") - return m.ctx.Err() + return nil, nil, m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes) + newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) - m.updateClientNetworks(updateSerial, newClientRoutesIDMap) - m.notifier.onNewRoutes(newClientRoutesIDMap) + filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + m.updateClientNetworks(updateSerial, filteredClientRoutes) + m.notifier.onNewRoutes(filteredClientRoutes) if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return fmt.Errorf("update routes: %w", err) + return nil, nil, fmt.Errorf("update routes: %w", err) } } - return nil + return newServerRoutesMap, newClientRoutesIDMap, nil } } @@ -152,16 +158,51 @@ func (m *DefaultManager) InitialRouteRange() []string { return m.notifier.initialRouteRanges() } -func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { - // removing routes that do not exist as per the update from the Management service. +// GetRouteSelector returns the route selector +func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector { + return m.routeSelector +} + +// GetClientRoutes returns the client routes +func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork { + return m.clientNetworks +} + +// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones +func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) { + m.mux.Lock() + defer m.mux.Unlock() + + networks = m.routeSelector.FilterSelected(networks) + m.stopObsoleteClients(networks) + + for id, routes := range networks { + if _, found := m.clientNetworks[id]; found { + // don't touch existing client network watchers + continue + } + + clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) + m.clientNetworks[id] = clientNetworkWatcher + go clientNetworkWatcher.peersStateAndUpdateWatcher() + clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) + } +} + +// stopObsoleteClients stops the client network watcher for the networks that are not in the new list +func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) { for id, client := range m.clientNetworks { - _, found := networks[id] - if !found { - log.Debugf("stopping client network watcher, %s", id) + if _, ok := networks[id]; !ok { + log.Debugf("Stopping client network watcher, %s", id) client.stop() delete(m.clientNetworks, id) } } +} + +func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { + // removing routes that do not exist as per the update from the Management service. + m.stopObsoleteClients(networks) for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] @@ -178,7 +219,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[ } } -func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) { +func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) { newClientRoutesIDMap := make(map[string][]*route.Route) newServerRoutesMap := make(map[string]*route.Route) ownNetworkIDs := make(map[string]bool) @@ -210,7 +251,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] } func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { - _, crMap := m.classifiesRoutes(initialRoutes) + _, crMap := m.classifyRoutes(initialRoutes) rs := make([]*route.Route, 0) for _, routes := range crMap { rs = append(rs, routes...) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 03e77e09bcb..7eb8dd00210 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -428,11 +428,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) require.NoError(t, err, "should update routes with init routes") } - err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index dd2c28e5927..b3464018ece 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -7,14 +7,17 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error - StopFunc func() + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) + TriggerSelectionFunc func(map[string][]*route.Route) + GetRouteSelectorFunc func() *routeselector.RouteSelector + StopFunc func() } func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { @@ -27,11 +30,25 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) } - return fmt.Errorf("method UpdateRoutes is not implemented") + return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") +} + +func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) { + if m.TriggerSelectionFunc != nil { + m.TriggerSelectionFunc(networks) + } +} + +// GetRouteSelector mock implementation of GetRouteSelector from Manager interface +func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { + if m.GetRouteSelectorFunc != nil { + return m.GetRouteSelectorFunc() + } + return nil } // Start mock implementation of Start from Manager interface diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index d1302b39cc6..7c77c9fbbbf 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -304,7 +304,10 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { Dst: ipNet, } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteDel(route); err != nil && + !errors.Is(err, syscall.ESRCH) && + !errors.Is(err, syscall.ENOENT) && + !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("netlink remove unreachable route: %w", err) } diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go new file mode 100644 index 00000000000..7bd93b46ebb --- /dev/null +++ b/client/internal/routeselector/routeselector.go @@ -0,0 +1,132 @@ +package routeselector + +import ( + "fmt" + "slices" + "strings" + + "github.com/hashicorp/go-multierror" + "golang.org/x/exp/maps" + + route "github.com/netbirdio/netbird/route" +) + +type RouteSelector struct { + selectedRoutes map[string]struct{} + selectAll bool +} + +func NewRouteSelector() *RouteSelector { + return &RouteSelector{ + selectedRoutes: map[string]struct{}{}, + // default selects all routes + selectAll: true, + } +} + +// SelectRoutes updates the selected routes based on the provided route IDs. +func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error { + if !appendRoute { + rs.selectedRoutes = map[string]struct{}{} + } + + var multiErr *multierror.Error + for _, route := range routes { + if !slices.Contains(allRoutes, route) { + multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route)) + continue + } + + rs.selectedRoutes[route] = struct{}{} + } + rs.selectAll = false + + if multiErr != nil { + multiErr.ErrorFormat = formatError + } + + return multiErr.ErrorOrNil() +} + +// SelectAllRoutes sets the selector to select all routes. +func (rs *RouteSelector) SelectAllRoutes() { + rs.selectAll = true + rs.selectedRoutes = map[string]struct{}{} +} + +// DeselectRoutes removes specific routes from the selection. +// If the selector is in "select all" mode, it will transition to "select specific" mode. +func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error { + if rs.selectAll { + rs.selectAll = false + rs.selectedRoutes = map[string]struct{}{} + for _, route := range allRoutes { + rs.selectedRoutes[route] = struct{}{} + } + } + + var multiErr *multierror.Error + + for _, route := range routes { + if !slices.Contains(allRoutes, route) { + multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route)) + continue + } + delete(rs.selectedRoutes, route) + } + + if multiErr != nil { + multiErr.ErrorFormat = formatError + } + + return multiErr.ErrorOrNil() +} + +// DeselectAllRoutes deselects all routes, effectively disabling route selection. +func (rs *RouteSelector) DeselectAllRoutes() { + rs.selectAll = false + rs.selectedRoutes = map[string]struct{}{} +} + +// IsSelected checks if a specific route is selected. +func (rs *RouteSelector) IsSelected(routeID string) bool { + if rs.selectAll { + return true + } + _, selected := rs.selectedRoutes[routeID] + return selected +} + +// FilterSelected removes unselected routes from the provided map. +func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route { + if rs.selectAll { + return maps.Clone(routes) + } + + filtered := map[string][]*route.Route{} + for id, rt := range routes { + netID := id + if i := strings.LastIndex(id, "-"); i != -1 { + netID = id[:i] + } + if rs.IsSelected(netID) { + filtered[id] = rt + } + } + return filtered +} + +func formatError(es []error) string { + if len(es) == 1 { + return fmt.Sprintf("1 error occurred:\n\t* %s", es[0]) + } + + points := make([]string, len(es)) + for i, err := range es { + points[i] = fmt.Sprintf("* %s", err) + } + + return fmt.Sprintf( + "%d errors occurred:\n\t%s", + len(es), strings.Join(points, "\n\t")) +} diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go new file mode 100644 index 00000000000..b3d0547b591 --- /dev/null +++ b/client/internal/routeselector/routeselector_test.go @@ -0,0 +1,275 @@ +package routeselector_test + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/route" +) + +func TestRouteSelector_SelectRoutes(t *testing.T) { + allRoutes := []string{"route1", "route2", "route3"} + + tests := []struct { + name string + initialSelected []string + + selectRoutes []string + append bool + + wantSelected []string + wantError bool + }{ + { + name: "Select specific routes, initial all selected", + selectRoutes: []string{"route1", "route2"}, + wantSelected: []string{"route1", "route2"}, + }, + { + name: "Select specific routes, initial all deselected", + initialSelected: []string{}, + selectRoutes: []string{"route1", "route2"}, + wantSelected: []string{"route1", "route2"}, + }, + { + name: "Select specific routes with initial selection", + initialSelected: []string{"route1"}, + selectRoutes: []string{"route2", "route3"}, + wantSelected: []string{"route2", "route3"}, + }, + { + name: "Select non-existing route", + selectRoutes: []string{"route1", "route4"}, + wantSelected: []string{"route1"}, + wantError: true, + }, + { + name: "Append route with initial selection", + initialSelected: []string{"route1"}, + selectRoutes: []string{"route2"}, + append: true, + wantSelected: []string{"route1", "route2"}, + }, + { + name: "Append route without initial selection", + selectRoutes: []string{"route2"}, + append: true, + wantSelected: []string{"route2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + if tt.initialSelected != nil { + err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) + require.NoError(t, err) + } + + err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + for _, id := range allRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id)) + } + }) + } +} + +func TestRouteSelector_SelectAllRoutes(t *testing.T) { + allRoutes := []string{"route1", "route2", "route3"} + + tests := []struct { + name string + initialSelected []string + + wantSelected []string + }{ + { + name: "Initial all selected", + wantSelected: []string{"route1", "route2", "route3"}, + }, + { + name: "Initial all deselected", + initialSelected: []string{}, + wantSelected: []string{"route1", "route2", "route3"}, + }, + { + name: "Initial some selected", + initialSelected: []string{"route1"}, + wantSelected: []string{"route1", "route2", "route3"}, + }, + { + name: "Initial all selected", + initialSelected: []string{"route1", "route2", "route3"}, + wantSelected: []string{"route1", "route2", "route3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + if tt.initialSelected != nil { + err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) + require.NoError(t, err) + } + + rs.SelectAllRoutes() + + for _, id := range allRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id)) + } + }) + } +} + +func TestRouteSelector_DeselectRoutes(t *testing.T) { + allRoutes := []string{"route1", "route2", "route3"} + + tests := []struct { + name string + initialSelected []string + + deselectRoutes []string + + wantSelected []string + wantError bool + }{ + { + name: "Deselect specific routes, initial all selected", + deselectRoutes: []string{"route1", "route2"}, + wantSelected: []string{"route3"}, + }, + { + name: "Deselect specific routes, initial all deselected", + initialSelected: []string{}, + deselectRoutes: []string{"route1", "route2"}, + wantSelected: []string{}, + }, + { + name: "Deselect specific routes with initial selection", + initialSelected: []string{"route1", "route2"}, + deselectRoutes: []string{"route1", "route3"}, + wantSelected: []string{"route2"}, + }, + { + name: "Deselect non-existing route", + initialSelected: []string{"route1", "route2"}, + deselectRoutes: []string{"route1", "route4"}, + wantSelected: []string{"route2"}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + if tt.initialSelected != nil { + err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) + require.NoError(t, err) + } + + err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + for _, id := range allRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id)) + } + }) + } +} + +func TestRouteSelector_DeselectAll(t *testing.T) { + allRoutes := []string{"route1", "route2", "route3"} + + tests := []struct { + name string + initialSelected []string + + wantSelected []string + }{ + { + name: "Initial all selected", + wantSelected: []string{}, + }, + { + name: "Initial all deselected", + initialSelected: []string{}, + wantSelected: []string{}, + }, + { + name: "Initial some selected", + initialSelected: []string{"route1", "route2"}, + wantSelected: []string{}, + }, + { + name: "Initial all selected", + initialSelected: []string{"route1", "route2", "route3"}, + wantSelected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + if tt.initialSelected != nil { + err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) + require.NoError(t, err) + } + + rs.DeselectAllRoutes() + + for _, id := range allRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id)) + } + }) + } +} + +func TestRouteSelector_IsSelected(t *testing.T) { + rs := routeselector.NewRouteSelector() + + err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) + require.NoError(t, err) + + assert.True(t, rs.IsSelected("route1")) + assert.True(t, rs.IsSelected("route2")) + assert.False(t, rs.IsSelected("route3")) + assert.False(t, rs.IsSelected("route4")) +} + +func TestRouteSelector_FilterSelected(t *testing.T) { + rs := routeselector.NewRouteSelector() + + err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) + require.NoError(t, err) + + routes := map[string][]*route.Route{ + "route1-10.0.0.0/8": {}, + "route2-192.168.0.0/16": {}, + "route3-172.16.0.0/12": {}, + } + + filtered := rs.FilterSelected(routes) + + assert.Equal(t, map[string][]*route.Route{ + "route1-10.0.0.0/8": {}, + "route2-192.168.0.0/16": {}, + }, filtered) +} diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 4b850226893..fbb754fc608 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,17 +1,17 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v3.12.4 // source: daemon.proto package proto import ( + _ "github.com/golang/protobuf/protoc-gen-go/descriptor" + duration "github.com/golang/protobuf/ptypes/duration" + timestamp "github.com/golang/protobuf/ptypes/timestamp" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" - _ "google.golang.org/protobuf/types/descriptorpb" - durationpb "google.golang.org/protobuf/types/known/durationpb" - timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" ) @@ -766,23 +766,23 @@ type PeerState struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` - PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` - ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` - ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` - Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` - Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` - LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` - RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` - Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` - LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"` - RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"` - LastWireguardHandshake *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"` - BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` - BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` - RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` - Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` + IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` + PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` + ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` + ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` + Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` + Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` + LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` + RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` + Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"` + RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"` + LastWireguardHandshake *timestamp.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"` + BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` + BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` + RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + Latency *duration.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` } func (x *PeerState) Reset() { @@ -838,7 +838,7 @@ func (x *PeerState) GetConnStatus() string { return "" } -func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp { +func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp { if x != nil { return x.ConnStatusUpdate } @@ -894,7 +894,7 @@ func (x *PeerState) GetRemoteIceCandidateEndpoint() string { return "" } -func (x *PeerState) GetLastWireguardHandshake() *timestamppb.Timestamp { +func (x *PeerState) GetLastWireguardHandshake() *timestamp.Timestamp { if x != nil { return x.LastWireguardHandshake } @@ -929,7 +929,7 @@ func (x *PeerState) GetRoutes() []string { return nil } -func (x *PeerState) GetLatency() *durationpb.Duration { +func (x *PeerState) GetLatency() *duration.Duration { if x != nil { return x.Latency } @@ -1383,6 +1383,255 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState { return nil } +type ListRoutesRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *ListRoutesRequest) Reset() { + *x = ListRoutesRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListRoutesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListRoutesRequest) ProtoMessage() {} + +func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[19] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead. +func (*ListRoutesRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{19} +} + +type ListRoutesResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` +} + +func (x *ListRoutesResponse) Reset() { + *x = ListRoutesResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListRoutesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListRoutesResponse) ProtoMessage() {} + +func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[20] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead. +func (*ListRoutesResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{20} +} + +func (x *ListRoutesResponse) GetRoutes() []*Route { + if x != nil { + return x.Routes + } + return nil +} + +type SelectRoutesRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"` + Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` + All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` +} + +func (x *SelectRoutesRequest) Reset() { + *x = SelectRoutesRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[21] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SelectRoutesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SelectRoutesRequest) ProtoMessage() {} + +func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[21] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead. +func (*SelectRoutesRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{21} +} + +func (x *SelectRoutesRequest) GetRouteIDs() []string { + if x != nil { + return x.RouteIDs + } + return nil +} + +func (x *SelectRoutesRequest) GetAppend() bool { + if x != nil { + return x.Append + } + return false +} + +func (x *SelectRoutesRequest) GetAll() bool { + if x != nil { + return x.All + } + return false +} + +type SelectRoutesResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SelectRoutesResponse) Reset() { + *x = SelectRoutesResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SelectRoutesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SelectRoutesResponse) ProtoMessage() {} + +func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[22] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead. +func (*SelectRoutesResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{22} +} + +type Route struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` + Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` + Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"` +} + +func (x *Route) Reset() { + *x = Route{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Route) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Route) ProtoMessage() {} + +func (x *Route) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[23] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Route.ProtoReflect.Descriptor instead. +func (*Route) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{23} +} + +func (x *Route) GetID() string { + if x != nil { + return x.ID + } + return "" +} + +func (x *Route) GetNetwork() string { + if x != nil { + return x.Network + } + return "" +} + +func (x *Route) GetSelected() bool { + if x != nil { + return x.Selected + } + return false +} + var File_daemon_proto protoreflect.FileDescriptor var file_daemon_proto_rawDesc = []byte{ @@ -1601,32 +1850,64 @@ var file_daemon_proto_rawDesc = []byte{ 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, - 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, - 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, - 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, - 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, - 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a, + 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, + 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, + 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, + 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, + 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14, + 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4d, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, + 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x32, 0xda, 0x04, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, + 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, + 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, + 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, + 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( @@ -1641,58 +1922,70 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 19) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 24) var file_daemon_proto_goTypes = []interface{}{ - (*LoginRequest)(nil), // 0: daemon.LoginRequest - (*LoginResponse)(nil), // 1: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 4: daemon.UpRequest - (*UpResponse)(nil), // 5: daemon.UpResponse - (*StatusRequest)(nil), // 6: daemon.StatusRequest - (*StatusResponse)(nil), // 7: daemon.StatusResponse - (*DownRequest)(nil), // 8: daemon.DownRequest - (*DownResponse)(nil), // 9: daemon.DownResponse - (*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse - (*PeerState)(nil), // 12: daemon.PeerState - (*LocalPeerState)(nil), // 13: daemon.LocalPeerState - (*SignalState)(nil), // 14: daemon.SignalState - (*ManagementState)(nil), // 15: daemon.ManagementState - (*RelayState)(nil), // 16: daemon.RelayState - (*NSGroupState)(nil), // 17: daemon.NSGroupState - (*FullStatus)(nil), // 18: daemon.FullStatus - (*timestamppb.Timestamp)(nil), // 19: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 20: google.protobuf.Duration + (*LoginRequest)(nil), // 0: daemon.LoginRequest + (*LoginResponse)(nil), // 1: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 4: daemon.UpRequest + (*UpResponse)(nil), // 5: daemon.UpResponse + (*StatusRequest)(nil), // 6: daemon.StatusRequest + (*StatusResponse)(nil), // 7: daemon.StatusResponse + (*DownRequest)(nil), // 8: daemon.DownRequest + (*DownResponse)(nil), // 9: daemon.DownResponse + (*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse + (*PeerState)(nil), // 12: daemon.PeerState + (*LocalPeerState)(nil), // 13: daemon.LocalPeerState + (*SignalState)(nil), // 14: daemon.SignalState + (*ManagementState)(nil), // 15: daemon.ManagementState + (*RelayState)(nil), // 16: daemon.RelayState + (*NSGroupState)(nil), // 17: daemon.NSGroupState + (*FullStatus)(nil), // 18: daemon.FullStatus + (*ListRoutesRequest)(nil), // 19: daemon.ListRoutesRequest + (*ListRoutesResponse)(nil), // 20: daemon.ListRoutesResponse + (*SelectRoutesRequest)(nil), // 21: daemon.SelectRoutesRequest + (*SelectRoutesResponse)(nil), // 22: daemon.SelectRoutesResponse + (*Route)(nil), // 23: daemon.Route + (*timestamp.Timestamp)(nil), // 24: google.protobuf.Timestamp + (*duration.Duration)(nil), // 25: google.protobuf.Duration } var file_daemon_proto_depIdxs = []int32{ 18, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 19, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 19, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 20, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 24, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 24, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 25, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration 15, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 14, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState 13, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 12, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState 16, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState 17, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 0, // 10: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 2, // 11: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 4, // 12: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 6, // 13: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 8, // 14: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 10, // 15: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 1, // 16: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 3, // 17: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 5, // 18: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 7, // 19: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 9, // 20: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 11, // 21: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 16, // [16:22] is the sub-list for method output_type - 10, // [10:16] is the sub-list for method input_type - 10, // [10:10] is the sub-list for extension type_name - 10, // [10:10] is the sub-list for extension extendee - 0, // [0:10] is the sub-list for field type_name + 23, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route + 0, // 11: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 2, // 12: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 4, // 13: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 6, // 14: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 8, // 15: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 10, // 16: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 19, // 17: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest + 21, // 18: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest + 21, // 19: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest + 1, // 20: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 3, // 21: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 5, // 22: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 7, // 23: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 9, // 24: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 11, // 25: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 20, // 26: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse + 22, // 27: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse + 22, // 28: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse + 20, // [20:29] is the sub-list for method output_type + 11, // [11:20] is the sub-list for method input_type + 11, // [11:11] is the sub-list for extension type_name + 11, // [11:11] is the sub-list for extension extendee + 0, // [0:11] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -1929,6 +2222,66 @@ func file_daemon_proto_init() { return nil } } + file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListRoutesRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListRoutesResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SelectRoutesRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SelectRoutesResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Route); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} type x struct{} @@ -1937,7 +2290,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 0, - NumMessages: 19, + NumMessages: 24, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 5f8878a11b9..31ef4abc78b 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -27,6 +27,15 @@ service DaemonService { // GetConfig of the daemon. rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {} + + // List available network routes + rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {} + + // Select specific routes + rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + + // Deselect specific routes + rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} }; message LoginRequest { @@ -195,4 +204,26 @@ message FullStatus { repeated PeerState peers = 4; repeated RelayState relays = 5; repeated NSGroupState dns_servers = 6; +} + +message ListRoutesRequest { +} + +message ListRoutesResponse { + repeated Route routes = 1; +} + +message SelectRoutesRequest { + repeated string routeIDs = 1; + bool append = 2; + bool all = 3; +} + +message SelectRoutesResponse { +} + +message Route { + string ID = 1; + string network = 2; + bool selected = 3; } \ No newline at end of file diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 0b339fab2a7..d149ee6cd85 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -31,6 +31,12 @@ type DaemonServiceClient interface { Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) + // List available network routes + ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) + // Select specific routes + SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + // Deselect specific routes + DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) } type daemonServiceClient struct { @@ -95,6 +101,33 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques return out, nil } +func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) { + out := new(ListRoutesResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { + out := new(SelectRoutesResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { + out := new(SelectRoutesResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -112,6 +145,12 @@ type DaemonServiceServer interface { Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) + // List available network routes + ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) + // Select specific routes + SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + // Deselect specific routes + DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -137,6 +176,15 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") } +func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented") +} +func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented") +} +func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -258,6 +306,60 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListRoutesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).ListRoutes(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/ListRoutes", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectRoutesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).SelectRoutes(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/SelectRoutes", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectRoutesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).DeselectRoutes(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/DeselectRoutes", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -289,6 +391,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetConfig", Handler: _DaemonService_GetConfig_Handler, }, + { + MethodName: "ListRoutes", + Handler: _DaemonService_ListRoutes_Handler, + }, + { + MethodName: "SelectRoutes", + Handler: _DaemonService_SelectRoutes_Handler, + }, + { + MethodName: "DeselectRoutes", + Handler: _DaemonService_DeselectRoutes_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "daemon.proto", diff --git a/client/server/route.go b/client/server/route.go new file mode 100644 index 00000000000..4aa37dbb78b --- /dev/null +++ b/client/server/route.go @@ -0,0 +1,100 @@ +package server + +import ( + "context" + "fmt" + "sort" + + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" +) + +// ListRoutes returns a list of all available routes. +func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.engine == nil { + return nil, fmt.Errorf("not connected") + } + + routesMap := s.engine.GetClientRoutesWithNetID() + routeSelector := s.engine.GetRouteManager().GetRouteSelector() + + var routes []*route.Route + for id, rt := range routesMap { + if len(rt) == 0 { + continue + } + rt[0].ID = id + routes = append(routes, rt[0]) + } + + sort.Slice(routes, func(i, j int) bool { + iPrefix := routes[i].Network.Bits() + jPrefix := routes[j].Network.Bits() + + if iPrefix == jPrefix { + iAddr := routes[i].Network.Addr() + jAddr := routes[j].Network.Addr() + if iAddr == jAddr { + return routes[i].ID < routes[j].ID + } + return iAddr.String() < jAddr.String() + } + return iPrefix < jPrefix + }) + + var pbRoutes []*proto.Route + for _, route := range routes { + pbRoutes = append(pbRoutes, &proto.Route{ + ID: route.ID, + Network: route.Network.String(), + Selected: routeSelector.IsSelected(route.ID), + }) + } + + return &proto.ListRoutesResponse{ + Routes: pbRoutes, + }, nil +} + +// SelectRoutes selects specific routes based on the client request. +func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + routeManager := s.engine.GetRouteManager() + routeSelector := routeManager.GetRouteSelector() + if req.GetAll() { + routeSelector.SelectAllRoutes() + } else { + if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { + return nil, fmt.Errorf("select routes: %w", err) + } + } + routeManager.TriggerSelection(s.engine.GetClientRoutes()) + + return &proto.SelectRoutesResponse{}, nil +} + +// DeselectRoutes deselects specific routes based on the client request. +func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + routeManager := s.engine.GetRouteManager() + routeSelector := routeManager.GetRouteSelector() + if req.GetAll() { + routeSelector.DeselectAllRoutes() + } else { + if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { + return nil, fmt.Errorf("deselect routes: %w", err) + } + } + routeManager.TriggerSelection(s.engine.GetClientRoutes()) + + return &proto.SelectRoutesResponse{}, nil +} diff --git a/client/server/server.go b/client/server/server.go index d33bb515582..e0e9504faac 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -15,15 +15,15 @@ import ( "google.golang.org/protobuf/types/known/durationpb" - "github.com/netbirdio/netbird/client/internal/auth" - "github.com/netbirdio/netbird/client/system" - log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" @@ -57,6 +57,8 @@ type Server struct { config *internal.Config proto.UnimplementedDaemonServiceServer + engine *internal.Engine + statusRecorder *peer.Status sessionWatcher *internal.SessionWatcher @@ -141,8 +143,11 @@ func (s *Server) Start() error { s.sessionWatcher.SetOnExpireListener(s.onSessionExpire) } + engineChan := make(chan *internal.Engine, 1) + go s.watchEngine(ctx, engineChan) + if !config.DisableAutoConnect { - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan) } return nil @@ -153,6 +158,7 @@ func (s *Server) Start() error { // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, + engineChan chan<- *internal.Engine, ) { backOff := getConnectWithBackoff(ctx) retryStarted := false @@ -182,7 +188,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf runOperation := func() error { log.Tracef("running client connection") - err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) + err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) } @@ -562,7 +568,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + engineChan := make(chan *internal.Engine, 1) + go s.watchEngine(ctx, engineChan) + + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan) return &proto.UpResponse{}, nil } @@ -579,6 +588,8 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) + s.engine = nil + return &proto.DownResponse{}, nil } @@ -661,7 +672,6 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto PreSharedKey: preSharedKey, }, nil } - func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() @@ -673,6 +683,22 @@ func (s *Server) onSessionExpire() { } } +// watchEngine watches the engine channel and updates the engine state +func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) { + log.Tracef("Started watching engine") + for { + select { + case <-ctx.Done(): + s.engine = nil + log.Tracef("Stopped watching engine") + return + case engine := <-engineChan: + log.Tracef("Received engine from watcher") + s.engine = engine + } + } +} + func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus := proto.FullStatus{ ManagementState: &proto.ManagementState{}, diff --git a/client/server/server_test.go b/client/server/server_test.go index 4e4a091453f..8082e6bbaad 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -2,11 +2,12 @@ package server import ( "context" - "github.com/netbirdio/management-integrations/integrations" "net" "testing" "time" + "github.com/netbirdio/management-integrations/integrations" + log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" @@ -69,7 +70,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index aec2c8fac81..0f16369a5ca 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -1,7 +1,5 @@ //go:build !(linux && 386) -// +build !linux !386 -// skipping linux 32 bits build and tests package main import ( @@ -58,6 +56,8 @@ func main() { var showSettings bool flag.BoolVar(&showSettings, "settings", false, "run settings windows") + var showRoutes bool + flag.BoolVar(&showRoutes, "routes", false, "run routes windows") var errorMSG string flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") @@ -71,8 +71,8 @@ func main() { return } - client := newServiceClient(daemonAddr, a, showSettings) - if showSettings { + client := newServiceClient(daemonAddr, a, showSettings, showRoutes) + if showSettings || showRoutes { a.Run() } else { if err := checkPIDFile(); err != nil { @@ -135,6 +135,7 @@ type serviceClient struct { mVersionDaemon *systray.MenuItem mUpdate *systray.MenuItem mQuit *systray.MenuItem + mRoutes *systray.MenuItem // application with main windows. app fyne.App @@ -159,12 +160,15 @@ type serviceClient struct { daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool + + showRoutes bool + wRoutes fyne.Window } // newServiceClient instance constructor // // This constructor also builds the UI elements for the settings window. -func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient { +func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient { s := &serviceClient{ ctx: context.Background(), addr: addr, @@ -172,6 +176,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient sendNotification: false, showSettings: showSettings, + showRoutes: showRoutes, update: version.NewUpdate(), } @@ -191,14 +196,16 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient } if showSettings { - s.showUIElements() + s.showSettingsUI() return s + } else if showRoutes { + s.showRoutesUI() } return s } -func (s *serviceClient) showUIElements() { +func (s *serviceClient) showSettingsUI() { // add settings window UI elements. s.wSettings = s.app.NewWindow("NetBird Settings") s.iMngURL = widget.NewEntry() @@ -416,6 +423,7 @@ func (s *serviceClient) updateStatus() error { s.mStatus.SetTitle("Connected") s.mUp.Disable() s.mDown.Enable() + s.mRoutes.Enable() systrayIconState = true } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { s.connected = false @@ -428,6 +436,7 @@ func (s *serviceClient) updateStatus() error { s.mStatus.SetTitle("Disconnected") s.mDown.Disable() s.mUp.Enable() + s.mRoutes.Disable() systrayIconState = false } @@ -483,9 +492,11 @@ func (s *serviceClient) onTrayReady() { s.mUp = systray.AddMenuItem("Connect", "Connect") s.mDown = systray.AddMenuItem("Disconnect", "Disconnect") s.mDown.Disable() - s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Wiretrustee Admin Panel") + s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel") systray.AddSeparator() s.mSettings = systray.AddMenuItem("Settings", "Settings of the application") + s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window") + s.mRoutes.Disable() systray.AddSeparator() s.mAbout = systray.AddMenuItem("About", "About") @@ -557,6 +568,12 @@ func (s *serviceClient) onTrayReady() { if err != nil { log.Errorf("%s", err) } + case <-s.mRoutes.ClickedCh: + s.mRoutes.Disable() + go func() { + defer s.mRoutes.Enable() + s.runSelfCommand("routes", "true") + }() } if err != nil { log.Errorf("process connection: %v", err) diff --git a/client/ui/route.go b/client/ui/route.go new file mode 100644 index 00000000000..0ac58e5d5b0 --- /dev/null +++ b/client/ui/route.go @@ -0,0 +1,203 @@ +//go:build !(linux && 386) + +package main + +import ( + "fmt" + "strings" + "time" + + "fyne.io/fyne/v2" + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/dialog" + "fyne.io/fyne/v2/layout" + "fyne.io/fyne/v2/widget" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/proto" +) + +func (s *serviceClient) showRoutesUI() { + s.wRoutes = s.app.NewWindow("NetBird Routes") + + grid := container.New(layout.NewGridLayout(2)) + go s.updateRoutes(grid) + routeCheckContainer := container.NewVBox() + routeCheckContainer.Add(grid) + scrollContainer := container.NewVScroll(routeCheckContainer) + scrollContainer.SetMinSize(fyne.NewSize(200, 300)) + + buttonBox := container.NewHBox( + layout.NewSpacer(), + widget.NewButton("Refresh", func() { + s.updateRoutes(grid) + }), + widget.NewButton("Select all", func() { + s.selectAllRoutes() + s.updateRoutes(grid) + }), + widget.NewButton("Deselect All", func() { + s.deselectAllRoutes() + s.updateRoutes(grid) + }), + layout.NewSpacer(), + ) + + content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer) + + s.wRoutes.SetContent(content) + s.wRoutes.Show() + + s.startAutoRefresh(5*time.Second, grid) +} + +func (s *serviceClient) updateRoutes(grid *fyne.Container) { + routes, err := s.fetchRoutes() + if err != nil { + log.Errorf("get client: %v", err) + s.showError(fmt.Errorf("get client: %v", err)) + return + } + + grid.Objects = nil + idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + + grid.Add(idHeader) + grid.Add(networkHeader) + for _, route := range routes { + r := route + + checkBox := widget.NewCheck(r.ID, func(checked bool) { + s.selectRoute(r.ID, checked) + }) + checkBox.Checked = route.Selected + checkBox.Resize(fyne.NewSize(20, 20)) + checkBox.Refresh() + + grid.Add(checkBox) + grid.Add(widget.NewLabel(r.Network)) + } + + s.wRoutes.Content().Refresh() +} + +func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + return nil, fmt.Errorf("get client: %v", err) + } + + resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to list routes: %v", err) + } + + return resp.Routes, nil +} + +func (s *serviceClient) selectRoute(id string, checked bool) { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("get client: %v", err) + s.showError(fmt.Errorf("get client: %v", err)) + return + } + + req := &proto.SelectRoutesRequest{ + RouteIDs: []string{id}, + Append: checked, + } + + if checked { + if _, err := conn.SelectRoutes(s.ctx, req); err != nil { + log.Errorf("failed to select route: %v", err) + s.showError(fmt.Errorf("failed to select route: %v", err)) + return + } + log.Infof("Route %s selected", id) + } else { + if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { + log.Errorf("failed to deselect route: %v", err) + s.showError(fmt.Errorf("failed to deselect route: %v", err)) + return + } + log.Infof("Route %s deselected", id) + } +} + +func (s *serviceClient) selectAllRoutes() { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("get client: %v", err) + return + } + + req := &proto.SelectRoutesRequest{ + All: true, + } + if _, err := conn.SelectRoutes(s.ctx, req); err != nil { + log.Errorf("failed to select all routes: %v", err) + s.showError(fmt.Errorf("failed to select all routes: %v", err)) + return + } + + log.Debug("All routes selected") +} + +func (s *serviceClient) deselectAllRoutes() { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Errorf("get client: %v", err) + return + } + + req := &proto.SelectRoutesRequest{ + All: true, + } + if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { + log.Errorf("failed to deselect all routes: %v", err) + s.showError(fmt.Errorf("failed to deselect all routes: %v", err)) + return + } + + log.Debug("All routes deselected") +} + +func (s *serviceClient) showError(err error) { + wrappedMessage := wrapText(err.Error(), 50) + + dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes) +} + +func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) { + ticker := time.NewTicker(interval) + go func() { + for range ticker.C { + s.updateRoutes(grid) + } + }() + + s.wRoutes.SetOnClosed(func() { + ticker.Stop() + }) +} + +// wrapText inserts newlines into the text to ensure that each line is +// no longer than 'lineLength' runes. +func wrapText(text string, lineLength int) string { + var sb strings.Builder + var currentLineLength int + + for _, runeValue := range text { + sb.WriteRune(runeValue) + currentLineLength++ + + if currentLineLength >= lineLength || runeValue == '\n' { + sb.WriteRune('\n') + currentLineLength = 0 + } + } + + return sb.String() +} From 1e6addaa652305140fc48d04165a6a8ff4e90b18 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 23 Apr 2024 19:09:58 +0200 Subject: [PATCH 63/89] Add account locks to getAccountWithAuthorizationClaims method (#1847) --- management/server/account.go | 46 ++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 23f03015ecb..9c6da05bdeb 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1439,29 +1439,14 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, } // handleExistingUserAccount handles existing User accounts and update its domain attributes. -// -// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, -// we compare the account's ID with the domain account ID, and if they don't match, we set the account as -// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain -// was previously unclassified or classified as public so N users that logged int that time, has they own account -// and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( existingAcc *Account, - domainAcc *Account, + primaryDomain bool, claims jwtclaims.AuthorizationClaims, ) error { - var err error - - if domainAcc != nil && existingAcc.Id != domainAcc.Id { - err = am.updateAccountDomainAttributes(existingAcc, claims, false) - if err != nil { - return err - } - } else { - err = am.updateAccountDomainAttributes(existingAcc, claims, true) - if err != nil { - return err - } + err := am.updateAccountDomainAttributes(existingAcc, claims, primaryDomain) + if err != nil { + return err } // we should register the account ID to this user's metadata in our IDP manager @@ -1795,12 +1780,33 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla account, err := am.Store.GetAccountByUser(claims.UserId) if err == nil { - err = am.handleExistingUserAccount(account, domainAccount, claims) + unlockAccount := am.Store.AcquireAccountLock(account.Id) + defer unlockAccount() + account, err = am.Store.GetAccountByUser(claims.UserId) + if err != nil { + return nil, err + } + // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, + // we compare the account's ID with the domain account ID, and if they don't match, we set the account as + // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain + // was previously unclassified or classified as public so N users that logged int that time, has they own account + // and peers that shouldn't be lost. + primaryDomain := domainAccount == nil || account.Id == domainAccount.Id + + err = am.handleExistingUserAccount(account, primaryDomain, claims) if err != nil { return nil, err } return account, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + if domainAccount != nil { + unlockAccount := am.Store.AcquireAccountLock(domainAccount.Id) + defer unlockAccount() + domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain) + if err != nil { + return nil, err + } + } return am.handleNewUserAccount(domainAccount, claims) } else { // other error From 1f33e2e0039169ecb0063dd86358104a14431dd1 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 23 Apr 2024 19:12:16 +0200 Subject: [PATCH 64/89] Support exit nodes on iOS (#1878) --- client/internal/routemanager/manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 57007c4a3a5..dfc39102f9c 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -262,7 +262,7 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou func isPrefixSupported(prefix netip.Prefix) bool { if !nbnet.CustomRoutingDisabled() { switch runtime.GOOS { - case "linux", "windows", "darwin": + case "linux", "windows", "darwin", "ios": return true } } From 8f3a0f2c38bfeb20b3122dedbad90d0ae6d31d2f Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 23 Apr 2024 19:23:43 +0200 Subject: [PATCH 65/89] Add retry to IdP cache lookup (#1882) --- .github/workflows/golangci-lint.yml | 2 +- management/server/account.go | 53 ++++++++++++++++++++--------- management/server/peer.go | 2 +- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 50cb4e2afaf..78b9f504f67 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta + ignore_words_list: erro,clienta,hastable, skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/management/server/account.go b/management/server/account.go index 9c6da05bdeb..aac13665749 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1350,18 +1350,46 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo } func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) { - data, err := am.getAccountFromCache(accountID, false) + var data []*idp.UserData + var err error + + maxAttempts := 2 + + data, err = am.getAccountFromCache(accountID, false) if err != nil { return nil, err } + for attempt := 1; attempt <= maxAttempts; attempt++ { + if am.isCacheFresh(accountUsers, data) { + return data, nil + } + + if attempt > 1 { + time.Sleep(200 * time.Millisecond) + } + + log.Infof("refreshing cache for account %s", accountID) + data, err = am.refreshCache(accountID) + if err != nil { + return nil, err + } + + if attempt == maxAttempts { + log.Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts) + } + } + + return data, nil +} + +// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status +func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool { userDataMap := make(map[string]*idp.UserData, len(data)) for _, datum := range data { userDataMap[datum.ID] = datum } - mustRefreshInviteStatus := false - // the accountUsers ID list of non integration users from store, we check if cache has all of them // as result of for loop knownUsersCount will have number of users are not presented in the cashed knownUsersCount := len(accountUsers) @@ -1369,9 +1397,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedI if datum, ok := userDataMap[user]; ok { // check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple - mustRefreshInviteStatus = true - log.Infof("user %s has a pending invite and has logged in once, forcing cache refresh", user) - break + log.Infof("user %s has a pending invite and has logged in once, cache invalid", user) + return false } knownUsersCount-- continue @@ -1380,18 +1407,12 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedI } // if we know users that are not yet in cache more likely cache is outdated - if knownUsersCount > 0 || mustRefreshInviteStatus { - if !mustRefreshInviteStatus { - log.Infof("reloading cache with IDP manager. Users unknown to the cache: %d", knownUsersCount) - } - // reload cache once avoiding loops - data, err = am.refreshCache(accountID) - if err != nil { - return nil, err - } + if knownUsersCount > 0 { + log.Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount) + return false } - return data, err + return true } func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error { diff --git a/management/server/peer.go b/management/server/peer.go index 1448e301197..1a8b183ed42 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -374,7 +374,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { if am.idpManager != nil { userdata, err := am.lookupUserInCache(userID, account) - if err == nil { + if err == nil && userdata != nil { peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) } } From 7b254cb966eeca3fe5d11620fb317087618930f5 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 23 Apr 2024 19:26:03 +0200 Subject: [PATCH 66/89] add methods to manage rosenpass settings for iOS (#1879) --- client/ios/NetBirdSDK/preferences.go | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index 297d53ff08e..b7814667959 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -71,6 +71,42 @@ func (p *Preferences) SetPreSharedKey(key string) { p.configInput.PreSharedKey = &key } +// SetRosenpassEnabled store if rosenpass is enabled +func (p *Preferences) SetRosenpassEnabled(enabled bool) { + p.configInput.RosenpassEnabled = &enabled +} + +// GetRosenpassEnabled read rosenpass enabled from config file +func (p *Preferences) GetRosenpassEnabled() (bool, error) { + if p.configInput.RosenpassEnabled != nil { + return *p.configInput.RosenpassEnabled, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.RosenpassEnabled, err +} + +// SetRosenpassPermissive store the given permissive and wait for commit +func (p *Preferences) SetRosenpassPermissive(permissive bool) { + p.configInput.RosenpassPermissive = &permissive +} + +// GetRosenpassPermissive read rosenpass permissive from config file +func (p *Preferences) GetRosenpassPermissive() (bool, error) { + if p.configInput.RosenpassPermissive != nil { + return *p.configInput.RosenpassPermissive, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.RosenpassPermissive, err +} + // Commit write out the changes into config file func (p *Preferences) Commit() error { _, err := internal.UpdateOrCreateConfig(p.configInput) From 71c6437bab0c73a46f396476d70c3715ff125cf6 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 25 Apr 2024 22:20:24 +0300 Subject: [PATCH 67/89] add content type before writing header (#1887) --- management/server/http/setupkeys_handler.go | 2 +- management/server/http/util/util.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 4adf3fdd055..5faedea13f3 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -181,8 +181,8 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques } func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { - w.WriteHeader(200) w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) err := json.NewEncoder(w).Encode(toResponseBody(key)) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 2bb279c7671..acaa2838c6a 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -20,8 +20,8 @@ type ErrorResponse struct { // WriteJSONObject simply writes object to the HTTP response in JSON format func WriteJSONObject(w http.ResponseWriter, obj interface{}) { - w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json; charset=UTF-8") + w.WriteHeader(http.StatusOK) err := json.NewEncoder(w).Encode(obj) if err != nil { WriteError(err, w) @@ -63,8 +63,8 @@ func (d *Duration) UnmarshalJSON(b []byte) error { // WriteErrorResponse prepares and writes an error response i nJSON func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { - w.WriteHeader(httpStatus) w.Header().Set("Content-Type", "application/json; charset=UTF-8") + w.WriteHeader(httpStatus) err := json.NewEncoder(w).Encode(&ErrorResponse{ Message: errMsg, Code: httpStatus, From 54b045d9ca511b15dd6ea43e6edff717161b2ca4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 26 Apr 2024 16:37:27 +0200 Subject: [PATCH 68/89] Replaces powershell with the route command and cache route lookups on windows (#1880) --- client/internal/routemanager/client.go | 17 +++- client/internal/routemanager/routemanager.go | 7 +- client/internal/routemanager/systemops.go | 42 +++------ .../routemanager/systemops_android.go | 4 +- .../internal/routemanager/systemops_darwin.go | 12 +-- .../routemanager/systemops_darwin_test.go | 2 +- client/internal/routemanager/systemops_ios.go | 4 +- .../internal/routemanager/systemops_linux.go | 61 +++++++----- .../routemanager/systemops_nonlinux.go | 5 +- .../internal/routemanager/systemops_test.go | 48 +++++++--- .../routemanager/systemops_windows.go | 92 +++++++++---------- 11 files changed, 160 insertions(+), 134 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index d41ed422b81..3569d13ae67 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -3,6 +3,7 @@ package routemanager import ( "context" "fmt" + "net" "net/netip" "time" @@ -215,7 +216,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { + if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil { return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } @@ -256,7 +257,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // otherwise add the route to the system - if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { + if err := addVPNRoute(c.network, c.getAsInterface()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -344,3 +345,15 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } } + +func (c *clientNetwork) getAsInterface() *net.Interface { + intf, err := net.InterfaceByName(c.wgInterface.Name()) + if err != nil { + log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err) + intf = &net.Interface{ + Name: c.wgInterface.Name(), + } + } + + return intf +} diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index 8f9ff9f4bd0..7715aa8194d 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -5,6 +5,7 @@ package routemanager import ( "errors" "fmt" + "net" "net/netip" "sync" @@ -17,7 +18,7 @@ import ( type ref struct { count int nexthop netip.Addr - intf string + intf *net.Interface } type RouteManager struct { @@ -30,8 +31,8 @@ type RouteManager struct { mutex sync.Mutex } -type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) -type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { // TODO: read initial routing table into refCountMap diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index 1ee54b746d8..1f37a8a3c22 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -60,17 +60,13 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { return nil } - var exitIntf string gatewayHop, intf, err := getNextHop(defaultGateway) if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } - if intf != nil { - exitIntf = intf.Name - } log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) + return addToRouteTable(gatewayPrefix, gatewayHop, intf) } func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { @@ -84,7 +80,7 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { return netip.Addr{}, nil, ErrRouteNotFound } - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { if preferredSrc == nil { return netip.Addr{}, nil, ErrRouteNotFound @@ -153,12 +149,7 @@ func isSubRange(prefix netip.Prefix) (bool, error) { // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { +func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) { addr := prefix.Addr() switch { case addr.IsLoopback(), @@ -168,39 +159,34 @@ func addRouteToNonVPNIntf( addr.IsUnspecified(), addr.IsMulticast(): - return netip.Addr{}, "", ErrRouteNotAllowed + return netip.Addr{}, nil, ErrRouteNotAllowed } // Determine the exit interface and next hop for the prefix, so we can add a specific route nexthop, intf, err := getNextHop(addr) if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err) } log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } + exitIntf := intf vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr") } // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() { log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } + exitIntf = initialIntf } log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err) } return exitNextHop, exitIntf, nil @@ -208,7 +194,7 @@ func addRouteToNonVPNIntf( // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // in two /1 prefixes to avoid replacing the existing default route -func genericAddVPNRoute(prefix netip.Prefix, intf string) error { +func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { if prefix == defaultv4 { if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { return err @@ -250,7 +236,7 @@ func genericAddVPNRoute(prefix netip.Prefix, intf string) error { } // addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { +func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error { ok, err := existsInRouteTable(prefix) if err != nil { return fmt.Errorf("exists in route table: %w", err) @@ -277,7 +263,7 @@ func addNonExistingRoute(prefix netip.Prefix, intf string) error { // genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, // it will remove the split /1 prefixes -func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { +func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { if prefix == defaultv4 { var result *multierror.Error if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { @@ -343,7 +329,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n } *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { + func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) { addr := prefix.Addr() nexthop, intf := initialNextHopV4, initialIntfV4 if addr.Is6() { diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 34d2d270fe3..4d23d39100e 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -24,10 +24,10 @@ func enableIPForwarding() error { return nil } -func addVPNRoute(netip.Prefix, string) error { +func addVPNRoute(netip.Prefix, *net.Interface) error { return nil } -func removeVPNRoute(netip.Prefix, string) error { +func removeVPNRoute(netip.Prefix, *net.Interface) error { return nil } diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go index f7ce72a4e89..017dc6c28a2 100644 --- a/client/internal/routemanager/systemops_darwin.go +++ b/client/internal/routemanager/systemops_darwin.go @@ -27,15 +27,15 @@ func cleanupRouting() error { return cleanupRoutingWithRouteManager(routeManager) } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return routeCmd("add", prefix, nexthop, intf) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return routeCmd("delete", prefix, nexthop, intf) } -func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { inet := "-inet" network := prefix.String() if prefix.IsSingleIP() { @@ -46,15 +46,15 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin // Special case for IPv6 split default route, pointing to the wg interface fails // TODO: Remove once we have IPv6 support on the interface if prefix.Bits() == 1 { - intf = "lo0" + intf = &net.Interface{Name: "lo0"} } } args := []string{"-n", action, inet, network} if nexthop.IsValid() { args = append(args, nexthop.Unmap().String()) - } else if intf != "" { - args = append(args, "-interface", intf) + } else if intf != nil { + args = append(args, "-interface", intf.Name) } if err := retryRouteCmd(args); err != nil { diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go index cc9bb9db598..c23a7cde3fa 100644 --- a/client/internal/routemanager/systemops_darwin_test.go +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -33,7 +33,7 @@ func init() { func TestConcurrentRoutes(t *testing.T) { baseIP := netip.MustParseAddr("192.0.2.0") - intf := "lo0" + intf := &net.Interface{Name: "lo0"} var wg sync.WaitGroup for i := 0; i < 1024; i++ { diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 34d2d270fe3..4d23d39100e 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -24,10 +24,10 @@ func enableIPForwarding() error { return nil } -func addVPNRoute(netip.Prefix, string) error { +func addVPNRoute(netip.Prefix, *net.Interface) error { return nil } -func removeVPNRoute(netip.Prefix, string) error { +func removeVPNRoute(netip.Prefix, *net.Interface) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 7c77c9fbbbf..ce0c07ce69e 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -46,9 +46,6 @@ var routeManager = &RouteManager{} // originalSysctl stores the original sysctl values before they are modified var originalSysctl map[string]int -// determines whether to use the legacy routing setup -var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() - // sysctlFailed is used as an indicator to emit a warning when default routes are configured var sysctlFailed bool @@ -62,6 +59,20 @@ type ruleParams struct { description string } +// isLegacy determines whether to use the legacy routing setup +func isLegacy() bool { + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() +} + +// setIsLegacy sets the legacy routing setup +func setIsLegacy(b bool) { + if b { + os.Setenv("NB_USE_LEGACY_ROUTING", "true") + } else { + os.Unsetenv("NB_USE_LEGACY_ROUTING") + } +} + func getSetupRules() []ruleParams { return []ruleParams{ {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, @@ -82,7 +93,7 @@ func getSetupRules() []ruleParams { // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { - if isLegacy { + if isLegacy() { log.Infof("Using legacy routing setup") return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) } @@ -111,7 +122,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before if err := addRule(rule); err != nil { if errors.Is(err, syscall.EOPNOTSUPP) { log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") - isLegacy = true + setIsLegacy(true) return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) @@ -125,7 +136,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. func cleanupRouting() error { - if isLegacy { + if isLegacy() { return cleanupRoutingWithRouteManager(routeManager) } @@ -154,16 +165,16 @@ func cleanupRouting() error { return result.ErrorOrNil() } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) } -func addVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { +func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if isLegacy() { return genericAddVPNRoute(prefix, intf) } @@ -185,8 +196,8 @@ func addVPNRoute(prefix netip.Prefix, intf string) error { return nil } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { +func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if isLegacy() { return genericRemoveVPNRoute(prefix, intf) } @@ -244,7 +255,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) { } // addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { +func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, @@ -316,7 +327,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { } // removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -470,20 +481,22 @@ func removeRule(params ruleParams) error { } // addNextHop adds the gateway and device to the route. -func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { +func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error { + if intf != nil { + route.LinkIndex = intf.Index + } + if addr.IsValid() { route.Gw = addr.AsSlice() - if intf == "" { - intf = addr.Zone() - } - } - if intf != "" { - link, err := netlink.LinkByName(intf) - if err != nil { - return fmt.Errorf("set interface %s: %w", intf, err) + // if zone is set, it means the gateway is a link-local address, so we set the link index + if addr.Zone() != "" && intf == nil { + link, err := netlink.LinkByName(addr.Zone()) + if err != nil { + return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err) + } + route.LinkIndex = link.Attrs().Index } - route.LinkIndex = link.Attrs().Index } return nil diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 38026107ec7..91879790a1f 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -3,6 +3,7 @@ package routemanager import ( + "net" "net/netip" "runtime" @@ -14,10 +15,10 @@ func enableIPForwarding() error { return nil } -func addVPNRoute(prefix netip.Prefix, intf string) error { +func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error { return genericAddVPNRoute(prefix, intf) } -func removeVPNRoute(prefix netip.Prefix, intf string) error { +func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error { return genericRemoveVPNRoute(prefix, intf) } diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 9f906c06fbe..8a92ac57971 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -50,6 +50,8 @@ func TestAddRemoveRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { + t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() newNet, err := stdnet.NewNet() if err != nil { @@ -67,7 +69,11 @@ func TestAddRemoveRoutes(t *testing.T) { assert.NoError(t, cleanupRouting()) }) - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + index, err := net.InterfaceByName(wgInterface.Name()) + require.NoError(t, err, "InterfaceByName should not return err") + intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} + + err = addVPNRoute(testCase.prefix, intf) require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.shouldRouteToWireguard { @@ -78,7 +84,7 @@ func TestAddRemoveRoutes(t *testing.T) { exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + err = removeVPNRoute(testCase.prefix, intf) require.NoError(t, err, "genericRemoveVPNRoute should not return err") prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) @@ -182,12 +188,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) { } for n, testCase := range testCases { + var buf bytes.Buffer log.SetOutput(&buf) defer func() { log.SetOutput(os.Stderr) }() t.Run(testCase.name, func(t *testing.T) { + t.Setenv("NB_USE_LEGACY_ROUTING", "true") + t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() newNet, err := stdnet.NewNet() if err != nil { @@ -200,14 +210,18 @@ func TestAddExistAndRemoveRoute(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + index, err := net.InterfaceByName(wgInterface.Name()) + require.NoError(t, err, "InterfaceByName should not return err") + intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} + // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := addVPNRoute(testCase.preExistingPrefix, intf) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + err = addVPNRoute(testCase.prefix, intf) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -217,7 +231,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + err = removeVPNRoute(testCase.prefix, intf) require.NoError(t, err, "should not return err") } @@ -345,43 +359,47 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, cleanupRouting()) }) + index, err := net.InterfaceByName(wgIface.Name()) + require.NoError(t, err, "InterfaceByName should not return err") + intf := &net.Interface{Index: index.Index, Name: wgIface.Name()} + // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf) require.NoError(t, err, "addVPNRoute should not return err") t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf) assert.NoError(t, err, "removeVPNRoute should not return err") }) } diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index ba211082f1f..f9e75e2ed5a 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -6,8 +6,12 @@ import ( "fmt" "net" "net/netip" + "os" "os/exec" + "strconv" "strings" + "sync" + "time" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" @@ -21,6 +25,10 @@ type Win32_IP4RouteTable struct { Mask string } +var prefixList []netip.Prefix +var lastUpdate time.Time +var mux = sync.Mutex{} + var routeManager *RouteManager func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { @@ -32,15 +40,23 @@ func cleanupRouting() error { } func getRoutesFromTable() ([]netip.Prefix, error) { - var routes []Win32_IP4RouteTable + mux.Lock() + defer mux.Unlock() + query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" + // If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result + if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second { + return prefixList, nil + } + + var routes []Win32_IP4RouteTable err := wmi.Query(query, &routes) if err != nil { return nil, fmt.Errorf("get routes: %w", err) } - var prefixList []netip.Prefix + prefixList = nil for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { @@ -60,54 +76,29 @@ func getRoutesFromTable() ([]netip.Prefix, error) { prefixList = append(prefixList, routePrefix) } } + + lastUpdate = time.Now() return prefixList, nil } -func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error { - destinationPrefix := prefix.String() - psCmd := "New-NetRoute" - - addressFamily := "IPv4" - if prefix.Addr().Is6() { - addressFamily = "IPv6" - } - - script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`, - psCmd, addressFamily, destinationPrefix, - ) - - if intfIdx != "" { - script = fmt.Sprintf( - `%s -InterfaceIndex %s`, script, intfIdx, - ) - } else { - script = fmt.Sprintf( - `%s -InterfaceAlias "%s"`, script, intf, - ) - } +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { + args := []string{"add", prefix.String()} if nexthop.IsValid() { - script = fmt.Sprintf( - `%s -NextHop "%s"`, script, nexthop, - ) + args = append(args, nexthop.Unmap().String()) + } else { + addr := "0.0.0.0" + if prefix.Addr().Is6() { + addr = "::" + } + args = append(args, addr) } - out, err := exec.Command("powershell", "-Command", script).CombinedOutput() - log.Tracef("PowerShell %s: %s", script, string(out)) - - if err != nil { - return fmt.Errorf("PowerShell add route: %w", err) + if intf != nil { + args = append(args, "if", strconv.Itoa(intf.Index)) } - return nil -} - -func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"add", prefix.String(), nexthop.Unmap().String()} - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) if err != nil { return fmt.Errorf("route add: %w", err) @@ -116,21 +107,20 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { return nil } -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - var intfIdx string - if nexthop.Zone() != "" { - intfIdx = nexthop.Zone() +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { + if nexthop.Zone() != "" && intf == nil { + zone, err := strconv.Atoi(nexthop.Zone()) + if err != nil { + return fmt.Errorf("invalid zone: %w", err) + } + intf = &net.Interface{Index: zone} nexthop.WithZone("") } - // Powershell doesn't support adding routes without an interface but allows to add interface by name - if intf != "" || intfIdx != "" { - return addRoutePowershell(prefix, nexthop, intf, intfIdx) - } return addRouteCmd(prefix, nexthop, intf) } -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error { args := []string{"delete", prefix.String()} if nexthop.IsValid() { nexthop.WithZone("") @@ -145,3 +135,7 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) err } return nil } + +func isCacheDisabled() bool { + return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" +} From 4424162bce5dfc7d985ab1777131de1c6d59234e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 26 Apr 2024 17:20:10 +0200 Subject: [PATCH 69/89] Add client debug features (#1884) * Add status anonymization * Add OS/arch to the status command * Use human-friendly last-update status messages * Add debug bundle command to collect (anonymized) logs * Add debug log level command * And debug for a certain time span command --- client/anonymize/anonymize.go | 212 +++++++++++++ client/anonymize/anonymize_test.go | 223 ++++++++++++++ client/cmd/debug.go | 248 +++++++++++++++ client/cmd/root.go | 19 ++ client/cmd/route.go | 13 - client/cmd/status.go | 166 +++++++++- client/cmd/status_test.go | 55 +++- client/internal/engine_watcher.go | 1 + client/proto/daemon.pb.go | 471 +++++++++++++++++++++++++---- client/proto/daemon.proto | 33 ++ client/proto/daemon_grpc.pb.go | 76 +++++ client/server/debug.go | 175 +++++++++++ 12 files changed, 1588 insertions(+), 104 deletions(-) create mode 100644 client/anonymize/anonymize.go create mode 100644 client/anonymize/anonymize_test.go create mode 100644 client/cmd/debug.go create mode 100644 client/internal/engine_watcher.go create mode 100644 client/server/debug.go diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go new file mode 100644 index 00000000000..acbd0441e1c --- /dev/null +++ b/client/anonymize/anonymize.go @@ -0,0 +1,212 @@ +package anonymize + +import ( + "crypto/rand" + "fmt" + "math/big" + "net" + "net/netip" + "net/url" + "regexp" + "slices" + "strings" +) + +type Anonymizer struct { + ipAnonymizer map[netip.Addr]netip.Addr + domainAnonymizer map[string]string + currentAnonIPv4 netip.Addr + currentAnonIPv6 netip.Addr + startAnonIPv4 netip.Addr + startAnonIPv6 netip.Addr +} + +func DefaultAddresses() (netip.Addr, netip.Addr) { + // 192.51.100.0, 100:: + return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) +} + +func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { + return &Anonymizer{ + ipAnonymizer: map[netip.Addr]netip.Addr{}, + domainAnonymizer: map[string]string{}, + currentAnonIPv4: startIPv4, + currentAnonIPv6: startIPv6, + startAnonIPv4: startIPv4, + startAnonIPv6: startIPv6, + } +} + +func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr { + if ip.IsLoopback() || + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || + ip.IsInterfaceLocalMulticast() || + ip.IsPrivate() || + ip.IsUnspecified() || + ip.IsMulticast() || + isWellKnown(ip) || + a.isInAnonymizedRange(ip) { + + return ip + } + + if _, ok := a.ipAnonymizer[ip]; !ok { + if ip.Is4() { + a.ipAnonymizer[ip] = a.currentAnonIPv4 + a.currentAnonIPv4 = a.currentAnonIPv4.Next() + } else { + a.ipAnonymizer[ip] = a.currentAnonIPv6 + a.currentAnonIPv6 = a.currentAnonIPv6.Next() + } + } + return a.ipAnonymizer[ip] +} + +// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs +func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { + if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 { + return true + } else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 { + return true + } + return false +} + +func (a *Anonymizer) AnonymizeIPString(ip string) string { + addr, err := netip.ParseAddr(ip) + if err != nil { + return ip + } + + return a.AnonymizeIP(addr).String() +} + +func (a *Anonymizer) AnonymizeDomain(domain string) string { + if strings.HasSuffix(domain, "netbird.io") || + strings.HasSuffix(domain, "netbird.selfhosted") || + strings.HasSuffix(domain, "netbird.cloud") || + strings.HasSuffix(domain, "netbird.stage") || + strings.HasSuffix(domain, ".domain") { + return domain + } + + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return domain + } + + baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1] + + anonymized, ok := a.domainAnonymizer[baseDomain] + if !ok { + anonymizedBase := "anon-" + generateRandomString(5) + ".domain" + a.domainAnonymizer[baseDomain] = anonymizedBase + anonymized = anonymizedBase + } + + return strings.Replace(domain, baseDomain, anonymized, 1) +} + +func (a *Anonymizer) AnonymizeURI(uri string) string { + u, err := url.Parse(uri) + if err != nil { + return uri + } + + var anonymizedHost string + if u.Opaque != "" { + host, port, err := net.SplitHostPort(u.Opaque) + if err == nil { + anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) + } else { + anonymizedHost = a.AnonymizeDomain(u.Opaque) + } + u.Opaque = anonymizedHost + } else if u.Host != "" { + host, port, err := net.SplitHostPort(u.Host) + if err == nil { + anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) + } else { + anonymizedHost = a.AnonymizeDomain(u.Host) + } + u.Host = anonymizedHost + } + return u.String() +} + +func (a *Anonymizer) AnonymizeString(str string) string { + ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`) + ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`) + + str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString) + str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString) + + for domain, anonDomain := range a.domainAnonymizer { + str = strings.ReplaceAll(str, domain, anonDomain) + } + + str = a.AnonymizeSchemeURI(str) + str = a.AnonymizeDNSLogLine(str) + + return str +} + +// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes. +func (a *Anonymizer) AnonymizeSchemeURI(text string) string { + re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`) + + return re.ReplaceAllStringFunc(text, a.AnonymizeURI) +} + +// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string. +func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { + domainPattern := `dns\.Question{Name:"([^"]+)",` + domainRegex := regexp.MustCompile(domainPattern) + + return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string { + parts := strings.Split(match, `"`) + if len(parts) >= 2 { + domain := parts[1] + if strings.HasSuffix(domain, ".domain") { + return match + } + randomDomain := generateRandomString(10) + ".domain" + return strings.Replace(match, domain, randomDomain, 1) + } + return match + }) +} + +func isWellKnown(addr netip.Addr) bool { + wellKnown := []string{ + "8.8.8.8", "8.8.4.4", // Google DNS IPv4 + "2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6 + "1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4 + "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6 + "9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4 + "2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6 + } + + if slices.Contains(wellKnown, addr.String()) { + return true + } + + cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0}) + cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10) + + return cgnatRange.Contains(addr) +} + +func generateRandomString(length int) string { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + result := make([]byte, length) + for i := range result { + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + if err != nil { + continue + } + result[i] = letters[num.Int64()] + } + return string(result) +} diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go new file mode 100644 index 00000000000..e660749ec5d --- /dev/null +++ b/client/anonymize/anonymize_test.go @@ -0,0 +1,223 @@ +package anonymize_test + +import ( + "net/netip" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/anonymize" +) + +func TestAnonymizeIP(t *testing.T) { + startIPv4 := netip.MustParseAddr("198.51.100.0") + startIPv6 := netip.MustParseAddr("100::") + anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) + + tests := []struct { + name string + ip string + expect string + }{ + {"Well known", "8.8.8.8", "8.8.8.8"}, + {"First Public IPv4", "1.2.3.4", "198.51.100.0"}, + {"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, + {"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, + {"Private IPv4", "192.168.1.1", "192.168.1.1"}, + {"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"}, + {"Second Public IPv6", "a::b", "100::1"}, + {"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"}, + {"Private IPv6", "fe80::1", "fe80::1"}, + {"In Range IPv4", "198.51.100.2", "198.51.100.2"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ip := netip.MustParseAddr(tc.ip) + anonymizedIP := anonymizer.AnonymizeIP(ip) + if anonymizedIP.String() != tc.expect { + t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP) + } + }) + } +} + +func TestAnonymizeDNSLogLine(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) + testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}` + + result := anonymizer.AnonymizeDNSLogLine(testLog) + require.NotEqual(t, testLog, result) + assert.NotContains(t, result, "example.com") +} + +func TestAnonymizeDomain(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) + tests := []struct { + name string + domain string + expectPattern string + shouldAnonymize bool + }{ + { + "General Domain", + "example.com", + `^anon-[a-zA-Z0-9]+\.domain$`, + true, + }, + { + "Subdomain", + "sub.example.com", + `^sub\.anon-[a-zA-Z0-9]+\.domain$`, + true, + }, + { + "Protected Domain", + "netbird.io", + `^netbird\.io$`, + false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeDomain(tc.domain) + if tc.shouldAnonymize { + assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern") + assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result") + } else { + assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized") + } + }) + } +} + +func TestAnonymizeURI(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) + tests := []struct { + name string + uri string + regex string + }{ + { + "HTTP URI with Port", + "http://example.com:80/path", + `^http://anon-[a-zA-Z0-9]+\.domain:80/path$`, + }, + { + "HTTP URI without Port", + "http://example.com/path", + `^http://anon-[a-zA-Z0-9]+\.domain/path$`, + }, + { + "Opaque URI with Port", + "stun:example.com:80?transport=udp", + `^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`, + }, + { + "Opaque URI without Port", + "stun:example.com?transport=udp", + `^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeURI(tc.uri) + assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern") + require.NotContains(t, result, "example.com", "Original domain should not be present") + }) + } +} + +func TestAnonymizeSchemeURI(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) + tests := []struct { + name string + input string + expect string + }{ + {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, + {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, + {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeSchemeURI(tc.input) + assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern") + require.NotContains(t, result, "example.com", "Original domain should not be present") + }) + } +} + +func TestAnonymizString_MemorizedDomain(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) + domain := "example.com" + anonymizedDomain := anonymizer.AnonymizeDomain(domain) + + sampleString := "This is a test string including the domain example.com which should be anonymized." + + firstPassResult := anonymizer.AnonymizeString(sampleString) + secondPassResult := anonymizer.AnonymizeString(firstPassResult) + + assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass") + assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output") + + assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string") +} + +func TestAnonymizeString_DoubleURI(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) + domain := "example.com" + anonymizedDomain := anonymizer.AnonymizeDomain(domain) + + sampleString := "Check out our site at https://example.com for more info." + + firstPassResult := anonymizer.AnonymizeString(sampleString) + secondPassResult := anonymizer.AnonymizeString(firstPassResult) + + assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass") + assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output") + + assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI") +} + +func TestAnonymizeString_IPAddresses(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + tests := []struct { + name string + input string + expect string + }{ + { + name: "IPv4 Address", + input: "Error occurred at IP 122.138.1.1", + expect: "Error occurred at IP 198.51.100.0", + }, + { + name: "IPv6 Address", + input: "Access attempted from 2001:db8::ff00:42", + expect: "Access attempted from 100::", + }, + { + name: "IPv6 Address with Port", + input: "Access attempted from [2001:db8::ff00:42]:8080", + expect: "Access attempted from [100::]:8080", + }, + { + name: "Both IPv4 and IPv6", + input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", + expect: "IPv4: 198.51.100.1 and IPv6: 100::1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeString(tc.input) + assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly") + }) + } +} diff --git a/client/cmd/debug.go b/client/cmd/debug.go new file mode 100644 index 00000000000..4deff11a6ff --- /dev/null +++ b/client/cmd/debug.go @@ -0,0 +1,248 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var debugCmd = &cobra.Command{ + Use: "debug", + Short: "Debugging commands", + Long: "Provides commands for debugging and logging control within the Netbird daemon.", +} + +var debugBundleCmd = &cobra.Command{ + Use: "bundle", + Example: " netbird debug bundle", + Short: "Create a debug bundle", + Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.", + RunE: debugBundle, +} + +var logCmd = &cobra.Command{ + Use: "log", + Short: "Manage logging for the Netbird daemon", + Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`, +} + +var logLevelCmd = &cobra.Command{ + Use: "level ", + Short: "Set the logging level for this session", + Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart. +Available log levels are: + panic: for panic level, highest level of severity + fatal: for fatal level errors that cause the program to exit + error: for error conditions + warn: for warning conditions + info: for informational messages + debug: for debug-level messages + trace: for trace-level messages, which include more fine-grained information than debug`, + Args: cobra.ExactArgs(1), + RunE: setLogLevel, +} + +var forCmd = &cobra.Command{ + Use: "for