From e40a29ba17749e0276510149b8da7b3b71550ff4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:51:42 +0100 Subject: [PATCH 01/19] [client] Add support for state manager on iOS (#2996) --- client/internal/config.go | 21 +++-- client/internal/connect.go | 2 + client/internal/engine.go | 11 +++ client/internal/mobile_dependency.go | 1 + .../routeselector/routeselector_test.go | 85 +++++++++++++++++++ client/ios/NetBirdSDK/client.go | 9 +- client/ios/NetBirdSDK/preferences.go | 5 +- client/ios/NetBirdSDK/preferences_test.go | 11 ++- 8 files changed, 130 insertions(+), 15 deletions(-) diff --git a/client/internal/config.go b/client/internal/config.go index ce87835cdd5..998690ef159 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -46,6 +46,7 @@ type ConfigInput struct { ManagementURL string AdminURL string ConfigPath string + StateFilePath string PreSharedKey *string ServerSSHAllowed *bool NATExternalIPs []string @@ -105,10 +106,10 @@ type Config struct { // DNSRouteInterval is the interval in which the DNS routes are updated DNSRouteInterval time.Duration - //Path to a certificate used for mTLS authentication + // Path to a certificate used for mTLS authentication ClientCertPath string - //Path to corresponding private key of ClientCertPath + // Path to corresponding private key of ClientCertPath ClientCertKeyPath string ClientCertKeyPair *tls.Certificate `json:"-"` @@ -116,7 +117,7 @@ type Config struct { // ReadConfig read config file and return with Config. If it is not exists create a new with default values func ReadConfig(configPath string) (*Config, error) { - if configFileIsExists(configPath) { + if fileExists(configPath) { err := util.EnforcePermission(configPath) if err != nil { log.Errorf("failed to enforce permission on config dir: %v", err) @@ -149,7 +150,7 @@ func ReadConfig(configPath string) (*Config, error) { // UpdateConfig update existing configuration according to input configuration and return with the configuration func UpdateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { return nil, status.Errorf(codes.NotFound, "config file doesn't exist") } @@ -158,7 +159,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) { // UpdateOrCreateConfig reads existing config or generates a new one func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { log.Infof("generating new config %s", input.ConfigPath) cfg, err := createNewConfig(input) if err != nil { @@ -472,11 +473,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool { return false } -func configFileIsExists(path string) bool { +func fileExists(path string) bool { _, err := os.Stat(path) return !os.IsNotExist(err) } +func createFile(path string) error { + file, err := os.Create(path) + if err != nil { + return err + } + return file.Close() +} + // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // The check is performed only for the NetBird's managed version. diff --git a/client/internal/connect.go b/client/internal/connect.go index 4848b1c1143..782984e2796 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -91,6 +91,7 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, + stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. debug.SetGCPercent(5) @@ -99,6 +100,7 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, nil) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 782bb48bb1b..63caec02a9c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -243,6 +243,17 @@ func NewEngineWithProbes( probes: probes, checks: checks, } + if runtime.GOOS == "ios" { + if !fileExists(mobileDep.StateFilePath) { + err := createFile(mobileDep.StateFilePath) + if err != nil { + log.Errorf("failed to create state file: %v", err) + // we are not exiting as we can run without the state manager + } + } + + engine.stateManager = statemanager.New(mobileDep.StateFilePath) + } if path := statemanager.GetDefaultStatePath(); path != "" { engine.stateManager = statemanager.New(path) } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2b0c92cc690..4ac0fc141eb 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -19,4 +19,5 @@ type MobileDependency struct { // iOS only DnsManager dns.IosDnsManager FileDescriptor int32 + StateFilePath string } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 7df433f9264..b1671f25464 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) { "route2|192.168.0.0/16": {}, }, filtered) } + +func TestRouteSelector_NewRoutesBehavior(t *testing.T) { + initialRoutes := []route.NetID{"route1", "route2", "route3"} + newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} + + tests := []struct { + name string + initialState func(rs *routeselector.RouteSelector) error // Setup initial state + wantNewSelected []route.NetID // Expected selected routes after new routes appear + }{ + { + name: "New routes with initial selectAll state", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + // When selectAll is true, all routes including new ones should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"}, + }, + { + name: "New routes after specific selection", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) + }, + // When specific routes were selected, new routes should remain unselected + wantNewSelected: []route.NetID{"route1", "route2"}, + }, + { + name: "New routes after deselect all", + initialState: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + // After deselect all, new routes should remain unselected + wantNewSelected: []route.NetID{}, + }, + { + name: "New routes after deselecting specific routes", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) + }, + // After deselecting specific routes, new routes should remain unselected + wantNewSelected: []route.NetID{"route2", "route3"}, + }, + { + name: "New routes after selecting with append", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) + }, + // When routes were appended, new routes should remain unselected + wantNewSelected: []route.NetID{"route1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + // Setup initial state + err := tt.initialState(rs) + require.NoError(t, err) + + // Verify selection state with new routes + for _, id := range newRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id), + "Route %s selection state incorrect", id) + } + + // Additional verification using FilterSelected + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + "route4|10.10.0.0/16": {}, + "route5|192.168.1.0/24": {}, + } + + filtered := rs.FilterSelected(routes) + expectedLen := len(tt.wantNewSelected) + assert.Equal(t, expectedLen, len(filtered), + "FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen) + }) + } +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 9d65bdbe080..6f501e0c636 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -59,6 +59,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string + stateFile string recorder *peer.Status ctxCancel context.CancelFunc ctxCancelLock *sync.Mutex @@ -73,9 +74,10 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { +func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { return &Client{ cfgFile: cfgFile, + stateFile: stateFile, deviceName: deviceName, osName: osName, osVersion: osVersion, @@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error { log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: c.cfgFile, + StateFilePath: c.stateFile, }) if err != nil { return err @@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } // Stop the internal client and free the resources diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index b7814667959..5a0abd9a72e 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -10,9 +10,10 @@ type Preferences struct { } // NewPreferences create new Preferences instance -func NewPreferences(configPath string) *Preferences { +func NewPreferences(configPath string, stateFilePath string) *Preferences { ci := internal.ConfigInput{ - ConfigPath: configPath, + ConfigPath: configPath, + StateFilePath: stateFilePath, } return &Preferences{ci} } diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index aa6a475aeab..7e5325a0036 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -9,7 +9,8 @@ import ( func TestPreferences_DefaultValues(t *testing.T) { cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) defaultVar, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read default value: %s", err) @@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) { func TestPreferences_ReadUncommitedValues(t *testing.T) { exampleString := "exampleString" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleString) resp, err := p.GetAdminURL() @@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) { exampleURL := "https://myurl.com:443" examplePresharedKey := "topsecret" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleURL) p.SetManagementURL(exampleURL) @@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) { t.Fatalf("failed to save changes: %s", err) } - p = NewPreferences(cfgFile) + p = NewPreferences(cfgFile, stateFile) resp, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read admin url: %s", err) From 2147bf75eb25a4d63ea428484a1dbfac5e2fbf82 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 9 Dec 2024 17:10:31 +0100 Subject: [PATCH 02/19] [client] Add peer conn init limit (#3001) Limit the peer connection initialization to 200 peers at the same time --- client/internal/engine.go | 6 +- client/internal/peer/conn.go | 9 ++- client/internal/peer/conn_test.go | 9 +-- util/semaphore-group/semaphore_group.go | 48 ++++++++++++++ util/semaphore-group/semaphore_group_test.go | 66 ++++++++++++++++++++ 5 files changed, 131 insertions(+), 7 deletions(-) create mode 100644 util/semaphore-group/semaphore_group.go create mode 100644 util/semaphore-group/semaphore_group_test.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 63caec02a9c..34219def185 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -39,6 +39,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -62,6 +63,7 @@ import ( const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms + connInitLimit = 200 ) var ErrResetConnection = fmt.Errorf("reset connection") @@ -177,6 +179,7 @@ type Engine struct { // Network map persistence persistNetworkMap bool latestNetworkMap *mgmProto.NetworkMap + connSemaphore *semaphoregroup.SemaphoreGroup } // Peer is an instance of the Connection Peer @@ -242,6 +245,7 @@ func NewEngineWithProbes( statusRecorder: statusRecorder, probes: probes, checks: checks, + connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), } if runtime.GOOS == "ios" { if !fileExists(mobileDep.StateFilePath) { @@ -1051,7 +1055,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore) if err != nil { return nil, err } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 3a698a82a7c..5c2e2cb60b6 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -23,6 +23,7 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) type ConnPriority int @@ -104,12 +105,13 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy - guard *guard.Guard + guard *guard.Guard + semaphore *semaphoregroup.SemaphoreGroup } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) @@ -130,6 +132,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu allowedIP: allowedIP, statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), + semaphore: semaphore, } rFns := WorkerRelayCallbacks{ @@ -169,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // be used. func (conn *Conn) Open() { + conn.semaphore.Add(conn.ctx) conn.log.Debugf("open connection to peer") conn.mu.Lock() @@ -191,6 +195,7 @@ func (conn *Conn) Open() { } func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + defer conn.semaphore.Done(conn.ctx) conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 039952588d8..b3e9d5b60dc 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/util" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) var connConf = ConnConfig{ @@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { func TestConn_GetKey(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { } func TestConn_Status(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } diff --git a/util/semaphore-group/semaphore_group.go b/util/semaphore-group/semaphore_group.go new file mode 100644 index 00000000000..ad74e1bfc81 --- /dev/null +++ b/util/semaphore-group/semaphore_group.go @@ -0,0 +1,48 @@ +package semaphoregroup + +import ( + "context" + "sync" +) + +// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore. +type SemaphoreGroup struct { + waitGroup sync.WaitGroup + semaphore chan struct{} +} + +// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit. +func NewSemaphoreGroup(limit int) *SemaphoreGroup { + return &SemaphoreGroup{ + semaphore: make(chan struct{}, limit), + } +} + +// Add increments the internal WaitGroup counter and acquires a semaphore slot. +func (sg *SemaphoreGroup) Add(ctx context.Context) { + sg.waitGroup.Add(1) + + // Acquire semaphore slot + select { + case <-ctx.Done(): + return + case sg.semaphore <- struct{}{}: + } +} + +// Done decrements the internal WaitGroup counter and releases a semaphore slot. +func (sg *SemaphoreGroup) Done(ctx context.Context) { + sg.waitGroup.Done() + + // Release semaphore slot + select { + case <-ctx.Done(): + return + case <-sg.semaphore: + } +} + +// Wait waits until the internal WaitGroup counter is zero. +func (sg *SemaphoreGroup) Wait() { + sg.waitGroup.Wait() +} diff --git a/util/semaphore-group/semaphore_group_test.go b/util/semaphore-group/semaphore_group_test.go new file mode 100644 index 00000000000..d4491cf772e --- /dev/null +++ b/util/semaphore-group/semaphore_group_test.go @@ -0,0 +1,66 @@ +package semaphoregroup + +import ( + "context" + "testing" + "time" +) + +func TestSemaphoreGroup(t *testing.T) { + semGroup := NewSemaphoreGroup(2) + + for i := 0; i < 5; i++ { + semGroup.Add(context.Background()) + go func(id int) { + defer semGroup.Done(context.Background()) + + got := len(semGroup.semaphore) + if got == 0 { + t.Errorf("Expected semaphore length > 0 , got 0") + } + + time.Sleep(time.Millisecond) + t.Logf("Goroutine %d is running\n", id) + }(i) + } + + semGroup.Wait() + + want := 0 + got := len(semGroup.semaphore) + if got != want { + t.Errorf("Expected semaphore length %d, got %d", want, got) + } +} + +func TestSemaphoreGroupContext(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + semGroup.Add(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + rChan := make(chan struct{}) + + go func() { + semGroup.Add(ctx) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Adding to semaphore group should not block when context is not done") + } + + semGroup.Done(context.Background()) + + ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancelDone) + go func() { + semGroup.Done(ctxDone) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Releasing from semaphore group should not block when context is not done") + } +} From 97bb74f824786c20f78f3131bb8ed3e9ece26782 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 9 Dec 2024 18:40:06 +0100 Subject: [PATCH 03/19] Remove peer login log (#3005) Signed-off-by: bcmmbaga --- management/server/peer.go | 1 - 1 file changed, 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index 761aa39a2da..ba211be9694 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -740,7 +740,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // it means that the client has already checked if it needs login and had been through the SSO flow // so, we can skip this check and directly proceed with the login if login.UserID == "" { - log.Info("Peer needs login") err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) if err != nil { return nil, nil, nil, err From 6142828a9ce8ca72d25a1ded4d3ddc9605566f58 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:59:25 +0100 Subject: [PATCH 04/19] [management] restructure api files (#3013) --- management/cmd/management.go | 3 +- management/server/http/configs/auth.go | 9 ++ management/server/http/handler.go | 151 +++--------------- .../accounts}/accounts_handler.go | 34 ++-- .../accounts}/accounts_handler_test.go | 12 +- .../dns}/dns_settings_handler.go | 33 ++-- .../dns}/dns_settings_handler_test.go | 10 +- .../{ => handlers/dns}/nameservers_handler.go | 44 +++-- .../dns}/nameservers_handler_test.go | 14 +- .../{ => handlers/events}/events_handler.go | 25 +-- .../events}/events_handler_test.go | 10 +- .../{ => handlers/groups}/groups_handler.go | 47 +++--- .../groups}/groups_handler_test.go | 18 +-- .../{ => handlers/peers}/peers_handler.go | 39 +++-- .../peers}/peers_handler_test.go | 6 +- .../policies}/geolocation_handler_test.go | 16 +- .../policies}/geolocations_handler.go | 29 ++-- .../policies}/policies_handler.go | 49 +++--- .../policies}/policies_handler_test.go | 16 +- .../policies}/posture_checks_handler.go | 47 +++--- .../policies}/posture_checks_handler_test.go | 32 ++-- .../{ => handlers/routes}/routes_handler.go | 46 +++--- .../routes}/routes_handler_test.go | 16 +- .../setup_keys}/setupkeys_handler.go | 42 +++-- .../setup_keys}/setupkeys_handler_test.go | 17 +- .../http/{ => handlers/users}/pat_handler.go | 39 +++-- .../{ => handlers/users}/pat_handler_test.go | 14 +- .../{ => handlers/users}/users_handler.go | 47 +++--- .../users}/users_handler_test.go | 18 +-- management/server/http/util/util.go | 4 + 30 files changed, 461 insertions(+), 426 deletions(-) create mode 100644 management/server/http/configs/auth.go rename management/server/http/{ => handlers/accounts}/accounts_handler.go (79%) rename management/server/http/{ => handlers/accounts}/accounts_handler_test.go (97%) rename management/server/http/{ => handlers/dns}/dns_settings_handler.go (62%) rename management/server/http/{ => handlers/dns}/dns_settings_handler_test.go (94%) rename management/server/http/{ => handlers/dns}/nameservers_handler.go (77%) rename management/server/http/{ => handlers/dns}/nameservers_handler_test.go (95%) rename management/server/http/{ => handlers/events}/events_handler.go (79%) rename management/server/http/{ => handlers/events}/events_handler_test.go (97%) rename management/server/http/{ => handlers/groups}/groups_handler.go (81%) rename management/server/http/{ => handlers/groups}/groups_handler_test.go (95%) rename management/server/http/{ => handlers/peers}/peers_handler.go (88%) rename management/server/http/{ => handlers/peers}/peers_handler_test.go (99%) rename management/server/http/{ => handlers/policies}/geolocation_handler_test.go (94%) rename management/server/http/{ => handlers/policies}/geolocations_handler.go (72%) rename management/server/http/{ => handlers/policies}/policies_handler.go (84%) rename management/server/http/{ => handlers/policies}/policies_handler_test.go (95%) rename management/server/http/{ => handlers/policies}/posture_checks_handler.go (70%) rename management/server/http/{ => handlers/policies}/posture_checks_handler_test.go (96%) rename management/server/http/{ => handlers/routes}/routes_handler.go (85%) rename management/server/http/{ => handlers/routes}/routes_handler_test.go (98%) rename management/server/http/{ => handlers/setup_keys}/setupkeys_handler.go (78%) rename management/server/http/{ => handlers/setup_keys}/setupkeys_handler_test.go (95%) rename management/server/http/{ => handlers/users}/pat_handler.go (75%) rename management/server/http/{ => handlers/users}/pat_handler_test.go (96%) rename management/server/http/{ => handlers/users}/users_handler.go (80%) rename management/server/http/{ => handlers/users}/users_handler_test.go (97%) diff --git a/management/cmd/management.go b/management/cmd/management.go index 719d1a78c1a..bfa158c5b53 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -42,6 +42,7 @@ import ( nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" @@ -257,7 +258,7 @@ var ( return fmt.Errorf("failed creating JWT validator: %v", err) } - httpAPIAuthCfg := httpapi.AuthCfg{ + httpAPIAuthCfg := configs.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, diff --git a/management/server/http/configs/auth.go b/management/server/http/configs/auth.go new file mode 100644 index 00000000000..aa91fa55b05 --- /dev/null +++ b/management/server/http/configs/auth.go @@ -0,0 +1,9 @@ +package configs + +// AuthCfg contains parameters for authentication middleware +type AuthCfg struct { + Issuer string + Audience string + UserIDClaim string + KeysLocation string +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c3928bff681..373aa4dd7c0 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,6 +12,16 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/handlers/accounts" + "github.com/netbirdio/netbird/management/server/http/handlers/dns" + "github.com/netbirdio/netbird/management/server/http/handlers/events" + "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/peers" + "github.com/netbirdio/netbird/management/server/http/handlers/policies" + "github.com/netbirdio/netbird/management/server/http/handlers/routes" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -20,27 +30,15 @@ import ( const apiPrefix = "/api" -// AuthCfg contains parameters for authentication middleware -type AuthCfg struct { - Issuer string - Audience string - UserIDClaim string - KeysLocation string -} - type apiHandler struct { Router *mux.Router AccountManager s.AccountManager geolocationManager *geolocation.Geolocation - AuthCfg AuthCfg -} - -// EmptyObject is an empty struct used to return empty JSON object -type emptyObject struct { + AuthCfg configs.AuthCfg } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -86,122 +84,15 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa return nil, fmt.Errorf("register integrations endpoints: %w", err) } - api.addAccountsEndpoint() - api.addPeersEndpoint() - api.addUsersEndpoint() - api.addUsersTokensEndpoint() - api.addSetupKeysEndpoint() - api.addPoliciesEndpoint() - api.addGroupsEndpoint() - api.addRoutesEndpoint() - api.addDNSNameserversEndpoint() - api.addDNSSettingEndpoint() - api.addEventsEndpoint() - api.addPostureCheckEndpoint() - api.addLocationsEndpoint() + accounts.AddEndpoints(api.AccountManager, authCfg, router) + peers.AddEndpoints(api.AccountManager, authCfg, router) + users.AddEndpoints(api.AccountManager, authCfg, router) + setup_keys.AddEndpoints(api.AccountManager, authCfg, router) + policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) + groups.AddEndpoints(api.AccountManager, authCfg, router) + routes.AddEndpoints(api.AccountManager, authCfg, router) + dns.AddEndpoints(api.AccountManager, authCfg, router) + events.AddEndpoints(api.AccountManager, authCfg, router) return rootRouter, nil } - -func (apiHandler *apiHandler) addAccountsEndpoint() { - accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPeersEndpoint() { - peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). - Methods("GET", "PUT", "DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersEndpoint() { - userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersTokensEndpoint() { - tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addSetupKeysEndpoint() { - keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addPoliciesEndpoint() { - policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addGroupsEndpoint() { - groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addRoutesEndpoint() { - routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSNameserversEndpoint() { - nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSSettingEndpoint() { - dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") -} - -func (apiHandler *apiHandler) addEventsEndpoint() { - eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPostureCheckEndpoint() { - postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addLocationsEndpoint() { - locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS") -} diff --git a/management/server/http/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go similarity index 79% rename from management/server/http/accounts_handler.go rename to management/server/http/handlers/accounts/accounts_handler.go index 4baf9c6925f..c952077777e 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "encoding/json" @@ -10,20 +10,28 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// AccountsHandler is a handler that handles the server.Account HTTP endpoints -type AccountsHandler struct { +// handler is a handler that handles the server.Account HTTP endpoints +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewAccountsHandler creates a new AccountsHandler HTTP handler -func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler { - return &AccountsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + accountsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") + router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") +} + +// newHandler creates a new handler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +40,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. -func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { +// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. +func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -51,8 +59,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } -// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) -func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { +// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) +func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -111,8 +119,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteAccount is a HTTP DELETE handler to delete an account -func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { +// deleteAccount is a HTTP DELETE handler to delete an account +func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -127,7 +135,7 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toAccountResponse(accountID string, settings *server.Settings) *api.Account { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go similarity index 97% rename from management/server/http/accounts_handler_test.go rename to management/server/http/handlers/accounts/accounts_handler_test.go index cacb3d43010..9d7e8a84ddc 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "bytes" @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { - return &AccountsHandler{ +func initAccountsTestData(account *server.Account, admin *server.User) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return account.Id, admin.Id, nil @@ -89,7 +89,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllAccounts OK", + name: "getAllAccounts OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/accounts", @@ -189,8 +189,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET") - router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT") + router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") + router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go similarity index 62% rename from management/server/http/dns_settings_handler.go rename to management/server/http/handlers/dns/dns_settings_handler.go index 13c2101a755..7dd8c1fc1aa 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -1,26 +1,39 @@ -package http +package dns import ( "encoding/json" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// DNSSettingsHandler is a handler that returns the DNS settings of the account -type DNSSettingsHandler struct { +// dnsSettingsHandler is a handler that returns the DNS settings of the account +type dnsSettingsHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler -func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler { - return &DNSSettingsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + addDNSSettingEndpoint(accountManager, authCfg, router) + addDNSNameserversEndpoint(accountManager, authCfg, router) +} + +func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg) + router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") +} + +// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler +func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -29,8 +42,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetDNSSettings returns the DNS settings for the account -func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { +// getDNSSettings returns the DNS settings for the account +func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -52,8 +65,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque util.WriteJSONObject(r.Context(), w, apiDNSSettings) } -// UpdateDNSSettings handles update to DNS settings of an account -func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { +// updateDNSSettings handles update to DNS settings of an account +func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go similarity index 94% rename from management/server/http/dns_settings_handler_test.go rename to management/server/http/handlers/dns/dns_settings_handler_test.go index 8baea7b1538..a64e3fd8356 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -40,8 +40,8 @@ var testingDNSSettingsAccount = &server.Account{ DNSSettings: baseExistingDNSSettings, } -func initDNSSettingsTestData() *DNSSettingsHandler { - return &DNSSettingsHandler{ +func initDNSSettingsTestData() *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: &mock_server.MockAccountManager{ GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil @@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") - router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") + router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") + router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go similarity index 77% rename from management/server/http/nameservers_handler.go rename to management/server/http/handlers/dns/nameservers_handler.go index e7a2bc2ae8a..09047e231af 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -1,4 +1,4 @@ -package http +package dns import ( "encoding/json" @@ -11,20 +11,30 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// NameserversHandler is the nameserver group handler of the account -type NameserversHandler struct { +// nameserversHandler is the nameserver group handler of the account +type nameserversHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewNameserversHandler returns a new instance of NameserversHandler handler -func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler { - return &NameserversHandler{ +func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + nameserversHandler := newNameserversHandler(accountManager, authCfg) + router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") +} + +// newNameserversHandler returns a new instance of nameserversHandler handler +func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler { + return &nameserversHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetAllNameservers returns the list of nameserver groups for the account -func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { +// getAllNameservers returns the list of nameserver groups for the account +func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, apiNameservers) } -// CreateNameserverGroup handles nameserver group creation request -func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// createNameserverGroup handles nameserver group creation request +func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// UpdateNameserverGroup handles update to a nameserver group identified by a given ID -func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// updateNameserverGroup handles update to a nameserver group identified by a given ID +func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteNameserverGroup handles nameserver group deletion request -func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { +// deleteNameserverGroup handles nameserver group deletion request +func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetNameserverGroup handles a nameserver group Get request identified by ID -func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { +// getNameserverGroup handles a nameserver group Get request identified by ID +func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go similarity index 95% rename from management/server/http/nameservers_handler_test.go rename to management/server/http/handlers/dns/nameservers_handler_test.go index 98c2e402d84..c6561e4d826 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ Enabled: true, } -func initNameserversTestData() *NameserversHandler { - return &NameserversHandler{ +func initNameserversTestData() *nameserversHandler { + return &nameserversHandler{ accountManager: &mock_server.MockAccountManager{ GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if nsGroupID == existingNSGroupID { @@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET") - router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") + router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/events_handler.go b/management/server/http/handlers/events/events_handler.go similarity index 79% rename from management/server/http/events_handler.go rename to management/server/http/handlers/events/events_handler.go index ee0c63f2822..62da5953524 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -1,28 +1,35 @@ -package http +package events import ( "context" "fmt" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// EventsHandler HTTP handler -type EventsHandler struct { +// handler HTTP handler +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewEventsHandler creates a new EventsHandler HTTP handler -func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler { - return &EventsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + eventsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") +} + +// newHandler creates a new events handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev } } -// GetAllEvents list of the given account -func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { +// getAllEvents list of the given account +func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, events) } -func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { +func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { // build email, name maps based on users userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) if err != nil { diff --git a/management/server/http/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go similarity index 97% rename from management/server/http/events_handler_test.go rename to management/server/http/handlers/events/events_handler_test.go index e525cf2ee01..6af2e534646 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -1,4 +1,4 @@ -package http +package events import ( "context" @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { - return &EventsHandler{ +func initEventsTestData(account string, events ...*activity.Event) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { if accountID == account { @@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllEvents OK", + name: "getAllEvents OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/events/", @@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET") + router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go similarity index 81% rename from management/server/http/groups_handler.go rename to management/server/http/handlers/groups/groups_handler.go index f369d1a0091..e60529cec94 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -1,13 +1,15 @@ -package http +package groups import ( "encoding/json" "net/http" "github.com/gorilla/mux" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/http/configs" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -16,15 +18,24 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -// GroupsHandler is a handler that returns groups of the account -type GroupsHandler struct { +// handler is a handler that returns groups of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGroupsHandler creates a new GroupsHandler HTTP handler -func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *GroupsHandler { - return &GroupsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + groupsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") + router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new groups handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr } } -// GetAllGroups list for the account -func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { +// getAllGroups list for the account +func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -63,8 +74,8 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, groupsResponse) } -// UpdateGroup handles update to a group identified by a given ID -func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { +// updateGroup handles update to a group identified by a given ID +func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +152,8 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// CreateGroup handles group creation request -func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { +// createGroup handles group creation request +func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -189,8 +200,8 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// DeleteGroup handles group deletion request -func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { +// deleteGroup handles group deletion request +func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -215,11 +226,11 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetGroup returns a group -func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { +// getGroup returns a group +func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go similarity index 95% rename from management/server/http/groups_handler_test.go rename to management/server/http/handlers/groups/groups_handler_test.go index 7f3c81f1872..089c1a40f0a 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -1,4 +1,4 @@ -package http +package groups import ( "bytes" @@ -31,8 +31,8 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { - return &GroupsHandler{ +func initGroupTestData(initGroups ...*nbgroup.Group) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { @@ -106,14 +106,14 @@ func TestGetGroup(t *testing.T) { requestBody io.Reader }{ { - name: "GetGroup OK", + name: "getGroup OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/groups/idofthegroup", expectedStatus: http.StatusOK, }, { - name: "GetGroup not found", + name: "getGroup not found", requestType: http.MethodGet, requestPath: "/api/groups/notexists", expectedStatus: http.StatusNotFound, @@ -133,7 +133,7 @@ func TestGetGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET") + router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -254,8 +254,8 @@ func TestWriteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST") - router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT") + router.HandleFunc("/api/groups", p.createGroup).Methods("POST") + router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -331,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE") + router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go similarity index 88% rename from management/server/http/peers_handler.go rename to management/server/http/handlers/peers/peers_handler.go index f5027cd7798..c53cbc038e9 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -1,4 +1,4 @@ -package http +package peers import ( "context" @@ -12,21 +12,30 @@ import ( "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) -// PeersHandler is a handler that returns peers of the account -type PeersHandler struct { +// Handler is a handler that returns peers of the account +type Handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPeersHandler creates a new PeersHandler HTTP handler -func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *PeersHandler { - return &PeersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + peersHandler := NewHandler(accountManager, authCfg) + router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). + Methods("GET", "PUT", "DELETE", "OPTIONS") + router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") +} + +// NewHandler creates a new peers Handler +func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler { + return &Handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -35,7 +44,7 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } -func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { +func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { peerToReturn := peer.Copy() if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected @@ -48,7 +57,7 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { +func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { util.WriteError(ctx, err, w) @@ -75,7 +84,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -120,18 +129,18 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) } -func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { +func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) if err != nil { log.WithContext(ctx).Errorf("failed to delete peer: %v", err) util.WriteError(ctx, err, w) return } - util.WriteJSONObject(ctx, w, emptyObject{}) + util.WriteJSONObject(ctx, w, util.EmptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations -func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { +func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -168,7 +177,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { } // GetAllPeers returns a list of all peers associated with a provided account -func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -219,7 +228,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { for _, peer := range respBody { _, ok := approvedPeersMap[peer.Id] if !ok { @@ -229,7 +238,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv } // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. -func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go similarity index 99% rename from management/server/http/peers_handler_test.go rename to management/server/http/handlers/peers/peers_handler_test.go index dd49c03b848..3e3e39deb60 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -1,4 +1,4 @@ -package http +package peers import ( "bytes" @@ -38,8 +38,8 @@ const ( userIDKey ctxKey = "user_id" ) -func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { - return &PeersHandler{ +func initTestMetaData(peers ...*nbpeer.Peer) *Handler { + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { var p *nbpeer.Peer diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go similarity index 94% rename from management/server/http/geolocation_handler_test.go rename to management/server/http/handlers/policies/geolocation_handler_test.go index 19c916dd2e3..002b914efff 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "context" @@ -11,9 +11,9 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -21,12 +21,12 @@ import ( "github.com/netbirdio/netbird/util" ) -func initGeolocationTestData(t *testing.T) *GeolocationsHandler { +func initGeolocationTestData(t *testing.T) *geolocationsHandler { t.Helper() var ( - mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" - geonamesdbPath = "../testdata/geonames_20240305.db" + mmdbPath = "../../../testdata/GeoLite2-City_20240305.mmdb" + geonamesdbPath = "../../../testdata/geonames_20240305.db" ) tempDir := t.TempDir() @@ -41,7 +41,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) - return &GeolocationsHandler{ + return &geolocationsHandler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil @@ -114,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET") + router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -202,7 +202,7 @@ func TestGetAllCountries(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET") + router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go similarity index 72% rename from management/server/http/geolocations_handler.go rename to management/server/http/handlers/policies/geolocations_handler.go index 418228abfe6..e5bf3e6952d 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "net/http" @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -18,16 +19,22 @@ var ( countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") ) -// GeolocationsHandler is a handler that returns locations. -type GeolocationsHandler struct { +// geolocationsHandler is a handler that returns locations. +type geolocationsHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGeolocationsHandlerHandler creates a new Geolocations handler -func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler { - return &GeolocationsHandler{ +func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") +} + +// newGeolocationsHandlerHandler creates a new Geolocations handler +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { + return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -37,8 +44,8 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca } } -// GetAllCountries retrieves a list of all countries -func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { +// getAllCountries retrieves a list of all countries +func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -63,8 +70,8 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req util.WriteJSONObject(r.Context(), w, countries) } -// GetCitiesByCountry retrieves a list of cities based on the given country code -func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { +// getCitiesByCountry retrieves a list of cities based on the given country code +func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -96,7 +103,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { +func (l *geolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go similarity index 84% rename from management/server/http/policies_handler.go rename to management/server/http/handlers/policies/policies_handler.go index eff9092d45e..a47f2e620f1 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -6,23 +6,36 @@ import ( "strconv" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/geolocation" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// Policies is a handler that returns policy of the account -type Policies struct { +// handler is a handler that returns policy of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPoliciesHandler creates a new Policies handler -func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies { - return &Policies{ +func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + policiesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") + router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") + addPostureCheckEndpoint(accountManager, locationManager, authCfg, router) +} + +// newHandler creates a new policies handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +44,8 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllPolicies list for the account -func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { +// getAllPolicies list for the account +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -65,8 +78,8 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, policies) } -// UpdatePolicy handles update to a policy identified by a given ID -func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { +// updatePolicy handles update to a policy identified by a given ID +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +103,8 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { h.savePolicy(w, r, accountID, userID, policyID) } -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { +// createPolicy handles policy creation request +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -103,7 +116,7 @@ func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { } // savePolicy handles policy creation and update -func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { +func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -245,8 +258,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteJSONObject(r.Context(), w, resp) } -// DeletePolicy handles policy deletion request -func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { +// deletePolicy handles policy deletion request +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -266,11 +279,11 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetPolicy handles a group Get request identified by ID -func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { +// getPolicy handles a group Get request identified by ID +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go similarity index 95% rename from management/server/http/policies_handler_test.go rename to management/server/http/handlers/policies/policies_handler_test.go index f8a897eb27b..4b465a85a98 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -24,12 +24,12 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initPoliciesTestData(policies ...*server.Policy) *Policies { +func initPoliciesTestData(policies ...*server.Policy) *handler { testPolicies := make(map[string]*server.Policy, len(policies)) for _, policy := range policies { testPolicies[policy.ID] = policy } - return &Policies{ + return &handler{ accountManager: &mock_server.MockAccountManager{ GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { policy, ok := testPolicies[policyID] @@ -91,14 +91,14 @@ func TestPoliciesGetPolicy(t *testing.T) { requestBody io.Reader }{ { - name: "GetPolicy OK", + name: "getPolicy OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/policies/idofthepolicy", expectedStatus: http.StatusOK, }, { - name: "GetPolicy not found", + name: "getPolicy not found", requestType: http.MethodGet, requestPath: "/api/policies/notexists", expectedStatus: http.StatusNotFound, @@ -121,7 +121,7 @@ func TestPoliciesGetPolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies/{policyId}", p.GetPolicy).Methods("GET") + router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -269,8 +269,8 @@ func TestPoliciesWritePolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies", p.CreatePolicy).Methods("POST") - router.HandleFunc("/api/policies/{policyId}", p.UpdatePolicy).Methods("PUT") + router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") + router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go similarity index 70% rename from management/server/http/posture_checks_handler.go rename to management/server/http/handlers/policies/posture_checks_handler.go index 2c820429278..44917605ba2 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -9,22 +9,33 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PostureChecksHandler is a handler that returns posture checks of the account. -type PostureChecksHandler struct { +// postureChecksHandler is a handler that returns posture checks of the account. +type postureChecksHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPostureChecksHandler creates a new PostureChecks handler -func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler { - return &PostureChecksHandler{ +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") + addLocationsEndpoint(accountManager, locationManager, authCfg, router) +} + +// newPostureChecksHandler creates a new PostureChecks handler +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { + return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -34,8 +45,8 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa } } -// GetAllPostureChecks list for the account -func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { +// getAllPostureChecks list for the account +func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +68,8 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, postureChecks) } -// UpdatePostureCheck handles update to a posture check identified by a given ID -func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { +// updatePostureCheck handles update to a posture check identified by a given ID +func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -82,8 +93,8 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, postureChecksID) } -// CreatePostureCheck handles posture check creation request -func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { +// createPostureCheck handles posture check creation request +func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -94,8 +105,8 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, "") } -// GetPostureCheck handles a posture check Get request identified by ID -func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { +// getPostureCheck handles a posture check Get request identified by ID +func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -119,8 +130,8 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } -// DeletePostureCheck handles posture check deletion request -func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { +// deletePostureCheck handles posture check deletion request +func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -140,11 +151,11 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { +func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go similarity index 96% rename from management/server/http/posture_checks_handler_test.go rename to management/server/http/handlers/policies/posture_checks_handler_test.go index f400cec8154..e9a539e450a 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -25,13 +25,13 @@ import ( var berlin = "Berlin" var losAngeles = "Los Angeles" -func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler { +func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler { testPostureChecks := make(map[string]*posture.Checks, len(postureChecks)) for _, postureCheck := range postureChecks { testPostureChecks[postureCheck.ID] = postureCheck } - return &PostureChecksHandler{ + return &postureChecksHandler{ accountManager: &mock_server.MockAccountManager{ GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { p, ok := testPostureChecks[postureChecksID] @@ -147,35 +147,35 @@ func TestGetPostureCheck(t *testing.T) { requestBody io.Reader }{ { - name: "GetPostureCheck NBVersion OK", + name: "getPostureCheck NBVersion OK", expectedBody: true, id: postureCheck.ID, checkName: postureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck OSVersion OK", + name: "getPostureCheck OSVersion OK", expectedBody: true, id: osPostureCheck.ID, checkName: osPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck GeoLocation OK", + name: "getPostureCheck GeoLocation OK", expectedBody: true, id: geoPostureCheck.ID, checkName: geoPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck PrivateNetwork OK", + name: "getPostureCheck PrivateNetwork OK", expectedBody: true, id: privateNetworkCheck.ID, checkName: privateNetworkCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck Not Found", + name: "getPostureCheck Not Found", id: "not-exists", expectedStatus: http.StatusNotFound, }, @@ -189,7 +189,7 @@ func TestGetPostureCheck(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET") + router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -231,7 +231,7 @@ func TestPostureCheckUpdate(t *testing.T) { requestType string requestPath string requestBody io.Reader - setupHandlerFunc func(handler *PostureChecksHandler) + setupHandlerFunc func(handler *postureChecksHandler) }{ { name: "Create Posture Checks NB version", @@ -286,7 +286,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -427,7 +427,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -614,7 +614,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -677,7 +677,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -842,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) { } router := mux.NewRouter() - router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST") - router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT") + router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") + router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go similarity index 85% rename from management/server/http/routes_handler.go rename to management/server/http/handlers/routes/routes_handler.go index f44a164e26e..9d420066cc8 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -1,4 +1,4 @@ -package http +package routes import ( "encoding/json" @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -23,15 +24,24 @@ import ( const maxDomains = 32 const failedToConvertRoute = "failed to convert route to response: %v" -// RoutesHandler is the routes handler of the account -type RoutesHandler struct { +// handler is the routes handler of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewRoutesHandler returns a new instance of RoutesHandler handler -func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RoutesHandler { - return &RoutesHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + routesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") + router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") +} + +// newHandler returns a new instance of routes handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -40,8 +50,8 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro } } -// GetAllRoutes returns the list of routes for the account -func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { +// getAllRoutes returns the list of routes for the account +func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -67,8 +77,8 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiRoutes) } -// CreateRoute handles route creation request -func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { +// createRoute handles route creation request +func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -139,7 +149,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { +func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { if req.Network != nil && req.Domains != nil { return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") } @@ -164,8 +174,8 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro return nil } -// UpdateRoute handles update to a route identified by a given ID -func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { +// updateRoute handles update to a route identified by a given ID +func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -257,8 +267,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -// DeleteRoute handles route deletion request -func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { +// deleteRoute handles route deletion request +func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -278,11 +288,11 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetRoute handles a route Get request identified by ID -func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { +// getRoute handles a route Get request identified by ID +func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go similarity index 98% rename from management/server/http/routes_handler_test.go rename to management/server/http/handlers/routes/routes_handler_test.go index 83bd7004d1c..a25c899c960 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -1,4 +1,4 @@ -package http +package routes import ( "bytes" @@ -87,8 +87,8 @@ var testingAccount = &server.Account{ }, } -func initRoutesTestData() *RoutesHandler { - return &RoutesHandler{ +func initRoutesTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { @@ -152,7 +152,7 @@ func initRoutesTestData() *RoutesHandler { return nil }, GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - //return testingAccount, testingAccount.Users["test_user"], nil + // return testingAccount, testingAccount.Users["test_user"], nil return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, @@ -521,10 +521,10 @@ func TestRoutesHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET") - router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE") - router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST") - router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT") + router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") + router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") + router.HandleFunc("/api/routes", p.createRoute).Methods("POST") + router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go similarity index 78% rename from management/server/http/setupkeys_handler.go rename to management/server/http/handlers/setup_keys/setupkeys_handler.go index 9ba5977bb25..9432d554912 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "context" @@ -10,20 +10,30 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// SetupKeysHandler is a handler that returns a list of setup keys of the account -type SetupKeysHandler struct { +// handler is a handler that returns a list of setup keys of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewSetupKeysHandler creates a new SetupKeysHandler HTTP handler -func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeysHandler { - return &SetupKeysHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + keysHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new setup key handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +42,8 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) } } -// CreateSetupKey is a POST requests that creates a new SetupKey -func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { +// createSetupKey is a POST requests that creates a new SetupKey +func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -89,8 +99,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -// GetSetupKey is a GET request to get a SetupKey by ID -func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { +// getSetupKey is a GET request to get a SetupKey by ID +func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -114,8 +124,8 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { writeSuccess(r.Context(), w, key) } -// UpdateSetupKey is a PUT request to update server.SetupKey -func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { +// updateSetupKey is a PUT request to update server.SetupKey +func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -155,8 +165,8 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request writeSuccess(r.Context(), w, newKey) } -// GetAllSetupKeys is a GET request that returns a list of SetupKey -func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { +// getAllSetupKeys is a GET request that returns a list of SetupKey +func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -178,7 +188,7 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { +func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -199,7 +209,7 @@ func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go similarity index 95% rename from management/server/http/setupkeys_handler_test.go rename to management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 09256d0ea5e..516a2ab8b01 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "bytes" @@ -26,12 +26,13 @@ const ( newSetupKeyName = "New Setup Key" updatedSetupKeyName = "KKKey" notFoundSetupKeyID = "notFoundSetupKeyID" + testAccountID = "test_id" ) func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, user *server.User, -) *SetupKeysHandler { - return &SetupKeysHandler{ +) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil @@ -178,11 +179,11 @@ func TestSetupKeysHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/pat_handler.go b/management/server/http/handlers/users/pat_handler.go similarity index 75% rename from management/server/http/pat_handler.go rename to management/server/http/handlers/users/pat_handler.go index dfa9563e3b9..2caf98ad857 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,20 +9,29 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// PATHandler is the nameserver group handler of the account -type PATHandler struct { +// patHandler is the nameserver group handler of the account +type patHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPATsHandler creates a new PATHandler HTTP handler -func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { - return &PATHandler{ +func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + tokenHandler := newPATsHandler(accountManager, authCfg) + router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") +} + +// newPATsHandler creates a new patHandler HTTP handler +func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler { + return &patHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +40,8 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH } } -// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user -func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { +// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user +func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -61,8 +70,8 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, patResponse) } -// GetToken is HTTP GET handler that returns a personal access token for the given user -func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { +// getToken is HTTP GET handler that returns a personal access token for the given user +func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -92,8 +101,8 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) } -// CreateToken is HTTP POST handler that creates a personal access token for the given user -func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { +// createToken is HTTP POST handler that creates a personal access token for the given user +func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -124,8 +133,8 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) } -// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user -func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { +// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user +func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -152,7 +161,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { diff --git a/management/server/http/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go similarity index 96% rename from management/server/http/pat_handler_test.go rename to management/server/http/handlers/users/pat_handler_test.go index c28228a506e..ef6fb973edc 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -61,8 +61,8 @@ var testAccount = &server.Account{ }, } -func initPATTestData() *PATHandler { - return &PATHandler{ +func initPATTestData() *patHandler { + return &patHandler{ accountManager: &mock_server.MockAccountManager{ CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { @@ -186,10 +186,10 @@ func TestTokenHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/users_handler.go b/management/server/http/handlers/users/users_handler.go similarity index 80% rename from management/server/http/users_handler.go rename to management/server/http/handlers/users/users_handler.go index 6e151a0da3a..c843bc52b08 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" @@ -16,15 +17,25 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// UsersHandler is a handler that returns users of the account -type UsersHandler struct { +// handler is a handler that returns users of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewUsersHandler creates a new UsersHandler HTTP handler -func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *UsersHandler { - return &UsersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + userHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + addUsersTokensEndpoint(accountManager, authCfg, router) +} + +// newHandler creates a new UsersHandler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Use } } -// UpdateUser is a PUT requests to update User data -func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { +// updateUser is a PUT requests to update User data +func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -94,8 +105,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// DeleteUser is a DELETE request to delete a user -func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { +// deleteUser is a DELETE request to delete a user +func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -121,11 +132,11 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). -func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { +// createUser creates a User in the system with a status "invited" (effectively this is a user invite). +func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -175,9 +186,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// GetAllUsers returns a list of users of the account this user belongs to. +// getAllUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. -func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { +func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -222,9 +233,9 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, users) } -// InviteUser resend invitations to users who haven't activated their accounts, +// inviteUser resend invitations to users who haven't activated their accounts, // prior to the expiration period. -func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -250,7 +261,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { diff --git a/management/server/http/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go similarity index 97% rename from management/server/http/users_handler_test.go rename to management/server/http/handlers/users/users_handler_test.go index f3d989da19f..6f6a912360d 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -61,8 +61,8 @@ var usersTestAccount = &server.Account{ }, } -func initUsersTestData() *UsersHandler { - return &UsersHandler{ +func initUsersTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return usersTestAccount.Id, claims.UserId, nil @@ -147,7 +147,7 @@ func TestGetUsers(t *testing.T) { requestPath string expectedUserIDs []string }{ - {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, + {name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, {name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}}, {name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}}, } @@ -159,7 +159,7 @@ func TestGetUsers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - userHandler.GetAllUsers(recorder, req) + userHandler.getAllUsers(recorder, req) res := recorder.Result() defer res.Body.Close() @@ -265,7 +265,7 @@ func TestUpdateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT") + router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -356,7 +356,7 @@ func TestCreateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - userHandler.CreateUser(rr, req) + userHandler.createUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -401,7 +401,7 @@ func TestInviteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.InviteUser(rr, req) + userHandler.inviteUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -454,7 +454,7 @@ func TestDeleteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.DeleteUser(rr, req) + userHandler.deleteUser(rr, req) res := rr.Result() defer res.Body.Close() diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 603c1c6963c..3d7eed49871 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -14,6 +14,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// EmptyObject is an empty struct used to return empty JSON object +type EmptyObject struct { +} + type ErrorResponse struct { Message string `json:"message"` Code int `json:"code"` From dcba6a6b7e5bcb8980860fee4d634002a9270e56 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 11 Dec 2024 16:46:51 +0100 Subject: [PATCH 05/19] fix: client/Dockerfile to reduce vulnerabilities (#3019) The following vulnerabilities are fixed with an upgrade: - https://snyk.io/vuln/SNYK-ALPINE320-OPENSSL-8235201 - https://snyk.io/vuln/SNYK-ALPINE320-OPENSSL-8235201 Co-authored-by: snyk-bot --- client/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/Dockerfile b/client/Dockerfile index b9f7c135575..2f5ff14ae87 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.20 +FROM alpine:3.21.0 RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] From a4a30744adab711799e10e1355c20b73c81fb2b3 Mon Sep 17 00:00:00 2001 From: "M. Essam" Date: Sat, 14 Dec 2024 22:17:53 +0200 Subject: [PATCH 06/19] Fix race condition with systray ready (#2993) --- client/ui/client_ui.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d046bab5f1f..8ca0db73f38 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -572,6 +572,7 @@ func (s *serviceClient) onTrayReady() { s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() + time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon for { err := s.updateStatus() if err != nil { From 287ae811958c362b9acaf888f07968b7960964d2 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 14 Dec 2024 21:18:46 +0100 Subject: [PATCH 07/19] [misc] split tests with management and rest (#3051) optimize go cache for tests --- .github/workflows/golang-test-darwin.yml | 6 +- .github/workflows/golang-test-linux.yml | 192 +++++++++++++++++++--- .github/workflows/golang-test-windows.yml | 25 ++- 3 files changed, 197 insertions(+), 26 deletions(-) diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 88db8c5e89f..2dbeb106abb 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -21,6 +21,7 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.x" + cache: false - name: Checkout code uses: actions/checkout@v4 @@ -28,8 +29,9 @@ jobs: uses: actions/cache@v4 with: path: ~/go/pkg/mod - key: macos-go-${{ hashFiles('**/go.sum') }} + key: macos-gotest-${{ hashFiles('**/go.sum') }} restore-keys: | + macos-gotest- macos-go- - name: Install libpcap @@ -42,4 +44,4 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 36dcb791f76..85658d237a2 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -11,31 +11,164 @@ concurrency: cancel-in-progress: true jobs: + build-cache: + runs-on: ubuntu-22.04 + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache@v4 + id: cache + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + + + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: steps.cache.outputs.cache-hit != 'true' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Build client + if: steps.cache.outputs.cache-hit != 'true' + working-directory: client + run: CGO_ENABLED=1 go build . + + - name: Build client 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: client + run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 . + + - name: Build management + if: steps.cache.outputs.cache-hit != 'true' + working-directory: management + run: CGO_ENABLED=1 go build . + + - name: Build management 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: management + run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 . + + - name: Build signal + if: steps.cache.outputs.cache-hit != 'true' + working-directory: signal + run: CGO_ENABLED=1 go build . + + - name: Build signal 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: signal + run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 . + + - name: Build relay + if: steps.cache.outputs.cache-hit != 'true' + working-directory: relay + run: CGO_ENABLED=1 go build . + + - name: Build relay 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: relay + run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . + test: + needs: [build-cache] strategy: fail-fast: false matrix: arch: [ '386','amd64' ] - store: [ 'sqlite', 'postgres'] runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 with: go-version: "1.23.x" + cache: false + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV - name: Cache Go modules - uses: actions/cache@v4 + uses: actions/cache/restore@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go- + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) + + test_management: + needs: [ build-cache ] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres'] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev @@ -50,9 +183,10 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) benchmark: + needs: [ build-cache ] strategy: fail-fast: false matrix: @@ -64,18 +198,25 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.x" + cache: false + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV - name: Cache Go modules - uses: actions/cache@v4 + uses: actions/cache/restore@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go- - - - name: Checkout code - uses: actions/checkout@v4 + ${{ runner.os }}-gotest-cache- - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev @@ -91,26 +232,35 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m ./... test_client_on_docker: + needs: [ build-cache ] runs-on: ubuntu-20.04 steps: - name: Install Go uses: actions/setup-go@v5 with: go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV - name: Cache Go modules - uses: actions/cache@v4 + uses: actions/cache/restore@v4 with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go- - - - name: Checkout code - uses: actions/checkout@v4 + ${{ runner.os }}-gotest-cache- - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index d378bec3fe4..3a3c470525f 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -24,6 +24,23 @@ jobs: id: go with: go-version: "1.23.x" + cache: false + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest- + ${{ runner.os }}-go- - name: Download wintun uses: carlosperate/download-file-action@v2 @@ -42,11 +59,13 @@ jobs: - run: choco install -y sysinternals --ignore-checksums - run: choco install -y mingw - - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy + - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt From f591e4740475e7bc0028fe91f27e4865ab699ec0 Mon Sep 17 00:00:00 2001 From: "M. Essam" Date: Mon, 16 Dec 2024 10:41:36 +0200 Subject: [PATCH 08/19] Handle DNF5 install script (#3026) --- release_files/install.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/release_files/install.sh b/release_files/install.sh index b0fec27339c..bb917c39a96 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -239,7 +239,12 @@ install_netbird() { dnf) add_rpm_repo ${SUDO} dnf -y install dnf-plugin-config-manager - ${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + if [[ "$(dnf --version | head -n1 | cut -d. -f1)" > "4" ]]; + then + ${SUDO} dnf config-manager addrepo --from-repofile=/etc/yum.repos.d/netbird.repo + else + ${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + fi ${SUDO} dnf -y install netbird if ! $SKIP_UI_APP; then From 3844516aa7722e3a9cdaf478bfb409caed701b8a Mon Sep 17 00:00:00 2001 From: Jesse R Codling Date: Mon, 16 Dec 2024 03:58:54 -0500 Subject: [PATCH 09/19] [client] fix: reformat IPv6 ICE addresses when punching (#3050) Should fix #2327 and #2606 by checking for IPv6 addresses from ICE --- client/internal/peer/worker_ice.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 7ce4797c3ff..4cdd18ff106 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -264,7 +264,13 @@ func (w *WorkerICE) closeAgent(cancel context.CancelFunc) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort)) + addrString := pair.Remote.Address() + parsed, err := netip.ParseAddr(addrString) + if (err == nil) && (parsed.Is6()) { + addrString = fmt.Sprintf("[%s]", addrString) + //IPv6 Literals need to be wrapped in brackets for Resolve*Addr() + } + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort)) if err != nil { w.log.Warnf("got an error while resolving the udp address, err: %s", err) return From 9eff58ae62c094363172a0bf2cdc1557b13b4a20 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 16 Dec 2024 10:30:41 +0100 Subject: [PATCH 10/19] Upgrade x/crypto package (#3055) Mitigates the CVE-2024-45337 --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 2b4111ce32f..c504925d200 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.28.0 - golang.org/x/sys v0.26.0 + golang.org/x/crypto v0.31.0 + golang.org/x/sys v0.28.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -92,8 +92,8 @@ require ( golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.8.0 - golang.org/x/term v0.25.0 + golang.org/x/sync v0.10.0 + golang.org/x/term v0.27.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -219,7 +219,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.19.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 35abe82d239..04d2bc59af8 100644 --- a/go.sum +++ b/go.sum @@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 703647da1e59636ac4affe34261c27f4e18b8878 Mon Sep 17 00:00:00 2001 From: "VYSE V.E.O" Date: Mon, 16 Dec 2024 21:17:46 +0800 Subject: [PATCH 11/19] fix client unsupported h2 protocol when only 443 activated (#3009) When I remove 80 http port in Caddyfile, netbird client cannot connect server:443. Logs show error below: {"level":"debug","ts":1733809631.4012625,"logger":"http.stdlib","msg":"http: TLS handshake error from redacted:41580: tls: client requested unsupported application protocols ([h2])"} I wonder here h2 protocol is absent. --- infrastructure_files/getting-started-with-zitadel.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 7793d1fda1d..9b80058c27a 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -530,7 +530,7 @@ renderCaddyfile() { { debug servers :80,:443 { - protocols h1 h2c h3 + protocols h1 h2c h2 h3 } } From 37ad370344aec96c01d461db53473c78ed7f7240 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 16 Dec 2024 18:09:31 +0100 Subject: [PATCH 12/19] [client] Avoid using iota on mixed const block (#3057) Used the values as resolved when the first iota value was the second const in the block. --- client/iface/device/kernel_module_linux.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/client/iface/device/kernel_module_linux.go b/client/iface/device/kernel_module_linux.go index 0d195779dfe..b28ddd36c67 100644 --- a/client/iface/device/kernel_module_linux.go +++ b/client/iface/device/kernel_module_linux.go @@ -27,14 +27,14 @@ import ( type status int const ( - defaultModuleDir = "/lib/modules" - unknown status = iota - unloaded - unloading - loading - live - inuse - envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED" + unknown status = 1 + unloaded status = 2 + unloading status = 3 + loading status = 4 + live status = 5 + inuse status = 6 + defaultModuleDir = "/lib/modules" + envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED" ) type module struct { From ddc365f7a0d52ab1d10d44d76f70176040e2fe64 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:30:28 +0100 Subject: [PATCH 13/19] [client, management] Add new network concept (#3047) --------- Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: bcmmbaga Co-authored-by: Maycon Santos Co-authored-by: Zoltan Papp --- .github/workflows/golang-test-linux.yml | 2 +- .github/workflows/golangci-lint.yml | 4 +- client/anonymize/anonymize.go | 15 +- client/anonymize/anonymize_test.go | 56 +- client/cmd/networks.go | 173 ++ client/cmd/root.go | 6 +- client/cmd/route.go | 174 -- client/cmd/status.go | 44 +- client/cmd/status_test.go | 28 +- client/cmd/testutil_test.go | 6 +- client/internal/dns/handler_chain.go | 222 +++ client/internal/dns/handler_chain_test.go | 511 ++++++ client/internal/dns/local.go | 14 +- client/internal/dns/mock_server.go | 22 +- client/internal/dns/server.go | 104 +- client/internal/dns/server_test.go | 91 +- client/internal/dns/service_listener.go | 1 + client/internal/dns/upstream.go | 9 + client/internal/dnsfwd/forwarder.go | 157 ++ client/internal/dnsfwd/manager.go | 106 ++ client/internal/engine.go | 155 +- client/internal/engine_test.go | 27 +- client/internal/peer/conn.go | 5 + client/internal/peer/status.go | 28 +- client/internal/peerstore/store.go | 87 + client/internal/routemanager/client.go | 81 +- .../routemanager/dnsinterceptor/handler.go | 356 ++++ client/internal/routemanager/dynamic/route.go | 8 +- client/internal/routemanager/manager.go | 105 +- client/internal/routemanager/manager_test.go | 6 +- client/internal/routemanager/mock.go | 31 +- client/ios/NetBirdSDK/client.go | 17 +- client/proto/daemon.pb.go | 620 +++---- client/proto/daemon.proto | 28 +- client/proto/daemon_grpc.pb.go | 88 +- client/server/{route.go => network.go} | 67 +- client/server/server.go | 4 +- client/server/server_test.go | 6 +- client/ui/client_ui.go | 8 +- client/ui/{route.go => network.go} | 164 +- dns/dns.go | 6 + go.mod | 3 +- go.sum | 5 +- management/client/client_test.go | 6 +- management/cmd/management.go | 24 +- management/cmd/migration_up.go | 4 +- management/proto/management.pb.go | 527 +++--- management/proto/management.proto | 9 + management/server/account.go | 1105 ++---------- management/server/account_request_buffer.go | 11 +- management/server/account_test.go | 345 ++-- management/server/activity/codes.go | 36 + management/server/config.go | 3 +- management/server/dns.go | 141 +- management/server/dns_test.go | 39 +- management/server/ephemeral.go | 10 +- management/server/ephemeral_test.go | 12 +- management/server/group.go | 237 ++- management/server/group_test.go | 76 +- management/server/groups/manager.go | 196 +++ management/server/grpcserver.go | 33 +- management/server/http/api/openapi.yml | 697 +++++++- management/server/http/api/types.gen.go | 194 ++- management/server/http/handler.go | 8 +- .../handlers/accounts/accounts_handler.go | 9 +- .../accounts/accounts_handler_test.go | 78 +- .../http/handlers/dns/dns_settings_handler.go | 3 +- .../handlers/dns/dns_settings_handler_test.go | 14 +- .../handlers/events/events_handler_test.go | 8 +- .../http/handlers/groups/groups_handler.go | 49 +- .../handlers/groups/groups_handler_test.go | 24 +- .../server/http/handlers/networks/handler.go | 321 ++++ .../handlers/networks/resources_handler.go | 222 +++ .../http/handlers/networks/routers_handler.go | 165 ++ .../http/handlers/peers/peers_handler.go | 47 +- .../http/handlers/peers/peers_handler_test.go | 27 +- .../policies/geolocation_handler_test.go | 6 +- .../handlers/policies/policies_handler.go | 101 +- .../policies/policies_handler_test.go | 47 +- .../http/handlers/routes/routes_handler.go | 2 +- .../handlers/routes/routes_handler_test.go | 40 +- .../handlers/setup_keys/setupkeys_handler.go | 13 +- .../setup_keys/setupkeys_handler_test.go | 26 +- .../server/http/handlers/users/pat_handler.go | 5 +- .../http/handlers/users/pat_handler_test.go | 24 +- .../http/handlers/users/users_handler.go | 15 +- .../http/handlers/users/users_handler_test.go | 32 +- .../server/http/middleware/access_control.go | 4 +- .../server/http/middleware/auth_middleware.go | 4 +- .../http/middleware/auth_middleware_test.go | 10 +- management/server/integrated_validator.go | 8 +- .../server/integrated_validator/interface.go | 4 +- management/server/management_proto_test.go | 11 +- management/server/management_test.go | 10 +- management/server/metrics/selfhosted.go | 7 +- management/server/metrics/selfhosted_test.go | 58 +- management/server/migration/migration_test.go | 66 +- management/server/mock_server/account_mock.go | 162 +- management/server/nameserver.go | 63 +- management/server/nameserver_test.go | 15 +- management/server/networks/manager.go | 187 +++ management/server/networks/manager_test.go | 254 +++ .../server/networks/resources/manager.go | 383 +++++ .../server/networks/resources/manager_test.go | 411 +++++ .../networks/resources/types/resource.go | 169 ++ .../networks/resources/types/resource_test.go | 53 + management/server/networks/routers/manager.go | 289 ++++ .../server/networks/routers/manager_test.go | 234 +++ .../server/networks/routers/types/router.go | 75 + .../networks/routers/types/router_test.go | 100 ++ management/server/networks/types/network.go | 56 + management/server/peer.go | 104 +- management/server/peer_test.go | 243 ++- management/server/permissions/manager.go | 102 ++ management/server/policy.go | 467 +----- management/server/policy_test.go | 192 +-- management/server/posture_checks.go | 54 +- management/server/posture_checks_test.go | 51 +- management/server/resource.go | 21 + management/server/route.go | 304 +--- management/server/route_test.go | 620 ++++++- management/server/settings/manager.go | 37 + management/server/setupkey.go | 241 +-- management/server/setupkey_test.go | 60 +- management/server/status/error.go | 32 + management/server/{ => store}/file_store.go | 32 +- management/server/{ => store}/sql_store.go | 461 ++++-- .../server/{ => store}/sql_store_test.go | 696 ++++++-- management/server/{ => store}/store.go | 119 +- management/server/{ => store}/store_test.go | 10 +- management/server/testdata/networks.sql | 18 + management/server/testdata/store.sql | 14 + management/server/types/account.go | 1475 +++++++++++++++++ management/server/types/account_test.go | 375 +++++ management/server/types/dns_settings.go | 16 + management/server/types/firewall_rule.go | 130 ++ management/server/{group => types}/group.go | 44 +- .../server/{group => types}/group_test.go | 2 +- management/server/{ => types}/network.go | 12 +- management/server/{ => types}/network_test.go | 2 +- .../{ => types}/personal_access_token.go | 2 +- .../{ => types}/personal_access_token_test.go | 2 +- management/server/types/policy.go | 125 ++ management/server/types/policyrule.go | 87 + management/server/types/resource.go | 30 + .../server/types/route_firewall_rule.go | 32 + management/server/types/settings.go | 68 + management/server/types/setupkey.go | 181 ++ management/server/types/user.go | 231 +++ management/server/updatechannel.go | 3 +- management/server/user.go | 356 +--- management/server/user_test.go | 442 +++-- management/server/users/manager.go | 49 + management/server/util/util.go | 16 + route/route.go | 58 +- 155 files changed, 13903 insertions(+), 4987 deletions(-) create mode 100644 client/cmd/networks.go delete mode 100644 client/cmd/route.go create mode 100644 client/internal/dns/handler_chain.go create mode 100644 client/internal/dns/handler_chain_test.go create mode 100644 client/internal/dnsfwd/forwarder.go create mode 100644 client/internal/dnsfwd/manager.go create mode 100644 client/internal/peerstore/store.go create mode 100644 client/internal/routemanager/dnsinterceptor/handler.go rename client/server/{route.go => network.go} (58%) rename client/ui/{route.go => network.go} (56%) create mode 100644 management/server/groups/manager.go create mode 100644 management/server/http/handlers/networks/handler.go create mode 100644 management/server/http/handlers/networks/resources_handler.go create mode 100644 management/server/http/handlers/networks/routers_handler.go create mode 100644 management/server/networks/manager.go create mode 100644 management/server/networks/manager_test.go create mode 100644 management/server/networks/resources/manager.go create mode 100644 management/server/networks/resources/manager_test.go create mode 100644 management/server/networks/resources/types/resource.go create mode 100644 management/server/networks/resources/types/resource_test.go create mode 100644 management/server/networks/routers/manager.go create mode 100644 management/server/networks/routers/manager_test.go create mode 100644 management/server/networks/routers/types/router.go create mode 100644 management/server/networks/routers/types/router_test.go create mode 100644 management/server/networks/types/network.go create mode 100644 management/server/permissions/manager.go create mode 100644 management/server/resource.go create mode 100644 management/server/settings/manager.go rename management/server/{ => store}/file_store.go (89%) rename management/server/{ => store}/sql_store.go (76%) rename management/server/{ => store}/sql_store_test.go (75%) rename management/server/{ => store}/store.go (73%) rename management/server/{ => store}/store_test.go (93%) create mode 100644 management/server/testdata/networks.sql create mode 100644 management/server/types/account.go create mode 100644 management/server/types/account_test.go create mode 100644 management/server/types/dns_settings.go create mode 100644 management/server/types/firewall_rule.go rename management/server/{group => types}/group.go (58%) rename management/server/{group => types}/group_test.go (99%) rename management/server/{ => types}/network.go (96%) rename management/server/{ => types}/network_test.go (98%) rename management/server/{ => types}/personal_access_token.go (99%) rename management/server/{ => types}/personal_access_token_test.go (98%) create mode 100644 management/server/types/policy.go create mode 100644 management/server/types/policyrule.go create mode 100644 management/server/types/resource.go create mode 100644 management/server/types/route_firewall_rule.go create mode 100644 management/server/types/settings.go create mode 100644 management/server/types/setupkey.go create mode 100644 management/server/types/user.go create mode 100644 management/server/users/manager.go create mode 100644 management/server/util/util.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 85658d237a2..50c6cba848c 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -232,7 +232,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... test_client_on_docker: needs: [ build-cache ] diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index dacb1922be9..89defce3289 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -46,7 +46,7 @@ jobs: if: matrix.os == 'ubuntu-latest' run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - args: --timeout=12m + args: --timeout=12m --out-format colored-line-number diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 9a6d9720794..89552724ae7 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -21,6 +21,8 @@ type Anonymizer struct { currentAnonIPv6 netip.Addr startAnonIPv4 netip.Addr startAnonIPv6 netip.Addr + + domainKeyRegex *regexp.Regexp } func DefaultAddresses() (netip.Addr, netip.Addr) { @@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { currentAnonIPv6: startIPv6, startAnonIPv4: startIPv4, startAnonIPv6: startIPv6, + + domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`), } } @@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string { return re.ReplaceAllStringFunc(text, a.AnonymizeURI) } -// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string. func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { - domainPattern := `dns\.Question{Name:"([^"]+)",` - domainRegex := regexp.MustCompile(domainPattern) - - return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string { - parts := strings.Split(match, `"`) + return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string { + parts := strings.SplitN(match, "=", 2) if len(parts) >= 2 { domain := parts[1] if strings.HasSuffix(domain, anonTLD) { return match } - randomDomain := generateRandomString(10) + anonTLD - return strings.Replace(match, domain, randomDomain, 1) + return "domain=" + a.AnonymizeDomain(domain) } return match }) diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index a3aae1ee982..ff2e4886943 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) { func TestAnonymizeDNSLogLine(t *testing.T) { anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) - testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}` + tests := []struct { + name string + input string + original string + expect string + }{ + { + name: "Basic domain with trailing content", + input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123", + original: "example.com", + expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`, + }, + { + name: "Domain with trailing dot", + input: "domain=example.com. processing request with status=pending", + original: "example.com", + expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`, + }, + { + name: "Multiple domains in log", + input: "forward domain=first.com status=ok, redirect to domain=second.com port=443", + original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately + expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`, + }, + { + name: "Already anonymized domain", + input: "got request domain=anon-xyz123.domain from=client1 to=server2", + original: "", // nothing should be anonymized + expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`, + }, + { + name: "Subdomain with trailing dot", + input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp", + original: "example.com", + expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`, + }, + { + name: "Handler chain pattern log", + input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100", + original: "example.com", + expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`, + }, + } - result := anonymizer.AnonymizeDNSLogLine(testLog) - require.NotEqual(t, testLog, result) - assert.NotContains(t, result, "example.com") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeDNSLogLine(tc.input) + if tc.original != "" { + assert.NotContains(t, result, tc.original) + } + assert.Regexp(t, tc.expect, result) + }) + } } func TestAnonymizeDomain(t *testing.T) { diff --git a/client/cmd/networks.go b/client/cmd/networks.go new file mode 100644 index 00000000000..7b9724bc595 --- /dev/null +++ b/client/cmd/networks.go @@ -0,0 +1,173 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var appendFlag bool + +var networksCMD = &cobra.Command{ + Use: "networks", + Aliases: []string{"routes"}, + Short: "Manage networks", + Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`, +} + +var routesListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List networks", + Example: " netbird networks list", + Long: "List all available network routes.", + RunE: networksList, +} + +var routesSelectCmd = &cobra.Command{ + Use: "select network...|all", + Short: "Select network", + Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.", + Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3", + Args: cobra.MinimumNArgs(1), + RunE: networksSelect, +} + +var routesDeselectCmd = &cobra.Command{ + Use: "deselect network...|all", + Short: "Deselect networks", + Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.", + Example: " netbird networks deselect all\n netbird networks deselect route1 route2", + Args: cobra.MinimumNArgs(1), + RunE: networksDeselect, +} + +func init() { + routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing") +} + +func networksList(cmd *cobra.Command, _ []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{}) + if err != nil { + return fmt.Errorf("failed to list network: %v", status.Convert(err).Message()) + } + + if len(resp.Routes) == 0 { + cmd.Println("No networks available.") + return nil + } + + printNetworks(cmd, resp) + + return nil +} + +func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) { + cmd.Println("Available Networks:") + for _, route := range resp.Routes { + printNetwork(cmd, route) + } +} + +func printNetwork(cmd *cobra.Command, route *proto.Network) { + selectedStatus := getSelectedStatus(route) + domains := route.GetDomains() + + if len(domains) > 0 { + printDomainRoute(cmd, route, domains, selectedStatus) + } else { + printNetworkRoute(cmd, route, selectedStatus) + } +} + +func getSelectedStatus(route *proto.Network) string { + if route.GetSelected() { + return "Selected" + } + return "Not Selected" +} + +func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) + resolvedIPs := route.GetResolvedIPs() + + if len(resolvedIPs) > 0 { + printResolvedIPs(cmd, domains, resolvedIPs) + } else { + cmd.Printf(" Resolved IPs: -\n") + } +} + +func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus) +} + +func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) { + cmd.Printf(" Resolved IPs:\n") + for resolvedDomain, ipList := range resolvedIPs { + cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", ")) + } +} + +func networksSelect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } else if appendFlag { + req.Append = true + } + + if _, err := client.SelectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks selected successfully.") + + return nil +} + +func networksDeselect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } + + if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks deselected successfully.") + + return nil +} diff --git a/client/cmd/root.go b/client/cmd/root.go index 3f2d04ef304..0305bacc88f 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -142,14 +142,14 @@ func init() { rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(sshCmd) - rootCmd.AddCommand(routesCmd) + rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(debugCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service - routesCmd.AddCommand(routesListCmd) - routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd) + networksCMD.AddCommand(routesListCmd) + networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(logCmd) diff --git a/client/cmd/route.go b/client/cmd/route.go deleted file mode 100644 index c8881822b30..00000000000 --- a/client/cmd/route.go +++ /dev/null @@ -1,174 +0,0 @@ -package cmd - -import ( - "fmt" - "strings" - - "github.com/spf13/cobra" - "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/proto" -) - -var appendFlag bool - -var routesCmd = &cobra.Command{ - Use: "routes", - Short: "Manage network routes", - Long: `Commands to list, select, or deselect network routes.`, -} - -var routesListCmd = &cobra.Command{ - Use: "list", - Aliases: []string{"ls"}, - Short: "List routes", - Example: " netbird routes list", - Long: "List all available network routes.", - RunE: routesList, -} - -var routesSelectCmd = &cobra.Command{ - Use: "select route...|all", - Short: "Select routes", - Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.", - Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3", - Args: cobra.MinimumNArgs(1), - RunE: routesSelect, -} - -var routesDeselectCmd = &cobra.Command{ - Use: "deselect route...|all", - Short: "Deselect routes", - Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.", - Example: " netbird routes deselect all\n netbird routes deselect route1 route2", - Args: cobra.MinimumNArgs(1), - RunE: routesDeselect, -} - -func init() { - routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing") -} - -func routesList(cmd *cobra.Command, _ []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{}) - if err != nil { - return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message()) - } - - if len(resp.Routes) == 0 { - cmd.Println("No routes available.") - return nil - } - - printRoutes(cmd, resp) - - return nil -} - -func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) { - cmd.Println("Available Routes:") - for _, route := range resp.Routes { - printRoute(cmd, route) - } -} - -func printRoute(cmd *cobra.Command, route *proto.Route) { - selectedStatus := getSelectedStatus(route) - domains := route.GetDomains() - - if len(domains) > 0 { - printDomainRoute(cmd, route, domains, selectedStatus) - } else { - printNetworkRoute(cmd, route, selectedStatus) - } -} - -func getSelectedStatus(route *proto.Route) string { - if route.GetSelected() { - return "Selected" - } - return "Not Selected" -} - -func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) - resolvedIPs := route.GetResolvedIPs() - - if len(resolvedIPs) > 0 { - printResolvedIPs(cmd, domains, resolvedIPs) - } else { - cmd.Printf(" Resolved IPs: -\n") - } -} - -func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus) -} - -func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) { - cmd.Printf(" Resolved IPs:\n") - for _, domain := range domains { - if ipList, exists := resolvedIPs[domain]; exists { - cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", ")) - } - } -} - -func routesSelect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } else if appendFlag { - req.Append = true - } - - if _, err := client.SelectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes selected successfully.") - - return nil -} - -func routesDeselect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } - - if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes deselected successfully.") - - return nil -} diff --git a/client/cmd/status.go b/client/cmd/status.go index 6db52a67795..fa4bff77ba8 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -40,6 +40,7 @@ type peerStateDetailOutput struct { Latency time.Duration `json:"latency" yaml:"latency"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` } type peersStateOutput struct { @@ -98,6 +99,7 @@ type statusOutputOverview struct { RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` } @@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), - Routes: pbFullStatus.GetLocalPeerState().GetRoutes(), + Routes: pbFullStatus.GetLocalPeerState().GetNetworks(), + Networks: pbFullStatus.GetLocalPeerState().GetNetworks(), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), } @@ -390,7 +393,8 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { TransferSent: transferSent, Latency: pbPeerState.GetLatency().AsDuration(), RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), - Routes: pbPeerState.GetRoutes(), + Routes: pbPeerState.GetNetworks(), + Networks: pbPeerState.GetNetworks(), } peersStateDetail = append(peersStateDetail, peerState) @@ -491,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) } - routes := "-" - if len(overview.Routes) > 0 { - sort.Strings(overview.Routes) - routes = strings.Join(overview.Routes, ", ") + networks := "-" + if len(overview.Networks) > 0 { + sort.Strings(overview.Networks) + networks = strings.Join(overview.Networks, ", ") } var dnsServersString string @@ -556,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Routes: %s\n"+ + "Networks: %s\n"+ "Peers count: %s\n", fmt.Sprintf("%s/%s%s", goos, goarch, goarm), overview.DaemonVersion, @@ -568,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays interfaceIP, interfaceTypeString, rosenpassEnabledStatus, - routes, + networks, + networks, peersCountString, ) return summary @@ -631,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo } } - routes := "-" - if len(peerState.Routes) > 0 { - sort.Strings(peerState.Routes) - routes = strings.Join(peerState.Routes, ", ") + networks := "-" + if len(peerState.Networks) > 0 { + sort.Strings(peerState.Networks) + networks = strings.Join(peerState.Networks, ", ") } peerString := fmt.Sprintf( @@ -652,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Transfer status (received/sent) %s/%s\n"+ " Quantum resistance: %s\n"+ " Routes: %s\n"+ + " Networks: %s\n"+ " Latency: %s\n", peerState.FQDN, peerState.IP, @@ -668,7 +675,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo toIEC(peerState.TransferReceived), toIEC(peerState.TransferSent), rosenpassEnabledStatus, - routes, + networks, + networks, peerState.Latency.String(), ) @@ -810,6 +818,14 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeIPString(route) + } + + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range peer.Routes { peer.Routes[i] = a.AnonymizeIPString(route) } @@ -850,6 +866,10 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) } } + for i, route := range overview.Networks { + overview.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range overview.Routes { overview.Routes[i] = a.AnonymizeRoute(route) } diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index ca43df8a523..1f1e957263c 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{ LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), BytesRx: 200, BytesTx: 100, - Routes: []string{ + Networks: []string{ "10.1.0.0/24", }, Latency: durationpb.New(time.Duration(10000000)), @@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{ PubKey: "Some-Pub-Key", KernelInterface: true, Fqdn: "some-localhost.awesome-domain.com", - Routes: []string{ + Networks: []string{ "10.10.0.0/24", }, }, @@ -149,6 +149,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.1.0.0/24", }, + Networks: []string{ + "10.1.0.0/24", + }, Latency: time.Duration(10000000), }, { @@ -230,6 +233,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.10.0.0/24", }, + Networks: []string{ + "10.10.0.0/24", + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -295,6 +301,9 @@ func TestParsingToJSON(t *testing.T) { "quantumResistance": false, "routes": [ "10.1.0.0/24" + ], + "networks": [ + "10.1.0.0/24" ] }, { @@ -318,7 +327,8 @@ func TestParsingToJSON(t *testing.T) { "transferSent": 1000, "latency": 10000000, "quantumResistance": false, - "routes": null + "routes": null, + "networks": null } ] }, @@ -359,6 +369,9 @@ func TestParsingToJSON(t *testing.T) { "routes": [ "10.10.0.0/24" ], + "networks": [ + "10.10.0.0/24" + ], "dnsServers": [ { "servers": [ @@ -418,6 +431,8 @@ func TestParsingToYAML(t *testing.T) { quantumResistance: false routes: - 10.1.0.0/24 + networks: + - 10.1.0.0/24 - fqdn: peer-2.awesome-domain.com netbirdIp: 192.168.178.102 publicKey: Pubkey2 @@ -437,6 +452,7 @@ func TestParsingToYAML(t *testing.T) { latency: 10ms quantumResistance: false routes: [] + networks: [] cliVersion: development daemonVersion: 0.14.1 management: @@ -465,6 +481,8 @@ quantumResistance: false quantumResistancePermissive: false routes: - 10.10.0.0/24 +networks: + - 10.10.0.0/24 dnsServers: - servers: - 8.8.8.8:53 @@ -509,6 +527,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 200 B/100 B Quantum resistance: false Routes: 10.1.0.0/24 + Networks: 10.1.0.0/24 Latency: 10ms peer-2.awesome-domain.com: @@ -525,6 +544,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 2.0 KiB/1000 B Quantum resistance: false Routes: - + Networks: - Latency: 10ms OS: %s/%s @@ -543,6 +563,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) @@ -564,6 +585,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected ` diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d998f9ea9e6..e3e644357e7 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -10,6 +10,8 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -71,7 +73,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -93,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go new file mode 100644 index 00000000000..c6ac3ebd61e --- /dev/null +++ b/client/internal/dns/handler_chain.go @@ -0,0 +1,222 @@ +package dns + +import ( + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" +) + +const ( + PriorityDNSRoute = 100 + PriorityMatchDomain = 50 + PriorityDefault = 0 +) + +type SubdomainMatcher interface { + dns.Handler + MatchSubdomains() bool +} + +type HandlerEntry struct { + Handler dns.Handler + Priority int + Pattern string + OrigPattern string + IsWildcard bool + StopHandler handlerWithStop + MatchSubdomains bool +} + +// HandlerChain represents a prioritized chain of DNS handlers +type HandlerChain struct { + mu sync.RWMutex + handlers []HandlerEntry +} + +// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain +type ResponseWriterChain struct { + dns.ResponseWriter + origPattern string + shouldContinue bool +} + +func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { + // Check if this is a continue signal (NXDOMAIN with Zero bit set) + if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero { + w.shouldContinue = true + return nil + } + return w.ResponseWriter.WriteMsg(m) +} + +func NewHandlerChain() *HandlerChain { + return &HandlerChain{ + handlers: make([]HandlerEntry, 0), + } +} + +// GetOrigPattern returns the original pattern of the handler that wrote the response +func (w *ResponseWriterChain) GetOrigPattern() string { + return w.origPattern +} + +// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority +func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { + c.mu.Lock() + defer c.mu.Unlock() + + origPattern := pattern + isWildcard := strings.HasPrefix(pattern, "*.") + if isWildcard { + pattern = pattern[2:] + } + pattern = dns.Fqdn(pattern) + origPattern = dns.Fqdn(origPattern) + + // First remove any existing handler with same original pattern and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { + if c.handlers[i].StopHandler != nil { + c.handlers[i].StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + break + } + } + + // Check if handler implements SubdomainMatcher interface + matchSubdomains := false + if matcher, ok := handler.(SubdomainMatcher); ok { + matchSubdomains = matcher.MatchSubdomains() + } + + log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", + pattern, origPattern, isWildcard, matchSubdomains, priority) + + entry := HandlerEntry{ + Handler: handler, + Priority: priority, + Pattern: pattern, + OrigPattern: origPattern, + IsWildcard: isWildcard, + StopHandler: stopHandler, + MatchSubdomains: matchSubdomains, + } + + // Insert handler in priority order + pos := 0 + for i, h := range c.handlers { + if h.Priority < priority { + pos = i + break + } + pos = i + 1 + } + + c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) +} + +// RemoveHandler removes a handler for the given pattern and priority +func (c *HandlerChain) RemoveHandler(pattern string, priority int) { + c.mu.Lock() + defer c.mu.Unlock() + + pattern = dns.Fqdn(pattern) + + // Find and remove handlers matching both original pattern and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + entry := c.handlers[i] + if entry.OrigPattern == pattern && entry.Priority == priority { + if entry.StopHandler != nil { + entry.StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + return + } + } +} + +// HasHandlers returns true if there are any handlers remaining for the given pattern +func (c *HandlerChain) HasHandlers(pattern string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + pattern = dns.Fqdn(pattern) + for _, entry := range c.handlers { + if entry.Pattern == pattern { + return true + } + } + return false +} + +func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + + qname := r.Question[0].Name + log.Tracef("handling DNS request for domain=%s", qname) + + c.mu.RLock() + defer c.mu.RUnlock() + + log.Tracef("current handlers (%d):", len(c.handlers)) + for _, h := range c.handlers { + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + } + + // Try handlers in priority order + for _, entry := range c.handlers { + var matched bool + switch { + case entry.Pattern == ".": + matched = true + case entry.IsWildcard: + parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") + matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) + default: + // For non-wildcard patterns: + // If handler wants subdomain matching, allow suffix match + // Otherwise require exact match + if entry.MatchSubdomains { + matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + } else { + matched = qname == entry.Pattern + } + } + + if !matched { + log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", + qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) + continue + } + + log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", + qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) + + chainWriter := &ResponseWriterChain{ + ResponseWriter: w, + origPattern: entry.OrigPattern, + } + entry.Handler.ServeDNS(chainWriter, r) + + // If handler wants to continue, try next handler + if chainWriter.shouldContinue { + log.Tracef("handler requested continue to next handler") + continue + } + return + } + + // No handler matched or all handlers passed + log.Tracef("no handler found for domain=%s", qname) + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go new file mode 100644 index 00000000000..727b6e9087d --- /dev/null +++ b/client/internal/dns/handler_chain_test.go @@ -0,0 +1,511 @@ +package dns_test + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + nbdns "github.com/netbirdio/netbird/client/internal/dns" +) + +// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order +func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create mock handlers for different priorities + defaultHandler := &nbdns.MockHandler{} + matchDomainHandler := &nbdns.MockHandler{} + dnsRouteHandler := &nbdns.MockHandler{} + + // Setup handlers with different priorities + chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Create test writer + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations - only highest priority handler should be called + dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() + matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe() + defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() + + // Execute + chain.ServeDNS(w, r) + + // Verify all expectations were met + dnsRouteHandler.AssertExpectations(t) + matchDomainHandler.AssertExpectations(t) + defaultHandler.AssertExpectations(t) +} + +// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios +func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { + tests := []struct { + name string + handlerDomain string + queryDomain string + isWildcard bool + matchSubdomains bool + shouldMatch bool + }{ + { + name: "exact match", + handlerDomain: "example.com.", + queryDomain: "example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard and MatchSubdomains true", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard and MatchSubdomains false", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "wildcard match", + handlerDomain: "*.example.com.", + queryDomain: "sub.example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "wildcard no match on apex", + handlerDomain: "*.example.com.", + queryDomain: "example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "root zone match", + handlerDomain: ".", + queryDomain: "anything.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "no match different domain", + handlerDomain: "example.com.", + queryDomain: "example.org.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handler dns.Handler + + if tt.matchSubdomains { + mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + handler = mockSubHandler + if tt.shouldMatch { + mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } else { + mockHandler := &nbdns.MockHandler{} + handler = mockHandler + if tt.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } + + pattern := tt.handlerDomain + if tt.isWildcard { + pattern = "*." + tt.handlerDomain[2:] + } + + chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) + + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + chain.ServeDNS(w, r) + + if h, ok := handler.(*nbdns.MockHandler); ok { + h.AssertExpectations(t) + } else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok { + h.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns +func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { + tests := []struct { + name string + handlers []struct { + pattern string + priority int + } + queryDomain string + expectedCalls int + expectedHandler int // index of the handler that should be called + }{ + { + name: "wildcard and exact same priority - exact should win", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // exact match handler should be called + }, + { + name: "higher priority wildcard over lower priority exact", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority wildcard handler should be called + }, + { + name: "multiple wildcards different priorities", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority handler should be called + }, + { + name: "subdomain with mix of patterns", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "sub.test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority matching handler should be called + }, + { + name: "root zone with specific domain", + handlers: []struct { + pattern string + priority int + }{ + {pattern: ".", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority specific domain should win over root + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handlers []*nbdns.MockHandler + + // Setup handlers and expectations + for i := range tt.handlers { + handler := &nbdns.MockHandler{} + handlers = append(handlers, handler) + + // Set expectation based on whether this handler should be called + if i == tt.expectedHandler { + handler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } else { + handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() + } + + chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) + } + + // Create and execute request + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality +func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create handlers + handler1 := &nbdns.MockHandler{} + handler2 := &nbdns.MockHandler{} + handler3 := &nbdns.MockHandler{} + + // Add handlers in priority order + chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) + chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Setup mock responses to simulate chain continuation + handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // First handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true // Signal to continue + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Second handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Last handler responds normally + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + // Execute + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify all handlers were called in order + handler1.AssertExpectations(t) + handler2.AssertExpectations(t) + handler3.AssertExpectations(t) +} + +// mockResponseWriter implements dns.ResponseWriter for testing +type mockResponseWriter struct { + mock.Mock +} + +func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } +func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } +func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) Close() error { return nil } +func (m *mockResponseWriter) TsigStatus() error { return nil } +func (m *mockResponseWriter) TsigTimersOnly(bool) {} +func (m *mockResponseWriter) Hijack() {} + +func TestHandlerChain_PriorityDeregistration(t *testing.T) { + tests := []struct { + name string + ops []struct { + action string // "add" or "remove" + pattern string + priority int + } + query string + expectedCalls map[int]bool // map[priority]shouldBeCalled + }{ + { + name: "remove high priority keeps lower priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: true, + }, + }, + { + name: "remove lower priority keeps high priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: true, + nbdns.PriorityMatchDomain: false, + }, + }, + { + name: "remove all handlers in order", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityDefault}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: false, + nbdns.PriorityDefault: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlers := make(map[int]*nbdns.MockHandler) + + // Execute operations + for _, op := range tt.ops { + if op.action == "add" { + handler := &nbdns.MockHandler{} + handlers[op.priority] = handler + chain.AddHandler(op.pattern, handler, op.priority, nil) + } else { + chain.RemoveHandler(op.pattern, op.priority) + } + } + + // Create test request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations + for priority, handler := range handlers { + if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall { + handler.On("ServeDNS", mock.Anything, r).Once() + } else { + handler.On("ServeDNS", mock.Anything, r).Maybe() + } + } + + // Execute request + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + + // Verify handler exists check + for priority, shouldExist := range tt.expectedCalls { + if shouldExist { + assert.True(t, chain.HasHandlers(tt.ops[0].pattern), + "Handler chain should have handlers for pattern after removing priority %d", priority) + } + } + }) + } +} + +func TestHandlerChain_MultiPriorityHandling(t *testing.T) { + chain := nbdns.NewHandlerChain() + + testDomain := "example.com." + testQuery := "test.example.com." + + // Create handlers with MatchSubdomains enabled + routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + + // Create test request that will be reused + r := new(dns.Msg) + r.SetQuestion(testQuery, dns.TypeA) + + // Add handlers in mixed order + chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) + + // Test 1: Initial state with all three handlers + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Highest priority handler (routeHandler) should be called + routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + routeHandler.AssertExpectations(t) + + // Test 2: Remove highest priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now middle priority handler (matchHandler) should be called + matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + matchHandler.AssertExpectations(t) + + // Test 3: Remove middle priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now lowest priority handler (defaultHandler) should be called + defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + defaultHandler.AssertExpectations(t) + + // Test 4: Remove last handler + chain.RemoveHandler(testDomain, nbdns.PriorityDefault) + assert.False(t, chain.HasHandlers(testDomain)) +} diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 6a459794b96..9a78d4d5057 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -17,12 +17,24 @@ type localResolver struct { records sync.Map } +func (d *localResolver) MatchSubdomains() bool { + return true +} + func (d *localResolver) stop() { } +// String returns a string representation of the local resolver +func (d *localResolver) String() string { + return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) +} + // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received question: %#v", r.Question[0]) + if len(r.Question) > 0 { + log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + } + replyMessage := &dns.Msg{} replyMessage.SetReply(r) replyMessage.RecursionAvailable = true diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 0739f05429a..7e36ea5df19 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,14 +3,30 @@ package dns import ( "fmt" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func([]string, dns.Handler, int) + DeregisterHandlerFunc func([]string, int) +} + +func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + if m.RegisterHandlerFunc != nil { + m.RegisterHandlerFunc(domains, handler, priority) + } +} + +func (m *MockServer) DeregisterHandler(domains []string, priority int) { + if m.DeregisterHandlerFunc != nil { + m.DeregisterHandlerFunc(domains, priority) + } } // Initialize mock implementation of Initialize from Server interface diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f0277319cd5..5a9cb50d081 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -30,6 +30,8 @@ type IosDnsManager interface { // Server is a dns server interface type Server interface { + RegisterHandler(domains []string, handler dns.Handler, priority int) + DeregisterHandler(domains []string, priority int) Initialize() error Stop() DnsIP() string @@ -48,12 +50,14 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap + handlerPriorities map[string]int localResolver *localResolver wgInterface WGIface hostManager hostManager updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + handlerChain *HandlerChain // permanent related properties permanent bool @@ -74,8 +78,9 @@ type handlerWithStop interface { } type muxUpdate struct { - domain string - handler handlerWithStop + domain string + handler handlerWithStop + priority int } // NewDefaultServer returns a new dns server @@ -135,10 +140,12 @@ func NewDefaultServerIos( func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - service: dnsService, - dnsMuxMap: make(registeredHandlerMap), + ctx: ctx, + ctxCancel: stop, + service: dnsService, + handlerChain: NewHandlerChain(), + dnsMuxMap: make(registeredHandlerMap), + handlerPriorities: make(map[string]int), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -151,6 +158,51 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi return defaultServer } +func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.registerHandler(domains, handler, priority) +} + +func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { + log.Debugf("registering handler %s with priority %d", handler, priority) + + for _, domain := range domains { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.handlerChain.AddHandler(domain, handler, priority, nil) + s.handlerPriorities[domain] = priority + s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) + } +} + +func (s *DefaultServer) DeregisterHandler(domains []string, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.deregisterHandler(domains, priority) +} + +func (s *DefaultServer) deregisterHandler(domains []string, priority int) { + log.Debugf("deregistering handler %v with priority %d", domains, priority) + + for _, domain := range domains { + s.handlerChain.RemoveHandler(domain, priority) + + // Only deregister from service if no handlers remain + if !s.handlerChain.HasHandlers(domain) { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.service.DeregisterMux(nbdns.NormalizeZone(domain)) + } + } +} + // Initialize instantiate host manager and the dns service func (s *DefaultServer) Initialize() (err error) { s.mux.Lock() @@ -343,14 +395,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) localRecords := make(map[string]nbdns.SimpleRecord, 0) for _, customZone := range customZones { - if len(customZone.Records) == 0 { return nil, nil, fmt.Errorf("received an empty list of records") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: customZone.Domain, - handler: s.localResolver, + domain: customZone.Domain, + handler: s.localResolver, + priority: PriorityMatchDomain, }) for _, record := range customZone.Records { @@ -412,8 +464,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam if nsGroup.Primary { muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, + domain: nbdns.RootZone, + handler: handler, + priority: PriorityDefault, }) continue } @@ -429,8 +482,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam return nil, fmt.Errorf("received a nameserver group with an empty domain element") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, + domain: domain, + handler: handler, + priority: PriorityMatchDomain, }) } } @@ -440,12 +494,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { muxUpdateMap := make(registeredHandlerMap) + handlersByPriority := make(map[string]int) var isContainRootUpdate bool + // First register new handlers for _, update := range muxUpdates { - s.service.RegisterMux(update.domain, update.handler) + s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.domain] = update.handler + handlersByPriority[update.domain] = update.priority + if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { existingHandler.stop() } @@ -455,6 +513,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { } } + // Then deregister old handlers not in the update for key, existingHandler := range s.dnsMuxMap { _, found := muxUpdateMap[key] if !found { @@ -463,12 +522,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { existingHandler.stop() } else { existingHandler.stop() - s.service.DeregisterMux(key) + // Deregister with the priority that was used to register + if oldPriority, ok := s.handlerPriorities[key]; ok { + s.deregisterHandler([]string{key}, oldPriority) + } } } } s.dnsMuxMap = muxUpdateMap + s.handlerPriorities = handlersByPriority } func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { @@ -517,13 +580,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.service.DeregisterMux(nbdns.RootZone) + s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.service.DeregisterMux(item.Domain) + s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) removeIndex[item.Domain] = i } } @@ -554,7 +617,7 @@ func (s *DefaultServer) upstreamCallbacks( continue } s.currentConfig.Domains[i].Disabled = false - s.service.RegisterMux(domain, handler) + s.registerHandler([]string{domain}, handler, PriorityMatchDomain) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -562,7 +625,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.service.RegisterMux(nbdns.RootZone, handler) + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") @@ -593,7 +656,8 @@ func (s *DefaultServer) addHostRootZone() { } handler.deactivate = func(error) {} handler.reactivate = func() {} - s.service.RegisterMux(nbdns.RootZone, handler) + + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index eab9f4ecbfa..44d20c6f362 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -11,7 +11,9 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" @@ -512,7 +514,7 @@ func TestDNSServerStartStop(t *testing.T) { t.Error(err) } - dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver) + dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1) resolver := &net.Resolver{ PreferGo: true, @@ -560,7 +562,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { localResolver: &localResolver{ registeredMap: make(registrationMap), }, - hostManager: hostManager, + handlerChain: NewHandlerChain(), + handlerPriorities: make(map[string]int), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -872,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver { }, } } + +// MockHandler implements dns.Handler interface for testing +type MockHandler struct { + mock.Mock +} + +func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + m.Called(w, r) +} + +type MockSubdomainHandler struct { + MockHandler + Subdomains bool +} + +func (m *MockSubdomainHandler) MatchSubdomains() bool { + return m.Subdomains +} + +func TestHandlerChain_DomainPriorities(t *testing.T) { + chain := NewHandlerChain() + + dnsRouteHandler := &MockHandler{} + upstreamHandler := &MockSubdomainHandler{ + Subdomains: true, + } + + chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) + chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) + + testCases := []struct { + name string + query string + expectedHandler dns.Handler + }{ + { + name: "exact domain with dns route handler", + query: "example.com.", + expectedHandler: dnsRouteHandler, + }, + { + name: "subdomain should use upstream handler", + query: "sub.example.com.", + expectedHandler: upstreamHandler, + }, + { + name: "deep subdomain should use upstream handler", + query: "deep.sub.example.com.", + expectedHandler: upstreamHandler, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := new(dns.Msg) + r.SetQuestion(tc.query, dns.TypeA) + w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } + + chain.ServeDNS(w, r) + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.AssertExpectations(t) + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.AssertExpectations(t) + } + + // Reset mocks + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } + }) + } +} diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index e0f9da26f83..72dc4bc6ef7 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() { } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { + log.Debugf("registering dns handler for pattern: %s", pattern) s.dnsMux.Handle(pattern, handler) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b3baf2fa8fd..f0aa12b6539 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -66,6 +66,15 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) * } } +// String returns a string representation of the upstream resolver +func (u *upstreamResolverBase) String() string { + return fmt.Sprintf("upstream %v", u.upstreamServers) +} + +func (u *upstreamResolverBase) MatchSubdomains() bool { + return true +} + func (u *upstreamResolverBase) stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go new file mode 100644 index 00000000000..ae31ffac602 --- /dev/null +++ b/client/internal/dnsfwd/forwarder.go @@ -0,0 +1,157 @@ +package dnsfwd + +import ( + "context" + "errors" + "net" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" +) + +const errResolveFailed = "failed to resolve query for domain=%s: %v" + +type DNSForwarder struct { + listenAddress string + ttl uint32 + domains []string + + dnsServer *dns.Server + mux *dns.ServeMux +} + +func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder { + log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) + return &DNSForwarder{ + listenAddress: listenAddress, + ttl: ttl, + } +} + +func (f *DNSForwarder) Listen(domains []string) error { + log.Infof("listen DNS forwarder on address=%s", f.listenAddress) + mux := dns.NewServeMux() + + dnsServer := &dns.Server{ + Addr: f.listenAddress, + Net: "udp", + Handler: mux, + } + f.dnsServer = dnsServer + f.mux = mux + + f.UpdateDomains(domains) + + return dnsServer.ListenAndServe() +} + +func (f *DNSForwarder) UpdateDomains(domains []string) { + log.Debugf("Updating domains from %v to %v", f.domains, domains) + + for _, d := range f.domains { + f.mux.HandleRemove(d) + } + + newDomains := filterDomains(domains) + for _, d := range newDomains { + f.mux.HandleFunc(d, f.handleDNSQuery) + } + f.domains = newDomains +} + +func (f *DNSForwarder) Close(ctx context.Context) error { + if f.dnsServer == nil { + return nil + } + return f.dnsServer.ShutdownContext(ctx) +} + +func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { + if len(query.Question) == 0 { + return + } + log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", + query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass) + + question := query.Question[0] + domain := question.Name + + resp := query.SetReply(query) + + ips, err := net.LookupIP(domain) + if err != nil { + var dnsErr *net.DNSError + + switch { + case errors.As(err, &dnsErr): + resp.Rcode = dns.RcodeServerFailure + if dnsErr.IsNotFound { + // Pass through NXDOMAIN + resp.Rcode = dns.RcodeNameError + } + + if dnsErr.Server != "" { + log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + } else { + log.Warnf(errResolveFailed, domain, err) + } + default: + resp.Rcode = dns.RcodeServerFailure + log.Warnf(errResolveFailed, domain, err) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write failure DNS response: %v", err) + } + return + } + + for _, ip := range ips { + var respRecord dns.RR + if ip.To4() == nil { + log.Tracef("resolved domain=%s to IPv6=%s", domain, ip) + rr := dns.AAAA{ + AAAA: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } else { + log.Tracef("resolved domain=%s to IPv4=%s", domain, ip) + rr := dns.A{ + A: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } + resp.Answer = append(resp.Answer, respRecord) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} + +// filterDomains returns a list of normalized domains +func filterDomains(domains []string) []string { + newDomains := make([]string, 0, len(domains)) + for _, d := range domains { + if d == "" { + log.Warn("empty domain in DNS forwarder") + continue + } + newDomains = append(newDomains, nbdns.NormalizeZone(d)) + } + return newDomains +} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go new file mode 100644 index 00000000000..7cff6d51780 --- /dev/null +++ b/client/internal/dnsfwd/manager.go @@ -0,0 +1,106 @@ +package dnsfwd + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +const ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also + ListenPort = 5353 + dnsTTL = 60 //seconds +) + +type Manager struct { + firewall firewall.Manager + + fwRules []firewall.Rule + dnsForwarder *DNSForwarder +} + +func NewManager(fw firewall.Manager) *Manager { + return &Manager{ + firewall: fw, + } +} + +func (m *Manager) Start(domains []string) error { + log.Infof("starting DNS forwarder") + if m.dnsForwarder != nil { + return nil + } + + if err := m.allowDNSFirewall(); err != nil { + return err + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL) + go func() { + if err := m.dnsForwarder.Listen(domains); err != nil { + // todo handle close error if it is exists + log.Errorf("failed to start DNS forwarder, err: %v", err) + } + }() + + return nil +} + +func (m *Manager) UpdateDomains(domains []string) { + if m.dnsForwarder == nil { + return + } + + m.dnsForwarder.UpdateDomains(domains) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m.dnsForwarder == nil { + return nil + } + + var mErr *multierror.Error + if err := m.dropDNSFirewall(); err != nil { + mErr = multierror.Append(mErr, err) + } + + if err := m.dnsForwarder.Close(ctx); err != nil { + mErr = multierror.Append(mErr, err) + } + + m.dnsForwarder = nil + return nberrors.FormatErrorOrNil(mErr) +} + +func (h *Manager) allowDNSFirewall() error { + dport := &firewall.Port{ + IsRange: false, + Values: []int{ListenPort}, + } + dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + h.fwRules = dnsRules + + return nil +} + +func (h *Manager) dropDNSFirewall() error { + var mErr *multierror.Error + for _, rule := range h.fwRules { + if err := h.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) + } + } + + h.fwRules = nil + return nberrors.FormatErrorOrNil(mErr) +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 34219def185..9724e2a2257 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "maps" "math/rand" "net" "net/netip" @@ -30,10 +29,12 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -117,7 +118,7 @@ type Engine struct { // mgmClient is a Management Service client mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer - peerConns map[string]*peer.Conn + peerStore *peerstore.Store beforePeerHook nbnet.AddHookFunc afterPeerHook nbnet.RemoveHookFunc @@ -137,10 +138,6 @@ type Engine struct { TURNs []*stun.URI stunTurn atomic.Value - // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes route.HAMap - clientRoutesMu sync.RWMutex - clientCtx context.Context clientCancel context.CancelFunc @@ -161,9 +158,10 @@ type Engine struct { statusRecorder *peer.Status - firewall manager.Manager - routeManager routemanager.Manager - acl acl.Manager + firewall manager.Manager + routeManager routemanager.Manager + acl acl.Manager + dnsForwardMgr *dnsfwd.Manager dnsServer dns.Server @@ -234,7 +232,7 @@ func NewEngineWithProbes( signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), mgmClient: mgmClient, relayManager: relayManager, - peerConns: make(map[string]*peer.Conn), + peerStore: peerstore.NewConnStore(), syncMsgMux: &sync.Mutex{}, config: config, mobileDep: mobileDep, @@ -287,6 +285,13 @@ func (e *Engine) Stop() error { e.routeManager.Stop(e.stateManager) } + if e.dnsForwardMgr != nil { + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } + if e.srWatcher != nil { e.srWatcher.Close() } @@ -300,10 +305,6 @@ func (e *Engine) Stop() error { return fmt.Errorf("failed to remove all peers: %s", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = nil - e.clientRoutesMu.Unlock() - if e.cancel != nil { e.cancel() } @@ -382,6 +383,8 @@ func (e *Engine) Start() error { e.relayManager, initialRoutes, e.stateManager, + dnsServer, + e.peerStore, ) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { @@ -460,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { var modified []*mgmProto.RemotePeerConfig for _, p := range peersUpdate { peerPubKey := p.GetWgPubKey() - if peerConn, ok := e.peerConns[peerPubKey]; ok { - if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") { + if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { + if allowedIPs != strings.Join(p.AllowedIps, ",") { modified = append(modified, p) continue } @@ -492,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { - currentPeers := make([]string, 0, len(e.peerConns)) - for p := range e.peerConns { - currentPeers = append(currentPeers, p) - } - newPeers := make([]string, 0, len(peersUpdate)) for _, p := range peersUpdate { newPeers = append(newPeers, p.GetWgPubKey()) } - toRemove := util.SliceDiff(currentPeers, newPeers) + toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers) for _, p := range toRemove { err := e.removePeer(p) @@ -516,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) removeAllPeers() error { log.Debugf("removing all peer connections") - for p := range e.peerConns { + for _, p := range e.peerStore.PeersPubKey() { err := e.removePeer(p) if err != nil { return err @@ -540,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error { } }() - conn, exists := e.peerConns[peerKey] + conn, exists := e.peerStore.Remove(peerKey) if exists { - delete(e.peerConns, peerKey) conn.Close() } return nil @@ -786,7 +783,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error { } func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { - // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't if networkMap.GetPeerConfig() != nil { err := e.updateConfig(networkMap.GetPeerConfig()) @@ -806,20 +802,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - protoRoutes := networkMap.GetRoutes() - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} + var dnsRouteFeatureFlag bool + if networkMap.PeerConfig != nil { + dnsRouteFeatureFlag = networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled } + routedDomains, routes := toRoutes(networkMap.GetRoutes()) - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { + e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains) + + if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = clientRoutes - e.clientRoutesMu.Unlock() - log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) e.updateOfflinePeers(networkMap.GetOfflinePeers()) @@ -867,8 +861,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) - if err != nil { + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { log.Errorf("failed to update dns server, err: %v", err) } @@ -881,7 +874,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } -func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { +func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + var dnsRoutes []string routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { var prefix netip.Prefix @@ -892,6 +890,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { continue } } + dnsRoutes = append(dnsRoutes, protoRoute.Domains...) + convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), Network: prefix, @@ -905,7 +905,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { } routes = append(routes, convertedRoute) } - return routes + return dnsRoutes, routes } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { @@ -982,12 +982,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { peerKey := peerConfig.GetWgPubKey() peerIPs := peerConfig.GetAllowedIps() - if _, ok := e.peerConns[peerKey]; !ok { + if _, ok := e.peerStore.PeerConn(peerKey); !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { return fmt.Errorf("create peer connection: %w", err) } - e.peerConns[peerKey] = conn + + if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { + conn.Close() + return fmt.Errorf("peer already exists: %s", peerKey) + } if e.beforePeerHook != nil && e.afterPeerHook != nil { conn.AddBeforeAddPeerHook(e.beforePeerHook) @@ -1076,8 +1080,8 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - conn := e.peerConns[msg.Key] - if conn == nil { + conn, ok := e.peerStore.PeerConn(msg.Key) + if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) } @@ -1135,7 +1139,7 @@ func (e *Engine) receiveSignalEvents() { return err } - go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) + go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes()) case sProto.Body_MODE: } @@ -1239,7 +1243,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { if err != nil { return nil, nil, err } - routes := toRoutes(netMap.GetRoutes()) + _, routes := toRoutes(netMap.GetRoutes()) dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, &dnsCfg, nil } @@ -1322,26 +1326,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { } } -// GetClientRoutes returns the current routes from the route map -func (e *Engine) GetClientRoutes() route.HAMap { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - return maps.Clone(e.clientRoutes) -} - -// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only -func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) - for id, v := range e.clientRoutes { - routes[id.NetID()] = v - } - return routes -} - // GetRouteManager returns the route manager func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager @@ -1426,9 +1410,8 @@ func (e *Engine) receiveProbeEvents() { go e.probes.WgProbe.Receive(e.ctx, func() bool { log.Debug("received wg probe request") - for _, peer := range e.peerConns { - key := peer.GetKey() - wgStats, err := peer.WgConfig().WgInterface.GetStats(key) + for _, key := range e.peerStore.PeersPubKey() { + wgStats, err := e.wgInterface.GetStats(key) if err != nil { log.Debugf("failed to get wg stats for peer %s: %s", key, err) } @@ -1505,7 +1488,7 @@ func (e *Engine) startNetworkMonitor() { func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { var vpnRoutes []netip.Prefix - for _, routes := range e.GetClientRoutes() { + for _, routes := range e.routeManager.GetClientRoutes() { if len(routes) > 0 && routes[0] != nil { vpnRoutes = append(vpnRoutes, routes[0].Network) } @@ -1573,6 +1556,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nm, nil } +// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag +func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { + if !enabled { + if e.dnsForwardMgr == nil { + return + } + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + return + } + + if len(domains) > 0 { + log.Infof("enable domain router service for domains: %v", domains) + if e.dnsForwardMgr == nil { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall) + + if err := e.dnsForwardMgr.Start(domains); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } + } else { + log.Infof("update domain router service for domains: %v", domains) + e.dnsForwardMgr.UpdateDomains(domains) + } + } else if e.dnsForwardMgr != nil { + log.Infof("disable domain router service") + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { for _, check := range checks { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index b58c1f7e93a..b81d8bd3f5e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -39,6 +39,8 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" @@ -251,7 +253,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil) _, _, err = engine.routeManager.Init() require.NoError(t, err) engine.dnsServer = &dns.MockServer{ @@ -391,8 +393,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { return } - if len(engine.peerConns) != c.expectedLen { - t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns)) + if len(engine.peerStore.PeersPubKey()) != c.expectedLen { + t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey())) } if engine.networkSerial != c.expectedSerial { @@ -400,7 +402,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { } for _, p := range c.expectedPeers { - conn, ok := engine.peerConns[p.GetWgPubKey()] + conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey()) if !ok { t.Errorf("expecting Engine.peerConns to contain peer %s", p) } @@ -625,10 +627,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { }{} mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { input.inputSerial = updateSerial input.inputRoutes = newRoutes - return nil, nil, testCase.inputErr + return testCase.inputErr }, } @@ -801,8 +803,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { - return nil, nil, nil + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + return nil }, } @@ -1196,7 +1198,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } @@ -1218,7 +1220,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } @@ -1237,7 +1239,8 @@ func getConnectedPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() i := 0 - for _, conn := range e.peerConns { + for _, id := range e.peerStore.PeersPubKey() { + conn, _ := e.peerStore.PeerConn(id) if conn.Status() == peer.StatusConnected { i++ } @@ -1249,5 +1252,5 @@ func getPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - return len(e.peerConns) + return len(e.peerStore.PeersPubKey()) } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 5c2e2cb60b6..b8cb2582fb9 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { conn.wgProxyRelay = proxy } +// AllowedIP returns the allowed IP of the remote peer +func (conn *Conn) AllowedIP() net.IP { + return conn.allowedIP +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 74e2ee82c0e..dc461257adf 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -17,6 +17,11 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" ) +type ResolvedDomainInfo struct { + Prefixes []netip.Prefix + ParentDomain domain.Domain +} + // State contains the latest state of a peer type State struct { Mux *sync.RWMutex @@ -138,7 +143,7 @@ type Status struct { rosenpassEnabled bool rosenpassPermissive bool nsGroupStates []NSGroupState - resolvedDomainsStates map[domain.Domain][]netip.Prefix + resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo // To reduce the number of notification invocation this bool will be true when need to call the notification // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events @@ -156,7 +161,7 @@ func NewRecorder(mgmAddress string) *Status { offlinePeers: make([]State, 0), notifier: newNotifier(), mgmAddress: mgmAddress, - resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix), + resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{}, } } @@ -591,16 +596,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } -func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) { +func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) { d.mux.Lock() defer d.mux.Unlock() - d.resolvedDomainsStates[domain] = prefixes + + // Store both the original domain pattern and resolved domain + d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{ + Prefixes: prefixes, + ParentDomain: originalDomain, + } } func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { d.mux.Lock() defer d.mux.Unlock() - delete(d.resolvedDomainsStates, domain) + + // Remove all entries that have this domain as their parent + for k, v := range d.resolvedDomainsStates { + if v.ParentDomain == domain { + delete(d.resolvedDomainsStates, k) + } + } } func (d *Status) GetRosenpassState() RosenpassState { @@ -702,7 +718,7 @@ func (d *Status) GetDNSStates() []NSGroupState { return d.nsGroupStates } -func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix { +func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { d.mux.Lock() defer d.mux.Unlock() return maps.Clone(d.resolvedDomainsStates) diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go new file mode 100644 index 00000000000..6b3385ff584 --- /dev/null +++ b/client/internal/peerstore/store.go @@ -0,0 +1,87 @@ +package peerstore + +import ( + "net" + "sync" + + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +// Store is a thread-safe store for peer connections. +type Store struct { + peerConns map[string]*peer.Conn + peerConnsMu sync.RWMutex +} + +func NewConnStore() *Store { + return &Store{ + peerConns: make(map[string]*peer.Conn), + } +} + +func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + _, ok := s.peerConns[pubKey] + if ok { + return false + } + + s.peerConns[pubKey] = conn + return true +} + +func (s *Store) Remove(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + delete(s.peerConns, pubKey) + return p, true +} + +func (s *Store) AllowedIPs(pubKey string) (string, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return "", false + } + return p.WgConfig().AllowedIps, true +} + +func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p.AllowedIP(), true +} + +func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p, true +} + +func (s *Store) PeersPubKey() []string { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + return maps.Keys(s.peerConns) +} diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 13e45b3a360..73f552aab74 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -13,12 +13,20 @@ import ( "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/route" ) +const ( + handlerTypeDynamic = iota + handlerTypeDomain + handlerTypeStatic +) + type routerPeerStatus struct { connected bool relayed bool @@ -53,7 +61,18 @@ type clientNetwork struct { updateSerial uint64 } -func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { +func newClientNetworkWatcher( + ctx context.Context, + dnsRouteInterval time.Duration, + wgInterface iface.IWGIface, + statusRecorder *peer.Status, + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsServer nbdns.Server, + peerStore *peerstore.Store, + useNewDNSRoute bool, +) *clientNetwork { ctx, cancel := context.WithCancel(ctx) client := &clientNetwork{ @@ -65,7 +84,17 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface), + handler: handlerFromRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouteInterval, + statusRecorder, + wgInterface, + dnsServer, + peerStore, + useNewDNSRoute, + ), } return client } @@ -368,10 +397,50 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { - if rt.IsDynamic() { +func handlerFromRoute( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsRouterInteval time.Duration, + statusRecorder *peer.Status, + wgInterface iface.IWGIface, + dnsServer nbdns.Server, + peerStore *peerstore.Store, + useNewDNSRoute bool, +) RouteHandler { + switch handlerType(rt, useNewDNSRoute) { + case handlerTypeDomain: + return dnsinterceptor.New( + rt, + routeRefCounter, + allowedIPsRefCounter, + statusRecorder, + dnsServer, + peerStore, + ) + case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(wgInterface) - return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) + return dynamic.NewRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouterInteval, + statusRecorder, + wgInterface, + fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), + ) + default: + return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) + } +} + +func handlerType(rt *route.Route, useNewDNSRoute bool) int { + if !rt.IsDynamic() { + return handlerTypeStatic + } + + if useNewDNSRoute { + return handlerTypeDomain } - return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) + return handlerTypeDynamic } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go new file mode 100644 index 00000000000..10cb03f1d2a --- /dev/null +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -0,0 +1,356 @@ +package dnsinterceptor + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" +) + +type domainMap map[domain.Domain][]netip.Prefix + +type DnsInterceptor struct { + mu sync.RWMutex + route *route.Route + routeRefCounter *refcounter.RouteRefCounter + allowedIPsRefcounter *refcounter.AllowedIPsRefCounter + statusRecorder *peer.Status + dnsServer nbdns.Server + currentPeerKey string + interceptedDomains domainMap + peerStore *peerstore.Store +} + +func New( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + statusRecorder *peer.Status, + dnsServer nbdns.Server, + peerStore *peerstore.Store, +) *DnsInterceptor { + return &DnsInterceptor{ + route: rt, + routeRefCounter: routeRefCounter, + allowedIPsRefcounter: allowedIPsRefCounter, + statusRecorder: statusRecorder, + dnsServer: dnsServer, + interceptedDomains: make(domainMap), + peerStore: peerStore, + } +} + +func (d *DnsInterceptor) String() string { + return d.route.Domains.SafeString() +} + +func (d *DnsInterceptor) AddRoute(context.Context) error { + d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute) + return nil +} + +func (d *DnsInterceptor) RemoveRoute() error { + d.mu.Lock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) + + } + for _, domain := range d.route.Domains { + d.statusRecorder.DeleteResolvedDomainsStates(domain) + } + + clear(d.interceptedDomains) + d.mu.Unlock() + + d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) + + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } + } + } + + d.currentPeerKey = peerKey + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) RemoveAllowedIPs() error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for _, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + + d.currentPeerKey = "" + return nberrors.FormatErrorOrNil(merr) +} + +// ServeDNS implements the dns.Handler interface +func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + log.Tracef("received DNS request for domain=%s type=%v class=%v", + r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + d.mu.RLock() + peerKey := d.currentPeerKey + d.mu.RUnlock() + + if peerKey == "" { + log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name) + + d.continueToNextHandler(w, r, "no current peer key") + return + } + + upstreamIP, err := d.getUpstreamIP(peerKey) + if err != nil { + log.Errorf("failed to get upstream IP: %v", err) + d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err)) + return + } + + client := &dns.Client{ + Timeout: 5 * time.Second, + Net: "udp", + } + upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) + reply, _, err := client.ExchangeContext(context.Background(), r, upstream) + + var answer []dns.RR + if reply != nil { + answer = reply.Answer + } + log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer) + + if err != nil { + log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } + return + } + + reply.Id = r.Id + if err := d.writeMsg(w, reply); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } +} + +// continueToNextHandler signals the handler chain to try the next handler +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { + log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) + + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + // Set Zero bit to signal handler chain to continue + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed writing DNS continue response: %v", err) + } +} + +func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { + peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) + if !exists { + return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) + } + return peerAllowedIP, nil +} + +func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { + if r == nil { + return fmt.Errorf("received nil DNS message") + } + + if len(r.Answer) > 0 && len(r.Question) > 0 { + origPattern := "" + if writer, ok := w.(*nbdns.ResponseWriterChain); ok { + origPattern = writer.GetOrigPattern() + } + + resolvedDomain := domain.Domain(r.Question[0].Name) + + // already punycode via RegisterHandler() + originalDomain := domain.Domain(origPattern) + if originalDomain == "" { + originalDomain = resolvedDomain + } + + var newPrefixes []netip.Prefix + for _, answer := range r.Answer { + var ip netip.Addr + switch rr := answer.(type) { + case *dns.A: + addr, ok := netip.AddrFromSlice(rr.A) + if !ok { + log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) + continue + } + ip = addr + case *dns.AAAA: + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) + continue + } + ip = addr + default: + continue + } + + prefix := netip.PrefixFrom(ip, ip.BitLen()) + newPrefixes = append(newPrefixes, prefix) + } + + if len(newPrefixes) > 0 { + if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { + log.Errorf("failed to update domain prefixes: %v", err) + } + } + } + + if err := w.WriteMsg(r); err != nil { + return fmt.Errorf("failed to write DNS response: %v", err) + } + + return nil +} + +func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { + d.mu.Lock() + defer d.mu.Unlock() + + oldPrefixes := d.interceptedDomains[resolvedDomain] + toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) + + var merr *multierror.Error + + // Add new prefixes + for _, prefix := range toAdd { + if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) + continue + } + + if d.currentPeerKey == "" { + continue + } + if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != d.currentPeerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + resolvedDomain.SafeString(), + ref.Out, + ) + } + } + + if !d.route.KeepRoute { + // Remove old prefixes + for _, prefix := range toRemove { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + } + + // Update domain prefixes using resolved domain as key + if len(toAdd) > 0 || len(toRemove) > 0 { + d.interceptedDomains[resolvedDomain] = newPrefixes + originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) + d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes) + + if len(toAdd) > 0 { + log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toAdd) + } + if len(toRemove) > 0 { + log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toRemove) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { + prefixSet := make(map[netip.Prefix]bool) + for _, prefix := range oldPrefixes { + prefixSet[prefix] = false + } + for _, prefix := range newPrefixes { + if _, exists := prefixSet[prefix]; exists { + prefixSet[prefix] = true + } else { + toAdd = append(toAdd, prefix) + } + } + for prefix, inUse := range prefixSet { + if !inUse { + toRemove = append(toRemove, prefix) + } + } + return +} diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index ac94d4a5c74..a0fff7713ca 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -74,11 +74,7 @@ func NewRoute( } func (r *Route) String() string { - s, err := r.route.Domains.String() - if err != nil { - return r.route.Domains.PunycodeString() - } - return s + return r.route.Domains.SafeString() } func (r *Route) AddRoute(ctx context.Context) error { @@ -292,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes) r.dynamicDomains[domain] = updatedPrefixes - r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes) + r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes) } return nberrors.FormatErrorOrNil(merr) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 8bf3a91b0d2..389e97e2dcc 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -12,12 +12,15 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" @@ -33,9 +36,11 @@ import ( // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector + GetClientRoutes() route.HAMap + GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error @@ -60,6 +65,11 @@ type DefaultManager struct { allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration stateManager *statemanager.Manager + // clientRoutes is the most recent list of clientRoutes received from the Management Service + clientRoutes route.HAMap + dnsServer dns.Server + peerStore *peerstore.Store + useNewDNSRoute bool } func NewManager( @@ -71,6 +81,8 @@ func NewManager( relayMgr *relayClient.Manager, initialRoutes []*route.Route, stateManager *statemanager.Manager, + dnsServer dns.Server, + peerStore *peerstore.Store, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -88,6 +100,8 @@ func NewManager( pubKey: pubKey, notifier: notifier, stateManager: stateManager, + dnsServer: dnsServer, + peerStore: peerStore, } dm.routeRefCounter = refcounter.New( @@ -116,7 +130,7 @@ func NewManager( ) if runtime.GOOS == "android" { - cr := dm.clientRoutes(initialRoutes) + cr := dm.initialClientRoutes(initialRoutes) dm.notifier.SetInitialClientRoutes(cr) } return dm @@ -207,33 +221,41 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } m.ctx = nil + + m.mux.Lock() + defer m.mux.Unlock() + m.clientRoutes = nil } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") - return nil, nil, m.ctx.Err() + return nil default: - m.mux.Lock() - defer m.mux.Unlock() + } + + m.mux.Lock() + defer m.mux.Unlock() + m.useNewDNSRoute = useNewDNSRoute - newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) + newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) - filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) - m.updateClientNetworks(updateSerial, filteredClientRoutes) - m.notifier.OnNewRoutes(filteredClientRoutes) + filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + m.updateClientNetworks(updateSerial, filteredClientRoutes) + m.notifier.OnNewRoutes(filteredClientRoutes) - if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return nil, nil, fmt.Errorf("update routes: %w", err) - } + if m.serverRouter != nil { + err := m.serverRouter.updateRoutes(newServerRoutesMap) + if err != nil { + return err } - - return newServerRoutesMap, newClientRoutesIDMap, nil } + + m.clientRoutes = newClientRoutesIDMap + + return nil } // SetRouteChangeListener set RouteListener for route change Notifier @@ -251,9 +273,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector { return m.routeSelector } -// GetClientRoutes returns the client routes -func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { - return m.clientNetworks +// GetClientRoutes returns most recent list of clientRoutes received from the Management Service +func (m *DefaultManager) GetClientRoutes() route.HAMap { + m.mux.Lock() + defer m.mux.Unlock() + + return maps.Clone(m.clientRoutes) +} + +// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only +func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + m.mux.Lock() + defer m.mux.Unlock() + + routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes)) + for id, v := range m.clientRoutes { + routes[id.NetID()] = v + } + return routes } // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones @@ -273,7 +310,18 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { continue } - clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher := newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) @@ -302,7 +350,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher = newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() } @@ -345,7 +404,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] return newServerRoutesMap, newClientRoutesIDMap } -func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { +func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route { _, crMap := m.classifyRoutes(initialRoutes) rs := make([]*route.Route, 0, len(crMap)) for _, routes := range crMap { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 07dac21b819..4b7c984e5a0 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil) _, _, err = routeManager.Init() @@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) require.NoError(t, err, "should update routes with init routes") } - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 556a6235138..64fdffceb3e 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -2,7 +2,6 @@ package routemanager import ( "context" - "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" @@ -15,10 +14,12 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) - TriggerSelectionFunc func(haMap route.HAMap) - GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func(manager *statemanager.Manager) + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + TriggerSelectionFunc func(haMap route.HAMap) + GetRouteSelectorFunc func() *routeselector.RouteSelector + GetClientRoutesFunc func() route.HAMap + GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route + StopFunc func(manager *statemanager.Manager) } func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { @@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) } - return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") + return nil } func (m *MockManager) TriggerSelection(networks route.HAMap) { @@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { return nil } +// GetClientRoutes mock implementation of GetClientRoutes from Manager interface +func (m *MockManager) GetClientRoutes() route.HAMap { + if m.GetClientRoutesFunc != nil { + return m.GetClientRoutesFunc() + } + return nil +} + +// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface +func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + if m.GetClientRoutesWithNetIDFunc != nil { + return m.GetClientRoutesWithNetIDFunc() + } + return nil +} + // Start mock implementation of Start from Manager interface func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 6f501e0c636..befce56a2d3 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() routeManager := engine.GetRouteManager() + routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } @@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } -func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails { +func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { var routeSelection []RoutesSelectionInfo for _, r := range routes { domainList := make([]DomainInfo, 0) @@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom domainResp := DomainInfo{ Domain: d.SafeString(), } - if prefixes, exists := resolvedDomains[d]; exists { + + if info, exists := resolvedDomains[d]; exists { var ipStrings []string - for _, prefix := range prefixes { + for _, prefix := range info.Prefixes { ipStrings = append(ipStrings, prefix.Addr().String()) } domainResp.ResolvedIPs = strings.Join(ipStrings, ", ") @@ -365,12 +366,12 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } @@ -392,12 +393,12 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 98ce2c4a289..f0d3021e92b 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v3.21.9 // source: daemon.proto package proto @@ -908,7 +908,7 @@ type PeerState struct { BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` } @@ -1043,9 +1043,9 @@ func (x *PeerState) GetRosenpassEnabled() bool { return false } -func (x *PeerState) GetRoutes() []string { +func (x *PeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1076,7 +1076,7 @@ type LocalPeerState struct { Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - Routes []string `protobuf:"bytes,7,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"` } func (x *LocalPeerState) Reset() { @@ -1153,9 +1153,9 @@ func (x *LocalPeerState) GetRosenpassPermissive() bool { return false } -func (x *LocalPeerState) GetRoutes() []string { +func (x *LocalPeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1511,14 +1511,14 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState { return nil } -type ListRoutesRequest struct { +type ListNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *ListRoutesRequest) Reset() { - *x = ListRoutesRequest{} +func (x *ListNetworksRequest) Reset() { + *x = ListNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1526,13 +1526,13 @@ func (x *ListRoutesRequest) Reset() { } } -func (x *ListRoutesRequest) String() string { +func (x *ListNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesRequest) ProtoMessage() {} +func (*ListNetworksRequest) ProtoMessage() {} -func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1544,21 +1544,21 @@ func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead. -func (*ListRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. +func (*ListNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{19} } -type ListRoutesResponse struct { +type ListNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` + Routes []*Network `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` } -func (x *ListRoutesResponse) Reset() { - *x = ListRoutesResponse{} +func (x *ListNetworksResponse) Reset() { + *x = ListNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1566,13 +1566,13 @@ func (x *ListRoutesResponse) Reset() { } } -func (x *ListRoutesResponse) String() string { +func (x *ListNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesResponse) ProtoMessage() {} +func (*ListNetworksResponse) ProtoMessage() {} -func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1584,30 +1584,30 @@ func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead. -func (*ListRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. +func (*ListNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{20} } -func (x *ListRoutesResponse) GetRoutes() []*Route { +func (x *ListNetworksResponse) GetRoutes() []*Network { if x != nil { return x.Routes } return nil } -type SelectRoutesRequest struct { +type SelectNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"` - Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` - All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` + NetworkIDs []string `protobuf:"bytes,1,rep,name=networkIDs,proto3" json:"networkIDs,omitempty"` + Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` + All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` } -func (x *SelectRoutesRequest) Reset() { - *x = SelectRoutesRequest{} +func (x *SelectNetworksRequest) Reset() { + *x = SelectNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1615,13 +1615,13 @@ func (x *SelectRoutesRequest) Reset() { } } -func (x *SelectRoutesRequest) String() string { +func (x *SelectNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesRequest) ProtoMessage() {} +func (*SelectNetworksRequest) ProtoMessage() {} -func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1633,40 +1633,40 @@ func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead. -func (*SelectRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. +func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{21} } -func (x *SelectRoutesRequest) GetRouteIDs() []string { +func (x *SelectNetworksRequest) GetNetworkIDs() []string { if x != nil { - return x.RouteIDs + return x.NetworkIDs } return nil } -func (x *SelectRoutesRequest) GetAppend() bool { +func (x *SelectNetworksRequest) GetAppend() bool { if x != nil { return x.Append } return false } -func (x *SelectRoutesRequest) GetAll() bool { +func (x *SelectNetworksRequest) GetAll() bool { if x != nil { return x.All } return false } -type SelectRoutesResponse struct { +type SelectNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *SelectRoutesResponse) Reset() { - *x = SelectRoutesResponse{} +func (x *SelectNetworksResponse) Reset() { + *x = SelectNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1674,13 +1674,13 @@ func (x *SelectRoutesResponse) Reset() { } } -func (x *SelectRoutesResponse) String() string { +func (x *SelectNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesResponse) ProtoMessage() {} +func (*SelectNetworksResponse) ProtoMessage() {} -func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1692,8 +1692,8 @@ func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead. -func (*SelectRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. +func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{22} } @@ -1744,20 +1744,20 @@ func (x *IPList) GetIps() []string { return nil } -type Route struct { +type Network struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` + Range string `protobuf:"bytes,2,opt,name=range,proto3" json:"range,omitempty"` Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"` Domains []string `protobuf:"bytes,4,rep,name=domains,proto3" json:"domains,omitempty"` ResolvedIPs map[string]*IPList `protobuf:"bytes,5,rep,name=resolvedIPs,proto3" json:"resolvedIPs,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } -func (x *Route) Reset() { - *x = Route{} +func (x *Network) Reset() { + *x = Network{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1765,13 +1765,13 @@ func (x *Route) Reset() { } } -func (x *Route) String() string { +func (x *Network) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Route) ProtoMessage() {} +func (*Network) ProtoMessage() {} -func (x *Route) ProtoReflect() protoreflect.Message { +func (x *Network) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1783,40 +1783,40 @@ func (x *Route) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Route.ProtoReflect.Descriptor instead. -func (*Route) Descriptor() ([]byte, []int) { +// Deprecated: Use Network.ProtoReflect.Descriptor instead. +func (*Network) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{24} } -func (x *Route) GetID() string { +func (x *Network) GetID() string { if x != nil { return x.ID } return "" } -func (x *Route) GetNetwork() string { +func (x *Network) GetRange() string { if x != nil { - return x.Network + return x.Range } return "" } -func (x *Route) GetSelected() bool { +func (x *Network) GetSelected() bool { if x != nil { return x.Selected } return false } -func (x *Route) GetDomains() []string { +func (x *Network) GetDomains() []string { if x != nil { return x.Domains } return nil } -func (x *Route) GetResolvedIPs() map[string]*IPList { +func (x *Network) GetResolvedIPs() map[string]*IPList { if x != nil { return x.ResolvedIPs } @@ -2671,7 +2671,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xda, 0x05, 0x0a, 0x09, 0x50, 0x65, + 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xde, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, @@ -2710,233 +2710,235 @@ var file_daemon_proto_rawDesc = []byte{ 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, - 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, - 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, - 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, - 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, - 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, - 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, - 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, - 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, - 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, - 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, - 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, - 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, - 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, - 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, - 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, - 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, - 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, - 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, - 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x04, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x40, 0x0a, - 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x18, 0x05, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, - 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, - 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, - 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, - 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, - 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, - 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, - 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, - 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, + 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x0e, 0x4c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, + 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, + 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, + 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, + 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x18, 0x07, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x22, 0x53, 0x0a, + 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, + 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, + 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, + 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, + 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, + 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, + 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, + 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, + 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, + 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, + 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, + 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, + 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, + 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, + 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x69, 0x73, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, + 0x3f, 0x0a, 0x14, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x22, 0x61, 0x0a, 0x15, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x6e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x6e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, + 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, + 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, + 0x61, 0x6c, 0x6c, 0x22, 0x18, 0x0a, 0x16, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, + 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x07, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x42, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, + 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, + 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, + 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, + 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, + 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, + 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, + 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, + 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, + 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, + 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, + 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, + 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, + 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, + 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, + 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, + 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, + 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, + 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, + 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, + 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, + 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, + 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, + 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, + 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, + 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, + 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, + 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, + 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, + 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, + 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, + 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, + 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, + 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, + 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, - 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, - 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, - 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, - 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, - 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, - 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, + 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, + 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, - 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, - 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, - 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, - 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a, - 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, - 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, - 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, - 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, - 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, - 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, - 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, - 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, - 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, - 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, - 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, - 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, - 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, - 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, - 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, - 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, - 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, - 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, - 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2974,12 +2976,12 @@ var file_daemon_proto_goTypes = []interface{}{ (*RelayState)(nil), // 17: daemon.RelayState (*NSGroupState)(nil), // 18: daemon.NSGroupState (*FullStatus)(nil), // 19: daemon.FullStatus - (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest - (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse - (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest - (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse + (*ListNetworksRequest)(nil), // 20: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 21: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 22: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 23: daemon.SelectNetworksResponse (*IPList)(nil), // 24: daemon.IPList - (*Route)(nil), // 25: daemon.Route + (*Network)(nil), // 25: daemon.Network (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest @@ -2995,7 +2997,7 @@ var file_daemon_proto_goTypes = []interface{}{ (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse - nil, // 41: daemon.Route.ResolvedIPsEntry + nil, // 41: daemon.Network.ResolvedIPsEntry (*durationpb.Duration)(nil), // 42: google.protobuf.Duration (*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp } @@ -3011,21 +3013,21 @@ var file_daemon_proto_depIdxs = []int32{ 13, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route - 41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry + 25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State - 24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList + 24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest - 22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest - 22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest + 20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest @@ -3039,9 +3041,9 @@ var file_daemon_proto_depIdxs = []int32{ 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse - 23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse - 23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse + 21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse @@ -3291,7 +3293,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesRequest); i { + switch v := v.(*ListNetworksRequest); i { case 0: return &v.state case 1: @@ -3303,7 +3305,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesResponse); i { + switch v := v.(*ListNetworksResponse); i { case 0: return &v.state case 1: @@ -3315,7 +3317,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesRequest); i { + switch v := v.(*SelectNetworksRequest); i { case 0: return &v.state case 1: @@ -3327,7 +3329,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesResponse); i { + switch v := v.(*SelectNetworksResponse); i { case 0: return &v.state case 1: @@ -3351,7 +3353,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*Network); i { case 0: return &v.state case 1: diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 96ade5b4e51..cddf78242dc 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -28,14 +28,14 @@ service DaemonService { // GetConfig of the daemon. rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {} - // List available network routes - rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {} + // List available networks + rpc ListNetworks(ListNetworksRequest) returns (ListNetworksResponse) {} // Select specific routes - rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc SelectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // Deselect specific routes - rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc DeselectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // DebugBundle creates a debug bundle rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {} @@ -190,7 +190,7 @@ message PeerState { int64 bytesRx = 13; int64 bytesTx = 14; bool rosenpassEnabled = 15; - repeated string routes = 16; + repeated string networks = 16; google.protobuf.Duration latency = 17; string relayAddress = 18; } @@ -203,7 +203,7 @@ message LocalPeerState { string fqdn = 4; bool rosenpassEnabled = 5; bool rosenpassPermissive = 6; - repeated string routes = 7; + repeated string networks = 7; } // SignalState contains the latest state of a signal connection @@ -244,20 +244,20 @@ message FullStatus { repeated NSGroupState dns_servers = 6; } -message ListRoutesRequest { +message ListNetworksRequest { } -message ListRoutesResponse { - repeated Route routes = 1; +message ListNetworksResponse { + repeated Network routes = 1; } -message SelectRoutesRequest { - repeated string routeIDs = 1; +message SelectNetworksRequest { + repeated string networkIDs = 1; bool append = 2; bool all = 3; } -message SelectRoutesResponse { +message SelectNetworksResponse { } message IPList { @@ -265,9 +265,9 @@ message IPList { } -message Route { +message Network { string ID = 1; - string network = 2; + string range = 2; bool selected = 3; repeated string domains = 4; map resolvedIPs = 5; diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 2e063604a4a..39424aee938 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -31,12 +31,12 @@ type DaemonServiceClient interface { Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) - // List available network routes - ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) + // List available networks + ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon @@ -115,27 +115,27 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques return out, nil } -func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) { - out := new(ListRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...) +func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { + out := new(ListNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...) +func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...) +func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -222,12 +222,12 @@ type DaemonServiceServer interface { Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) - // List available network routes - ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) + // List available networks + ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon @@ -267,14 +267,14 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") } -func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented") +func (UnimplementedDaemonServiceServer) ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListNetworks not implemented") } -func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented") +func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SelectNetworks not implemented") } -func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented") +func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented") } func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented") @@ -418,56 +418,56 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } -func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ListRoutesRequest) +func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).ListRoutes(ctx, in) + return srv.(DaemonServiceServer).ListNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListRoutes", + FullMethod: "/daemon.DaemonService/ListNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest)) + return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).SelectRoutes(ctx, in) + return srv.(DaemonServiceServer).SelectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SelectRoutes", + FullMethod: "/daemon.DaemonService/SelectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, in) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeselectRoutes", + FullMethod: "/daemon.DaemonService/DeselectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } @@ -630,16 +630,16 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ Handler: _DaemonService_GetConfig_Handler, }, { - MethodName: "ListRoutes", - Handler: _DaemonService_ListRoutes_Handler, + MethodName: "ListNetworks", + Handler: _DaemonService_ListNetworks_Handler, }, { - MethodName: "SelectRoutes", - Handler: _DaemonService_SelectRoutes_Handler, + MethodName: "SelectNetworks", + Handler: _DaemonService_SelectNetworks_Handler, }, { - MethodName: "DeselectRoutes", - Handler: _DaemonService_DeselectRoutes_Handler, + MethodName: "DeselectNetworks", + Handler: _DaemonService_DeselectNetworks_Handler, }, { MethodName: "DebugBundle", diff --git a/client/server/route.go b/client/server/network.go similarity index 58% rename from client/server/route.go rename to client/server/network.go index d70e0dca391..aaf361524dd 100644 --- a/client/server/route.go +++ b/client/server/network.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "slices" "sort" "golang.org/x/exp/maps" @@ -20,8 +21,8 @@ type selectRoute struct { Selected bool } -// ListRoutes returns a list of all available routes. -func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { +// ListNetworks returns a list of all available networks. +func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*proto.ListNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -34,7 +35,7 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() + routesMap := engine.GetRouteManager().GetClientRoutesWithNetID() routeSelector := engine.GetRouteManager().GetRouteSelector() var routes []*selectRoute @@ -67,37 +68,47 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L }) resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() - var pbRoutes []*proto.Route + var pbRoutes []*proto.Network for _, route := range routes { - pbRoute := &proto.Route{ + pbRoute := &proto.Network{ ID: string(route.NetID), - Network: route.Network.String(), + Range: route.Network.String(), Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, } - for _, domain := range route.Domains { - if prefixes, exists := resolvedDomains[domain]; exists { - var ipStrings []string - for _, prefix := range prefixes { - ipStrings = append(ipStrings, prefix.Addr().String()) - } - pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ - Ips: ipStrings, + // Group resolved IPs by their parent domain + domainMap := map[domain.Domain][]string{} + + for resolvedDomain, info := range resolvedDomains { + // Check if this resolved domain's parent is in our route's domains + if slices.Contains(route.Domains, info.ParentDomain) { + ips := make([]string, 0, len(info.Prefixes)) + for _, prefix := range info.Prefixes { + ips = append(ips, prefix.Addr().String()) } + domainMap[resolvedDomain] = ips + } + } + + // Convert to proto format + for domain, ips := range domainMap { + pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ + Ips: ips, } } + pbRoutes = append(pbRoutes, pbRoute) } - return &proto.ListRoutesResponse{ + return &proto.ListNetworksResponse{ Routes: pbRoutes, }, nil } -// SelectRoutes selects specific routes based on the client request. -func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// SelectNetworks selects specific networks based on the client request. +func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -115,18 +126,19 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) if req.GetAll() { routeSelector.SelectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetNetworkIDs()) + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } -// DeselectRoutes deselects specific routes based on the client request. -func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// DeselectNetworks deselects specific networks based on the client request. +func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -144,14 +156,15 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques if req.GetAll() { routeSelector.DeselectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetNetworkIDs()) + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } func toNetIDs(routes []string) []route.NetID { diff --git a/client/server/server.go b/client/server/server.go index 71eb58a66bc..5640ffa3926 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -772,7 +772,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled - pbFullStatus.LocalPeerState.Routes = maps.Keys(fullStatus.LocalPeerState.Routes) + pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) for _, peerState := range fullStatus.Peers { pbPeerState := &proto.PeerState{ @@ -791,7 +791,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, - Routes: maps.Keys(peerState.GetRoutes()), + Networks: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) diff --git a/client/server/server_test.go b/client/server/server_test.go index 61bdaf660d2..128de8e020f 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -20,6 +20,8 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -110,7 +112,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } @@ -132,7 +134,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 8ca0db73f38..49b0f53cf05 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -58,7 +58,7 @@ func main() { var showSettings bool flag.BoolVar(&showSettings, "settings", false, "run settings windows") var showRoutes bool - flag.BoolVar(&showRoutes, "routes", false, "run routes windows") + flag.BoolVar(&showRoutes, "networks", false, "run networks windows") var errorMSG string flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") @@ -233,7 +233,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo s.showSettingsUI() return s } else if showRoutes { - s.showRoutesUI() + s.showNetworksUI() } return s @@ -549,7 +549,7 @@ func (s *serviceClient) onTrayReady() { s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application") s.loadSettings() - s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window") + s.mRoutes = systray.AddMenuItem("Networks", "Open the networks management window") s.mRoutes.Disable() systray.AddSeparator() @@ -657,7 +657,7 @@ func (s *serviceClient) onTrayReady() { s.mRoutes.Disable() go func() { defer s.mRoutes.Enable() - s.runSelfCommand("routes", "true") + s.runSelfCommand("networks", "true") }() } if err != nil { diff --git a/client/ui/route.go b/client/ui/network.go similarity index 56% rename from client/ui/route.go rename to client/ui/network.go index 5b6b8fee0d8..e6f027f0edf 100644 --- a/client/ui/route.go +++ b/client/ui/network.go @@ -19,32 +19,32 @@ import ( ) const ( - allRoutesText = "All routes" - overlappingRoutesText = "Overlapping routes" - exitNodeRoutesText = "Exit-node routes" - allRoutes filter = "all" - overlappingRoutes filter = "overlapping" - exitNodeRoutes filter = "exit-node" - getClientFMT = "get client: %v" + allNetworksText = "All networks" + overlappingNetworksText = "Overlapping networks" + exitNodeNetworksText = "Exit-node networks" + allNetworks filter = "all" + overlappingNetworks filter = "overlapping" + exitNodeNetworks filter = "exit-node" + getClientFMT = "get client: %v" ) type filter string -func (s *serviceClient) showRoutesUI() { - s.wRoutes = s.app.NewWindow("NetBird Routes") +func (s *serviceClient) showNetworksUI() { + s.wRoutes = s.app.NewWindow("Networks") allGrid := container.New(layout.NewGridLayout(3)) - go s.updateRoutes(allGrid, allRoutes) + go s.updateNetworks(allGrid, allNetworks) overlappingGrid := container.New(layout.NewGridLayout(3)) exitNodeGrid := container.New(layout.NewGridLayout(3)) routeCheckContainer := container.NewVBox() tabs := container.NewAppTabs( - container.NewTabItem(allRoutesText, allGrid), - container.NewTabItem(overlappingRoutesText, overlappingGrid), - container.NewTabItem(exitNodeRoutesText, exitNodeGrid), + container.NewTabItem(allNetworksText, allGrid), + container.NewTabItem(overlappingNetworksText, overlappingGrid), + container.NewTabItem(exitNodeNetworksText, exitNodeGrid), ) tabs.OnSelected = func(item *container.TabItem) { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) } tabs.OnUnselected = func(item *container.TabItem) { grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) @@ -58,17 +58,17 @@ func (s *serviceClient) showRoutesUI() { buttonBox := container.NewHBox( layout.NewSpacer(), widget.NewButton("Refresh", func() { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Select all", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.selectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.selectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Deselect All", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.deselectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.deselectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), layout.NewSpacer(), ) @@ -81,36 +81,36 @@ func (s *serviceClient) showRoutesUI() { s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) } -func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { +func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) { grid.Objects = nil grid.Refresh() idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) - networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + networkHeader := widget.NewLabelWithStyle("Range/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) grid.Add(idHeader) grid.Add(networkHeader) grid.Add(resolvedIPsHeader) - filteredRoutes, err := s.getFilteredRoutes(f) + filteredRoutes, err := s.getFilteredNetworks(f) if err != nil { return } - sortRoutesByIDs(filteredRoutes) + sortNetworksByIDs(filteredRoutes) for _, route := range filteredRoutes { r := route checkBox := widget.NewCheck(r.GetID(), func(checked bool) { - s.selectRoute(r.ID, checked) + s.selectNetwork(r.ID, checked) }) checkBox.Checked = route.Selected checkBox.Resize(fyne.NewSize(20, 20)) checkBox.Refresh() grid.Add(checkBox) - network := r.GetNetwork() + network := r.GetRange() domains := r.GetDomains() if len(domains) == 0 { @@ -129,10 +129,8 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Add(domainsSelector) var resolvedIPsList []string - for _, domain := range domains { - if ipList, exists := r.GetResolvedIPs()[domain]; exists { - resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) - } + for domain, ipList := range r.GetResolvedIPs() { + resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) } if len(resolvedIPsList) == 0 { @@ -151,35 +149,35 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Refresh() } -func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) { - routes, err := s.fetchRoutes() +func (s *serviceClient) getFilteredNetworks(f filter) ([]*proto.Network, error) { + routes, err := s.fetchNetworks() if err != nil { log.Errorf(getClientFMT, err) s.showError(fmt.Errorf(getClientFMT, err)) return nil, err } switch f { - case overlappingRoutes: - return getOverlappingRoutes(routes), nil - case exitNodeRoutes: - return getExitNodeRoutes(routes), nil + case overlappingNetworks: + return getOverlappingNetworks(routes), nil + case exitNodeNetworks: + return getExitNodeNetworks(routes), nil default: } return routes, nil } -func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route - existingRange := make(map[string][]*proto.Route) +func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network + existingRange := make(map[string][]*proto.Network) for _, route := range routes { if len(route.Domains) > 0 { continue } - if r, exists := existingRange[route.GetNetwork()]; exists { + if r, exists := existingRange[route.GetRange()]; exists { r = append(r, route) - existingRange[route.GetNetwork()] = r + existingRange[route.GetRange()] = r } else { - existingRange[route.GetNetwork()] = []*proto.Route{route} + existingRange[route.GetRange()] = []*proto.Network{route} } } for _, r := range existingRange { @@ -190,29 +188,29 @@ func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { return filteredRoutes } -func getExitNodeRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route +func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network for _, route := range routes { - if route.Network == "0.0.0.0/0" { + if route.Range == "0.0.0.0/0" { filteredRoutes = append(filteredRoutes, route) } } return filteredRoutes } -func sortRoutesByIDs(routes []*proto.Route) { +func sortNetworksByIDs(routes []*proto.Network) { sort.Slice(routes, func(i, j int) bool { return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID()) }) } -func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { +func (s *serviceClient) fetchNetworks() ([]*proto.Network, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { return nil, fmt.Errorf(getClientFMT, err) } - resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{}) + resp, err := conn.ListNetworks(s.ctx, &proto.ListNetworksRequest{}) if err != nil { return nil, fmt.Errorf("failed to list routes: %v", err) } @@ -220,7 +218,7 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { return resp.Routes, nil } -func (s *serviceClient) selectRoute(id string, checked bool) { +func (s *serviceClient) selectNetwork(id string, checked bool) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) @@ -228,73 +226,73 @@ func (s *serviceClient) selectRoute(id string, checked bool) { return } - req := &proto.SelectRoutesRequest{ - RouteIDs: []string{id}, - Append: checked, + req := &proto.SelectNetworksRequest{ + NetworkIDs: []string{id}, + Append: checked, } if checked { - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select route: %v", err) - s.showError(fmt.Errorf("failed to select route: %v", err)) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select network: %v", err) + s.showError(fmt.Errorf("failed to select network: %v", err)) return } log.Infof("Route %s selected", id) } else { - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect route: %v", err) - s.showError(fmt.Errorf("failed to deselect route: %v", err)) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect network: %v", err) + s.showError(fmt.Errorf("failed to deselect network: %v", err)) return } - log.Infof("Route %s deselected", id) + log.Infof("Network %s deselected", id) } } -func (s *serviceClient) selectAllFilteredRoutes(f filter) { +func (s *serviceClient) selectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, true) - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select all routes: %v", err) - s.showError(fmt.Errorf("failed to select all routes: %v", err)) + req := s.getNetworksRequest(f, true) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select all networks: %v", err) + s.showError(fmt.Errorf("failed to select all networks: %v", err)) return } - log.Debug("All routes selected") + log.Debug("All networks selected") } -func (s *serviceClient) deselectAllFilteredRoutes(f filter) { +func (s *serviceClient) deselectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, false) - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect all routes: %v", err) - s.showError(fmt.Errorf("failed to deselect all routes: %v", err)) + req := s.getNetworksRequest(f, false) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect all networks: %v", err) + s.showError(fmt.Errorf("failed to deselect all networks: %v", err)) return } - log.Debug("All routes deselected") + log.Debug("All networks deselected") } -func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest { - req := &proto.SelectRoutesRequest{} - if f == allRoutes { +func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.SelectNetworksRequest { + req := &proto.SelectNetworksRequest{} + if f == allNetworks { req.All = true } else { - routes, err := s.getFilteredRoutes(f) + routes, err := s.getFilteredNetworks(f) if err != nil { return nil } for _, route := range routes { - req.RouteIDs = append(req.RouteIDs, route.GetID()) + req.NetworkIDs = append(req.NetworkIDs, route.GetID()) } req.Append = appendRoute } @@ -311,7 +309,7 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container ticker := time.NewTicker(interval) go func() { for range ticker.C { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) } }() @@ -320,20 +318,20 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container }) } -func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { +func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) s.wRoutes.Content().Refresh() - s.updateRoutes(grid, f) + s.updateNetworks(grid, f) } func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) { switch tabs.Selected().Text { - case overlappingRoutesText: - return overlappingGrid, overlappingRoutes - case exitNodeRoutesText: - return exitNodesGrid, exitNodeRoutes + case overlappingNetworksText: + return overlappingGrid, overlappingNetworks + case exitNodeNetworksText: + return exitNodesGrid, exitNodeNetworks default: - return allGrid, allRoutes + return allGrid, allNetworks } } diff --git a/dns/dns.go b/dns/dns.go index 18528c74328..8dfdf852619 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) { return validHost, nil } + +// NormalizeZone returns a normalized domain name without the wildcard prefix +func NormalizeZone(domain string) string { + d, _ := strings.CutPrefix(domain, "*.") + return d +} diff --git a/go.mod b/go.mod index c504925d200..d48280df02a 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 + github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 @@ -207,6 +207,7 @@ require ( github.com/spf13/cast v1.5.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect diff --git a/go.sum b/go.sum index 04d2bc59af8..540cbf20bb9 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= @@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/management/client/client_test.go b/management/client/client_test.go index 100b3fcaa12..8bd8af8d2aa 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/client/system" @@ -57,7 +59,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -76,7 +78,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index bfa158c5b53..4f34009b7e1 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -41,12 +41,20 @@ import ( "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" httpapi "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" ) @@ -150,7 +158,7 @@ var ( if err != nil { return err } - store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) + store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) } @@ -265,7 +273,15 @@ var ( KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + userManager := users.NewManager(store) + settingsManager := settings.NewManager(store) + permissionsManager := permissions.NewManager(userManager, settingsManager) + groupsManager := groups.NewManager(store, permissionsManager, accountManager) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager) + routersManager := routers.NewManager(store, permissionsManager, accountManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) + + httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -274,7 +290,7 @@ var ( ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } @@ -400,7 +416,7 @@ func notifyStop(ctx context.Context, msg string) { } } -func getInstallationID(ctx context.Context, store server.Store) (string, error) { +func getInstallationID(ctx context.Context, store store.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { return installationID, nil diff --git a/management/cmd/migration_up.go b/management/cmd/migration_up.go index 7aa11f0c92c..183fc554dec 100644 --- a/management/cmd/migration_up.go +++ b/management/cmd/migration_up.go @@ -9,7 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -32,7 +32,7 @@ var upCmd = &cobra.Command{ //nolint ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource) - if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { + if err := store.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { return err } log.WithContext(ctx).Info("Migration finished successfully") diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 672b2a10228..b4ff16e6d0a 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v4.24.3 // source: management.proto package proto @@ -29,6 +29,7 @@ const ( RuleProtocol_TCP RuleProtocol = 2 RuleProtocol_UDP RuleProtocol = 3 RuleProtocol_ICMP RuleProtocol = 4 + RuleProtocol_CUSTOM RuleProtocol = 5 ) // Enum value maps for RuleProtocol. @@ -39,6 +40,7 @@ var ( 2: "TCP", 3: "UDP", 4: "ICMP", + 5: "CUSTOM", } RuleProtocol_value = map[string]int32{ "UNKNOWN": 0, @@ -46,6 +48,7 @@ var ( "TCP": 2, "UDP": 3, "ICMP": 4, + "CUSTOM": 5, } ) @@ -1393,7 +1396,8 @@ type PeerConfig struct { // SSHConfig of the peer. SshConfig *SSHConfig `protobuf:"bytes,3,opt,name=sshConfig,proto3" json:"sshConfig,omitempty"` // Peer fully qualified domain name - Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` } func (x *PeerConfig) Reset() { @@ -1456,6 +1460,13 @@ func (x *PeerConfig) GetFqdn() string { return "" } +func (x *PeerConfig) GetRoutingPeerDnsResolutionEnabled() bool { + if x != nil { + return x.RoutingPeerDnsResolutionEnabled + } + return false +} + // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -2780,6 +2791,10 @@ type RouteFirewallRule struct { PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"` // IsDynamic indicates if the route is a DNS route. IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"` + // Domains is a list of domains for which the rule is applicable. + Domains []string `protobuf:"bytes,7,rep,name=domains,proto3" json:"domains,omitempty"` + // CustomProtocol is a custom protocol ID. + CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` } func (x *RouteFirewallRule) Reset() { @@ -2856,6 +2871,20 @@ func (x *RouteFirewallRule) GetIsDynamic() bool { return false } +func (x *RouteFirewallRule) GetDomains() []string { + if x != nil { + return x.Domains + } + return nil +} + +func (x *RouteFirewallRule) GetCustomProtocol() uint32 { + if x != nil { + return x.CustomProtocol + } + return 0 +} + type PortInfo_Range struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3075,7 +3104,7 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, - 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x81, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, + 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0xcb, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, @@ -3083,250 +3112,260 @@ var file_management_proto_rawDesc = []byte{ 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xf3, 0x04, 0x0a, 0x0a, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, - 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, - 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, - 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, - 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, - 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, - 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, - 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, - 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, 0x1f, 0x52, + 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, + 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, + 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0xf3, 0x04, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, + 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, + 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, + 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, - 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, - 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, - 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, - 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, - 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, - 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, - 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, - 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, - 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, - 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, - 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, - 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, - 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, - 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, - 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, - 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, - 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, - 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, - 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, - 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, - 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, - 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, - 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, - 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, - 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, - 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, - 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, - 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, - 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, - 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, - 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, - 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, - 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, - 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, - 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, - 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, - 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, - 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, - 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, - 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, - 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, - 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, - 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, - 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, - 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, - 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, - 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, - 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, - 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, - 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, - 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, - 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, - 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, - 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, - 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, + 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, + 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, + 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, + 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, + 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, + 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, + 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, + 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, + 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, + 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, + 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, + 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, + 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, + 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, + 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, + 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, + 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, + 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, + 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, + 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, + 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, + 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, + 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, + 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, + 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, + 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, + 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, + 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, + 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xd1, 0x02, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, + 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, + 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, + 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, + 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, + 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, + 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, + 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, + 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, + 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, + 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, + 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, + 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, + 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, + 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, + 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, + 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/management/proto/management.proto b/management/proto/management.proto index fe6a828b1e5..5f4e0df46b0 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -222,6 +222,8 @@ message PeerConfig { SSHConfig sshConfig = 3; // Peer fully qualified domain name string fqdn = 4; + + bool RoutingPeerDnsResolutionEnabled = 5; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections @@ -396,6 +398,7 @@ enum RuleProtocol { TCP = 2; UDP = 3; ICMP = 4; + CUSTOM = 5; } enum RuleDirection { @@ -459,5 +462,11 @@ message RouteFirewallRule { // IsDynamic indicates if the route is a DNS route. bool isDynamic = 6; + + // Domains is a list of domains for which the rule is applicable. + repeated string domains = 7; + + // CustomProtocol is a custom protocol ID. + uint32 customProtocol = 8; } diff --git a/management/server/account.go b/management/server/account.go index fbe6fcc1a4b..e60b41b4ec1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -19,8 +19,6 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" - "github.com/hashicorp/go-multierror" - "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -29,31 +27,26 @@ import ( "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour - DefaultPeerInactivityExpiration = 10 * time.Minute - emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -66,56 +59,56 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) - GetAccount(ctx context.Context, accountID string) (*Account, error) - CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, - autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) - SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) - CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) + GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccount(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, + autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) - SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) - SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) - SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) - GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) + SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) + GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) + GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error - GetUserByID(ctx context.Context, id string) (*User, error) - GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) - ListUsers(ctx context.Context, accountID string) ([]*User, error) + GetUserByID(ctx context.Context, id string) (*types.User, error) + GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) - GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) + AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error - GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) - GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) - GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error + GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) + GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error - GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) + GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error - ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) + ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error @@ -129,12 +122,12 @@ type AccountManager interface { GetDNSDomain() string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) - SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error + GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) + SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -145,18 +138,19 @@ type AccountManager interface { GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + GetValidatedPeers(account *types.Account) (map[string]struct{}, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) - GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error + UpdateAccountPeers(ctx context.Context, accountID string) } type DefaultAccountManager struct { - Store Store + Store store.Store // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded @@ -191,763 +185,40 @@ type DefaultAccountManager struct { metrics telemetry.AppMetrics } -// Settings represents Account settings structure that can be modified via API and Dashboard -type Settings struct { - // PeerLoginExpirationEnabled globally enables or disables peer login expiration - PeerLoginExpirationEnabled bool - - // PeerLoginExpiration is a setting that indicates when peer login expires. - // Applies to all peers that have Peer.LoginExpirationEnabled set to true. - PeerLoginExpiration time.Duration - - // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration - PeerInactivityExpirationEnabled bool - - // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. - // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. - PeerInactivityExpiration time.Duration - - // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements - RegularUsersViewBlocked bool - - // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer - GroupsPropagationEnabled bool - - // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName - // and add it to account groups. - JWTGroupsEnabled bool - - // JWTGroupsClaimName from which we extract groups name to add it to account groups - JWTGroupsClaimName string - - // JWTAllowGroups list of groups to which users are allowed access - JWTAllowGroups []string `gorm:"serializer:json"` - - // Extra is a dictionary of Account settings - Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` -} - -// Copy copies the Settings struct -func (s *Settings) Copy() *Settings { - settings := &Settings{ - PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, - PeerLoginExpiration: s.PeerLoginExpiration, - JWTGroupsEnabled: s.JWTGroupsEnabled, - JWTGroupsClaimName: s.JWTGroupsClaimName, - GroupsPropagationEnabled: s.GroupsPropagationEnabled, - JWTAllowGroups: s.JWTAllowGroups, - RegularUsersViewBlocked: s.RegularUsersViewBlocked, - - PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, - PeerInactivityExpiration: s.PeerInactivityExpiration, - } - if s.Extra != nil { - settings.Extra = s.Extra.Copy() - } - return settings -} - -// Account represents a unique account of the system -type Account struct { - // we have to name column to aid as it collides with Network.Id when work with associations - Id string `gorm:"primaryKey"` - - // User.Id it was created by - CreatedBy string - CreatedAt time.Time - Domain string `gorm:"index"` - DomainCategory string - IsDomainPrimaryAccount bool - SetupKeys map[string]*SetupKey `gorm:"-"` - SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` - Network *Network `gorm:"embedded;embeddedPrefix:network_"` - Peers map[string]*nbpeer.Peer `gorm:"-"` - PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` - Users map[string]*User `gorm:"-"` - UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*nbgroup.Group `gorm:"-"` - GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` - Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` - Routes map[route.ID]*route.Route `gorm:"-"` - RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` - NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` - NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` - PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` - // Settings is a dictionary of Account settings - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load settings and not whole account -type AccountSettings struct { - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load network and not whole account -type AccountNetwork struct { - Network *Network `gorm:"embedded;embeddedPrefix:network_"` -} - -// AccountDNSSettings used in gorm to only load dns settings and not whole account -type AccountDNSSettings struct { - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` -} - -type UserPermissions struct { - DashboardView string `json:"dashboard_view"` -} - -type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` - NonDeletable bool `json:"non_deletable"` - LastLogin time.Time `json:"last_login"` - Issued string `json:"issued"` - IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` -} - -// getRoutesToSync returns the enabled routes for the peer ID and the routes -// from the ACL peers that have distribution groups associated with the peer ID. -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) - peerRoutesMembership := make(lookupMap) - for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - groupListMap := a.getPeerGroups(peerID) - for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) - filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - routes = append(routes, filteredRoutes...) - } - - return routes -} - -// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership -func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - _, found := peerMemberships[string(r.GetHAUniqueID())] - if !found { - filteredRoutes = append(filteredRoutes, r) - } - } - return filteredRoutes -} - -// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map -func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - for _, groupID := range r.Groups { - _, found := groupListMap[groupID] - if found { - filteredRoutes = append(filteredRoutes, r) - break - } - } - } - return filteredRoutes -} - -// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -// If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - - peer := a.GetPeer(peerID) - if peer == nil { - log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) - return enabledRoutes, disabledRoutes - } - - // currently we support only linux routing peers - if peer.Meta.GoOS != "linux" { - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - for _, r := range a.Routes { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) - continue - } - for _, id := range group.Peers { - if id != peerID { - continue - } - - newPeerRoute := r.Copy() - newPeerRoute.Peer = id - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map - takeRoute(newPeerRoute, id) - break - } - } - if r.Peer == peerID { - takeRoute(r.Copy(), peerID) - } - } - - return enabledRoutes, disabledRoutes -} - -// GetRoutesByPrefixOrDomains return list of routes by account and route prefix -func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { - var routes []*route.Route - for _, r := range a.Routes { - dynamic := r.IsDynamic() - if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || - !dynamic && r.Network.String() == prefix.String() { - routes = append(routes, r) - } - } - - return routes -} - -// GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *nbgroup.Group { - return a.Groups[groupID] -} - -// GetPeerNetworkMap returns the networkmap for the given peer ID. -func (a *Account) GetPeerNetworkMap( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - validatedPeersMap map[string]struct{}, - metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - - peer := a.Peers[peerID] - if peer == nil { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - if _, ok := validatedPeersMap[peerID]; !ok { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) - // exclude expired peers - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - for _, p := range aclPeers { - expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) - if a.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, p) - continue - } - peersToConnect = append(peersToConnect, p) - } - - routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) - routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - - dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) - } - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) - } - - nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: routesFirewallRules, - } - - if metrics != nil { - objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ - "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", - a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) - } - } - - return nm -} - -func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { - var merr *multierror.Error - - if dnsDomain == "" { - log.WithContext(ctx).Error("no dns domain is set, returning empty zone") - return nbdns.CustomZone{} - } - - customZone := nbdns.CustomZone{ - Domain: dns.Fqdn(dnsDomain), - Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), - } - - domainSuffix := "." + dnsDomain - - var sb strings.Builder - for _, peer := range a.Peers { - if peer.DNSLabel == "" { - merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) - continue - } - - sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) - sb.WriteString(peer.DNSLabel) - sb.WriteString(domainSuffix) - - customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), - Type: int(dns.TypeA), - Class: nbdns.DefaultClass, - TTL: defaultTTL, - RData: peer.IP.String(), - }) - - sb.Reset() - } - - go func() { - if merr != nil { - log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) - } - }() - - return customZone -} - -// GetExpiredPeers returns peers that have been expired -func (a *Account) GetExpiredPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.GetPeersWithExpiration() { - expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if expired { - peers = append(peers, peer) - } - } - - return peers -} - -// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are connected. -func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithExpiration() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - // consider only connected peers because others will require login on connecting to the management server - if peer.Status.LoginExpired || !peer.Status.Connected { - continue - } - _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetInactivePeers returns peers that have been expired by inactivity -func (a *Account) GetInactivePeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, inactivePeer := range a.GetPeersWithInactivity() { - inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) - if inactive { - peers = append(peers, inactivePeer) - } - } - return peers -} - -// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are not connected. -func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithInactivity() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - if peer.Status.LoginExpired || peer.Status.Connected { - continue - } - _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetPeers returns a list of all Account peers -func (a *Account) GetPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.Peers { - peers = append(peers, peer) - } - return peers -} - -// UpdateSettings saves new account settings -func (a *Account) UpdateSettings(update *Settings) *Account { - a.Settings = update.Copy() - return a -} - -// UpdatePeer saves new or replaces existing peer -func (a *Account) UpdatePeer(update *nbpeer.Peer) { - a.Peers[update.ID] = update -} - -// DeletePeer deletes peer from the account cleaning up all the references -func (a *Account) DeletePeer(peerID string) { - // delete peer from groups - for _, g := range a.Groups { - for i, pk := range g.Peers { - if pk == peerID { - g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) - break - } - } - } - - for _, r := range a.Routes { - if r.Peer == peerID { - r.Enabled = false - r.Peer = "" - } - } - - delete(a.Peers, peerID) - a.Network.IncSerial() -} - -// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. -// It will return an object copy of the peer. -func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { - for _, peer := range a.Peers { - if peer.Key == peerPubKey { - return peer.Copy(), nil - } - } - - return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) -} - -// FindUserPeers returns a list of peers that user owns (created) -func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.UserID == userID { - peers = append(peers, peer) - } - } - - return peers, nil -} - -// FindUser looks for a given user in the Account or returns error if user wasn't found. -func (a *Account) FindUser(userID string) (*User, error) { - user := a.Users[userID] - if user == nil { - return nil, status.Errorf(status.NotFound, "user %s not found", userID) - } - - return user, nil -} - -// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { - for _, group := range a.Groups { - if group.Name == groupName { - return group, nil - } - } - return nil, status.Errorf(status.NotFound, "group %s not found", groupName) -} - -// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. -func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { - key := a.SetupKeys[setupKey] - if key == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return key, nil -} - -// GetPeerGroupsList return with the list of groups ID. -func (a *Account) GetPeerGroupsList(peerID string) []string { - var grps []string - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - grps = append(grps, groupID) - break - } - } - } - return grps -} - -func (a *Account) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := a.getPeerGroups(peerID) - enabled := true - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - -func (a *Account) getPeerGroups(peerID string) lookupMap { - groupList := make(lookupMap) - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - groupList[groupID] = struct{}{} - break - } - } - } - return groupList -} - -func (a *Account) getTakenIPs() []net.IP { - var takenIps []net.IP - for _, existingPeer := range a.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps -} - -func (a *Account) getPeerDNSLabels() lookupMap { - existingLabels := make(lookupMap) - for _, peer := range a.Peers { - if peer.DNSLabel != "" { - existingLabels[peer.DNSLabel] = struct{}{} - } - } - return existingLabels -} - -func (a *Account) Copy() *Account { - peers := map[string]*nbpeer.Peer{} - for id, peer := range a.Peers { - peers[id] = peer.Copy() - } - - users := map[string]*User{} - for id, user := range a.Users { - users[id] = user.Copy() - } - - setupKeys := map[string]*SetupKey{} - for id, key := range a.SetupKeys { - setupKeys[id] = key.Copy() - } - - groups := map[string]*nbgroup.Group{} - for id, group := range a.Groups { - groups[id] = group.Copy() - } - - policies := []*Policy{} - for _, policy := range a.Policies { - policies = append(policies, policy.Copy()) - } - - routes := map[route.ID]*route.Route{} - for id, r := range a.Routes { - routes[id] = r.Copy() - } - - nsGroups := map[string]*nbdns.NameServerGroup{} - for id, nsGroup := range a.NameServerGroups { - nsGroups[id] = nsGroup.Copy() - } - - dnsSettings := a.DNSSettings.Copy() - - var settings *Settings - if a.Settings != nil { - settings = a.Settings.Copy() - } - - postureChecks := []*posture.Checks{} - for _, postureCheck := range a.PostureChecks { - postureChecks = append(postureChecks, postureCheck.Copy()) - } - - return &Account{ - Id: a.Id, - CreatedBy: a.CreatedBy, - CreatedAt: a.CreatedAt, - Domain: a.Domain, - DomainCategory: a.DomainCategory, - IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, - SetupKeys: setupKeys, - Network: a.Network.Copy(), - Peers: peers, - Users: users, - Groups: groups, - Policies: policies, - Routes: routes, - NameServerGroups: nsGroups, - DNSSettings: dnsSettings, - PostureChecks: postureChecks, - Settings: settings, - } -} - -func (a *Account) GetGroupAll() (*nbgroup.Group, error) { - for _, g := range a.Groups { - if g.Name == "All" { - return g, nil - } - } - return nil, fmt.Errorf("no group ALL found") -} - -// GetPeer looks up a Peer by ID -func (a *Account) GetPeer(peerID string) *nbpeer.Peer { - return a.Peers[peerID] -} - // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. -func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { - existedGroupsByName := make(map[string]*nbgroup.Group) +func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []*types.Group, groupNames []string) (bool, []string, []*types.Group, error) { + existedGroupsByName := make(map[string]*types.Group) for _, group := range groups { existedGroupsByName[group.Name] = group } newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) - groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) + groupsToAdd := util.Difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := util.Difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return false, nil, nil, nil } - newGroupsToCreate := make([]*nbgroup.Group, 0) + newGroupsToCreate := make([]*types.Group, 0) var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { - group = &nbgroup.Group{ + group = &types.Group{ ID: xid.New().String(), AccountID: user.AccountID, Name: name, - Issued: nbgroup.GroupIssuedJWT, + Issued: types.GroupIssuedJWT, } newGroupsToCreate = append(newGroupsToCreate, group) } - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } @@ -964,78 +235,10 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro return modified, newUserAutoGroups, newGroupsToCreate, nil } -// UserGroupsAddToPeers adds groups to all peers of user -func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { - groupUpdates := make(map[string][]string) - - userPeers := make(map[string]struct{}) - for pid, peer := range a.Peers { - if peer.UserID == userID { - userPeers[pid] = struct{}{} - } - } - - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok { - continue - } - - oldPeers := group.Peers - - groupPeers := make(map[string]struct{}) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - for pid := range userPeers { - groupPeers[pid] = struct{}{} - } - - group.Peers = group.Peers[:0] - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - groupUpdates[gid] = difference(group.Peers, oldPeers) - } - - return groupUpdates -} - -// UserGroupsRemoveFromPeers removes groups from all peers of user -func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { - groupUpdates := make(map[string][]string) - - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok || group.Name == "All" { - continue - } - - oldPeers := group.Peers - - update := make([]string, 0, len(group.Peers)) - for _, pid := range group.Peers { - peer, ok := a.Peers[pid] - if !ok { - continue - } - if peer.UserID != userID { - update = append(update, pid) - } - } - group.Peers = update - groupUpdates[gid] = difference(oldPeers, group.Peers) - } - - return groupUpdates -} - // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, - store Store, + store store.Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, @@ -1137,7 +340,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1186,6 +389,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.checkAndSchedulePeerLoginExpiration(ctx, account) } + updateAccountPeers := false + if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled { + if newSettings.RoutingPeerDNSResolutionEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil) + } + updateAccountPeers = true + account.Network.Serial++ + } + err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err @@ -1203,10 +417,14 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } + if updateAccountPeers { + go am.UpdateAccountPeers(ctx, accountID) + } + return updatedAccount, nil } -func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { if newSettings.GroupsPropagationEnabled { am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) @@ -1219,7 +437,7 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con return nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *types.Account, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { @@ -1272,7 +490,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *types.Account) { am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextPeerExpiration(); ok { go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) @@ -1309,7 +527,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions -func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *types.Account) { am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) @@ -1318,7 +536,7 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*types.Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() @@ -1398,7 +616,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") } - if user.Role != UserRoleOwner { + if user.Role != types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } for _, otherUser := range account.Users { @@ -1436,7 +654,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { - return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) + return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) } // GetAccountIDByUserID retrieves the account ID based on the userID provided. @@ -1473,13 +691,13 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } - cachedAccount := &Account{ + cachedAccount := &types.Account{ Id: accountID, - Users: make(map[string]*User), + Users: make(map[string]*types.User), } for _, user := range accountUsers { cachedAccount.Users[user.Id] = user @@ -1562,14 +780,14 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) { +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) { users := make(map[string]userLoggedInOnce, len(account.Users)) // ignore service users and users provisioned by integrations than are never logged in for _, user := range account.Users { if user.IsServiceUser { continue } - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { continue } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) @@ -1739,7 +957,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlockAccount() - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err @@ -1749,7 +967,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err @@ -1834,8 +1052,8 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - usersMap := make(map[string]*User) - usersMap[claims.UserId] = NewRegularUser(claims.UserId) + usersMap := make(map[string]*types.User) + usersMap[claims.UserId] = types.NewRegularUser(claims.UserId) err := am.Store.SaveUsers(domainAccountID, usersMap) if err != nil { return "", err @@ -1923,22 +1141,22 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string } // GetAccount returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { return am.Store.GetAccount(ctx, accountID) } // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { - if len(token) != PATLength { +func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { + if len(token) != types.PATLength { return nil, nil, nil, fmt.Errorf("token has wrong length") } - prefix := token[:len(PATPrefix)] - if prefix != PATPrefix { + prefix := token[:len(types.PATPrefix)] + if prefix != types.PATPrefix { return nil, nil, nil, fmt.Errorf("token has wrong prefix") } - secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] - encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] + secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] + encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { @@ -1976,8 +1194,8 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st } // GetAccountByID returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -1998,7 +1216,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain - claims.DomainCategory = PrivateCategory + claims.DomainCategory = types.PrivateCategory log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } @@ -2007,7 +1225,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", err } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { // this is not really possible because we got an account by user ID return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) @@ -2034,7 +1252,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -2060,14 +1278,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var addNewGroups []string var removeOldGroups []string var hasChanges bool - var user *User - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - user, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + var user *types.User + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2083,31 +1301,31 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return nil } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } - addNewGroups = difference(updatedAutoGroups, user.AutoGroups) - removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) user.AutoGroups = updatedAutoGroups - if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) } // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } - peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } @@ -2117,11 +1335,11 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2139,7 +1357,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2152,7 +1370,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2177,7 +1395,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } } @@ -2210,7 +1428,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", errors.New(emptyUserID) } - if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { + if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) { return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } @@ -2248,7 +1466,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return am.addNewPrivateAccount(ctx, domainAccountID, claims) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2263,7 +1481,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont cancel := am.Store.AcquireGlobalLock(ctx) // check again if the domain has a primary account because of simultaneous requests - domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2284,7 +1502,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err @@ -2295,7 +1513,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err @@ -2322,10 +1540,10 @@ func handleNotFound(err error) error { } func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain + return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) @@ -2422,7 +1640,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, return err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -2444,7 +1662,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -2455,8 +1673,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } -func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -2477,14 +1695,14 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee return false, nil } -func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { + existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) if err != nil { return "", fmt.Errorf("failed to get peer dns labels: %w", err) } labelMap := ConvertSliceToMap(existingLabels) - newLabel, err := getPeerHostLabel(peerHostName, labelMap) + newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap) if err != nil { return "", fmt.Errorf("failed to get new host label: %w", err) } @@ -2496,8 +1714,8 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor return newLabel, nil } -func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -2506,70 +1724,70 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } - return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) } // addAllGroup to account object if it doesn't exist -func addAllGroup(account *Account) error { +func addAllGroup(account *types.Account) error { if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ + allGroup := &types.Group{ ID: xid.New().String(), Name: "All", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} id := xid.New().String() - defaultPolicy := &Policy{ + defaultPolicy := &types.Policy{ ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, Sources: []string{allGroup.ID}, Destinations: []string{allGroup.ID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } - account.Policies = []*Policy{defaultPolicy} + account.Policies = []*types.Policy{defaultPolicy} } return nil } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { log.WithContext(ctx).Debugf("creating new account") - network := NewNetwork() + network := types.NewNetwork() peers := make(map[string]*nbpeer.Peer) - users := make(map[string]*User) + users := make(map[string]*types.User) routes := make(map[route.ID]*route.Route) - setupKeys := map[string]*SetupKey{} + setupKeys := map[string]*types.SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - owner := NewOwnerUser(userID) + owner := types.NewOwnerUser(userID) owner.AccountID = accountID users[userID] = owner - dnsSettings := DNSSettings{ + dnsSettings := types.DNSSettings{ DisabledManagementGroups: make([]string, 0), } log.WithContext(ctx).Debugf("created new account %s", accountID) - acc := &Account{ + acc := &types.Account{ Id: accountID, CreatedAt: time.Now().UTC(), SetupKeys: setupKeys, @@ -2581,14 +1799,15 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac Routes: routes, NameServerGroups: nameServersGroups, DNSSettings: dnsSettings, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, }, } @@ -2634,18 +1853,18 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID - allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + allGroupsMap := make(map[string]*types.Group, len(allGroups)) for _, group := range allGroups { allGroupsMap[group.ID] = group } for _, id := range autoGroups { if group, ok := allGroupsMap[id]; ok { - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { newAutoGroups = append(newAutoGroups, id) diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index 5f4897e6a88..fa6c45856be 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -7,6 +7,9 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // AccountRequest holds the result channel to return the requested account. @@ -17,19 +20,19 @@ type AccountRequest struct { // AccountResult holds the account data or an error. type AccountResult struct { - Account *Account + Account *types.Account Err error } type AccountRequestBuffer struct { - store Store + store store.Store getAccountRequests map[string][]*AccountRequest mu sync.Mutex getAccountRequestCh chan *AccountRequest bufferInterval time.Duration } -func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer { +func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer { bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL") bufferInterval, err := time.ParseDuration(bufferIntervalStr) if err != nil { @@ -52,7 +55,7 @@ func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBu return &ac } -func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) { +func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) { req := &AccountRequest{ AccountID: accountID, ResultChan: make(chan *AccountResult, 1), diff --git a/management/server/account_test.go b/management/server/account_test.go index d952e118acf..d83eab6d120 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -16,6 +16,11 @@ import ( "time" "github.com/golang-jwt/jwt" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,11 +29,12 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -46,7 +52,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} @@ -73,7 +79,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string) func (MocIntegratedValidator) Stop(_ context.Context) { } -func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { +func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { t.Helper() peer := &nbpeer.Peer{ Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=", @@ -101,7 +107,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac } } -func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy string, domain string, expectedUsers []string) { +func verifyNewAccountHasDefaultFields(t *testing.T, account *types.Account, createdBy string, domain string, expectedUsers []string) { t.Helper() if len(account.Peers) != 0 { t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers)) @@ -156,7 +162,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // peerID3 := "peer-3" tt := []struct { name string - accountSettings Settings + accountSettings types.Settings peerID string expectedPeers []string expectedOfflinePeers []string @@ -164,7 +170,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { }{ { name: "Should return ALL peers when global peer login expiration disabled", - accountSettings: Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{peerID2}, expectedOfflinePeers: []string{}, @@ -202,7 +208,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { }, { name: "Should return no peers when global peer login expiration enabled and peers expired", - accountSettings: Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{}, expectedOfflinePeers: []string{peerID2}, @@ -396,12 +402,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { netIP := net.IP{100, 64, 0, 0} netMask := net.IPMask{255, 255, 0, 0} - network := &Network{ + network := &types.Network{ Identifier: "network", Net: net.IPNet{IP: netIP, Mask: netMask}, Dns: "netbird.selfhosted", Serial: 0, - mu: sync.Mutex{}, + Mu: sync.Mutex{}, } for _, testCase := range tt { @@ -420,7 +426,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -485,12 +491,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { } initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory + initUnknown.DomainCategory = types.UnknownCategory initUnknown.Domain = unknownDomain privateInitAccount := defaultInitAccount privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory + privateInitAccount.DomainCategory = types.PrivateCategory testCases := []struct { name string @@ -500,7 +506,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputUpdateClaimAccount bool testingFunc require.ComparisonAssertionFunc expectedMSG string - expectedUserRole UserRole + expectedUserRole types.UserRole expectedDomainCategory string expectedDomain string expectedPrimaryDomainStatus bool @@ -512,12 +518,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: publicDomain, UserId: "pub-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomainCategory: "", expectedDomain: publicDomain, expectedPrimaryDomainStatus: false, @@ -529,12 +535,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: unknownDomain, UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + DomainCategory: types.UnknownCategory, }, inputInitUserParams: initUnknown, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: unknownDomain, expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -546,14 +552,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: "pvt-domain-user", expectedUsers: []string{"pvt-domain-user"}, @@ -563,15 +569,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateAttrs: true, inputInitUserParams: privateInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, + expectedUserRole: types.UserRoleUser, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, @@ -581,14 +587,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -598,15 +604,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateClaimAccount: true, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -616,12 +622,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: "", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: "", expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -733,7 +739,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*group.Group{} + groupsByNames := map[string]*types.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -741,32 +747,36 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match") }) } func TestAccountManager_GetAccountFromPAT(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -786,15 +796,20 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, @@ -802,7 +817,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -904,7 +919,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { return } - exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) assert.NoError(t, err) assert.True(t, exists, "expected to get existing account after creation using userid") @@ -914,7 +929,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { } } -func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { +func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) { account := newAccountWithId(context.Background(), accountID, userID, domain) err := am.Store.SaveAccount(context.Background(), account) if err != nil { @@ -990,13 +1005,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { claims := jwtclaims.AuthorizationClaims{ Domain: "example.com", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, } publicClaims := jwtclaims.AuthorizationClaims{ Domain: "test.com", UserId: "public-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, } am, err := createManager(b) @@ -1074,13 +1089,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } -func genUsers(p string, n int) map[string]*User { - users := map[string]*User{} +func genUsers(p string, n int) map[string]*types.User { + users := map[string]*types.User{} now := time.Now() for i := 0; i < n; i++ { - users[fmt.Sprintf("%s-%d", p, i)] = &User{ + users[fmt.Sprintf("%s-%d", p, i)] = &types.User{ Id: fmt.Sprintf("%s-%d", p, i), - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, LastLogin: now, CreatedAt: now, Issued: "api", @@ -1105,7 +1120,7 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1232,7 +1247,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{}, @@ -1242,15 +1257,15 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1309,7 +1324,7 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -1334,15 +1349,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1357,7 +1372,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, @@ -1367,15 +1382,15 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1413,7 +1428,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1426,15 +1441,15 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1482,7 +1497,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1557,7 +1572,7 @@ func TestGetUsersFromAccount(t *testing.T) { t.Fatal(err) } - users := map[string]*User{"1": {Id: "1", Role: UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} + users := map[string]*types.User{"1": {Id: "1", Role: types.UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} accountId := "test_account_id" account, err := createAccount(manager, accountId, users["1"].Id, "") @@ -1589,7 +1604,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1636,11 +1651,11 @@ func TestAccount_GetRoutesToSync(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1681,7 +1696,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { }, } - routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) + routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1691,26 +1706,26 @@ func TestAccount_GetRoutesToSync(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-3")) - emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) + emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) assert.Len(t, emptyRoutes, 0) } func TestAccount_Copy(t *testing.T) { - account := &Account{ + account := &types.Account{ Id: "account1", CreatedBy: "tester", CreatedAt: time.Now().UTC(), Domain: "test.com", DomainCategory: "public", IsDomainPrimaryAccount: true, - SetupKeys: map[string]*SetupKey{ + SetupKeys: map[string]*types.SetupKey{ "setup1": { Id: "setup1", AutoGroups: []string{"group1"}, }, }, - Network: &Network{ + Network: &types.Network{ Identifier: "net1", }, Peers: map[string]*nbpeer.Peer{ @@ -1723,12 +1738,12 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Users: map[string]*User{ + Users: map[string]*types.User{ "user1": { Id: "user1", - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, AutoGroups: []string{"group1"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", @@ -1741,17 +1756,18 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "group1": { - ID: "group1", - Peers: []string{"peer1"}, + ID: "group1", + Peers: []string{"peer1"}, + Resources: []types.Resource{}, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "policy1", Enabled: true, - Rules: make([]*PolicyRule, 0), + Rules: make([]*types.PolicyRule, 0), SourcePostureChecks: make([]string, 0), }, }, @@ -1771,13 +1787,36 @@ func TestAccount_Copy(t *testing.T) { NameServers: []nbdns.NameServer{}, }, }, - DNSSettings: DNSSettings{DisabledManagementGroups: []string{}}, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}}, PostureChecks: []*posture.Checks{ { ID: "posture Checks1", }, }, - Settings: &Settings{}, + Settings: &types.Settings{}, + Networks: []*networkTypes.Network{ + { + ID: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1", + NetworkID: "network1", + PeerGroups: []string{"group1"}, + Masquerade: false, + Metric: 0, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1", + NetworkID: "network1", + Name: "resource", + Type: "Subnet", + Address: "172.12.6.1/24", + }, + }, } err := hasNilField(account) if err != nil { @@ -1830,7 +1869,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.NotNil(t, settings) @@ -1863,7 +1902,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1911,7 +1950,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1980,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1993,7 +2032,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2011,7 +2050,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2019,19 +2058,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.False(t, settings.PeerLoginExpirationEnabled) assert.Equal(t, settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) @@ -2104,9 +2143,9 @@ func TestAccount_GetExpiredPeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -2188,9 +2227,9 @@ func TestAccount_GetInactivePeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -2255,7 +2294,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2324,7 +2363,7 @@ func TestAccount_GetPeersWithInactivity(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2488,9 +2527,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextPeerExpiration() @@ -2648,9 +2687,9 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextInactivePeerExpiration() @@ -2669,7 +2708,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { require.NoError(t, err, "unable to create account manager") // create a new account - account := &Account{ + account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, @@ -2678,11 +2717,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, - Users: map[string]*User{ + Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, + Users: map[string]*types.User{ "user1": {Id: "user1", AccountID: "accountID"}, "user2": {Id: "user2", AccountID: "accountID"}, }, @@ -2698,7 +2737,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) @@ -2711,18 +2750,18 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} - assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) + assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2731,13 +2770,13 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { @@ -2748,7 +2787,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2761,7 +2800,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2774,11 +2813,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "new group should be added") }) @@ -2791,7 +2830,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") @@ -2799,7 +2838,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { } func TestAccount_UserGroupsAddToPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2807,12 +2846,12 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("add groups", func(t *testing.T) { @@ -2835,7 +2874,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { } func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2843,12 +2882,12 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("remove groups", func(t *testing.T) { @@ -2891,10 +2930,10 @@ func createManager(t TB) (*DefaultAccountManager, error) { return manager, nil } -func createStore(t TB) (Store, error) { +func createStore(t TB) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -2917,7 +2956,7 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() manager, err := createManager(t) @@ -2930,12 +2969,12 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpee t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") } - getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + getPeer := func(manager *DefaultAccountManager, setupKey *types.SetupKey) *nbpeer.Peer { key, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) @@ -2999,11 +3038,11 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 1, 3, 3, 10}, - {"Medium", 500, 100, 7, 13, 10, 60}, - {"Large", 5000, 200, 65, 80, 60, 170}, - {"Small single", 50, 10, 1, 3, 3, 60}, + {"Medium", 500, 100, 7, 13, 10, 70}, + {"Large", 5000, 200, 65, 80, 60, 200}, + {"Small single", 50, 10, 1, 3, 3, 70}, {"Medium single", 500, 10, 7, 13, 10, 26}, - {"Large 5", 5000, 15, 65, 80, 60, 170}, + {"Large 5", 5000, 15, 65, 80, 60, 200}, } log.SetOutput(io.Discard) @@ -3047,7 +3086,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) @@ -3067,7 +3106,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { }{ {"Small", 50, 5, 102, 110, 102, 120}, {"Medium", 500, 100, 105, 140, 105, 170}, - {"Large", 5000, 200, 160, 200, 160, 270}, + {"Large", 5000, 200, 160, 200, 160, 300}, {"Small single", 50, 10, 102, 110, 102, 120}, {"Medium single", 500, 10, 105, 140, 105, 170}, {"Large 5", 5000, 15, 160, 200, 160, 270}, @@ -3121,7 +3160,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) @@ -3139,10 +3178,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 107, 120, 107, 140}, - {"Medium", 500, 100, 105, 140, 105, 170}, - {"Large", 5000, 200, 180, 220, 180, 340}, - {"Small single", 50, 10, 107, 120, 105, 140}, + {"Small", 50, 5, 107, 120, 107, 160}, + {"Medium", 500, 100, 105, 140, 105, 190}, + {"Large", 5000, 200, 180, 220, 180, 350}, + {"Small single", 50, 10, 107, 120, 105, 160}, {"Medium single", 500, 10, 105, 140, 105, 170}, {"Large 5", 5000, 15, 180, 220, 180, 340}, } @@ -3195,7 +3234,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4c57d65fb55..5379a8dd81b 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -151,6 +151,24 @@ const ( UserGroupPropagationEnabled Activity = 69 UserGroupPropagationDisabled Activity = 70 + + AccountRoutingPeerDNSResolutionEnabled Activity = 71 + AccountRoutingPeerDNSResolutionDisabled Activity = 72 + + NetworkCreated Activity = 73 + NetworkUpdated Activity = 74 + NetworkDeleted Activity = 75 + + NetworkResourceCreated Activity = 76 + NetworkResourceUpdated Activity = 77 + NetworkResourceDeleted Activity = 78 + + NetworkRouterCreated Activity = 79 + NetworkRouterUpdated Activity = 80 + NetworkRouterDeleted Activity = 81 + + ResourceAddedToGroup Activity = 82 + ResourceRemovedFromGroup Activity = 83 ) var activityMap = map[Activity]Code{ @@ -228,6 +246,24 @@ var activityMap = map[Activity]Code{ UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"}, UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"}, + + AccountRoutingPeerDNSResolutionEnabled: {"Account routing peer DNS resolution enabled", "account.setting.routing.peer.dns.resolution.enable"}, + AccountRoutingPeerDNSResolutionDisabled: {"Account routing peer DNS resolution disabled", "account.setting.routing.peer.dns.resolution.disable"}, + + NetworkCreated: {"Network created", "network.create"}, + NetworkUpdated: {"Network updated", "network.update"}, + NetworkDeleted: {"Network deleted", "network.delete"}, + + NetworkResourceCreated: {"Network resource created", "network.resource.create"}, + NetworkResourceUpdated: {"Network resource updated", "network.resource.update"}, + NetworkResourceDeleted: {"Network resource deleted", "network.resource.delete"}, + + NetworkRouterCreated: {"Network router created", "network.router.create"}, + NetworkRouterUpdated: {"Network router updated", "network.router.update"}, + NetworkRouterDeleted: {"Network router deleted", "network.router.delete"}, + + ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, + ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/config.go b/management/server/config.go index 2f7e497667a..f3555b92b63 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -5,6 +5,7 @@ import ( "net/url" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -156,7 +157,7 @@ type ProviderConfig struct { // StoreConfig contains Store configuration type StoreConfig struct { - Engine StoreEngine + Engine store.Engine } // ReverseProxy contains reverse proxy configuration in front of management. diff --git a/management/server/dns.go b/management/server/dns.go index 8df211b0b0b..39dc11eb247 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "slices" - "strconv" "sync" log "github.com/sirupsen/logrus" @@ -12,12 +10,12 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) -const defaultTTL = 300 - // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { CustomZones sync.Map @@ -62,26 +60,9 @@ func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerG c.NameServerGroups.Store(key, value) } -type lookupMap map[string]struct{} - -// DNSSettings defines dns settings at the account level -type DNSSettings struct { - // DisabledManagementGroups groups whose DNS management is disabled - DisabledManagementGroups []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the DNS settings -func (d DNSSettings) Copy() DNSSettings { - settings := DNSSettings{ - DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), - } - copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) - return settings -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID -func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -94,16 +75,16 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings -func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { +func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -119,18 +100,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { return err } - oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + oldSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return err } - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) if err != nil { @@ -140,11 +121,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) }) if err != nil { return err @@ -155,18 +136,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // prepareDNSSettingsEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) return nil @@ -203,8 +184,8 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t } // areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. -func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups) if err != nil { return false, err } @@ -213,16 +194,16 @@ func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, acc return true, nil } - return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) + return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups) } // validateDNSSettings validates the DNS settings. -func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { +func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error { if len(settings.DisabledManagementGroups) == 0 { return nil } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) if err != nil { return err } @@ -298,81 +279,3 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } return protoGroup } - -func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { - groupList := account.getPeerGroups(peerID) - - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -// peerIsNameserver returns true if the peer is a nameserver for a nsGroup -func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { - for _, ns := range nsGroup.NameServers { - if peer.IP.Equal(ns.IP.AsSlice()) { - return true - } - } - return false -} - -func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) { - for _, peer := range account.Peers { - label, err := getPeerHostLabel(peer.Name, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) - label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) - continue - } - } - peer.DNSLabel = label - peerLabels[label] = struct{}{} - } -} - -func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) { - label, err := nbdns.GetParsedDomainLabel(name) - if err != nil { - return "", err - } - - uniqueLabel := getUniqueHostLabel(label, peerLabels) - if uniqueLabel == "" { - return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) - } - return uniqueLabel, nil -} - -// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 -func getUniqueHostLabel(name string, peerLabels lookupMap) string { - _, found := peerLabels[name] - if !found { - return name - } - for i := 1; i < 1000; i++ { - nameWithSuffix := name + "-" + strconv.Itoa(i) - _, found = peerLabels[nameWithSuffix] - if !found { - return nameWithSuffix - } - } - return "" -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c0f..6fb9f6a29a9 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -53,7 +54,7 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = DNSSettings{ + account.DNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{group1ID}, } @@ -86,20 +87,20 @@ func TestSaveDNSSettings(t *testing.T) { testCases := []struct { name string userID string - inputSettings *DNSSettings + inputSettings *types.DNSSettings shouldFail bool }{ { name: "Saving As Admin Should Be OK", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, }, { name: "Should Not Update Settings As Regular User", userID: dnsRegularUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, shouldFail: true, @@ -113,7 +114,7 @@ func TestSaveDNSSettings(t *testing.T) { { name: "Should Not Update Settings If Group Is Invalid", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{"non-existing-group"}, }, shouldFail: true, @@ -210,10 +211,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createDNSStore(t *testing.T) (Store, error) { +func createDNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -222,7 +223,7 @@ func createDNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: dnsPeer1Key, @@ -259,9 +260,9 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) - account.Users[dnsRegularUserID] = &User{ + account.Users[dnsRegularUserID] = &types.User{ Id: dnsRegularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } err := am.Store.SaveAccount(context.Background(), account) @@ -293,13 +294,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &group.Group{ + newGroup1 := &types.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &group.Group{ + newGroup2 := &types.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } @@ -483,7 +484,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -510,7 +511,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -550,7 +551,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { // Creating DNS settings with groups that have peers should update account peers and send peer update t.Run("creating dns setting with used groups", func(t *testing.T) { - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -589,7 +590,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA", "groupB"}, }) assert.NoError(t, err) @@ -609,7 +610,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -629,7 +630,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{}, }) assert.NoError(t, err) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 590b1d708bc..3c629a0dbda 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -9,6 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -21,7 +23,7 @@ var ( type ephemeralPeer struct { id string - account *Account + account *types.Account deadline time.Time next *ephemeralPeer } @@ -32,7 +34,7 @@ type ephemeralPeer struct { // EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store Store + store store.Store accountManager AccountManager headPeer *ephemeralPeer @@ -42,7 +44,7 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager { +func NewEphemeralManager(store store.Store, accountManager AccountManager) *EphemeralManager { return &EphemeralManager{ store: store, accountManager: accountManager, @@ -177,7 +179,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } } -func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { +func (e *EphemeralManager) addPeer(id string, account *types.Account, deadline time.Time) { ep := &ephemeralPeer{ id: id, account: account, diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 1390352a5d0..ac83724409d 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -8,18 +8,20 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) type MockStore struct { - Store - account *Account + store.Store + account *types.Account } -func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { - return []*Account{s.account} +func (s *MockStore) GetAllAccounts(_ context.Context) []*types.Account { + return []*types.Account{s.account} } -func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { +func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*types.Account, error) { _, ok := s.account.Peers[peerId] if ok { return s.account, nil diff --git a/management/server/group.go b/management/server/group.go index 7b307cf1a8b..d433a348551 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -10,10 +10,12 @@ import ( log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -28,7 +30,7 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -45,38 +47,38 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco } // GetGroup returns a specific group by groupID in an account -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { + return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) + return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}) } // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -90,10 +92,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } var eventsToStore []func() - var groupsToSave []*nbgroup.Group + var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { @@ -113,11 +115,11 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) + return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave) }) if err != nil { return err @@ -128,23 +130,23 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { - addedPeers = difference(newGroup.Peers, oldGroup.Peers) - removedPeers = difference(oldGroup.Peers, newGroup.Peers) + addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) + removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { @@ -153,7 +155,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac } modifiedPeers := slices.Concat(addedPeers, removedPeers) - peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) if err != nil { log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) return nil @@ -194,21 +196,6 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac return eventsToStore } -// difference returns the elements in `a` that aren't in `b`. -func difference(a, b []string) []string { - mb := make(map[string]struct{}, len(b)) - for _, x := range b { - mb[x] = struct{}{} - } - var diff []string - for _, x := range a { - if _, found := mb[x]; !found { - diff = append(diff, x) - } - } - return diff -} - // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -223,7 +210,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -238,11 +225,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us var allErrors error var groupIDsToDelete []string - var deletedGroups []*nbgroup.Group + var deletedGroups []*types.Group - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID) if err != nil { allErrors = errors.Join(allErrors, err) continue @@ -257,11 +244,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) + return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) }) if err != nil { return err @@ -279,12 +266,12 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -298,18 +285,59 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupAddResource appends resource to the group +func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -320,12 +348,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -339,31 +367,72 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupDeleteResource removes resource from the group +func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemoveResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // validateNewGroup validates the new group for existence and required fields. -func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { +func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *types.Group) error { + if newGroup.ID == "" && newGroup.Issued != types.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) if err != nil { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { return err @@ -380,7 +449,7 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, } for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } @@ -389,14 +458,14 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, return nil } -func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { +func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user - if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if group.Issued == types.GroupIssuedIntegration { + executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } - if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { + if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } @@ -429,8 +498,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. } // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. -func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) +func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -439,7 +508,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -452,8 +521,8 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { - routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil @@ -469,8 +538,8 @@ func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID stri } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -487,8 +556,8 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -506,8 +575,8 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { - setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -522,8 +591,8 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { - users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -538,12 +607,12 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { +func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } @@ -566,7 +635,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -575,15 +644,15 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s return false } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) +// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. +func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) if err != nil { return false, err } for _, group := range groups { - if group.HasPeers() { + if group.HasPeers() || group.HasResources() { return true, nil } } diff --git a/management/server/group_test.go b/management/server/group_test.go index ec017fc577a..834388d1ef3 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -32,22 +32,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedIntegration + group.Issued = types.GroupIssuedIntegration err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedJWT + group.Issued = types.GroupIssuedJWT err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) + t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI group.ID = "" err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { @@ -145,13 +145,13 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { manager, account, err := initTestGroupAccount(am) assert.NoError(t, err, "Failed to init testing account") - groups := make([]*nbgroup.Group, 10) + groups := make([]*types.Group, 10) for i := 0; i < 10; i++ { - groups[i] = &nbgroup.Group{ + groups[i] = &types.Group{ ID: fmt.Sprintf("group-%d", i+1), AccountID: account.Id, Name: fmt.Sprintf("group-%d", i+1), - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } } @@ -267,63 +267,63 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } -func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) { +func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &nbgroup.Group{ + groupForRoute := &types.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForRoute2 := &nbgroup.Group{ + groupForRoute2 := &types.Group{ ID: "grp-for-route2", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &nbgroup.Group{ + groupForNameServerGroups := &types.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &nbgroup.Group{ + groupForPolicies := &types.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &nbgroup.Group{ + groupForSetupKeys := &types.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &nbgroup.Group{ + groupForUsers := &types.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &nbgroup.Group{ + groupForIntegration := &types.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: nbgroup.GroupIssuedIntegration, + Issued: types.GroupIssuedIntegration, Peers: make([]string, 0), } @@ -342,9 +342,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A Groups: []string{groupForNameServerGroups.ID}, } - policy := &Policy{ + policy := &types.Policy{ ID: "example policy", - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "example policy rule", Destinations: []string{groupForPolicies.ID}, @@ -352,12 +352,12 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A }, } - setupKey := &SetupKey{ + setupKey := &types.SetupKey{ Id: "example setup key", AutoGroups: []string{groupForSetupKeys.ID}, } - user := &User{ + user := &types.User{ Id: "example user", AutoGroups: []string{groupForUsers.ID}, } @@ -392,7 +392,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A func TestGroupAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -429,7 +429,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, @@ -500,15 +500,15 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -522,7 +522,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -591,7 +591,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, @@ -632,7 +632,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -648,7 +648,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { // Saving a group linked to dns settings should update account peers and send peer update t.Run("saving group linked to dns settings", func(t *testing.T) { - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupD"}, }) assert.NoError(t, err) @@ -659,7 +659,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go new file mode 100644 index 00000000000..f5abb212e13 --- /dev/null +++ b/management/server/groups/manager.go @@ -0,0 +1,196 @@ +package groups + +import ( + "context" + "fmt" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) + AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error + AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) + RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + accountManager s.AccountManager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read) + if err != nil { + return nil, err + } + if !ok { + return nil, err + } + + groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("error getting account groups: %w", err) + } + + groupsMap := make(map[string]*types.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + + return groupsMap, nil +} + +func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Write) + if err != nil { + return err + } + if !ok { + return err + } + + event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource) + if err != nil { + return fmt.Errorf("error adding resource to group: %w", err) + } + + event() + + return nil +} + +func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resource *types.Resource) (func(), error) { + err := transaction.AddResourceToGroup(ctx, accountID, groupID, resource) + if err != nil { + return nil, fmt.Errorf("error adding resource to group: %w", err) + } + + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + if err != nil { + return nil, fmt.Errorf("error getting group: %w", err) + } + + // TODO: at some point, this will need to become a switch statement + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) + if err != nil { + return nil, fmt.Errorf("error getting network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, groupID, accountID, activity.ResourceAddedToGroup, group.EventMetaResource(networkResource)) + } + + return event, nil +} + +func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) { + err := transaction.RemoveResourceFromGroup(ctx, accountID, groupID, resourceID) + if err != nil { + return nil, fmt.Errorf("error removing resource from group: %w", err) + } + + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + if err != nil { + return nil, fmt.Errorf("error getting group: %w", err) + } + + // TODO: at some point, this will need to become a switch statement + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("error getting network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, groupID, accountID, activity.ResourceRemovedFromGroup, group.EventMetaResource(networkResource)) + } + + return event, nil +} + +func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) +} + +func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { + groupsInfo := []api.GroupMinimum{} + groupsChecked := make(map[string]struct{}) + for _, group := range groups { + _, ok := groupsChecked[group.ID] + if ok { + continue + } + groupsChecked[group.ID] = struct{}{} + for _, pk := range group.Peers { + if pk == id { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), + } + groupsInfo = append(groupsInfo, info) + break + } + } + for _, rk := range group.Resources { + if rk.ID == id { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), + } + groupsInfo = append(groupsInfo, info) + break + } + } + } + return groupsInfo +} + +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + return map[string]*types.Group{}, nil +} + +func (m *mockManager) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error { + return nil +} + +func (m *mockManager) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) { + return func() { + // noop + }, nil +} + +func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) { + return func() { + // noop + }, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 9c12336f80f..2635ac11b0a 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -23,14 +23,17 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" internalStatus "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { - accountManager AccountManager - wgKey wgtypes.Key + accountManager AccountManager + settingsManager settings.Manager + wgKey wgtypes.Key proto.UnimplementedManagementServiceServer peersUpdateManager *PeersUpdateManager config *Config @@ -47,6 +50,7 @@ func NewServer( ctx context.Context, config *Config, accountManager AccountManager, + settingsManager settings.Manager, peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, @@ -99,6 +103,7 @@ func NewServer( // peerKey -> event channel peersUpdateManager: peersUpdateManager, accountManager: accountManager, + settingsManager: settingsManager, config: config, secretsManager: secretsManager, jwtValidator: jwtValidator, @@ -483,7 +488,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, relayToken), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false), Checks: toProtocolChecks(ctx, postureChecks), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) @@ -599,20 +604,21 @@ func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Tok } } -func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dnsResolutionOnRoutingPeerEnabled bool) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, - Fqdn: fqdn, + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + RoutingPeerDnsResolutionEnabled: dnsResolutionOnRoutingPeerEnabled, } } -func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnbled bool) *proto.SyncResponse { response := &proto.SyncResponse{ WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials), - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnbled), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -661,7 +667,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { var err error var turnToken *Token @@ -680,7 +686,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p } } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil) + settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, peer.UserID) + if err != nil { + return status.Errorf(codes.Internal, "error handling request") + } + + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2e084f6e407..351976baf02 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -84,6 +84,10 @@ components: items: type: string example: Administrators + routing_peer_dns_resolution_enabled: + description: Enables or disables DNS resolution on the routing peers + type: boolean + example: true extra: $ref: '#/components/schemas/AccountExtraSettings' required: @@ -668,6 +672,10 @@ components: description: Count of peers associated to the group type: integer example: 2 + resources_count: + description: Count of resources associated to the group + type: integer + example: 5 issued: description: How the group was issued (api, integration, jwt) type: string @@ -677,6 +685,7 @@ components: - id - name - peers_count + - resources_count GroupRequest: type: object properties: @@ -690,6 +699,10 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m1" + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - name Group: @@ -702,8 +715,13 @@ components: type: array items: $ref: '#/components/schemas/PeerMinimum' + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - peers + - resources PolicyRuleMinimum: type: object properties: @@ -782,15 +800,18 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv797" + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: type: string example: "ch8i4ug6lnn4g9h7v7m0" - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyRule: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' @@ -801,14 +822,17 @@ components: type: array items: $ref: '#/components/schemas/GroupMinimum' + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: $ref: '#/components/schemas/GroupMinimum' - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyMinimum: type: object properties: @@ -1176,6 +1200,171 @@ components: - id - network_type - $ref: '#/components/schemas/RouteRequest' + Resource: + type: object + properties: + id: + description: ID of the resource + type: string + example: chacdk86lnnboviihd7g + type: + description: Type of the resource + $ref: '#/components/schemas/ResourceType' + required: + - id + - type + ResourceType: + allOf: + - $ref: '#/components/schemas/NetworkResourceType' + - type: string + example: host + NetworkRequest: + type: object + properties: + name: + description: Network name + type: string + example: Remote Network 1 + description: + description: Network description + type: string + example: A remote network that needs to be accessed + required: + - name + Network: + allOf: + - type: object + properties: + id: + description: Network ID + type: string + example: chacdk86lnnboviihd7g + routers: + description: List of router IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + routing_peers_count: + description: Count of routing peers associated with the network + type: integer + example: 2 + resources: + description: List of network resource IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m1 + policies: + description: List of policy IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m2 + required: + - id + - routers + - resources + - routing_peers_count + - policies + - $ref: '#/components/schemas/NetworkRequest' + NetworkResourceMinimum: + type: object + properties: + name: + description: Network resource name + type: string + example: Remote Resource 1 + description: + description: Network resource description + type: string + example: A remote resource inside network 1 + address: + description: Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + type: string + example: "1.1.1.1" + required: + - name + - address + NetworkResourceRequest: + allOf: + - $ref: '#/components/schemas/NetworkResourceMinimum' + - type: object + properties: + groups: + description: Group IDs containing the resource + type: array + items: + type: string + example: "chacdk86lnnboviihd70" + required: + - groups + - address + NetworkResource: + allOf: + - type: object + properties: + id: + description: Network Resource ID + type: string + example: chacdk86lnnboviihd7g + type: + $ref: '#/components/schemas/NetworkResourceType' + groups: + description: Groups that the resource belongs to + type: array + items: + $ref: '#/components/schemas/GroupMinimum' + required: + - id + - type + - groups + - $ref: '#/components/schemas/NetworkResourceMinimum' + NetworkResourceType: + description: Network resource type based of the address + type: string + enum: [ "host", "subnet", "domain" ] + example: host + NetworkRouterRequest: + type: object + properties: + peer: + description: Peer Identifier associated with route. This property can not be set together with `peer_groups` + type: string + example: chacbco6lnnbn6cg5s91 + peer_groups: + description: Peers Group Identifier associated with route. This property can not be set together with `peer` + type: array + items: + type: string + example: chacbco6lnnbn6cg5s91 + metric: + description: Route metric number. Lowest number has higher priority + type: integer + maximum: 9999 + minimum: 1 + example: 9999 + masquerade: + description: Indicate if peer should masquerade traffic to this route's prefix + type: boolean + example: true + required: + # Only one property has to be set + #- peer + #- peer_groups + - metric + - masquerade + NetworkRouter: + allOf: + - type: object + properties: + id: + description: Network Router Id + type: string + example: chacdk86lnnboviihd7g + required: + - id + - $ref: '#/components/schemas/NetworkRouterRequest' Nameserver: type: object properties: @@ -2460,6 +2649,502 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/networks: + get: + summary: List all Networks + description: Returns a list of all networks + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Networks + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network + description: Creates a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New Network request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network Object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}: + get: + summary: Retrieve a Network + description: Get information about a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network + description: Update/Replace a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: Update Network request + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network + description: Delete a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources: + get: + summary: List all Network Resources + description: Returns a list of all resources in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Resources + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Resource + description: Creates a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources/{resourceId}: + get: + summary: Retrieve a Network Resource + description: Get information about a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Resource + description: Update a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a resource + requestBody: + description: Update Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Resource + description: Delete a network resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers: + get: + summary: List all Network Routers + description: Returns a list of all routers in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Routers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Router + description: Creates a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers/{routerId}: + get: + summary: Retrieve a Network Router + description: Get information about a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Router + description: Update a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + requestBody: + description: Update Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Router + description: Delete a network router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/dns/nameservers: get: summary: List all Nameserver Groups diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 321395d2557..40574d6f163 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -88,6 +88,13 @@ const ( NameserverNsTypeUdp NameserverNsType = "udp" ) +// Defines values for NetworkResourceType. +const ( + NetworkResourceTypeDomain NetworkResourceType = "domain" + NetworkResourceTypeHost NetworkResourceType = "host" + NetworkResourceTypeSubnet NetworkResourceType = "subnet" +) + // Defines values for PeerNetworkRangeCheckAction. const ( PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow" @@ -136,6 +143,13 @@ const ( PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) +// Defines values for ResourceType. +const ( + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + // Defines values for UserStatus. const ( UserStatusActive UserStatus = "active" @@ -234,6 +248,9 @@ type AccountSettings struct { // RegularUsersViewBlocked Allows blocking regular users from viewing parts of the system. RegularUsersViewBlocked bool `json:"regular_users_view_blocked"` + + // RoutingPeerDnsResolutionEnabled Enables or disables DNS resolution on the routing peers + RoutingPeerDnsResolutionEnabled *bool `json:"routing_peer_dns_resolution_enabled,omitempty"` } // Checks List of objects that perform the actual checks @@ -365,7 +382,11 @@ type Group struct { Peers []PeerMinimum `json:"peers"` // PeersCount Count of peers associated to the group - PeersCount int `json:"peers_count"` + PeersCount int `json:"peers_count"` + Resources []Resource `json:"resources"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupIssued How the group was issued (api, integration, jwt) @@ -384,6 +405,9 @@ type GroupMinimum struct { // PeersCount Count of peers associated to the group PeersCount int `json:"peers_count"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupMinimumIssued How the group was issued (api, integration, jwt) @@ -395,7 +419,8 @@ type GroupRequest struct { Name string `json:"name"` // Peers List of peers ids - Peers *[]string `json:"peers,omitempty"` + Peers *[]string `json:"peers,omitempty"` + Resources *[]Resource `json:"resources,omitempty"` } // Location Describe geographical location information @@ -494,6 +519,123 @@ type NameserverGroupRequest struct { SearchDomainsEnabled bool `json:"search_domains_enabled"` } +// Network defines model for Network. +type Network struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Id Network ID + Id string `json:"id"` + + // Name Network name + Name string `json:"name"` + + // Policies List of policy IDs associated with the network + Policies []string `json:"policies"` + + // Resources List of network resource IDs associated with the network + Resources []string `json:"resources"` + + // Routers List of router IDs associated with the network + Routers []string `json:"routers"` + + // RoutingPeersCount Count of routing peers associated with the network + RoutingPeersCount int `json:"routing_peers_count"` +} + +// NetworkRequest defines model for NetworkRequest. +type NetworkRequest struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Name Network name + Name string `json:"name"` +} + +// NetworkResource defines model for NetworkResource. +type NetworkResource struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Groups Groups that the resource belongs to + Groups []GroupMinimum `json:"groups"` + + // Id Network Resource ID + Id string `json:"id"` + + // Name Network resource name + Name string `json:"name"` + + // Type Network resource type based of the address + Type NetworkResourceType `json:"type"` +} + +// NetworkResourceMinimum defines model for NetworkResourceMinimum. +type NetworkResourceMinimum struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Name Network resource name + Name string `json:"name"` +} + +// NetworkResourceRequest defines model for NetworkResourceRequest. +type NetworkResourceRequest struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Groups Group IDs containing the resource + Groups []string `json:"groups"` + + // Name Network resource name + Name string `json:"name"` +} + +// NetworkResourceType Network resource type based of the address +type NetworkResourceType string + +// NetworkRouter defines model for NetworkRouter. +type NetworkRouter struct { + // Id Network Router Id + Id string `json:"id"` + + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + +// NetworkRouterRequest defines model for NetworkRouterRequest. +type NetworkRouterRequest struct { + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + // OSVersionCheck Posture check for the version of operating system type OSVersionCheck struct { // Android Posture check for the version of operating system @@ -779,10 +921,11 @@ type PolicyRule struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []GroupMinimum `json:"destinations"` + Destinations *[]GroupMinimum `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -800,10 +943,11 @@ type PolicyRule struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleProtocol `json:"protocol"` + Protocol PolicyRuleProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []GroupMinimum `json:"sources"` + Sources *[]GroupMinimum `json:"sources,omitempty"` } // PolicyRuleAction Policy rule accept or drops packets @@ -857,10 +1001,11 @@ type PolicyRuleUpdate struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []string `json:"destinations"` + Destinations *[]string `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -878,10 +1023,11 @@ type PolicyRuleUpdate struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleUpdateProtocol `json:"protocol"` + Protocol PolicyRuleUpdateProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []string `json:"sources"` + Sources *[]string `json:"sources,omitempty"` } // PolicyRuleUpdateAction Policy rule accept or drops packets @@ -955,6 +1101,16 @@ type ProcessCheck struct { Processes []Process `json:"processes"` } +// Resource defines model for Resource. +type Resource struct { + // Id ID of the resource + Id string `json:"id"` + Type ResourceType `json:"type"` +} + +// ResourceType defines model for ResourceType. +type ResourceType string + // Route defines model for Route. type Route struct { // AccessControlGroups Access control group identifier associated with route. @@ -1292,6 +1448,24 @@ type PostApiGroupsJSONRequestBody = GroupRequest // PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType. type PutApiGroupsGroupIdJSONRequestBody = GroupRequest +// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType. +type PostApiNetworksJSONRequestBody = NetworkRequest + +// PutApiNetworksNetworkIdJSONRequestBody defines body for PutApiNetworksNetworkId for application/json ContentType. +type PutApiNetworksNetworkIdJSONRequestBody = NetworkRequest + +// PostApiNetworksNetworkIdResourcesJSONRequestBody defines body for PostApiNetworksNetworkIdResources for application/json ContentType. +type PostApiNetworksNetworkIdResourcesJSONRequestBody = NetworkResourceRequest + +// PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody defines body for PutApiNetworksNetworkIdResourcesResourceId for application/json ContentType. +type PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody = NetworkResourceRequest + +// PostApiNetworksNetworkIdRoutersJSONRequestBody defines body for PostApiNetworksNetworkIdRouters for application/json ContentType. +type PostApiNetworksNetworkIdRoutersJSONRequestBody = NetworkRouterRequest + +// PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody defines body for PutApiNetworksNetworkIdRoutersRouterId for application/json ContentType. +type PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody = NetworkRouterRequest + // PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType. type PutApiPeersPeerIdJSONRequestBody = PeerRequest diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 373aa4dd7c0..7db7ab5b842 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,11 +12,13 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroups "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/handlers/accounts" "github.com/netbirdio/netbird/management/server/http/handlers/dns" "github.com/netbirdio/netbird/management/server/http/handlers/events" "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/networks" "github.com/netbirdio/netbird/management/server/http/handlers/peers" "github.com/netbirdio/netbird/management/server/http/handlers/policies" "github.com/netbirdio/netbird/management/server/http/handlers/routes" @@ -25,6 +27,9 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" + nbnetworks "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -38,7 +43,7 @@ type apiHandler struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -93,6 +98,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa routes.AddEndpoints(api.AccountManager, authCfg, router) dns.AddEndpoints(api.AccountManager, authCfg, router) events.AddEndpoints(api.AccountManager, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index c952077777e..a23628cdcc4 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that handles the server.Account HTTP endpoints @@ -82,7 +83,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - settings := &server.Settings{ + settings := &types.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, @@ -107,6 +108,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.JwtAllowGroups != nil { settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } + if req.Settings.RoutingPeerDnsResolutionEnabled != nil { + settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled + } updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { @@ -138,7 +142,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *server.Settings) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -154,6 +158,7 @@ func toAccountResponse(accountID string, settings *server.Settings) *api.Account JwtGroupsClaimName: &settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, RegularUsersViewBlocked: settings.RegularUsersViewBlocked, + RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, } if settings.Extra != nil { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 9d7e8a84ddc..e8a599863ce 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -13,23 +13,23 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -func initAccountsTestData(account *server.Account, admin *server.User) *handler { +func initAccountsTestData(account *types.Account, admin *types.User) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return account.Id, admin.Id, nil }, - GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil }, - UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -58,19 +58,19 @@ func initAccountsTestData(account *server.Account, admin *server.User) *handler func TestAccounts_AccountsHandler(t *testing.T) { accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") sr := func(v string) *string { return &v } br := func(v bool) *bool { return &v } - handler := initAccountsTestData(&server.Account{ + handler := initAccountsTestData(&types.Account{ Id: accountID, Domain: "hotmail.com", - Network: server.NewNetwork(), - Users: map[string]*server.User{ + Network: types.NewNetwork(), + Users: map[string]*types.User{ adminUser.Id: adminUser, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour, RegularUsersViewBlocked: true, @@ -95,13 +95,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestPath: "/api/accounts", expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: int(time.Hour.Seconds()), - PeerLoginExpirationEnabled: false, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr(""), - JwtGroupsEnabled: br(false), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: int(time.Hour.Seconds()), + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -114,13 +115,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 15552000, - PeerLoginExpirationEnabled: true, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr(""), - JwtGroupsEnabled: br(false), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: false, + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -133,13 +135,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 15552000, - PeerLoginExpirationEnabled: false, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr("roles"), - JwtGroupsEnabled: br(true), - JwtAllowGroups: &[]string{"test"}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -152,13 +155,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 554400, - PeerLoginExpirationEnabled: true, - GroupsPropagationEnabled: br(true), - JwtGroupsClaimName: sr("groups"), - JwtGroupsEnabled: br(true), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: 554400, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(true), + JwtGroupsClaimName: sr("groups"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 7dd8c1fc1aa..112eee1797b 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account @@ -81,7 +82,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re return } - updateDNSSettings := &server.DNSSettings{ + updateDNSSettings := &types.DNSSettings{ DisabledManagementGroups: req.DisabledManagementGroups, } diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index a64e3fd8356..9ca1dc03253 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -13,10 +13,10 @@ import ( "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -27,15 +27,15 @@ const ( testDNSSettingsUserID = "test_user" ) -var baseExistingDNSSettings = server.DNSSettings{ +var baseExistingDNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{testDNSSettingsExistingGroup}, } -var testingDNSSettingsAccount = &server.Account{ +var testingDNSSettingsAccount = &types.Account{ Id: testDNSSettingsAccountID, Domain: "hotmail.com", - Users: map[string]*server.User{ - testDNSSettingsUserID: server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + testDNSSettingsUserID: types.NewAdminUser("test_user"), }, DNSSettings: baseExistingDNSSettings, } @@ -43,10 +43,10 @@ var testingDNSSettingsAccount = &server.Account{ func initDNSSettingsTestData() *dnsSettingsHandler { return &dnsSettingsHandler{ accountManager: &mock_server.MockAccountManager{ - GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { + GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil }, - SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave != nil { return nil } diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index 6af2e534646..17478aba351 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" ) func initEventsTestData(account string, events ...*activity.Event) *handler { @@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - return make([]*server.UserInfo, 0), nil + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + return make([]*types.UserInfo, 0), nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -191,7 +191,7 @@ func TestEvents_GetEvents(t *testing.T) { }, } accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) handler := initEventsTestData(accountID, events...) diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index e60529cec94..0ecea7ec2f4 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/http/configs" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -129,10 +129,21 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ ID: groupID, Name: req.Name, Peers: peers, + Resources: resources, Issued: existingGroup.Issued, IntegrationReference: existingGroup.IntegrationReference, } @@ -179,10 +190,21 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ - Name: req.Name, - Peers: peers, - Issued: nbgroup.GroupIssuedAPI, + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ + Name: req.Name, + Peers: peers, + Resources: resources, + Issued: types.GroupIssuedAPI, } err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) @@ -259,13 +281,13 @@ func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { } -func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { peersMap := make(map[string]*nbpeer.Peer, len(peers)) for _, peer := range peers { peersMap[peer.ID] = peer } - cache := make(map[string]api.PeerMinimum) + peerCache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, Name: group.Name, @@ -273,7 +295,7 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { } for _, pid := range group.Peers { - _, ok := cache[pid] + _, ok := peerCache[pid] if !ok { peer, ok := peersMap[pid] if !ok { @@ -283,12 +305,19 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { Id: peer.ID, Name: peer.Name, } - cache[pid] = peerResp + peerCache[pid] = peerResp gr.Peers = append(gr.Peers, peerResp) } } gr.PeersCount = len(gr.Peers) + for _, res := range group.Resources { + resResp := res.ToAPIResponse() + gr.Resources = append(gr.Resources, *resResp) + } + + gr.ResourcesCount = len(gr.Resources) + return &gr } diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 089c1a40f0a..49805ca9b68 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -17,13 +17,13 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -31,20 +31,20 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *handler { +func initGroupTestData(initGroups ...*types.Group) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - groups := map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*types.Group, error) { + groups := map[string]*types.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, } for _, group := range initGroups { @@ -61,9 +61,9 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { - return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } return nil, fmt.Errorf("unknown group name") @@ -120,7 +120,7 @@ func TestGetGroup(t *testing.T) { }, } - group := &nbgroup.Group{ + group := &types.Group{ ID: "idofthegroup", Name: "Group", } @@ -154,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &nbgroup.Group{} + got := &types.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go new file mode 100644 index 00000000000..6b36a8fcecf --- /dev/null +++ b/management/server/http/handlers/networks/handler.go @@ -0,0 +1,321 @@ +package networks + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/gorilla/mux" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/status" + nbtypes "github.com/netbirdio/netbird/management/server/types" +) + +// handler is a handler that returns networks of the account +type handler struct { + networksManager networks.Manager + resourceManager resources.Manager + routerManager routers.Manager + accountManager s.AccountManager + + groupsManager groups.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + addRouterEndpoints(routerManager, extractFromToken, authCfg, router) + addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) + + networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg) + router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") + router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") +} + +func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { + return &handler{ + networksManager: networksManager, + resourceManager: resourceManager, + routerManager: routerManager, + groupsManager: groupsManager, + accountManager: accountManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account)) +} + +func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.AccountID = accountID + network, err = h.networksManager.CreateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs)) +} + +func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) +} + +func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.ID = networkID + network.AccountID = accountID + network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) +} + +func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, networkID string) ([]string, []string, int, error) { + resources, err := h.resourceManager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get resources in network: %w", err) + } + + var resourceIDs []string + for _, resource := range resources { + resourceIDs = append(resourceIDs, resource.ID) + } + + routers, err := h.routerManager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) + } + + groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err) + } + + peerCounter := 0 + var routerIDs []string + for _, router := range routers { + routerIDs = append(routerIDs, router.ID) + if router.Peer != "" { + peerCounter++ + } + if len(router.PeerGroups) > 0 { + for _, groupID := range router.PeerGroups { + peerCounter += len(groups[groupID].Peers) + } + } + } + + return routerIDs, resourceIDs, peerCounter, nil +} + +func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group, account *nbtypes.Account) []*api.Network { + var networkResponse []*api.Network + for _, network := range networks { + routerIDs, peerCounter := getRouterIDs(network, routers, groups) + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter, policyIDs)) + } + return networkResponse +} + +func getRouterIDs(network *types.Network, routers map[string][]*routerTypes.NetworkRouter, groups map[string]*nbtypes.Group) ([]string, int) { + routerIDs := []string{} + peerCounter := 0 + for _, router := range routers[network.ID] { + routerIDs = append(routerIDs, router.ID) + if router.Peer != "" { + peerCounter++ + } + if len(router.PeerGroups) > 0 { + for _, groupID := range router.PeerGroups { + group, ok := groups[groupID] + if !ok { + continue + } + peerCounter += len(group.Peers) + } + } + } + return routerIDs, peerCounter +} diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go new file mode 100644 index 00000000000..a0dc9a10def --- /dev/null +++ b/management/server/http/handlers/networks/resources_handler.go @@ -0,0 +1,222 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/resources/types" +) + +type resourceHandler struct { + resourceManager resources.Manager + groupsManager groups.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg) + router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") +} + +func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler { + return &resourceHandler{ + resourceManager: resourceManager, + groupsManager: groupsManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} +func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} + +func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.ID = mux.Vars(r)["resourceId"] + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go new file mode 100644 index 00000000000..2cf39a1329a --- /dev/null +++ b/management/server/http/handlers/networks/routers_handler.go @@ -0,0 +1,165 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/routers/types" +) + +type routersHandler struct { + routersManager routers.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg) + router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") +} + +func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler { + return &routersHandler{ + routersManager: routersManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var routersResponse []*api.NetworkRouter + for _, router := range routers { + routersResponse = append(routersResponse, router.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, routersResponse) +} + +func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = networkID + router.AccountID = accountID + + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = mux.Vars(r)["networkId"] + router.ID = mux.Vars(r)["routerId"] + router.AccountID = accountID + + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, struct{}{}) +} diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index c53cbc038e9..5bc616e47cb 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,13 +10,14 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // Handler is a handler that returns peers of the account @@ -57,7 +58,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { return peerToReturn, nil } -func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { +func (h *Handler) getPeer(ctx context.Context, account *types.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { util.WriteError(ctx, err, w) @@ -71,7 +72,7 @@ func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, } dnsDomain := h.accountManager.GetDNSDomain() - groupsInfo := toGroupsInfo(account.Groups, peer.ID) + groupsInfo := groups.ToGroupsInfo(account.Groups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { @@ -84,7 +85,7 @@ func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -115,7 +116,7 @@ func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userI } dnsDomain := h.accountManager.GetDNSDomain() - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(account.Groups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { @@ -199,9 +200,9 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - groupsMap := map[string]*nbgroup.Group{} - groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) - for _, group := range groups { + groupsMap := map[string]*types.Group{} + grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + for _, group := range grps { groupsMap[group.ID] = group } @@ -212,7 +213,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -290,12 +291,12 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } -func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { +func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain)) @@ -324,30 +325,6 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - groupsInfo := []api.GroupMinimum{} - groupsChecked := make(map[string]struct{}) - for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } - } - return groupsInfo -} - func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 3e3e39deb60..83abc1c400e 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -15,11 +15,10 @@ import ( "github.com/gorilla/mux" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/assert" @@ -73,18 +72,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { peersMap[peer.ID] = peer.Copy() } - policy := &server.Policy{ + policy := &types.Policy{ ID: "policy", AccountID: accountID, Name: "policy", Enabled: true, - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "rule", Name: "rule", @@ -99,19 +98,19 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, } - srvUser := server.NewRegularUser(serviceUser) + srvUser := types.NewRegularUser(serviceUser) srvUser.IsServiceUser = true - account := &server.Account{ + account := &types.Account{ Id: accountID, Domain: "hotmail.com", Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), + Users: map[string]*types.User{ + adminUser: types.NewAdminUser(adminUser), + regularUser: types.NewRegularUser(regularUser), serviceUser: srvUser, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "group1": { ID: "group1", AccountID: accountID, @@ -120,12 +119,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { Peers: maps.Keys(peersMap), }, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ + Policies: []*types.Policy{policy}, + Network: &types.Network{ Identifier: "ciclqisab2ss43jdn8q0", Net: net.IPNet{ IP: net.ParseIP("100.67.0.0"), diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index 002b914efff..fc5839baaab 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -46,8 +46,8 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { - return server.NewAdminUser(id), nil + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { + return types.NewAdminUser(id), nil }, }, geolocationManager: geo, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index a47f2e620f1..d538d07dbe7 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -9,12 +9,12 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account @@ -133,7 +133,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s return } - policy := &server.Policy{ + policy := &types.Policy{ ID: policyID, AccountID: accountID, Name: req.Name, @@ -146,15 +146,56 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s ruleID = *rule.Id } - pr := server.PolicyRule{ + hasSources := rule.Sources != nil + hasSourceResource := rule.SourceResource != nil + + hasDestinations := rule.Destinations != nil + hasDestinationResource := rule.DestinationResource != nil + + if hasSources && hasSourceResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources, not both"), w) + return + } + + if hasDestinations && hasDestinationResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either destinations or destination resources, not both"), w) + return + } + + if !(hasSources || hasSourceResource) || !(hasDestinations || hasDestinationResource) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources and destinations or destination resources"), w) + return + } + + pr := types.PolicyRule{ ID: ruleID, PolicyID: policyID, Name: rule.Name, - Destinations: rule.Destinations, - Sources: rule.Sources, Bidirectional: rule.Bidirectional, } + if hasSources { + pr.Sources = *rule.Sources + } + + if hasSourceResource { + // TODO: validate the resource id and type + sourceResource := &types.Resource{} + sourceResource.FromAPIRequest(rule.SourceResource) + pr.SourceResource = *sourceResource + } + + if hasDestinations { + pr.Destinations = *rule.Destinations + } + + if hasDestinationResource { + // TODO: validate the resource id and type + destinationResource := &types.Resource{} + destinationResource.FromAPIRequest(rule.DestinationResource) + pr.DestinationResource = *destinationResource + } + pr.Enabled = rule.Enabled if rule.Description != nil { pr.Description = *rule.Description @@ -162,9 +203,9 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s switch rule.Action { case api.PolicyRuleUpdateActionAccept: - pr.Action = server.PolicyTrafficActionAccept + pr.Action = types.PolicyTrafficActionAccept case api.PolicyRuleUpdateActionDrop: - pr.Action = server.PolicyTrafficActionDrop + pr.Action = types.PolicyTrafficActionDrop default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w) return @@ -172,13 +213,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s switch rule.Protocol { case api.PolicyRuleUpdateProtocolAll: - pr.Protocol = server.PolicyRuleProtocolALL + pr.Protocol = types.PolicyRuleProtocolALL case api.PolicyRuleUpdateProtocolTcp: - pr.Protocol = server.PolicyRuleProtocolTCP + pr.Protocol = types.PolicyRuleProtocolTCP case api.PolicyRuleUpdateProtocolUdp: - pr.Protocol = server.PolicyRuleProtocolUDP + pr.Protocol = types.PolicyRuleProtocolUDP case api.PolicyRuleUpdateProtocolIcmp: - pr.Protocol = server.PolicyRuleProtocolICMP + pr.Protocol = types.PolicyRuleProtocolICMP default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return @@ -205,7 +246,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) return } - pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + pr.PortRanges = append(pr.PortRanges, types.RulePortRange{ Start: uint16(portRange.Start), End: uint16(portRange.End), }) @@ -214,7 +255,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s // validate policy object switch pr.Protocol { - case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP: if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return @@ -223,7 +264,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } - case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP: if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return @@ -319,8 +360,8 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { - groupsMap := make(map[string]*nbgroup.Group) +func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -337,13 +378,15 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rID := r.ID rDescription := r.Description rule := api.PolicyRule{ - Id: &rID, - Name: r.Name, - Enabled: r.Enabled, - Description: &rDescription, - Bidirectional: r.Bidirectional, - Protocol: api.PolicyRuleProtocol(r.Protocol), - Action: api.PolicyRuleAction(r.Action), + Id: &rID, + Name: r.Name, + Enabled: r.Enabled, + Description: &rDescription, + Bidirectional: r.Bidirectional, + Protocol: api.PolicyRuleProtocol(r.Protocol), + Action: api.PolicyRuleAction(r.Action), + SourceResource: r.SourceResource.ToAPIResponse(), + DestinationResource: r.DestinationResource.ToAPIResponse(), } if len(r.Ports) != 0 { @@ -362,26 +405,30 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rule.PortRanges = &portRanges } + var sources []api.GroupMinimum for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, PeersCount: len(group.Peers), } - rule.Sources = append(rule.Sources, minimum) + sources = append(sources, minimum) cache[gid] = minimum } } + rule.Sources = &sources + var destinations []api.GroupMinimum for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { - rule.Destinations = append(rule.Destinations, cachedMinimum) + destinations = append(destinations, cachedMinimum) continue } if group, ok := groupsMap[gid]; ok { @@ -390,10 +437,12 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic Name: group.Name, PeersCount: len(group.Peers), } - rule.Destinations = append(rule.Destinations, minimum) + destinations = append(destinations, minimum) cache[gid] = minimum } } + rule.Destinations = &destinations + ap.Rules = append(ap.Rules, rule) } return ap diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index 4b465a85a98..956d0b7cdb7 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -10,9 +10,9 @@ import ( "strings" "testing" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" @@ -20,50 +20,49 @@ import ( "github.com/magiconair/properties/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) -func initPoliciesTestData(policies ...*server.Policy) *handler { - testPolicies := make(map[string]*server.Policy, len(policies)) +func initPoliciesTestData(policies ...*types.Policy) *handler { + testPolicies := make(map[string]*types.Policy, len(policies)) for _, policy := range policies { testPolicies[policy.ID] = policy } return &handler{ accountManager: &mock_server.MockAccountManager{ - GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { + GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*types.Policy, error) { policy, ok := testPolicies[policyID] if !ok { return nil, status.Errorf(status.NotFound, "policy not found") } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return policy, nil }, - GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { - return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{{ID: "F"}, {ID: "G"}}, nil }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - user := server.NewAdminUser(userID) - return &server.Account{ + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user := types.NewAdminUser(userID) + return &types.Account{ Id: accountID, Domain: "hotmail.com", - Policies: []*server.Policy{ + Policies: []*types.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "test_user": user, }, }, nil @@ -105,10 +104,10 @@ func TestPoliciesGetPolicy(t *testing.T) { }, } - policy := &server.Policy{ + policy := &types.Policy{ ID: "idofthepolicy", Name: "Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ {ID: "idoftherule", Name: "Rule"}, }, } @@ -177,7 +176,9 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["G"] } ]}`)), expectedStatus: http.StatusOK, @@ -193,6 +194,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "G"}}, }, }, }, @@ -221,7 +224,9 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["F"] } ]}`)), expectedStatus: http.StatusOK, @@ -237,6 +242,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "F"}}, }, }, }, @@ -251,10 +258,10 @@ func TestPoliciesWritePolicy(t *testing.T) { }, } - p := initPoliciesTestData(&server.Policy{ + p := initPoliciesTestData(&types.Policy{ ID: "id-existed", Name: "Default POSTed Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "id-existed", Name: "Default POSTed Rule", diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 9d420066cc8..a29ba45629d 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -360,7 +360,7 @@ func validateDomains(domains []string) (domain.List, error) { return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) } - domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) var domainList domain.List diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index a25c899c960..4cee3ee306b 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -16,13 +16,13 @@ import ( "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/gorilla/mux" "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -61,7 +61,7 @@ var baseExistingRoute = &route.Route{ Groups: []string{existingGroupID}, } -var testingAccount = &server.Account{ +var testingAccount = &types.Account{ Id: testAccountID, Domain: "hotmail.com", Peers: map[string]*nbpeer.Peer{ @@ -82,8 +82,8 @@ var testingAccount = &server.Account{ }, }, }, - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + "test_user": types.NewAdminUser("test_user"), }, } @@ -330,6 +330,14 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "POST Wildcard Domain", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["*.example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)), + expectedStatus: http.StatusOK, + expectedBody: false, + }, { name: "POST UnprocessableEntity when both network and domains are provided", requestType: http.MethodPost, @@ -609,6 +617,30 @@ func TestValidateDomains(t *testing.T) { expected: domain.List{"google.com"}, wantErr: true, }, + { + name: "Valid wildcard domain", + domains: []string{"*.example.com"}, + expected: domain.List{"*.example.com"}, + wantErr: false, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid wildcard domain", + domains: []string{"a.*.example.com"}, + expected: nil, + wantErr: true, + }, } for _, tt := range tests { diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 9432d554912..89696a16563 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account @@ -63,8 +64,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { return } - if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || - server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { + if !(types.SetupKeyType(req.Type) == types.SetupKeyReusable || + types.SetupKeyType(req.Type) == types.SetupKeyOneOff) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) return } @@ -85,7 +86,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) @@ -152,7 +153,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { return } - newKey := &server.SetupKey{} + newKey := &types.SetupKey{} newKey.AutoGroups = req.AutoGroups newKey.Revoked = req.Revoked newKey.Id = keyID @@ -212,7 +213,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { +func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) err := json.NewEncoder(w).Encode(toResponseBody(key)) @@ -222,7 +223,7 @@ func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupK } } -func toResponseBody(key *server.SetupKey) *api.SetupKey { +func toResponseBody(key *types.SetupKey) *api.SetupKey { var state string switch { case key.IsExpired(): diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 516a2ab8b01..4ecb1e9ed4c 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -14,11 +14,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -29,17 +29,17 @@ const ( testAccountID = "test_id" ) -func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, - user *server.User, +func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, + user *types.User, ) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, + CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, - ) (*server.SetupKey, error) { + ) (*types.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { nk := newKey.Copy() nk.Ephemeral = ephemeral @@ -47,7 +47,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } return nil, fmt.Errorf("failed creating setup key") }, - GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { + GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { switch keyID { case defaultKey.Id: return defaultKey, nil @@ -58,15 +58,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } }, - SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { + SaveSetupKeyFunc: func(_ context.Context, accountID string, key *types.SetupKey, _ string) (*types.SetupKey, error) { if key.Id == updatedSetupKey.Id { return updatedSetupKey, nil } return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) }, - ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { - return []*server.SetupKey{defaultKey}, nil + ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*types.SetupKey, error) { + return []*types.SetupKey{defaultKey}, nil }, DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error { @@ -89,13 +89,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } func TestSetupKeysHandlers(t *testing.T) { - defaultSetupKey, _ := server.GenerateDefaultSetupKey() + defaultSetupKey, _ := types.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") - newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, - server.SetupKeyUnlimitedUsage, true) + newSetupKey, plainKey := types.GenerateSetupKey(newSetupKeyName, types.SetupKeyReusable, 0, []string{"group-1"}, + types.SetupKeyUnlimitedUsage, true) newSetupKey.Key = plainKey updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 2caf98ad857..197785b349d 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account @@ -164,7 +165,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { +func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken { var lastUsed *time.Time if !pat.LastUsed.IsZero() { lastUsed = &pat.LastUsed @@ -179,7 +180,7 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { } } -func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { +func toPATGeneratedResponse(pat *types.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { return &api.PersonalAccessTokenGenerated{ PlainToken: pat.PlainToken, PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index ef6fb973edc..21bdc461e9a 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -14,11 +14,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -31,13 +31,13 @@ const ( testDomain = "hotmail.com" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ existingTokenID: { ID: existingTokenID, Name: "My first token", @@ -64,16 +64,16 @@ var testAccount = &server.Account{ func initPATTestData() *patHandler { return &patHandler{ accountManager: &mock_server.MockAccountManager{ - CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return &server.PersonalAccessTokenGenerated{ + return &types.PersonalAccessTokenGenerated{ PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe", - PersonalAccessToken: server.PersonalAccessToken{}, + PersonalAccessToken: types.PersonalAccessToken{}, }, nil }, @@ -92,7 +92,7 @@ func initPATTestData() *patHandler { } return nil }, - GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -104,14 +104,14 @@ func initPATTestData() *patHandler { } return testAccount.Users[existingUserID].PATs[existingTokenID], nil }, - GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil + return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -217,7 +217,7 @@ func TestTokenHandlers(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } assert.NotEmpty(t, got.PlainToken) - assert.Equal(t, server.PATLength, len(got.PlainToken)) + assert.Equal(t, types.PATLength, len(got.PlainToken)) case "Get All Tokens": expectedTokens := []api.PersonalAccessToken{ toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), @@ -243,7 +243,7 @@ func TestTokenHandlers(t *testing.T) { } } -func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { +func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessToken { return api.PersonalAccessToken{ Id: serverToken.ID, Name: serverToken.Name, diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index c843bc52b08..7380dd97e8a 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -83,13 +84,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { return } - userRole := server.StrRoleToUserRole(req.Role) - if userRole == server.UserRoleUnknown { + userRole := types.StrRoleToUserRole(req.Role) + if userRole == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w) return } - newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{ Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, @@ -156,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { return } - if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { + if types.StrRoleToUserRole(req.Role) == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) return } @@ -171,13 +172,13 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{ Email: email, Name: name, Role: req.Role, AutoGroups: req.AutoGroups, IsServiceUser: req.IsServiceUser, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }) if err != nil { util.WriteError(r.Context(), err, w) @@ -264,7 +265,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { +func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { autoGroups := user.AutoGroups if autoGroups == nil { autoGroups = []string{} diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index 6f6a912360d..90081830a0d 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -26,37 +26,37 @@ const ( regularUserID = "regularUserID" ) -var usersTestAccount = &server.Account{ +var usersTestAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, Role: "admin", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, regularUserID: { Id: regularUserID, Role: "user", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, serviceUserID: { Id: serviceUserID, Role: "user", IsServiceUser: true, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nonDeletableServiceUserID: { Id: serviceUserID, Role: "admin", IsServiceUser: true, NonDeletable: true, - Issued: server.UserIssuedIntegration, + Issued: types.UserIssuedIntegration, }, }, } @@ -67,13 +67,13 @@ func initUsersTestData() *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return usersTestAccount.Id, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return usersTestAccount.Users[id], nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - users := make([]*server.UserInfo, 0) + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + users := make([]*types.UserInfo, 0) for _, v := range usersTestAccount.Users { - users = append(users, &server.UserInfo{ + users = append(users, &types.UserInfo{ ID: v.Id, Role: string(v.Role), Name: "", @@ -85,7 +85,7 @@ func initUsersTestData() *handler { } return users, nil }, - CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { + CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) { if userID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } @@ -100,7 +100,7 @@ func initUsersTestData() *handler { } return nil }, - SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) { + SaveUserFunc: func(_ context.Context, accountID, userID string, update *types.User) (*types.UserInfo, error) { if update.Id == notFoundUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) } @@ -109,7 +109,7 @@ func initUsersTestData() *handler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false}) + info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false}) if err != nil { return nil, err } @@ -175,7 +175,7 @@ func TestGetUsers(t *testing.T) { return } - respBody := []*server.UserInfo{} + respBody := []*types.UserInfo{} err = json.Unmarshal(content, &respBody) if err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) @@ -342,7 +342,7 @@ func TestCreateUser(t *testing.T) { requestType string requestPath string requestBody io.Reader - expectedResult []*server.User + expectedResult []*types.User }{ {name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)}, // right now creation is blocked in AC middleware, will be refactored in the future diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 0ad250f433c..c5bdf5fe7f1 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -7,16 +7,16 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/jwtclaims" ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) +type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index b25aad99cae..0d345971202 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -11,16 +11,16 @@ import ( "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // GetAccountFromPATFunc function -type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) // ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index fdfb0ea240f..b0d970c5dca 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -10,9 +10,9 @@ import ( "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -28,13 +28,13 @@ const ( wrongToken = "wrongToken" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: accountID, Domain: domain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ userID: { Id: userID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ tokenID: { ID: tokenID, Name: "My first token", @@ -49,7 +49,7 @@ var testAccount = &server.Account{ }, } -func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { if token == PAT { return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 0c70b702a01..47c4ca6aebf 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -7,6 +7,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. @@ -57,9 +59,9 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } - err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) if err != nil { return err } @@ -73,6 +75,6 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) } diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 03be9d039ba..22b8026aa24 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -4,8 +4,8 @@ import ( "context" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" ) // IntegratedValidator interface exists to avoid the circle dependencies @@ -14,7 +14,7 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index dc8765e197f..c664237366a 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -23,7 +23,10 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -413,7 +416,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -437,7 +440,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { return nil, nil, "", cleanup, err } @@ -618,7 +621,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return @@ -705,7 +708,7 @@ func Test_LoginPerformance(t *testing.T) { return } - setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) + setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), types.SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) if err != nil { t.Logf("error creating setup key: %v", err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 5361da53fd5..40514ae14db 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -23,9 +23,11 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -457,7 +459,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for p := range peers { validatedPeers[p] = struct{}{} @@ -532,7 +534,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) + store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } @@ -551,7 +553,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 843fa575e83..82b34393f8c 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -15,7 +15,8 @@ import ( "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" nbversion "github.com/netbirdio/netbird/version" ) @@ -47,8 +48,8 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { - GetAllAccounts(ctx context.Context) []*server.Account - GetStoreEngine() server.StoreEngine + GetAllAccounts(ctx context.Context) []*types.Account + GetStoreEngine() store.Engine } // ConnManager peer connection manager that holds state for current active connections diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 2ac2d68a0cf..1d356387f38 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,10 +5,10 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -22,19 +22,19 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { } // GetAllAccounts returns a list of *server.Account for use in tests with predefined information -func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { - return []*server.Account{ +func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { + return []*types.Account{ { Id: "1", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -49,20 +49,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, SourcePostureChecks: []string{"1"}, @@ -94,16 +94,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, @@ -111,15 +111,15 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, { Id: "2", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -134,20 +134,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, @@ -158,16 +158,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { PeerGroups: make([]string, 1), }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, @@ -177,8 +177,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { } // GetStoreEngine returns FileStoreEngine -func (mockDatasource) GetStoreEngine() server.StoreEngine { - return server.FileStoreEngine +func (mockDatasource) GetStoreEngine() store.Engine { + return store.FileStoreEngine } // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties @@ -267,7 +267,7 @@ func TestGenerateProperties(t *testing.T) { t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) } - if properties["store_engine"] != server.FileStoreEngine { + if properties["store_engine"] != store.FileStoreEngine { t.Errorf("expected JsonFile, got %s", properties["store_engine"]) } diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 51358c7ad67..a645ae32569 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -12,9 +12,9 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -31,64 +31,64 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err := migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail for an empty database") } func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") type network struct { - server.Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } type account struct { - server.Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` } - err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error + err = db.Save(&account{Account: types.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert Gob data") var gobStr string - err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error + err = db.Model(&types.Account{}).Select("network_net").First(&gobStr).Error assert.NoError(t, err, "Failed to fetch Gob data") err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet) require.NoError(t, err, "Failed to decode Gob data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with Gob data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated") } func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") - err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error + err = db.Save(&types.Account{Network: &types.Network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with JSON data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged") } @@ -101,7 +101,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") type location struct { @@ -115,12 +115,12 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { } type account struct { - server.Account + types.Account Peers []peer `gorm:"foreignKey:AccountID;references:id"` } err = db.Save(&account{ - Account: server.Account{Id: "123"}, + Account: types.Account{Id: "123"}, Peers: []peer{ {Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}}, }}, @@ -142,10 +142,10 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.Account{ + err = db.Save(&types.Account{ Id: "1234", PeersG: []nbpeer.Peer{ {Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, @@ -164,20 +164,20 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -187,21 +187,21 @@ func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", KeySecret: "EEFDA****", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -211,20 +211,20 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing. func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 46a4fbc1faf..042137b1b02 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,47 +13,47 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) - GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) - CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) + GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error - GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error - GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error - ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -62,35 +62,35 @@ type MockAccountManager struct { SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error - CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -105,12 +105,16 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) - GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) - GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) - GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error } +func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { + // do nothing +} + func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { if am.DeleteSetupKeyFunc != nil { return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) @@ -118,7 +122,7 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } @@ -130,7 +134,7 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} @@ -139,7 +143,7 @@ func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[st } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) { +func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(ctx, accountId, groupID, userID) } @@ -147,7 +151,7 @@ func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(ctx, accountID, userID) } @@ -155,7 +159,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { return am.GetUsersFromAccountFunc(ctx, accountID, userID) } @@ -173,7 +177,7 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface func (am *MockAccountManager) GetOrCreateAccountByUser( ctx context.Context, userId, domain string, -) (*server.Account, error) { +) (*types.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) } @@ -188,13 +192,13 @@ func (am *MockAccountManager) CreateSetupKey( ctx context.Context, accountID string, keyName string, - keyType server.SetupKeyType, + keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, -) (*server.SetupKey, error) { +) (*types.SetupKey, error) { if am.CreateSetupKeyFunc != nil { return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) } @@ -221,7 +225,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } @@ -229,7 +233,7 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { return am.GetAccountFromPATFunc(ctx, pat) } @@ -253,7 +257,7 @@ func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error } // CreatePAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { +func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn) } @@ -269,7 +273,7 @@ func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, i } // GetPAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if am.GetPATFunc != nil { return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } @@ -277,7 +281,7 @@ func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, init } // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface -func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if am.GetAllPATsFunc != nil { return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID) } @@ -285,7 +289,7 @@ func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, } // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface -func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) { +func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*types.NetworkMap, error) { if am.GetNetworkMapFunc != nil { return am.GetNetworkMapFunc(ctx, peerKey) } @@ -293,7 +297,7 @@ func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) } // GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface -func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) { +func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*types.Network, error) { if am.GetPeerNetworkFunc != nil { return am.GetPeerNetworkFunc(ctx, peerKey) } @@ -306,7 +310,7 @@ func (am *MockAccountManager) AddPeer( setupKey string, userId string, peer *nbpeer.Peer, -) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { return am.AddPeerFunc(ctx, setupKey, userId, peer) } @@ -314,7 +318,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(ctx, accountID, groupName) } @@ -322,7 +326,7 @@ func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, gro } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(ctx, accountID, userID, group) } @@ -330,7 +334,7 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s } // SaveGroups mock implementation of SaveGroups from server.AccountManager interface -func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error { +func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { if am.SaveGroupsFunc != nil { return am.SaveGroupsFunc(ctx, accountID, userID, groups) } @@ -378,7 +382,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface -func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) { +func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { if am.GetPolicyFunc != nil { return am.GetPolicyFunc(ctx, accountID, policyID, userID) } @@ -386,7 +390,7 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { if am.SavePolicyFunc != nil { return am.SavePolicyFunc(ctx, accountID, userID, policy) } @@ -402,7 +406,7 @@ func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, polic } // ListPolicies mock implementation of ListPolicies from server.AccountManager interface -func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) { +func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { if am.ListPoliciesFunc != nil { return am.ListPoliciesFunc(ctx, accountID, userID) } @@ -418,14 +422,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) { +func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { if am.GetUserFunc != nil { return am.GetUserFunc(ctx, claims) } return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") } -func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) { +func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { if am.ListUsersFunc != nil { return am.ListUsersFunc(ctx, accountID) } @@ -481,7 +485,7 @@ func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID } // SaveSetupKey mocks SaveSetupKey of the AccountManager interface -func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { +func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) { if am.SaveSetupKeyFunc != nil { return am.SaveSetupKeyFunc(ctx, accountID, key, userID) } @@ -490,7 +494,7 @@ func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string } // GetSetupKey mocks GetSetupKey of the AccountManager interface -func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { +func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { if am.GetSetupKeyFunc != nil { return am.GetSetupKeyFunc(ctx, accountID, userID, keyID) } @@ -499,7 +503,7 @@ func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID } // ListSetupKeys mocks ListSetupKeys of the AccountManager interface -func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) { +func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { if am.ListSetupKeysFunc != nil { return am.ListSetupKeysFunc(ctx, accountID, userID) } @@ -508,7 +512,7 @@ func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, user } // SaveUser mocks SaveUser of the AccountManager interface -func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) { if am.SaveUserFunc != nil { return am.SaveUserFunc(ctx, accountID, userID, user) } @@ -516,7 +520,7 @@ func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID st } // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) { if am.SaveOrAddUserFunc != nil { return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists) } @@ -524,7 +528,7 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user } // SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if am.SaveOrAddUsersFunc != nil { return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists) } @@ -595,7 +599,7 @@ func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountI } // CreateUser mocks CreateUser of the AccountManager interface -func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { +func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { if am.CreateUserFunc != nil { return am.CreateUserFunc(ctx, accountID, userID, invite) } @@ -642,7 +646,7 @@ func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID s } // GetDNSSettings mocks GetDNSSettings of the AccountManager interface -func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { +func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { if am.GetDNSSettingsFunc != nil { return am.GetDNSSettingsFunc(ctx, accountID, userID) } @@ -650,7 +654,7 @@ func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID stri } // SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface -func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { +func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if am.SaveDNSSettingsFunc != nil { return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave) } @@ -666,7 +670,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { if am.UpdateAccountSettingsFunc != nil { return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } @@ -674,7 +678,7 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account } // LoginPeer mocks LoginPeer of the AccountManager interface -func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.LoginPeerFunc != nil { return am.LoginPeerFunc(ctx, login) } @@ -682,7 +686,7 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, account) } @@ -803,7 +807,7 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } // GetAccountByID mocks GetAccountByID of the AccountManager interface -func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { if am.GetAccountByIDFunc != nil { return am.GetAccountByIDFunc(ctx, accountID, userID) } @@ -811,21 +815,21 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri } // GetUserByID mocks GetUserByID of the AccountManager interface -func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { return am.GetUserByIDFunc(ctx, id) } return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") } -func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { if am.GetAccountSettingsFunc != nil { return am.GetAccountSettingsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") } -func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { if am.GetAccountFunc != nil { return am.GetAccountFunc(ctx, accountID) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e7a5387a142..1a01c7a89ca 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -11,15 +11,16 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -32,7 +33,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, status.NewAdminPermissionError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -40,7 +41,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -64,21 +65,21 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups) if err != nil { return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup) }) if err != nil { return nil, err @@ -87,7 +88,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return newNSGroup.Copy(), nil @@ -102,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -113,8 +114,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) if err != nil { return err } @@ -129,11 +130,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave) }) if err != nil { return err @@ -142,7 +143,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -153,7 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -165,22 +166,22 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco var nsGroup *nbdns.NameServerGroup var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) if err != nil { return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups) if err != nil { return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) }) if err != nil { return err @@ -189,7 +190,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -197,7 +198,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -210,10 +211,10 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) } -func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { +func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err @@ -224,7 +225,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -234,7 +235,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) if err != nil { return err } @@ -243,12 +244,12 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) if err != nil { return false, err } @@ -257,7 +258,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store return true, nil } - return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) + return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { @@ -305,7 +306,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*nbgroup.Group) error { +func validateGroups(list []string, groups map[string]*types.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf02370..0743db51396 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,9 +11,10 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -772,10 +773,10 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createNSStore(t *testing.T) (Store, error) { +func createNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -784,7 +785,7 @@ func createNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: nsGroupPeer1Key, @@ -842,12 +843,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &nbgroup.Group{ + newGroup1 := &types.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &nbgroup.Group{ + newGroup2 := &types.Group{ ID: group2ID, Name: group2ID, } @@ -944,7 +945,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go new file mode 100644 index 00000000000..4a7b3db775c --- /dev/null +++ b/management/server/networks/manager.go @@ -0,0 +1,187 @@ +package networks + +import ( + "context" + "fmt" + + "github.com/rs/xid" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) + CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) + UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error +} + +type managerImpl struct { + store store.Store + accountManager s.AccountManager + permissionsManager permissions.Manager + resourcesManager resources.Manager + routersManager routers.Manager +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + resourcesManager: resourceManager, + routersManager: routersManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + network.ID = xid.New().String() + + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + + err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) + if err != nil { + return nil, fmt.Errorf("failed to save network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta()) + + return network, nil +} + +func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + + _, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) + + return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) +} + +func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + network, err := m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get resources in network: %w", err) + } + + for _, resource := range resources { + event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID) + if err != nil { + return fmt.Errorf("failed to delete resource: %w", err) + } + eventsToStore = append(eventsToStore, event...) + } + + routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get routers in network: %w", err) + } + + for _, router := range routers { + event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID) + if err != nil { + return fmt.Errorf("failed to delete router: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) + }) + + return nil + }) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go new file mode 100644 index 00000000000..edd830c2564 --- /dev/null +++ b/management/server/networks/manager_test.go @@ -0,0 +1,254 @@ +package networks + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllNetworksReturnsNetworks(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetAllNetworks(ctx, accountID, userID) + require.NoError(t, err) + require.Len(t, networks, 1) + require.Equal(t, "testNetworkId", networks[0].ID) +} + +func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetAllNetworks(ctx, accountID, userID) + require.Error(t, err) + require.Nil(t, networks) +} + +func Test_GetNetworkReturnsNetwork(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Equal(t, "testNetworkId", networks.ID) +} + +func Test_GetNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + network, err := manager.GetNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Nil(t, network) +} + +func Test_CreateNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + network := &types.Network{ + AccountID: "testAccountId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + createdNetwork, err := manager.CreateNetwork(ctx, userID, network) + require.NoError(t, err) + require.Equal(t, network.Name, createdNetwork.Name) +} + +func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + network := &types.Network{ + AccountID: "testAccountId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + createdNetwork, err := manager.CreateNetwork(ctx, userID, network) + require.Error(t, err) + require.Nil(t, createdNetwork) +} + +func Test_DeleteNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + err = manager.DeleteNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) +} + +func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + err = manager.DeleteNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) +} + +func Test_UpdateNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + network := &types.Network{ + AccountID: "testAccountId", + ID: "testNetworkId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) + require.NoError(t, err) + require.Equal(t, network.Name, updatedNetwork.Name) +} + +func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + network := &types.Network{ + AccountID: "testAccountId", + ID: "testNetworkId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) + require.Error(t, err) + require.Nil(t, updatedNetwork) +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go new file mode 100644 index 00000000000..0fff5bcf8e9 --- /dev/null +++ b/management/server/networks/resources/manager.go @@ -0,0 +1,383 @@ +package resources + +import ( + "context" + "errors" + "fmt" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" +) + +type Manager interface { + GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) + GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) + GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) + CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) + UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error + DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + groupsManager groups.Manager + accountManager s.AccountManager +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + groupsManager: groupsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network resources: %w", err) + } + + resourceMap := make(map[string][]string) + for _, resource := range resources { + resourceMap[resource.NetworkID] = append(resourceMap[resource.NetworkID], resource.ID) + } + + return resourceMap, nil +} + +func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to create new network resource: %w", err) + } + + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil { + return errors.New("resource already exists") + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network)) + } + eventsToStore = append(eventsToStore, event) + + res := nbtypes.Resource{ + ID: resource.ID, + Type: resource.Type.String(), + } + for _, groupID := range resource.GroupIDs { + event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) + if err != nil { + return fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to create network resource: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + + return resource, nil +} + +func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + if resource.NetworkID != networkID { + return nil, errors.New("resource not part of network") + } + + return resource, nil +} + +func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resourceType, domain, prefix, err := types.GetResourceType(resource.Address) + if err != nil { + return nil, fmt.Errorf("failed to get resource type: %w", err) + } + + resource.Type = resourceType + resource.Domain = domain + resource.Prefix = prefix + + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != resource.NetworkID { + return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) + } + + _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil && oldResource.ID != resource.ID { + return errors.New("new resource name already exists") + } + + oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + events, err := m.updateResourceGroups(ctx, transaction, userID, resource, oldResource) + if err != nil { + return fmt.Errorf("failed to update resource groups: %w", err) + } + + eventsToStore = append(eventsToStore, events...) + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network)) + }) + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to update network resource: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + + return resource, nil +} + +func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) { + res := nbtypes.Resource{ + ID: newResource.ID, + Type: newResource.Type.String(), + } + + oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) + if err != nil { + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + oldGroupsIds := make([]string, 0) + for _, group := range oldResourceGroups { + oldGroupsIds = append(oldGroupsIds, group.ID) + } + + var eventsToStore []func() + groupsToAdd := util.Difference(newResource.GroupIDs, oldGroupsIds) + for _, groupID := range groupsToAdd { + events, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, newResource.AccountID, userID, groupID, &res) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + groupsToRemove := util.Difference(oldGroupsIds, newResource.GroupIDs) + for _, groupID := range groupsToRemove { + events, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, newResource.AccountID, userID, groupID, res.ID) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + return eventsToStore, nil +} + +func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var events []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID) + if err != nil { + return fmt.Errorf("failed to delete resource: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return fmt.Errorf("failed to delete network resource: %w", err) + } + + for _, event := range events { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + if resource.NetworkID != networkID { + return nil, errors.New("resource not part of network") + } + + groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + var eventsToStore []func() + + for _, group := range groups { + event, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, accountID, userID, group.ID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to remove resource from group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to delete network resource: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resourceID, accountID, activity.NetworkResourceDeleted, resource.EventMeta(network)) + }) + + return eventsToStore, nil +} diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go new file mode 100644 index 00000000000..993cd65df87 --- /dev/null +++ b/management/server/networks/resources/manager_test.go @@ -0,0 +1,411 @@ +package resources + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Len(t, resources, 2) +} + +func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} +func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) + require.NoError(t, err) + require.Len(t, resources, 2) +} + +func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} + +func Test_GetResourceInNetworkReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) + require.Equal(t, resourceID, resource.ID) +} + +func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} + +func Test_CreateResourceSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "newResourceId", + Description: "description", + Address: "192.168.1.1", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.NoError(t, err) + require.Equal(t, resource.Name, createdResource.Name) +} + +func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "192.168.1.1", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, createdResource) +} + +func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "invalid-address", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, createdResource) +} + +func Test_CreateResourceFailsWithUsedName(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "invalid-address", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, createdResource) +} + +func Test_UpdateResourceSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: "someNewName", + ID: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.NoError(t, err) + require.NotNil(t, updatedResource) + require.Equal(t, "new-description", updatedResource.Description) + require.Equal(t, "1.2.3.0/24", updatedResource.Address) + require.Equal(t, types.NetworkResourceType("subnet"), updatedResource.Type) +} + +func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "otherResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + ID: resourceID, + Name: "used-name", + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_DeleteResourceSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) +} + +func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) +} diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go new file mode 100644 index 00000000000..7eecdce0fef --- /dev/null +++ b/management/server/networks/resources/types/resource.go @@ -0,0 +1,169 @@ +package types + +import ( + "errors" + "fmt" + "net/netip" + "regexp" + + "github.com/rs/xid" + + nbDomain "github.com/netbirdio/netbird/management/domain" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type NetworkResourceType string + +const ( + host NetworkResourceType = "host" + subnet NetworkResourceType = "subnet" + domain NetworkResourceType = "domain" +) + +func (p NetworkResourceType) String() string { + return string(p) +} + +type NetworkResource struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string + Type NetworkResourceType + Address string `gorm:"-"` + GroupIDs []string `gorm:"-"` + Domain string + Prefix netip.Prefix `gorm:"serializer:json"` +} + +func NewNetworkResource(accountID, networkID, name, description, address string, groupIDs []string) (*NetworkResource, error) { + resourceType, domain, prefix, err := GetResourceType(address) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + + return &NetworkResource{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Name: name, + Description: description, + Type: resourceType, + Address: address, + Domain: domain, + Prefix: prefix, + GroupIDs: groupIDs, + }, nil +} + +func (n *NetworkResource) ToAPIResponse(groups []api.GroupMinimum) *api.NetworkResource { + addr := n.Prefix.String() + if n.Type == domain { + addr = n.Domain + } + + return &api.NetworkResource{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Type: api.NetworkResourceType(n.Type.String()), + Address: addr, + Groups: groups, + } +} + +func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) { + n.Name = req.Name + + if req.Description != nil { + n.Description = *req.Description + } + n.Address = req.Address + n.GroupIDs = req.Groups +} + +func (n *NetworkResource) Copy() *NetworkResource { + return &NetworkResource{ + ID: n.ID, + AccountID: n.AccountID, + NetworkID: n.NetworkID, + Name: n.Name, + Description: n.Description, + Type: n.Type, + Address: n.Address, + Domain: n.Domain, + Prefix: n.Prefix, + GroupIDs: n.GroupIDs, + } +} + +func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route { + r := &route.Route{ + ID: route.ID(fmt.Sprintf("%s:%s", n.ID, peer.ID)), + AccountID: n.AccountID, + KeepRoute: true, + NetID: route.NetID(n.Name), + Description: n.Description, + Peer: peer.Key, + PeerGroups: nil, + Masquerade: router.Masquerade, + Metric: router.Metric, + Enabled: true, + Groups: nil, + AccessControlGroups: nil, + } + + if n.Type == host || n.Type == subnet { + r.Network = n.Prefix + + r.NetworkType = route.IPv4Network + if n.Prefix.Addr().Is6() { + r.NetworkType = route.IPv6Network + } + } + + if n.Type == domain { + domainList, err := nbDomain.FromStringList([]string{n.Domain}) + if err != nil { + return nil + } + r.Domains = domainList + r.NetworkType = route.DomainNetwork + + // add default placeholder for domain network + r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) + } + + return r +} + +func (n *NetworkResource) EventMeta(network *networkTypes.Network) map[string]any { + return map[string]any{"name": n.Name, "type": n.Type, "network_name": network.Name, "network_id": network.ID} +} + +// GetResourceType returns the type of the resource based on the address +func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) { + if prefix, err := netip.ParsePrefix(address); err == nil { + if prefix.Bits() == 32 || prefix.Bits() == 128 { + return host, "", prefix, nil + } + return subnet, "", prefix, nil + } + + if ip, err := netip.ParseAddr(address); err == nil { + return host, "", netip.PrefixFrom(ip, ip.BitLen()), nil + } + + domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) + if domainRegex.MatchString(address) { + return domain, address, netip.Prefix{}, nil + } + + return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain") +} diff --git a/management/server/networks/resources/types/resource_test.go b/management/server/networks/resources/types/resource_test.go new file mode 100644 index 00000000000..6af384cce2f --- /dev/null +++ b/management/server/networks/resources/types/resource_test.go @@ -0,0 +1,53 @@ +package types + +import ( + "net/netip" + "testing" +) + +func TestGetResourceType(t *testing.T) { + tests := []struct { + input string + expectedType NetworkResourceType + expectedErr bool + expectedDomain string + expectedPrefix netip.Prefix + }{ + // Valid host IPs + {"1.1.1.1", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, + {"1.1.1.1/32", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, + // Valid subnets + {"192.168.1.0/24", subnet, false, "", netip.MustParsePrefix("192.168.1.0/24")}, + {"10.0.0.0/16", subnet, false, "", netip.MustParsePrefix("10.0.0.0/16")}, + // Valid domains + {"example.com", domain, false, "example.com", netip.Prefix{}}, + {"*.example.com", domain, false, "*.example.com", netip.Prefix{}}, + {"sub.example.com", domain, false, "sub.example.com", netip.Prefix{}}, + // Invalid inputs + {"invalid", "", true, "", netip.Prefix{}}, + {"1.1.1.1/abc", "", true, "", netip.Prefix{}}, + {"1234", "", true, "", netip.Prefix{}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result, domain, prefix, err := GetResourceType(tt.input) + + if result != tt.expectedType { + t.Errorf("Expected type %v, got %v", tt.expectedType, result) + } + + if tt.expectedErr && err == nil { + t.Errorf("Expected error, got nil") + } + + if prefix != tt.expectedPrefix { + t.Errorf("Expected address %v, got %v", tt.expectedPrefix, prefix) + } + + if domain != tt.expectedDomain { + t.Errorf("Expected domain %v, got %v", tt.expectedDomain, domain) + } + }) + } +} diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go new file mode 100644 index 00000000000..3b32810a2c6 --- /dev/null +++ b/management/server/networks/routers/manager.go @@ -0,0 +1,289 @@ +package routers + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/xid" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) + GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) + CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) + UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error + DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + accountManager s.AccountManager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network routers: %w", err) + } + + routersMap := make(map[string][]*types.NetworkRouter) + for _, router := range routers { + routersMap[router.NetworkID] = append(routersMap[router.NetworkID], router) + } + + return routersMap, nil +} + +func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return status.NewNetworkNotFoundError(router.NetworkID) + } + + router.ID = xid.New().String() + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to create network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil +} + +func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, errors.New("router not part of network") + } + + return router, nil +} + +func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) + } + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to update network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil +} + +func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var event func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID) + if err != nil { + return fmt.Errorf("failed to delete network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + event() + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID) + } + + err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to delete network router: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, routerID, accountID, activity.NetworkRouterDeleted, router.EventMeta(network)) + } + + return event, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + return []*types.NetworkRouter{}, nil +} + +func (m *mockManager) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { + return map[string][]*types.NetworkRouter{}, nil +} + +func (m *mockManager) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + return &types.NetworkRouter{}, nil +} + +func (m *mockManager) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + return nil +} + +func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { + return func() {}, nil +} diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go new file mode 100644 index 00000000000..e650074cc17 --- /dev/null +++ b/management/server/networks/routers/manager_test.go @@ -0,0 +1,234 @@ +package routers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Len(t, routers, 1) + require.Equal(t, "testRouterId", routers[0].ID) +} + +func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, routers) +} + +func Test_GetRouterReturnsRouter(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) + require.Equal(t, "testRouterId", router.ID) +} + +func Test_GetRouterReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, router) +} + +func Test_CreateRouterSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + createdRouter, err := manager.CreateRouter(ctx, userID, router) + require.NoError(t, err) + require.NotEqual(t, "", router.ID) + require.Equal(t, router.NetworkID, createdRouter.NetworkID) + require.Equal(t, router.Peer, createdRouter.Peer) + require.Equal(t, router.Metric, createdRouter.Metric) + require.Equal(t, router.Masquerade, createdRouter.Masquerade) +} + +func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + createdRouter, err := manager.CreateRouter(ctx, userID, router) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, createdRouter) +} + +func Test_DeleteRouterSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + routerID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) + require.NoError(t, err) +} + +func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + routerID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) +} + +func Test_UpdateRouterSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.NoError(t, err) + require.Equal(t, router.Metric, updatedRouter.Metric) +} + +func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, updatedRouter) +} diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go new file mode 100644 index 00000000000..f37ae0861a0 --- /dev/null +++ b/management/server/networks/routers/types/router.go @@ -0,0 +1,75 @@ +package types + +import ( + "errors" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/networks/types" +) + +type NetworkRouter struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Peer string + PeerGroups []string `gorm:"serializer:json"` + Masquerade bool + Metric int +} + +func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int) (*NetworkRouter, error) { + if peer != "" && len(peerGroups) > 0 { + return nil, errors.New("peer and peerGroups cannot be set at the same time") + } + + return &NetworkRouter{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Peer: peer, + PeerGroups: peerGroups, + Masquerade: masquerade, + Metric: metric, + }, nil +} + +func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { + return &api.NetworkRouter{ + Id: n.ID, + Peer: &n.Peer, + PeerGroups: &n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + } +} + +func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) { + if req.Peer != nil { + n.Peer = *req.Peer + } + + if req.PeerGroups != nil { + n.PeerGroups = *req.PeerGroups + } + + n.Masquerade = req.Masquerade + n.Metric = req.Metric +} + +func (n *NetworkRouter) Copy() *NetworkRouter { + return &NetworkRouter{ + ID: n.ID, + NetworkID: n.NetworkID, + AccountID: n.AccountID, + Peer: n.Peer, + PeerGroups: n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + } +} + +func (n *NetworkRouter) EventMeta(network *types.Network) map[string]any { + return map[string]any{"network_name": network.Name, "network_id": network.ID, "peer": n.Peer, "peer_groups": n.PeerGroups} +} diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go new file mode 100644 index 00000000000..3335f7c895b --- /dev/null +++ b/management/server/networks/routers/types/router_test.go @@ -0,0 +1,100 @@ +package types + +import "testing" + +func TestNewNetworkRouter(t *testing.T) { + tests := []struct { + name string + accountID string + networkID string + peer string + peerGroups []string + masquerade bool + metric int + expectedError bool + }{ + // Valid cases + { + name: "Valid with peer only", + networkID: "network-1", + accountID: "account-1", + peer: "peer-1", + peerGroups: nil, + masquerade: true, + metric: 100, + expectedError: false, + }, + { + name: "Valid with peerGroups only", + networkID: "network-2", + accountID: "account-2", + peer: "", + peerGroups: []string{"group-1", "group-2"}, + masquerade: false, + metric: 200, + expectedError: false, + }, + { + name: "Valid with no peer or peerGroups", + networkID: "network-3", + accountID: "account-3", + peer: "", + peerGroups: nil, + masquerade: true, + metric: 300, + expectedError: false, + }, + + // Invalid cases + { + name: "Invalid with both peer and peerGroups", + networkID: "network-4", + accountID: "account-4", + peer: "peer-2", + peerGroups: []string{"group-3"}, + masquerade: false, + metric: 400, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, err := NewNetworkRouter(tt.accountID, tt.networkID, tt.peer, tt.peerGroups, tt.masquerade, tt.metric) + + if tt.expectedError && err == nil { + t.Fatalf("Expected an error, got nil") + } + + if tt.expectedError == false { + if router == nil { + t.Fatalf("Expected a NetworkRouter object, got nil") + } + + if router.AccountID != tt.accountID { + t.Errorf("Expected AccountID %s, got %s", tt.accountID, router.AccountID) + } + + if router.NetworkID != tt.networkID { + t.Errorf("Expected NetworkID %s, got %s", tt.networkID, router.NetworkID) + } + + if router.Peer != tt.peer { + t.Errorf("Expected Peer %s, got %s", tt.peer, router.Peer) + } + + if len(router.PeerGroups) != len(tt.peerGroups) { + t.Errorf("Expected PeerGroups %v, got %v", tt.peerGroups, router.PeerGroups) + } + + if router.Masquerade != tt.masquerade { + t.Errorf("Expected Masquerade %v, got %v", tt.masquerade, router.Masquerade) + } + + if router.Metric != tt.metric { + t.Errorf("Expected Metric %d, got %d", tt.metric, router.Metric) + } + } + }) + } +} diff --git a/management/server/networks/types/network.go b/management/server/networks/types/network.go new file mode 100644 index 00000000000..a4ba7b821ff --- /dev/null +++ b/management/server/networks/types/network.go @@ -0,0 +1,56 @@ +package types + +import ( + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Network struct { + ID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string +} + +func NewNetwork(accountId, name, description string) *Network { + return &Network{ + ID: xid.New().String(), + AccountID: accountId, + Name: name, + Description: description, + } +} + +func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routingPeersCount int, policyIDs []string) *api.Network { + return &api.Network{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Routers: routerIDs, + Resources: resourceIDs, + RoutingPeersCount: routingPeersCount, + Policies: policyIDs, + } +} + +func (n *Network) FromAPIRequest(req *api.NetworkRequest) { + n.Name = req.Name + if req.Description != nil { + n.Description = *req.Description + } +} + +// Copy returns a copy of a posture checks. +func (n *Network) Copy() *Network { + return &Network{ + ID: n.ID, + AccountID: n.AccountID, + Name: n.Name, + Description: n.Description, + } +} + +func (n *Network) EventMeta() map[string]any { + return map[string]any{"name": n.Name} +} diff --git a/management/server/peer.go b/management/server/peer.go index ba211be9694..ad20d279a6a 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -16,6 +16,8 @@ import ( "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" @@ -92,7 +94,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -107,7 +109,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *types.Account) error { peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return fmt.Errorf("failed to find peer by pub key: %w", err) @@ -133,13 +135,13 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *types.Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -213,9 +215,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peerLabelUpdated { peer.Name = update.Name - existingLabels := account.getPeerDNSLabels() + existingLabels := account.GetPeerDNSLabels() - newLabel, err := getPeerHostLabel(peer.Name, existingLabels) + newLabel, err := types.GetPeerHostLabel(peer.Name, existingLabels) if err != nil { return nil, err } @@ -271,14 +273,14 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return peer, nil } // deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { +func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *types.Account, peerIDs []string, userID string) error { // the first loop is needed to ensure all peers present under the account before modifying, otherwise // we might have some inconsistencies @@ -316,7 +318,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, - NetworkMap: &NetworkMap{}, + NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -351,14 +353,14 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) -func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) { +func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -379,11 +381,11 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil + return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil), nil } // GetPeerNetwork returns the Network for a given peer -func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) { +func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -399,7 +401,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -433,7 +435,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key) + _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } @@ -446,12 +448,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var newPeer *nbpeer.Peer var groupsToAdd []string - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var setupKeyID string var setupKeyName string var ephemeral bool if addedByUser { - user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) if err != nil { return fmt.Errorf("failed to get user groups: %w", err) } @@ -460,7 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s opEvent.Activity = activity.PeerAddedByUser } else { // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey) + sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) if err != nil { return fmt.Errorf("failed to get setup key: %w", err) } @@ -533,7 +535,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("failed to get account settings: %w", err) } @@ -558,7 +560,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -609,7 +611,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } if newGroupsAffectsPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -623,22 +625,22 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) + networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil } -func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { - takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getFreeIP(ctx context.Context, s store.Store, accountID string) (net.IP, error) { + takenIps, err := s.GetTakenIPs(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed to get taken IPs: %w", err) } - network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + network, err := s.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed getting network: %w", err) } - nextIp, err := AllocatePeerIP(network.Net, takenIps) + nextIp, err := types.AllocatePeerIP(network.Net, takenIps) if err != nil { return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) } @@ -647,7 +649,7 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() @@ -691,11 +693,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } if peerNotValid { - emptyMap := &NetworkMap{ + emptyMap := &types.NetworkMap{ Network: account.Network.Copy(), } return peer, emptyMap, []*posture.Checks{}, nil @@ -707,10 +709,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -730,7 +732,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. -func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -755,12 +757,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } }() - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { return nil, nil, nil, err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -785,7 +787,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -837,7 +839,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged || (updated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -849,7 +851,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) if err != nil { return err } @@ -860,7 +862,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -872,11 +874,11 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *types.Account, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { var postureChecks []*posture.Checks if isRequiresApproval { - emptyMap := &NetworkMap{ + emptyMap := &types.NetworkMap{ Network: account.Network.Copy(), } return peer, emptyMap, nil, nil @@ -893,10 +895,10 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error { +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { return err @@ -918,7 +920,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us return nil } -func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error { +func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *types.User) error { if peer.AddedWithSSOLogin() { if user.IsBlocked() { return status.Errorf(status.PermissionDenied, "user is blocked") @@ -939,7 +941,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error return nil } -func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool { +func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) bool { expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration) expired = settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { @@ -991,7 +993,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, } for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -1002,9 +1004,9 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) } -// updateAccountPeers updates all peers that belong to an account. +// UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { +func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) @@ -1031,6 +1033,8 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account dnsCache := &DNSConfigCache{} customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() for _, peer := range peers { if !am.peersUpdateManager.HasChannel(peer.ID) { @@ -1050,8 +1054,8 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1069,7 +1073,7 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *types.Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b15315f9870..2ab262ff086 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -19,15 +19,20 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" nbroute "github.com/netbirdio/netbird/route" ) @@ -37,13 +42,13 @@ func TestPeer_LoginExpired(t *testing.T) { expirationEnabled bool lastLogin time.Time expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Login Expiration Disabled. Peer Login Should Not Expire", expirationEnabled: false, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -53,7 +58,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Expire", expirationEnabled: true, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -63,7 +68,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Not Expire", expirationEnabled: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -92,14 +97,14 @@ func TestPeer_SessionExpired(t *testing.T) { lastLogin time.Time connected bool expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", expirationEnabled: false, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Hour, }, @@ -110,7 +115,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -121,7 +126,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -161,7 +166,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -233,9 +238,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) } - var setupKey *SetupKey + var setupKey *types.SetupKey for _, key := range account.SetupKeys { - if key.Type == SetupKeyReusable { + if key.Type == types.SetupKeyReusable { setupKey = key } } @@ -281,8 +286,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 nbgroup.Group - group2 nbgroup.Group + group1 types.Group + group2 types.Group ) group1.ID = xid.New().String() @@ -303,16 +308,16 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy := &Policy{ + policy := &types.Policy{ Name: "test", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{group1.ID}, Destinations: []string{group2.ID}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -410,7 +415,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -469,9 +474,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } account.Settings.RegularUsersViewBlocked = false @@ -482,7 +487,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -567,77 +572,77 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { func TestDefaultAccountManager_GetPeers(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool isServiceUser bool expectedPeerCount int }{ { name: "Regular user, no limited view settings, not a service user", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 1, }, { name: "Service user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 0, }, { name: "Service user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, no limited view settings, not a service user", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin service user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin Service user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, @@ -656,12 +661,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, Role: testCase.role, IsServiceUser: testCase.isServiceUser, } - account.Policies = []*Policy{} + account.Policies = []*types.Policy{} account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings err = manager.Store.SaveAccount(context.Background(), account) @@ -726,9 +731,9 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou regularUser := "regular_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[regularUser] = &User{ + account.Users[regularUser] = &types.User{ Id: regularUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } // Create peers @@ -746,10 +751,10 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou } // Create groups and policies - account.Policies = make([]*Policy, 0, groups) + account.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) - group := &nbgroup.Group{ + group := &types.Group{ ID: groupID, Name: fmt.Sprintf("Group %d", i), } @@ -757,14 +762,95 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } + + // Create network, router and resource for this group + network := &networkTypes.Network{ + ID: fmt.Sprintf("network-%d", i), + AccountID: account.Id, + Name: fmt.Sprintf("Network for Group %d", i), + } + account.Networks = append(account.Networks, network) + + ips := account.GetTakenIPs() + peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) + if err != nil { + return nil, "", "", err + } + + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + DNSLabel: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + Key: peerKey.PublicKey().String(), + IP: peerIP, + Status: &nbpeer.PeerStatus{}, + UserID: regularUser, + Meta: nbpeer.PeerSystemMeta{ + Hostname: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + } + account.Peers[peer.ID] = peer + + group.Peers = append(group.Peers, peer.ID) account.Groups[groupID] = group + router := &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("network-router-%d", i), + NetworkID: network.ID, + AccountID: account.Id, + Peer: peer.ID, + PeerGroups: []string{}, + Masquerade: false, + Metric: 9999, + } + account.NetworkRouters = append(account.NetworkRouters, router) + + resource := &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("network-resource-%d", i), + NetworkID: network.ID, + AccountID: account.Id, + Name: fmt.Sprintf("Network resource for Group %d", i), + Type: "host", + Address: "192.0.2.0/32", + } + account.NetworkResources = append(account.NetworkResources, resource) + + // Create a policy for this network resource + nrPolicy := &types.Policy{ + ID: fmt.Sprintf("policy-nr-%d", i), + Name: fmt.Sprintf("Policy for network resource %d", i), + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: fmt.Sprintf("rule-nr-%d", i), + Name: fmt.Sprintf("Rule for network resource %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{}, + DestinationResource: types.Resource{ + ID: resource.ID, + }, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, nrPolicy) + // Create a policy for this group - policy := &Policy{ + policy := &types.Policy{ ID: fmt.Sprintf("policy-%d", i), Name: fmt.Sprintf("Policy for Group %d", i), Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: fmt.Sprintf("rule-%d", i), Name: fmt.Sprintf("Rule for Group %d", i), @@ -772,8 +858,8 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou Sources: []string{groupID}, Destinations: []string{groupID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -845,11 +931,12 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 90, 120, 90, 120}, - {"Medium", 500, 100, 110, 140, 120, 200}, - {"Large", 5000, 200, 800, 1300, 2500, 3600}, + {"Medium", 500, 100, 110, 150, 120, 260}, + {"Large", 5000, 200, 800, 1390, 2500, 4600}, {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, - {"Large 5", 5000, 15, 1300, 1800, 5000, 6000}, + {"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, + {"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, } log.SetOutput(io.Discard) @@ -881,7 +968,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id) } duration := time.Since(start) @@ -899,7 +986,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) @@ -939,8 +1026,8 @@ func TestToSyncResponse(t *testing.T) { Payload: "turn-user", Signature: "turn-pass", } - networkMap := &NetworkMap{ - Network: &Network{Net: *ipnet, Serial: 1000}, + networkMap := &types.NetworkMap{ + Network: &types.Network{Net: *ipnet, Serial: 1000}, Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, Routes: []*nbroute.Route{ @@ -987,8 +1074,8 @@ func TestToSyncResponse(t *testing.T) { }, CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, }, - FirewallRules: []*FirewallRule{ - {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + FirewallRules: []*types.FirewallRule{ + {PeerIP: "192.168.1.2", Direction: types.FirewallRuleDirectionIN, Action: string(types.PolicyTrafficActionAccept), Protocol: string(types.PolicyRuleProtocolTCP), Port: "80"}, }, } dnsName := "example.com" @@ -1003,7 +1090,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, true) assert.NotNil(t, response) // assert peer config @@ -1088,7 +1175,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1099,13 +1186,13 @@ func Test_RegisterPeerByUser(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1128,12 +1215,12 @@ func Test_RegisterPeerByUser(t *testing.T) { addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.UserID, existingUserID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname) @@ -1152,7 +1239,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1163,13 +1250,13 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1192,11 +1279,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID) @@ -1219,7 +1306,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1230,13 +1317,13 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1258,10 +1345,10 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) require.Error(t, err) - _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.Error(t, err) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.NotContains(t, account.Peers, newPeer.ID) assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID) @@ -1284,7 +1371,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1304,26 +1391,26 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) // create a user with auto groups - _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{ + _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{ { Id: "regularUser1", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupA"}, }, { Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupB"}, }, { Id: "regularUser3", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupC"}, }, }, true) @@ -1464,15 +1551,15 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go new file mode 100644 index 00000000000..320aad027d5 --- /dev/null +++ b/management/server/permissions/manager.go @@ -0,0 +1,102 @@ +package permissions + +import ( + "context" + "errors" + "fmt" + + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" +) + +type Module string + +const ( + Networks Module = "networks" + Peers Module = "peers" + Groups Module = "groups" +) + +type Operation string + +const ( + Read Operation = "read" + Write Operation = "write" +) + +type Manager interface { + ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) +} + +type managerImpl struct { + userManager users.Manager + settingsManager settings.Manager +} + +type managerMock struct { +} + +func NewManager(userManager users.Manager, settingsManager settings.Manager) Manager { + return &managerImpl{ + userManager: userManager, + settingsManager: settingsManager, + } +} + +func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + user, err := m.userManager.GetUser(ctx, userID) + if err != nil { + return false, err + } + + if user == nil { + return false, errors.New("user not found") + } + + if user.AccountID != accountID { + return false, errors.New("user does not belong to account") + } + + switch user.Role { + case types.UserRoleAdmin, types.UserRoleOwner: + return true, nil + case types.UserRoleUser: + return m.validateRegularUserPermissions(ctx, accountID, userID, module, operation) + case types.UserRoleBillingAdmin: + return false, nil + default: + return false, errors.New("invalid role") + } +} + +func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + settings, err := m.settingsManager.GetSettings(ctx, accountID, userID) + if err != nil { + return false, fmt.Errorf("failed to get settings: %w", err) + } + if settings.RegularUsersViewBlocked { + return false, nil + } + + if operation == Write { + return false, nil + } + + if module == Peers { + return true, nil + } + + return false, nil +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + if userID == "allowedUser" { + return true, nil + } + return false, nil +} diff --git a/management/server/policy.go b/management/server/policy.go index 2d3abc3f1e2..45b3e93e697 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,344 +3,21 @@ package server import ( "context" _ "embed" - "strconv" - "strings" - "github.com/netbirdio/netbird/management/proto" "github.com/rs/xid" - log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PolicyUpdateOperationType operation type -type PolicyUpdateOperationType int - -// PolicyTrafficActionType action type for the firewall -type PolicyTrafficActionType string - -// PolicyRuleProtocolType type of traffic -type PolicyRuleProtocolType string - -// PolicyRuleDirection direction of traffic -type PolicyRuleDirection string - -const ( - // PolicyTrafficActionAccept indicates that the traffic is accepted - PolicyTrafficActionAccept = PolicyTrafficActionType("accept") - // PolicyTrafficActionDrop indicates that the traffic is dropped - PolicyTrafficActionDrop = PolicyTrafficActionType("drop") -) - -const ( - // PolicyRuleProtocolALL type of traffic - PolicyRuleProtocolALL = PolicyRuleProtocolType("all") - // PolicyRuleProtocolTCP type of traffic - PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") - // PolicyRuleProtocolUDP type of traffic - PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") - // PolicyRuleProtocolICMP type of traffic - PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") -) - -const ( - // PolicyRuleFlowDirect allows traffic from source to destination - PolicyRuleFlowDirect = PolicyRuleDirection("direct") - // PolicyRuleFlowBidirect allows traffic to both directions - PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") -) - -const ( - // DefaultRuleName is a name for the Default rule that is created for every account - DefaultRuleName = "Default" - // DefaultRuleDescription is a description for the Default rule that is created for every account - DefaultRuleDescription = "This is a default rule that allows connections between all the resources" - // DefaultPolicyName is a name for the Default policy that is created for every account - DefaultPolicyName = "Default" - // DefaultPolicyDescription is a description for the Default policy that is created for every account - DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" -) - -const ( - firewallRuleDirectionIN = 0 - firewallRuleDirectionOUT = 1 -) - -// PolicyUpdateOperation operation object with type and values to be applied -type PolicyUpdateOperation struct { - Type PolicyUpdateOperationType - Values []string -} - -// RulePortRange represents a range of ports for a firewall rule. -type RulePortRange struct { - Start uint16 - End uint16 -} - -// PolicyRule is the metadata of the policy -type PolicyRule struct { - // ID of the policy rule - ID string `gorm:"primaryKey"` - - // PolicyID is a reference to Policy that this object belongs - PolicyID string `json:"-" gorm:"index"` - - // Name of the rule visible in the UI - Name string - - // Description of the rule visible in the UI - Description string - - // Enabled status of rule in the system - Enabled bool - - // Action policy accept or drops packets - Action PolicyTrafficActionType - - // Destinations policy destination groups - Destinations []string `gorm:"serializer:json"` - - // Sources policy source groups - Sources []string `gorm:"serializer:json"` - - // Bidirectional define if the rule is applicable in both directions, sources, and destinations - Bidirectional bool - - // Protocol type of the traffic - Protocol PolicyRuleProtocolType - - // Ports or it ranges list - Ports []string `gorm:"serializer:json"` - - // PortRanges a list of port ranges. - PortRanges []RulePortRange `gorm:"serializer:json"` -} - -// Copy returns a copy of a policy rule -func (pm *PolicyRule) Copy() *PolicyRule { - rule := &PolicyRule{ - ID: pm.ID, - PolicyID: pm.PolicyID, - Name: pm.Name, - Description: pm.Description, - Enabled: pm.Enabled, - Action: pm.Action, - Destinations: make([]string, len(pm.Destinations)), - Sources: make([]string, len(pm.Sources)), - Bidirectional: pm.Bidirectional, - Protocol: pm.Protocol, - Ports: make([]string, len(pm.Ports)), - PortRanges: make([]RulePortRange, len(pm.PortRanges)), - } - copy(rule.Destinations, pm.Destinations) - copy(rule.Sources, pm.Sources) - copy(rule.Ports, pm.Ports) - copy(rule.PortRanges, pm.PortRanges) - return rule -} - -// Policy of the Rego query -type Policy struct { - // ID of the policy' - ID string `gorm:"primaryKey"` - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name of the Policy - Name string - - // Description of the policy visible in the UI - Description string - - // Enabled status of the policy - Enabled bool - - // Rules of the policy - Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` - - // SourcePostureChecks are ID references to Posture checks for policy source groups - SourcePostureChecks []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the policy. -func (p *Policy) Copy() *Policy { - c := &Policy{ - ID: p.ID, - AccountID: p.AccountID, - Name: p.Name, - Description: p.Description, - Enabled: p.Enabled, - Rules: make([]*PolicyRule, len(p.Rules)), - SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), - } - for i, r := range p.Rules { - c.Rules[i] = r.Copy() - } - copy(c.SourcePostureChecks, p.SourcePostureChecks) - return c -} - -// EventMeta returns activity event meta related to this policy -func (p *Policy) EventMeta() map[string]any { - return map[string]any{"name": p.Name} -} - -// UpgradeAndFix different version of policies to latest version -func (p *Policy) UpgradeAndFix() { - for _, r := range p.Rules { - // start migrate from version v0.20.3 - if r.Protocol == "" { - r.Protocol = PolicyRuleProtocolALL - } - if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { - r.Bidirectional = true - } - // -- v0.20.4 - } -} - -// ruleGroups returns a list of all groups referenced in the policy's rules, -// including sources and destinations. -func (p *Policy) ruleGroups() []string { - groups := make([]string, 0) - for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) - groups = append(groups, rule.Destinations...) - } - - return groups -} - -// FirewallRule is a rule of the firewall. -type FirewallRule struct { - // PeerIP of the peer - PeerIP string - - // Direction of the traffic - Direction int - - // Action of the traffic - Action string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port string -} - -// getPeerConnectionResources for a given peer -// -// This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) - for _, policy := range a.Policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) - - if rule.Bidirectional { - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionIN) - } - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionOUT) - } - } - - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionOUT) - } - - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionIN) - } - } - } - - return getAccumulatedResources() -} - -// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls -// -// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. -// It safe to call the generator function multiple times for same peer and different rules no duplicates will be -// generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { - rulesExists := make(map[string]struct{}) - peersExists := make(map[string]struct{}) - rules := make([]*FirewallRule, 0) - peers := make([]*nbpeer.Peer, 0) - - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &nbgroup.Group{} - } - - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) - for _, peer := range groupPeers { - if peer == nil { - continue - } - - if _, ok := peersExists[peer.ID]; !ok { - peers = append(peers, peer) - peersExists[peer.ID] = struct{}{} - } - - fr := FirewallRule{ - PeerIP: peer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: string(rule.Protocol), - } - - if isAll { - fr.PeerIP = "0.0.0.0" - } - - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - if len(rule.Ports) == 0 { - rules = append(rules, &fr) - continue - } - - for _, port := range rule.Ports { - pr := fr // clone rule and add set new port - pr.Port = port - rules = append(rules, &pr) - } - } - }, func() ([]*nbpeer.Peer, []*FirewallRule) { - return peers, rules - } -} - // GetPolicy from the store -func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -353,15 +30,15 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, status.NewAdminPermissionError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -378,7 +55,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var updateAccountPeers bool var action = activity.PolicyAdded - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { return err } @@ -388,7 +65,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -398,7 +75,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user saveFunc = transaction.SavePolicy } - return saveFunc(ctx, LockingStrengthUpdate, policy) + return saveFunc(ctx, store.LockingStrengthUpdate, policy) }) if err != nil { return nil, err @@ -407,7 +84,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return policy, nil @@ -418,7 +95,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -431,11 +108,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return status.NewAdminPermissionError() } - var policy *Policy + var policy *types.Policy var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID) if err != nil { return err } @@ -445,11 +122,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID) }) if err != nil { return err @@ -458,15 +135,15 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // ListPolicies from the store. -func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -479,13 +156,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) } // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return false, err } @@ -494,7 +171,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err } @@ -504,13 +181,13 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account } } - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } // validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return err } @@ -519,12 +196,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po policy.AccountID = accountID } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) if err != nil { return err } - postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) if err != nil { return err } @@ -548,84 +225,6 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po return nil } -// getAllPeersFromGroups for given peer ID and list of groups -// -// Returns a list of peers from specified groups that pass specified posture checks -// and a boolean indicating if the supplied peer ID exists within these groups. -// -// Important: Posture checks are applicable only to source group peers, -// for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { - peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { - continue - } - - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) - } - } - return filteredPeers, peerInGroups -} - -// validatePostureChecksOnPeer validates the posture checks on a peer -func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { - peer, ok := a.Peers[peerID] - if !ok && peer == nil { - return false - } - - for _, postureChecksID := range sourcePostureChecksID { - postureChecks := a.getPostureChecks(postureChecksID) - if postureChecks == nil { - continue - } - - for _, check := range postureChecks.GetChecks() { - isValid, err := check.Check(ctx, *peer) - if err != nil { - log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) - } - if !isValid { - return false - } - } - } - return true -} - -func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { - for _, postureChecks := range a.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks - } - } - return nil -} - // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds)) @@ -639,7 +238,7 @@ func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureCh } // getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { +func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []string { validIDs := make([]string, 0, len(groupIDs)) for _, id := range groupIDs { if _, exists := groups[id]; exists { @@ -651,7 +250,7 @@ func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []str } // toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { result := make([]*proto.FirewallRule, len(rules)) for i := range rules { rule := rules[i] diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 62d80f46e7f..fab738abe53 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -10,13 +10,13 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" ) func TestAccount_getPeersByPolicy(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -59,7 +59,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -87,21 +87,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -116,15 +116,15 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", "GroupAll", @@ -145,14 +145,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -160,45 +160,45 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerF"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -206,14 +206,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -221,14 +221,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -236,14 +236,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -251,14 +251,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -266,14 +266,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -289,7 +289,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { } func TestAccount_getPeersByPolicyDirect(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -307,7 +307,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -332,21 +332,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: false, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: false, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -361,15 +361,15 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", }, @@ -388,20 +388,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -416,20 +416,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -446,13 +446,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -467,13 +467,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -489,7 +489,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -582,7 +582,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -630,17 +630,17 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, } - account.Policies = append(account.Policies, &Policy{ + account.Policies = append(account.Policies, &types.Policy{ ID: "PolicyPostureChecks", Name: "", Description: "This is the policy with posture checks applied", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Enabled: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, Destinations: []string{ "GroupSwarm", }, @@ -648,7 +648,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { "GroupAll", }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"}, }, }, @@ -664,7 +664,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -674,13 +674,13 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", @@ -690,7 +690,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -700,7 +700,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -715,19 +715,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -742,14 +742,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) @@ -760,45 +760,45 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // assert peers from Group All assert.Contains(t, peers, account.Peers["peerC"]) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", @@ -809,8 +809,8 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }) } -func sortFunc() func(a *FirewallRule, b *FirewallRule) int { - return func(a, b *FirewallRule) int { +func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { + return func(a, b *types.FirewallRule) int { // Concatenate PeerIP and Direction as string for comparison aStr := a.PeerIP + fmt.Sprintf("%d", a.Direction) bStr := b.PeerIP + fmt.Sprintf("%d", b.Direction) @@ -829,7 +829,7 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { func TestPolicyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -858,9 +858,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var policyWithGroupRulesNoPeers *Policy - var policyWithDestinationPeersOnly *Policy - var policyWithSourceAndDestinationPeers *Policy + var policyWithGroupRulesNoPeers *types.Policy + var policyWithDestinationPeersOnly *types.Policy + var policyWithSourceAndDestinationPeers *types.Policy // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { @@ -870,16 +870,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -901,17 +901,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -933,17 +933,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -965,16 +965,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 0467efedb84..1690f8e339a 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -12,10 +12,12 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -28,7 +30,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, status.NewAdminPermissionError() } - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) } // SavePostureChecks saves a posture check. @@ -36,7 +38,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -53,7 +55,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI var isUpdate = postureChecks.ID != "" var action = activity.PostureCheckCreated - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { return err } @@ -64,7 +66,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -72,7 +74,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } postureChecks.AccountID = accountID - return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks) }) if err != nil { return nil, err @@ -81,7 +83,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return postureChecks, nil @@ -92,7 +94,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -107,8 +109,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun var postureChecks *posture.Checks - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) if err != nil { return err } @@ -117,11 +119,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID) }) if err != nil { return err @@ -134,7 +136,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun // ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -147,11 +149,11 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) } // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) { +func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { peerPostureChecks := make(map[string]*posture.Checks) if len(account.PostureChecks) == 0 { @@ -172,15 +174,15 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID s } // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } for _, policy := range policies { if slices.Contains(policy.SourcePostureChecks, postureCheckID) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups()) if err != nil { return false, err } @@ -195,21 +197,21 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, a } // validatePostureChecks validates the posture checks. -func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { +func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { return status.Errorf(status.InvalidArgument, err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. if postureChecks.ID != "" { - if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { return err } return nil } // For new posture checks, ensure no duplicates by name. - checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -226,7 +228,7 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str } // addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) if err != nil { return err @@ -237,7 +239,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee } for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck := account.getPostureChecks(sourcePostureCheckID) + postureCheck := account.GetPostureChecks(sourcePostureCheckID) if postureCheck == nil { return errors.New("failed to add policy posture checks: posture checks not found") } @@ -248,7 +250,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) { +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue @@ -270,8 +272,8 @@ func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) } // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. -func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 93e5741cf28..bad162f050c 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -8,7 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/posture" ) @@ -92,17 +93,17 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { }) } -func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { +func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, error) { accountID := "testingAccount" domain := "example.com" - admin := &User{ + admin := &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, } - user := &User{ + user := &types.User{ Id: regularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) @@ -120,7 +121,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -209,15 +210,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -312,15 +313,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -356,15 +357,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -395,15 +396,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -443,18 +444,18 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { account, err := initTestPostureChecksAccount(manager) require.NoError(t, err, "failed to init testing account") - groupA := &group.Group{ + groupA := &types.Group{ ID: "groupA", AccountID: account.Id, Peers: []string{"peer1"}, } - groupB := &group.Group{ + groupB := &types.Group{ ID: "groupB", AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB}) require.NoError(t, err, "failed to save groups") postureCheckA := &posture.Checks{ @@ -477,9 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) require.NoError(t, err, "failed to save postureCheckB") - policy := &Policy{ + policy := &types.Policy{ AccountID: account.Id, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, @@ -534,7 +535,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { groupA.Peers = []string{} - err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA) require.NoError(t, err, "failed to save groups") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/resource.go b/management/server/resource.go new file mode 100644 index 00000000000..77a5612b3d3 --- /dev/null +++ b/management/server/resource.go @@ -0,0 +1,21 @@ +package server + +type ResourceType string + +const ( + // nolint + hostType ResourceType = "Host" + //nolint + subnetType ResourceType = "Subnet" + // nolint + domainType ResourceType = "Domain" +) + +func (p ResourceType) String() string { + return string(p) +} + +type Resource struct { + Type ResourceType + ID string +} diff --git a/management/server/route.go b/management/server/route.go index 23bea87e3b8..1eb51aea751 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,15 +4,12 @@ import ( "context" "fmt" "net/netip" - "slices" - "strconv" - "strings" "unicode/utf8" "github.com/rs/xid" - log "github.com/sirupsen/logrus" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -21,33 +18,9 @@ import ( "github.com/netbirdio/netbird/route" ) -// RouteFirewallRule a firewall rule applicable for a routed network. -type RouteFirewallRule struct { - // SourceRanges IP ranges of the routing peers. - SourceRanges []string - - // Action of the traffic when the rule is applicable - Action string - - // Destination a network prefix for the routed traffic - Destination string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port uint16 - - // PortRange represents the range of ports for a firewall rule - PortRange RulePortRange - - // isDynamic indicates whether the rule is for DNS routing - IsDynamic bool -} - // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -56,11 +29,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { // routes can have both peer and peer_groups routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) @@ -238,7 +211,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if am.isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +297,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +329,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if am.isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -364,7 +337,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -373,7 +346,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { @@ -404,244 +377,7 @@ func getPlaceholderIP() netip.Prefix { return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. -func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) - - enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) - for _, route := range enabledRoutes { - // If no access control groups are specified, accept all traffic. - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := a.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup}) - rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - -func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { - var fwRules []*RouteFirewallRule - for _, policy := range policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN) - fwRules = append(fwRules, rules...) - } - } - return fwRules -} - -func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) - for _, id := range rule.Sources { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - if pID == peerID { - continue - } - _, distPeer := distributionPeers[pID] - _, valid := validatedPeersMap[pID] - if distPeer && valid { - distPeersWithPolicy[pID] = struct{}{} - } - } - } - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { - peer := a.Peers[pID] - if peer == nil { - continue - } - distributionGroupPeers = append(distributionGroupPeers, peer) - } - return distributionGroupPeers -} - -func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func getDefaultPermit(route *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule - - sources := []string{"0.0.0.0/0"} - if route.Network.Addr().Is6() { - sources = []string{"::/0"} - } - rule := RouteFirewallRule{ - SourceRanges: sources, - Action: string(PolicyTrafficActionAccept), - Destination: route.Network.String(), - Protocol: string(PolicyRuleProtocolALL), - IsDynamic: route.IsDynamic(), - } - - rules = append(rules, &rule) - - // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally - if route.IsDynamic() { - ruleV6 := rule - ruleV6.SourceRanges = []string{"::/0"} - rules = append(rules, &ruleV6) - } - - return rules -} - -// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups -// and returns a list of policies that have rules with destinations matching the specified groups. -func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { - routePolicies := make([]*Policy, 0) - for _, groupID := range accessControlGroups { - group, ok := account.Groups[groupID] - if !ok { - continue - } - - for _, policy := range account.Policies { - for _, rule := range policy.Rules { - exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { - return groupID == group.ID - }) - if exist { - routePolicies = append(routePolicies, policy) - continue - } - } - } - } - - return routePolicies -} - -// generateRouteFirewallRules generates a list of firewall rules for a given route. -func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { - rulesExists := make(map[string]struct{}) - rules := make([]*RouteFirewallRule, 0) - - sourceRanges := make([]string, 0, len(groupPeers)) - for _, peer := range groupPeers { - if peer == nil { - continue - } - sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) - } - - baseRule := RouteFirewallRule{ - SourceRanges: sourceRanges, - Action: string(rule.Action), - Destination: route.Network.String(), - Protocol: string(rule.Protocol), - IsDynamic: route.IsDynamic(), - } - - // generate rule for port range - if len(rule.Ports) == 0 { - rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) - } else { - rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) - - } - - // TODO: generate IPv6 rules for dynamic routes - - return rules -} - -// generateRuleIDBase generates the base rule ID for checking duplicates. -func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { - return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action -} - -// generateRulesForPeer generates rules for a given peer based on ports and port ranges. -func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - - ruleIDBase := generateRuleIDBase(rule, baseRule) - if len(rule.Ports) == 0 { - if len(rule.PortRanges) == 0 { - if _, ok := rulesExists[ruleIDBase]; !ok { - rulesExists[ruleIDBase] = struct{}{} - rules = append(rules, &baseRule) - } - } else { - for _, portRange := range rule.PortRanges { - ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) - if _, ok := rulesExists[ruleID]; !ok { - rulesExists[ruleID] = struct{}{} - pr := baseRule - pr.PortRange = portRange - rules = append(rules, &pr) - } - } - } - return rules - } - - return rules -} - -// generateRulesWithPorts generates rules when specific ports are provided. -func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - ruleIDBase := generateRuleIDBase(rule, baseRule) - - for _, port := range rule.Ports { - ruleID := ruleIDBase + port - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - pr := baseRule - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) - continue - } - - pr.Port = uint16(p) - rules = append(rules, &pr) - } - - return rules -} - -func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { result := make([]*proto.RouteFirewallRule, len(rules)) for i := range rules { rule := rules[i] @@ -660,7 +396,7 @@ func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFir // getProtoDirection converts the direction to proto.RuleDirection. func getProtoDirection(direction int) proto.RuleDirection { - if direction == firewallRuleDirectionOUT { + if direction == types.FirewallRuleDirectionOUT { return proto.RuleDirection_OUT } return proto.RuleDirection_IN @@ -668,7 +404,7 @@ func getProtoDirection(direction int) proto.RuleDirection { // getProtoAction converts the action to proto.RuleAction. func getProtoAction(action string) proto.RuleAction { - if action == string(PolicyTrafficActionDrop) { + if action == string(types.PolicyTrafficActionDrop) { return proto.RuleAction_DROP } return proto.RuleAction_ACCEPT @@ -676,14 +412,14 @@ func getProtoAction(action string) proto.RuleAction { // getProtoProtocol converts the protocol to proto.RuleProtocol. func getProtoProtocol(protocol string) proto.RuleProtocol { - switch PolicyRuleProtocolType(protocol) { - case PolicyRuleProtocolALL: + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: return proto.RuleProtocol_ALL - case PolicyRuleProtocolTCP: + case types.PolicyRuleProtocolTCP: return proto.RuleProtocol_TCP - case PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolUDP: return proto.RuleProtocol_UDP - case PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolICMP: return proto.RuleProtocol_ICMP default: return proto.RuleProtocol_UNKNOWN @@ -691,7 +427,7 @@ func getProtoProtocol(protocol string) proto.RuleProtocol { } // getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { var portInfo proto.PortInfo if rule.Port != 0 { portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} @@ -708,6 +444,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/route_test.go b/management/server/route_test.go index 8bf9a3aebb3..5390cb66b94 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -13,11 +13,16 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -1092,9 +1097,9 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *nbgroup.Group + var groupHA1, groupHA2 *types.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -1202,7 +1207,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &nbgroup.Group{ + newGroup := &types.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1255,10 +1260,10 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createRouterStore(t *testing.T) (Store, error) { +func createRouterStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1267,7 +1272,7 @@ func createRouterStore(t *testing.T) (Store, error) { return store, nil } -func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() accountID := "testingAcc" @@ -1279,8 +1284,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - ips := account.getTakenIPs() - peer1IP, err := AllocatePeerIP(account.Network.Net, ips) + ips := account.GetTakenIPs() + peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1306,8 +1311,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer1.ID] = peer1 - ips = account.getTakenIPs() - peer2IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1333,8 +1338,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer2.ID] = peer2 - ips = account.getTakenIPs() - peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1360,8 +1365,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer3.ID] = peer3 - ips = account.getTakenIPs() - peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1387,8 +1392,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer4.ID] = peer4 - ips = account.getTakenIPs() - peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1439,7 +1444,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*nbgroup.Group{ + newGroup := []*types.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1491,7 +1496,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerKIp = "100.65.29.66" ) - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -1555,7 +1560,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "routingPeer1": { ID: "routingPeer1", Name: "RoutingPeer1", @@ -1685,19 +1690,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { AccessControlGroups: []string{"route4"}, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleRoute1", Name: "Route1", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute1", Name: "ruleRoute1", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Ports: []string{"80", "320"}, Sources: []string{ "dev", @@ -1712,15 +1717,15 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute2", Name: "Route2", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute2", Name: "ruleRoute2", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, - PortRanges: []RulePortRange{ + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ { Start: 80, End: 350, @@ -1742,14 +1747,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute4", Name: "RuleRoute4", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute4", Name: "RuleRoute4", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, Ports: []string{"80"}, Sources: []string{ "restrictQA", @@ -1764,14 +1769,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute5", Name: "RuleRoute5", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute5", Name: "RuleRoute5", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "unrestrictedQA", }, @@ -1791,28 +1796,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check applied policies for the route", func(t *testing.T) { route1 := account.Routes["route1"] - policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) assert.Len(t, policies, 1) route2 := account.Routes["route2"] - policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) assert.Len(t, policies, 1) route3 := account.Routes["route3"] - policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) assert.Len(t, policies, 0) }) t.Run("check peer routes firewall rules", func(t *testing.T) { - routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) assert.Len(t, routesFirewallRules, 4) - expectedRoutesFirewallRules := []*RouteFirewallRule{ + expectedRoutesFirewallRules := []*types.RouteFirewallRule{ { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1821,9 +1826,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1831,10 +1836,10 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - additionalFirewallRule := []*RouteFirewallRule{ + additionalFirewallRule := []*types.RouteFirewallRule{ { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerJIp), + fmt.Sprintf(types.AllowedIPsFormat, peerJIp), }, Action: "accept", Destination: "192.168.10.0/16", @@ -1843,7 +1848,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerKIp), + fmt.Sprintf(types.AllowedIPsFormat, peerKIp), }, Action: "accept", Destination: "192.168.10.0/16", @@ -1854,27 +1859,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) assert.Len(t, routesFirewallRules, 3) - expectedRoutesFirewallRules = []*RouteFirewallRule{ + expectedRoutesFirewallRules = []*types.RouteFirewallRule{ { SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, Action: "accept", Destination: existingNetwork.String(), Protocol: "tcp", - PortRange: RulePortRange{Start: 80, End: 350}, + PortRange: types.RulePortRange{Start: 80, End: 350}, }, { SourceRanges: []string{"0.0.0.0/0"}, Action: "accept", Destination: "192.0.2.0/32", Protocol: "all", + Domains: domain.List{"example.com"}, IsDynamic: true, }, { @@ -1882,20 +1888,21 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Action: "accept", Destination: "192.0.2.0/32", Protocol: "all", + Domains: domain.List{"example.com"}, IsDynamic: true, }, } assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) assert.Len(t, routesFirewallRules, 0) }) } // orderList is a helper function to sort a list of strings -func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule { +func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule { for _, rule := range ruleList { sort.Strings(rule.SourceRanges) } @@ -1909,7 +1916,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { account, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -2105,7 +2112,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2145,7 +2152,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, @@ -2159,3 +2166,502 @@ func TestRouteAccountPeersUpdate(t *testing.T) { } }) } + +func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { + var ( + peerBIp = "100.65.80.39" + peerCIp = "100.65.254.139" + peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" + peerMIp = "100.65.29.67" + ) + + account := &types.Account{ + Peers: map[string]*nbpeer.Peer{ + "peerA": { + ID: "peerA", + IP: net.ParseIP("100.65.14.88"), + Key: "peerA", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerB": { + ID: "peerB", + IP: net.ParseIP(peerBIp), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{}, + }, + "peerC": { + ID: "peerC", + IP: net.ParseIP(peerCIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerD": { + ID: "peerD", + IP: net.ParseIP("100.65.62.5"), + Key: "peerD", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerE": { + ID: "peerE", + IP: net.ParseIP("100.65.32.206"), + Key: "peerE", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerF": { + ID: "peerF", + IP: net.ParseIP("100.65.250.202"), + Status: &nbpeer.PeerStatus{}, + }, + "peerG": { + ID: "peerG", + IP: net.ParseIP("100.65.13.186"), + Status: &nbpeer.PeerStatus{}, + }, + "peerH": { + ID: "peerH", + IP: net.ParseIP(peerHIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerL": { + ID: "peerL", + IP: net.ParseIP("100.65.19.186"), + Key: "peerL", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerM": { + ID: "peerM", + IP: net.ParseIP(peerMIp), + Status: &nbpeer.PeerStatus{}, + }, + }, + Groups: map[string]*types.Group{ + "router1": { + ID: "router1", + Name: "router1", + Peers: []string{ + "peerA", + }, + }, + "router2": { + ID: "router2", + Name: "router2", + Peers: []string{ + "peerD", + }, + }, + "finance": { + ID: "finance", + Name: "Finance", + Peers: []string{ + "peerF", + "peerG", + }, + }, + "dev": { + ID: "dev", + Name: "Dev", + Peers: []string{ + "peerC", + "peerH", + "peerB", + }, + Resources: []types.Resource{ + {ID: "resource2"}, + }, + }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + Resources: []types.Resource{ + {ID: "resource4"}, + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + Resources: []types.Resource{ + {ID: "resource4"}, + }, + }, + "contractors": { + ID: "contractors", + Name: "Contractors", + Peers: []string{}, + }, + "pipeline": { + ID: "pipeline", + Name: "Pipeline", + Peers: []string{"peerM"}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: "network1", + Name: "Finance Network", + }, + { + ID: "network2", + Name: "Devs Network", + }, + { + ID: "network3", + Name: "Contractors Network", + }, + { + ID: "network4", + Name: "QA Network", + }, + { + ID: "network5", + Name: "Pipeline Network", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1", + NetworkID: "network1", + Peer: "peerE", + PeerGroups: nil, + Masquerade: false, + Metric: 9999, + }, + { + ID: "router2", + NetworkID: "network2", + PeerGroups: []string{"router1", "router2"}, + Masquerade: false, + Metric: 9999, + }, + { + ID: "router3", + NetworkID: "network3", + Peer: "peerE", + PeerGroups: []string{}, + }, + { + ID: "router4", + NetworkID: "network4", + PeerGroups: []string{"router1"}, + Masquerade: false, + Metric: 9999, + }, + { + ID: "router5", + NetworkID: "network5", + Peer: "peerL", + Masquerade: false, + Metric: 9999, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1", + NetworkID: "network1", + Name: "Resource 1", + Type: "subnet", + Prefix: netip.MustParsePrefix("10.10.10.0/24"), + }, + { + ID: "resource2", + NetworkID: "network2", + Name: "Resource 2", + Type: "subnet", + Prefix: netip.MustParsePrefix("192.168.0.0/16"), + }, + { + ID: "resource3", + NetworkID: "network3", + Name: "Resource 3", + Type: "domain", + Domain: "example.com", + }, + { + ID: "resource4", + NetworkID: "network4", + Name: "Resource 4", + Type: "domain", + Domain: "example.com", + }, + { + ID: "resource5", + NetworkID: "network5", + Name: "Resource 5", + Type: "host", + Prefix: netip.MustParsePrefix("10.12.12.1/32"), + }, + }, + Policies: []*types.Policy{ + { + ID: "policyResource1", + Name: "Policy for resource 1", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource1", + Name: "ruleResource1", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ + { + Start: 80, + End: 350, + }, { + Start: 80, + End: 350, + }, + }, + Sources: []string{ + "finance", + }, + DestinationResource: types.Resource{ID: "resource1"}, + }, + }, + }, + { + ID: "policyResource2", + Name: "Policy for resource 2", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource2", + Name: "ruleResource2", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"80", "320"}, + Sources: []string{"dev"}, + Destinations: []string{"dev"}, + }, + }, + }, + { + ID: "policyResource3", + Name: "policyResource3", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource3", + Name: "ruleResource3", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{"restrictQA"}, + Destinations: []string{"restrictQA"}, + }, + }, + }, + { + ID: "policyResource4", + Name: "policyResource4", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource4", + Name: "ruleResource4", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + Sources: []string{"unrestrictedQA"}, + Destinations: []string{"unrestrictedQA"}, + }, + }, + }, + { + ID: "policyResource5", + Name: "policyResource5", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource5", + Name: "ruleResource5", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"8080"}, + Sources: []string{"pipeline"}, + DestinationResource: types.Resource{ID: "resource5"}, + }, + }, + }, + }, + } + + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + t.Run("validate applied policies for different network resources", func(t *testing.T) { + // Test case: Resource1 is directly applied to the policy (policyResource1) + policies := account.GetPoliciesForNetworkResource("resource1") + assert.Len(t, policies, 1, "resource1 should have exactly 1 policy applied directly") + + // Test case: Resource2 is applied to an access control group (dev), + // which is part of the destination in the policy (policyResource2) + policies = account.GetPoliciesForNetworkResource("resource2") + assert.Len(t, policies, 1, "resource2 should have exactly 1 policy applied via access control group") + + // Test case: Resource3 is not applied to any access control group or policy + policies = account.GetPoliciesForNetworkResource("resource3") + assert.Len(t, policies, 0, "resource3 should have no policies applied") + + // Test case: Resource4 is applied to the access control groups (restrictQA and unrestrictedQA), + // which is part of the destination in the policies (policyResource3 and policyResource4) + policies = account.GetPoliciesForNetworkResource("resource4") + assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups") + }) + + t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { + resourcePoliciesMap := account.GetResourcePoliciesMap() + resourceRoutersMap := account.GetResourceRoutersMap() + _, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap) + firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 4) + assert.Len(t, sourcePeers, 5) + + expectedFirewallRules := []*types.RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 320, + }, + } + + additionalFirewallRules := []*types.RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "tcp", + Port: 80, + Domains: domain.List{"example.com"}, + IsDynamic: true, + }, + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + Domains: domain.List{"example.com"}, + IsDynamic: true, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) + + // peerD is also the routing peer for resource2 + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 2) + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + assert.Len(t, sourcePeers, 3) + + // peerE is a single routing peer for resource1 and resource3 + // PeerE should only receive rules for resource1 since resource3 has no applied policy + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 1) + assert.Len(t, sourcePeers, 2) + + expectedFirewallRules = []*types.RouteFirewallRule{ + { + SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, + Action: "accept", + Destination: "10.10.10.0/24", + Protocol: "tcp", + PortRange: types.RulePortRange{Start: 80, End: 350}, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + + // peerC is part of distribution groups for resource2 but should not receive the firewall rules + firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + assert.Len(t, firewallRules, 0) + + // peerL is the single routing peer for resource5 + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 1) + assert.Len(t, sourcePeers, 1) + + expectedFirewallRules = []*types.RouteFirewallRule{ + { + SourceRanges: []string{"100.65.29.67/32"}, + Action: "accept", + Destination: "10.12.12.1/32", + Protocol: "tcp", + Port: 8080, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + assert.Len(t, sourcePeers, 0) + }) +} diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go new file mode 100644 index 00000000000..37bc9f549b0 --- /dev/null +++ b/management/server/settings/manager.go @@ -0,0 +1,37 @@ +package settings + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) +} + +type managerImpl struct { + store store.Store +} + +type managerMock struct { +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return &types.Settings{}, nil +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index ef431d3adbc..9a4a1efb853 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,34 +2,16 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" - "hash/fnv" "slices" - "strconv" - "strings" "time" - "unicode/utf8" - "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - // SetupKeyReusable is a multi-use key (can be used for multiple machines) - SetupKeyReusable SetupKeyType = "reusable" - // SetupKeyOneOff is a single use key (can be used only once) - SetupKeyOneOff SetupKeyType = "one-off" - - // DefaultSetupKeyDuration = 1 month - DefaultSetupKeyDuration = 24 * 30 * time.Hour - // DefaultSetupKeyName is a default name of the default setup key - DefaultSetupKeyName = "Default key" - // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key - SetupKeyUnlimitedUsage = 0 + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) const ( @@ -67,169 +49,14 @@ type SetupKeyUpdateOperation struct { Values []string } -// SetupKeyType is the type of setup key -type SetupKeyType string - -// SetupKey represents a pre-authorized key used to register machines (peers) -type SetupKey struct { - Id string - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Key string - KeySecret string - Name string - Type SetupKeyType - CreatedAt time.Time - ExpiresAt time.Time - UpdatedAt time.Time `gorm:"autoUpdateTime:false"` - // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) - Revoked bool - // UsedTimes indicates how many times the key was used - UsedTimes int - // LastUsed last time the key was used for peer registration - LastUsed time.Time - // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register - AutoGroups []string `gorm:"serializer:json"` - // UsageLimit indicates the number of times this key can be used to enroll a machine. - // The value of 0 indicates the unlimited usage. - UsageLimit int - // Ephemeral indicate if the peers will be ephemeral or not - Ephemeral bool -} - -// Copy copies SetupKey to a new object -func (key *SetupKey) Copy() *SetupKey { - autoGroups := make([]string, len(key.AutoGroups)) - copy(autoGroups, key.AutoGroups) - if key.UpdatedAt.IsZero() { - key.UpdatedAt = key.CreatedAt - } - return &SetupKey{ - Id: key.Id, - AccountID: key.AccountID, - Key: key.Key, - KeySecret: key.KeySecret, - Name: key.Name, - Type: key.Type, - CreatedAt: key.CreatedAt, - ExpiresAt: key.ExpiresAt, - UpdatedAt: key.UpdatedAt, - Revoked: key.Revoked, - UsedTimes: key.UsedTimes, - LastUsed: key.LastUsed, - AutoGroups: autoGroups, - UsageLimit: key.UsageLimit, - Ephemeral: key.Ephemeral, - } -} - -// EventMeta returns activity event meta related to the setup key -func (key *SetupKey) EventMeta() map[string]any { - return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} -} - -// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. -// E.g., "831F6*******************************" -func hiddenKey(key string, length int) string { - prefix := key[0:5] - if length > utf8.RuneCountInString(key) { - length = utf8.RuneCountInString(key) - len(prefix) - } - return prefix + strings.Repeat("*", length) -} - -// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now -func (key *SetupKey) IncrementUsage() *SetupKey { - c := key.Copy() - c.UsedTimes++ - c.LastUsed = time.Now().UTC() - return c -} - -// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to -func (key *SetupKey) IsValid() bool { - return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() -} - -// IsRevoked if key was revoked -func (key *SetupKey) IsRevoked() bool { - return key.Revoked -} - -// IsExpired if key was expired -func (key *SetupKey) IsExpired() bool { - if key.ExpiresAt.IsZero() { - return false - } - return time.Now().After(key.ExpiresAt) -} - -// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. -func (key *SetupKey) IsOverUsed() bool { - limit := key.UsageLimit - if key.Type == SetupKeyOneOff { - limit = 1 - } - return limit > 0 && key.UsedTimes >= limit -} - -// GenerateSetupKey generates a new setup key -func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int, ephemeral bool) (*SetupKey, string) { - key := strings.ToUpper(uuid.New().String()) - limit := usageLimit - if t == SetupKeyOneOff { - limit = 1 - } - - expiresAt := time.Time{} - if validFor != 0 { - expiresAt = time.Now().UTC().Add(validFor) - } - - hashedKey := sha256.Sum256([]byte(key)) - encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - - return &SetupKey{ - Id: strconv.Itoa(int(Hash(key))), - Key: encodedHashedKey, - KeySecret: hiddenKey(key, 4), - Name: name, - Type: t, - CreatedAt: time.Now().UTC(), - ExpiresAt: expiresAt, - UpdatedAt: time.Now().UTC(), - Revoked: false, - UsedTimes: 0, - AutoGroups: autoGroups, - UsageLimit: limit, - Ephemeral: ephemeral, - }, key -} - -// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration -func GenerateDefaultSetupKey() (*SetupKey, string) { - return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, - SetupKeyUnlimitedUsage, false) -} - -func Hash(s string) uint32 { - h := fnv.New32a() - _, err := h.Write([]byte(s)) - if err != nil { - panic(err) - } - return h.Sum32() -} - // CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key, // and adds it to the specified account. A list of autoGroups IDs can be empty. -func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { +func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -242,22 +69,22 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - var setupKey *SetupKey + var setupKey *types.SetupKey var plainKey string var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { return err } - setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) + setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey) }) if err != nil { return nil, err @@ -278,7 +105,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key. -func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { +func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *types.SetupKey, userID string) (*types.SetupKey, error) { if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } @@ -286,7 +113,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -299,16 +126,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewAdminPermissionError() } - var oldKey *SetupKey - var newKey *SetupKey + var oldKey *types.SetupKey + var newKey *types.SetupKey var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { return err } - oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) + oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return err } @@ -323,13 +150,13 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() - addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) - removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + addedGroups := util.Difference(newKey.AutoGroups, oldKey.AutoGroups) + removedGroups := util.Difference(oldKey.AutoGroups, newKey.AutoGroups) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey) }) if err != nil { return nil, err @@ -347,8 +174,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } // ListSetupKeys returns a list of all setup keys of the account -func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -361,12 +188,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. -func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -379,7 +206,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewAdminPermissionError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } @@ -394,7 +221,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use // DeleteSetupKey removes the setup key from the account func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -407,15 +234,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return status.NewAdminPermissionError() } - var deletedSetupKey *SetupKey + var deletedSetupKey *types.SetupKey - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return err } - return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID) }) if err != nil { return err @@ -426,8 +253,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs) +func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) if err != nil { return err } @@ -447,11 +274,11 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI } // prepareSetupKeyEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { +func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string, key *types.SetupKey) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 614547c60f6..f728db5d458 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/types" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -30,7 +30,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "group_1", Name: "group_name_1", @@ -49,15 +49,15 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, - SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, types.SetupKeyReusable, expiresIn, []string{}, + types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } autoGroups := []string{"group_1", "group_2"} revoked := true - newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, Revoked: revoked, AutoGroups: autoGroups, @@ -85,7 +85,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { // saving setup key with All group assigned to auto groups should return error autoGroups = append(autoGroups, groupAll.ID) - _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + _, err = manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, Revoked: revoked, AutoGroups: autoGroups, @@ -105,7 +105,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -114,7 +114,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -167,8 +167,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, - tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, types.SetupKeyReusable, expiresIn, + tCase.expectedGroups, types.SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { if err == nil { @@ -182,7 +182,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, - tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), + tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))), tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated @@ -210,7 +210,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } @@ -258,10 +258,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key, plainKey := GenerateDefaultSetupKey() + key, plainKey := types.GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) } @@ -275,48 +275,48 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) } func TestSetupKey_IsValid(t *testing.T) { - validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + validKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + expiredKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, -time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + revokedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + overUsedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + reusableKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyReusable, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) } } -func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, +func assertKey(t *testing.T, key *types.SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) { t.Helper() @@ -388,7 +388,7 @@ func isValidBase64SHA256(encodedKey string) bool { func TestSetupKey_Copy(t *testing.T) { - key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, _ := types.GenerateSetupKey("key name", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, @@ -399,22 +399,22 @@ func TestSetupKey_Copy(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) assert.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"group"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -426,7 +426,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var setupKey *SetupKey + var setupKey *types.SetupKey // Creating setup key should not update account peers and not send peer update t.Run("creating setup key", func(t *testing.T) { @@ -436,7 +436,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { close(done) }() - setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) assert.NoError(t, err) select { @@ -477,7 +477,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t t.Fatal(err) } - key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) assert.NoError(t, err) // revoke the key diff --git a/management/server/status/error.go b/management/server/status/error.go index 59f436f5b19..d9cab02315c 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -154,3 +154,35 @@ func NewPolicyNotFoundError(policyID string) error { func NewNameServerGroupNotFoundError(nsGroupID string) error { return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) } + +// NewNetworkNotFoundError creates a new Error with NotFound type for a missing network. +func NewNetworkNotFoundError(networkID string) error { + return Errorf(NotFound, "network: %s not found", networkID) +} + +// NewNetworkRouterNotFoundError creates a new Error with NotFound type for a missing network router. +func NewNetworkRouterNotFoundError(routerID string) error { + return Errorf(NotFound, "network router: %s not found", routerID) +} + +// NewNetworkResourceNotFoundError creates a new Error with NotFound type for a missing network resource. +func NewNetworkResourceNotFoundError(resourceID string) error { + return Errorf(NotFound, "network resource: %s not found", resourceID) +} + +// NewPermissionDeniedError creates a new Error with PermissionDenied type for a permission denied error. +func NewPermissionDeniedError() error { + return Errorf(PermissionDenied, "permission denied") +} + +func NewPermissionValidationError(err error) error { + return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err) +} + +func NewResourceNotPartOfNetworkError(resourceID, networkID string) error { + return Errorf(BadRequest, "resource %s is not part of the network %s", resourceID, networkID) +} + +func NewRouterNotPartOfNetworkError(routerID, networkID string) error { + return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID) +} diff --git a/management/server/file_store.go b/management/server/store/file_store.go similarity index 89% rename from management/server/file_store.go rename to management/server/store/file_store.go index f375fb99062..f40a0392b29 100644 --- a/management/server/file_store.go +++ b/management/server/store/file_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -11,9 +11,9 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -22,7 +22,7 @@ const storeFileName = "store.json" // FileStore represents an account storage backed by a file persisted to disk type FileStore struct { - Accounts map[string]*Account + Accounts map[string]*types.Account SetupKeyID2AccountID map[string]string `json:"-"` PeerKeyID2AccountID map[string]string `json:"-"` PeerID2AccountID map[string]string `json:"-"` @@ -55,7 +55,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ - Accounts: make(map[string]*Account), + Accounts: make(map[string]*types.Account), mux: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), @@ -92,12 +92,14 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for accountID, account := range store.Accounts { if account.Settings == nil { - account.Settings = &Settings{ + account.Settings = &types.Settings{ PeerLoginExpirationEnabled: false, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + + RoutingPeerDNSResolutionEnabled: true, } } @@ -112,7 +114,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID if user.Issued == "" { - user.Issued = UserIssuedAPI + user.Issued = types.UserIssuedAPI account.Users[user.Id] = user } @@ -122,7 +124,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { } } - if account.Domain != "" && account.DomainCategory == PrivateCategory && + if account.Domain != "" && account.DomainCategory == types.PrivateCategory && account.IsDomainPrimaryAccount { store.PrivateDomain2AccountID[account.Domain] = accountID } @@ -134,20 +136,20 @@ func restore(ctx context.Context, file string) (*FileStore, error) { policy.UpgradeAndFix() } if account.Policies == nil { - account.Policies = make([]*Policy, 0) + account.Policies = make([]*types.Policy, 0) } // for data migration. Can be removed once most base will be with labels - existingLabels := account.getPeerDNSLabels() + existingLabels := account.GetPeerDNSLabels() if len(existingLabels) != len(account.Peers) { - addPeerLabelsToAccount(ctx, account, existingLabels) + types.AddPeerLabelsToAccount(ctx, account, existingLabels) } // TODO: delete this block after migration // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI } } @@ -236,7 +238,7 @@ func (s *FileStore) persist(ctx context.Context, file string) error { } // GetAllAccounts returns all accounts -func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { +func (s *FileStore) GetAllAccounts(_ context.Context) (all []*types.Account) { s.mux.Lock() defer s.mux.Unlock() for _, a := range s.Accounts { @@ -257,6 +259,6 @@ func (s *FileStore) Close(ctx context.Context) error { } // GetStoreEngine returns FileStoreEngine -func (s *FileStore) GetStoreEngine() StoreEngine { +func (s *FileStore) GetStoreEngine() Engine { return FileStoreEngine } diff --git a/management/server/sql_store.go b/management/server/store/sql_store.go similarity index 76% rename from management/server/sql_store.go rename to management/server/store/sql_store.go index 1fd8ae2aabe..62b004f9c94 100644 --- a/management/server/sql_store.go +++ b/management/server/store/sql_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -24,11 +24,14 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -49,7 +52,7 @@ type SqlStore struct { globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int - storeEngine StoreEngine + storeEngine Engine } type installation struct { @@ -60,7 +63,7 @@ type installation struct { type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. -func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) { sql, err := db.DB() if err != nil { return nil, err @@ -86,9 +89,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr return nil, fmt.Errorf("migrate: %w", err) } err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, - &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, + &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, + &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, ) if err != nil { return nil, fmt.Errorf("auto migrate: %w", err) @@ -151,7 +155,7 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u return unlock } -func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error { start := time.Now() defer func() { elapsed := time.Since(start) @@ -201,7 +205,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { } // generateAccountSQLTypes generates the GORM compatible types for the account -func generateAccountSQLTypes(account *Account) { +func generateAccountSQLTypes(account *types.Account) { for _, key := range account.SetupKeys { account.SetupKeysG = append(account.SetupKeysG, *key) } @@ -238,7 +242,7 @@ func generateAccountSQLTypes(account *Account) { // checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { - var acc Account + var acc types.Account var domain string result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) if result.Error != nil { @@ -252,7 +256,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, } } -func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() err := s.db.Transaction(func(tx *gorm.DB) error { @@ -333,14 +337,14 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. } func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { - accountCopy := Account{ + accountCopy := types.Account{ Domain: domain, DomainCategory: category, IsDomainPrimaryAccount: isPrimaryDomain, } fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.Model(&Account{}). + result := s.db.Model(&types.Account{}). Select(fieldsToUpdate). Where(idQueryCondition, accountID). Updates(&accountCopy) @@ -402,8 +406,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P // SaveUsers saves the given list of users to the database. // It updates existing users if a conflict occurs. -func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { - usersToSave := make([]User, 0, len(users)) +func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error { + usersToSave := make([]types.User, 0, len(users)) for _, user := range users { user.AccountID = accountID for id, pat := range user.PATs { @@ -423,7 +427,7 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { } // SaveUser saves the given user to the database. -func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) @@ -432,7 +436,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u } // SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error { if len(groups) == 0 { return nil } @@ -454,7 +458,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { return nil } -func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { +func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if err != nil { return nil, err @@ -466,9 +470,9 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory, + strings.ToLower(domain), true, types.PrivateCategory, ).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -481,8 +485,8 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return accountID, nil } -func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { - var key SetupKey +func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) { + var key types.SetupKey result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -500,7 +504,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* } func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { - var token PersonalAccessToken + var token types.PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -513,8 +517,8 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return token.ID, nil } -func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { - var token PersonalAccessToken +func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) { + var token types.PersonalAccessToken result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -528,13 +532,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - var user User + var user types.User result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID) if result.Error != nil { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG)) + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG)) for _, pat := range user.PATsG { user.PATs[pat.ID] = pat.Copy() } @@ -542,8 +546,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return &user, nil } -func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { - var user User +func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { + var user types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).First(&user, idQueryCondition, userID) if result.Error != nil { @@ -556,8 +560,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { - var users []*User +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { + var users []*types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -570,8 +574,8 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -584,8 +588,27 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr return groups, nil } -func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { - var accounts []Account +func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + var groups []*types.Group + + likePattern := `%"ID":"` + resourceID + `"%` + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("resources LIKE ?", likePattern). + Find(&groups) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, result.Error + } + + return groups, nil +} + +func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { + var accounts []types.Account result := s.db.Find(&accounts) if result.Error != nil { return all @@ -600,7 +623,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { return all } -func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -609,7 +632,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } }() - var account Account + var account types.Account result := s.db.Model(&account). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). @@ -624,15 +647,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us for i, policy := range account.Policies { - var rules []*PolicyRule - err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error if err != nil { return nil, status.Errorf(status.NotFound, "rule not found") } account.Policies[i].Rules = rules } - account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG)) + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { account.SetupKeys[key.Key] = key.Copy() } @@ -644,9 +667,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.PeersG = nil - account.Users = make(map[string]*User, len(account.UsersG)) + account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs)) + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { user.PATs[pat.ID] = pat.Copy() } @@ -654,7 +677,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.UsersG = nil - account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } @@ -675,8 +698,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, return &account, nil } -func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { - var user User +func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { + var user types.User result := s.db.Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -692,7 +715,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun return s.GetAccount(ctx, user.AccountID) } -func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { @@ -709,7 +732,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { @@ -742,7 +765,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { var accountID string - result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Model(&types.User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -755,7 +778,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) + result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewSetupKeyNotFoundError(setupKey) @@ -815,9 +838,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return labels, nil } -func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - var accountNetwork AccountNetwork - if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { +func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + var accountNetwork types.AccountNetwork + if err := s.db.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -839,9 +862,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking return &peer, nil } -func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { - var accountSettings AccountSettings - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { +func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + var accountSettings types.AccountSettings + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } @@ -852,7 +875,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - var user User + var user types.User result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -890,7 +913,7 @@ func (s *SqlStore) Close(_ context.Context) error { } // GetStoreEngine returns underlying store engine -func (s *SqlStore) GetStoreEngine() StoreEngine { +func (s *SqlStore) GetStoreEngine() Engine { return s.storeEngine } @@ -982,8 +1005,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } -func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - var setupKey SetupKey +func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { + var setupKey types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, keyQueryCondition, key) if result.Error != nil { @@ -997,7 +1020,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - result := s.db.Model(&SetupKey{}). + result := s.db.Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -1015,8 +1038,9 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string return nil } +// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1040,8 +1064,9 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer return nil } +// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1066,6 +1091,59 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction +func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error { + var group types.Group + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.NewGroupNotFoundError(groupID) + } + + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) + } + + for _, res := range group.Resources { + if res.ID == resource.ID { + return nil + } + } + + group.Resources = append(group.Resources, *resource) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group: %s", err) + } + + return nil +} + +// RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction +func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error { + var group types.Group + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.NewGroupNotFoundError(groupID) + } + + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) + } + + for i, res := range group.Resources { + if res.ID == resourceID { + group.Resources = append(group.Resources[:i], group.Resources[i+1:]...) + break + } + } + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group: %s", err) + } + + return nil +} + // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) @@ -1114,7 +1192,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -1156,9 +1234,9 @@ func (s *SqlStore) GetDB() *gorm.DB { return s.db } -func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { - var accountDNSSettings AccountDNSSettings - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { + var accountDNSSettings types.AccountDNSSettings + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1173,7 +1251,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1187,8 +1265,8 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { - var account Account - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + var account types.Account + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1201,8 +1279,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { - var group *nbgroup.Group +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { + var group *types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1216,8 +1294,8 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { - var group nbgroup.Group +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { + var group types.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. @@ -1240,15 +1318,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } // GetGroupsByIDs retrieves groups by their IDs and account ID. -func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } - groupsMap := make(map[string]*nbgroup.Group) + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -1257,7 +1335,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } // SaveGroup saves a group to the store. -func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) @@ -1269,7 +1347,7 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, // DeleteGroup deletes a group from the database. func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete group from store") @@ -1285,7 +1363,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store") @@ -1295,8 +1373,8 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a } // GetAccountPolicies retrieves policies for an account. -func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - var policies []*Policy +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) { + var policies []*types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1308,8 +1386,8 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { - var policy *Policy +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) { + var policy *types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). First(&policy, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { @@ -1323,7 +1401,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return policy, nil } -func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) @@ -1334,7 +1412,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt } // SavePolicy saves a policy to the database. -func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) if err := result.Error; err != nil { @@ -1346,7 +1424,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) return status.Errorf(status.Internal, "failed to delete policy from store") @@ -1442,8 +1520,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt } // GetAccountSetupKeys retrieves setup keys for an account. -func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - var setupKeys []*SetupKey +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) { + var setupKeys []*types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&setupKeys, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1455,8 +1533,8 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { - var setupKey *SetupKey +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) { + var setupKey *types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { @@ -1471,7 +1549,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre } // SaveSetupKey saves a setup key to the database. -func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) @@ -1483,7 +1561,7 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt // DeleteSetupKey deletes a setup key from the database. func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete setup key from store") @@ -1583,9 +1661,9 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } // SaveDNSSettings saves the DNS settings to the store. -func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings}) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save dns settings to store") @@ -1597,3 +1675,198 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre return nil } + +func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { + var networks []*networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get networks from store") + } + + return networks, nil +} + +func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) { + var network *networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&network, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkNotFoundError(networkID) + } + + log.WithContext(ctx).Errorf("failed to get network from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network from store") + } + + return network, nil +} + +func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkNotFoundError(networkID) + } + + return nil +} + +func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) { + var netRouter *routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netRouter, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkRouterNotFoundError(routerID) + } + log.WithContext(ctx).Errorf("failed to get network router from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network router from store") + } + + return netRouter, nil +} + +func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network router to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network router from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkRouterNotFoundError(routerID) + } + + return nil +} + +func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) { + var netResources *resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netResources, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkResourceNotFoundError(resourceID) + } + log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resource from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) { + var netResources *resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkResourceNotFoundError(resourceName) + } + log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resource from store") + } + + return netResources, nil +} + +func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network resource to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network resource from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkResourceNotFoundError(resourceID) + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/store/sql_store_test.go similarity index 75% rename from management/server/sql_store_test.go rename to management/server/store/sql_store_test.go index 6064b019f29..845bc8fd474 100644 --- a/management/server/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -14,17 +14,24 @@ import ( "time" "github.com/google/uuid" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/posture" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" ) func TestSqlite_NewStore(t *testing.T) { @@ -73,7 +80,7 @@ func runLargeTest(t *testing.T, store Store) { if err != nil { t.Fatal(err) } - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { @@ -86,14 +93,14 @@ func runLargeTest(t *testing.T, store Store) { IP: netIP, Name: peerID, DNSLabel: peerID, - UserID: userID, + UserID: "testuser", Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } account.Peers[peerID] = peer group, _ := account.GetGroupAll() group.Peers = append(group.Peers, peerID) - user := &User{ + user := &types.User{ Id: fmt.Sprintf("%s-user-%d", account.Id, n), AccountID: account.Id, } @@ -111,7 +118,7 @@ func runLargeTest(t *testing.T, store Store) { } account.Routes[route.ID] = route - group = &nbgroup.Group{ + group = &types.Group{ ID: fmt.Sprintf("group-id-%d", n), AccountID: account.Id, Name: fmt.Sprintf("group-id-%d", n), @@ -134,7 +141,7 @@ func runLargeTest(t *testing.T, store Store) { } account.NameServerGroups[nameserver.ID] = nameserver - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey } @@ -216,7 +223,7 @@ func TestSqlite_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -230,7 +237,7 @@ func TestSqlite_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -289,14 +296,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -306,6 +313,35 @@ func TestSqlite_DeleteAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user + account.Networks = []*networkTypes.Network{ + { + ID: "network_id", + AccountID: account.Id, + Name: "network name", + Description: "network description", + }, + } + account.NetworkRouters = []*routerTypes.NetworkRouter{ + { + ID: "router_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + PeerGroups: []string{"group_id"}, + Masquerade: true, + Metric: 1, + }, + } + account.NetworkResources = []*resourceTypes.NetworkResource{ + { + ID: "resource_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + Name: "Name", + Description: "Description", + Type: "Domain", + Address: "example.com", + }, + } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -337,21 +373,30 @@ func TestSqlite_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") } + for _, network := range account.Networks { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network routers") + require.Len(t, routers, 0, "expecting no network routers to be found after DeleteAccount") + + resources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources") + require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount") + } } func TestSqlite_GetAccount(t *testing.T) { @@ -360,7 +405,7 @@ func TestSqlite_GetAccount(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -383,7 +428,7 @@ func TestSqlite_SavePeer(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -433,7 +478,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -488,7 +533,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -542,7 +587,7 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -565,7 +610,7 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -589,7 +634,7 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -625,7 +670,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, err, "Failed to parse CIDR") type network struct { - Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } @@ -640,7 +685,7 @@ func TestMigrate(t *testing.T) { } type account struct { - Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` Peers []peer `gorm:"foreignKey:AccountID;references:id"` } @@ -700,23 +745,10 @@ func TestMigrate(t *testing.T) { } -func newSqliteStore(t *testing.T) *SqlStore { - t.Helper() - - store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - t.Cleanup(func() { - store.Close(context.Background()) - }) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ Key: "peerkey" + str, @@ -755,7 +787,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -769,7 +801,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -828,14 +860,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -876,16 +908,16 @@ func TestPostgresql_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -899,7 +931,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -940,7 +972,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -960,7 +992,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -978,7 +1010,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -991,7 +1023,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { func TestSqlite_GetTakenIPs(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { t.Fatal(err) @@ -1036,7 +1068,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) if err != nil { return } @@ -1078,7 +1110,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { func TestSqlite_GetAccountNetwork(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1101,7 +1133,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { func TestSqlite_GetSetupKeyBySecret(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1119,14 +1151,14 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, encodedHashedKey, setupKey.Key) - assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret) + assert.Equal(t, types.HiddenKey(plainKey, 4), setupKey.KeySecret) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "Default key", setupKey.Name) } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1162,13 +1194,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) } - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: "account-id", Name: "group-name", @@ -1193,7 +1225,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { } func TestSqlite_GetAccoundUsers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1207,7 +1239,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { } func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1253,7 +1285,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { } func TestSqlite_GetGroupByName(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1267,7 +1299,7 @@ func TestSqlite_GetGroupByName(t *testing.T) { func Test_DeleteSetupKeySuccessfully(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1283,7 +1315,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1295,7 +1327,7 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { } func TestSqlStore_GetGroupsByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1338,13 +1370,13 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) { } func TestSqlStore_SaveGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: accountID, Issued: "api", @@ -1359,13 +1391,13 @@ func TestSqlStore_SaveGroup(t *testing.T) { } func TestSqlStore_SaveGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - groups := []*nbgroup.Group{ + groups := []*types.Group{ { ID: "group-1", AccountID: accountID, @@ -1384,7 +1416,7 @@ func TestSqlStore_SaveGroups(t *testing.T) { } func TestSqlStore_DeleteGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1432,7 +1464,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) { } func TestSqlStore_DeleteGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1479,7 +1511,7 @@ func TestSqlStore_DeleteGroups(t *testing.T) { } func TestSqlStore_GetPeerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1525,7 +1557,7 @@ func TestSqlStore_GetPeerByID(t *testing.T) { } func TestSqlStore_GetPeersByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1567,7 +1599,7 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { } func TestSqlStore_GetPostureChecksByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1613,7 +1645,7 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1656,7 +1688,7 @@ func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { } func TestSqlStore_SavePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1697,7 +1729,7 @@ func TestSqlStore_SavePostureChecks(t *testing.T) { } func TestSqlStore_DeletePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1744,7 +1776,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { } func TestSqlStore_GetPolicyByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1790,23 +1822,23 @@ func TestSqlStore_GetPolicyByID(t *testing.T) { } func TestSqlStore_CreatePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - policy := &Policy{ + policy := &types.Policy{ ID: "policy-id", AccountID: accountID, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1820,7 +1852,7 @@ func TestSqlStore_CreatePolicy(t *testing.T) { } func TestSqlStore_SavePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1843,7 +1875,7 @@ func TestSqlStore_SavePolicy(t *testing.T) { } func TestSqlStore_DeletePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1859,7 +1891,7 @@ func TestSqlStore_DeletePolicy(t *testing.T) { } func TestSqlStore_GetDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1903,7 +1935,7 @@ func TestSqlStore_GetDNSSettings(t *testing.T) { } func TestSqlStore_SaveDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1922,7 +1954,7 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { } func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1959,7 +1991,7 @@ func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { } func TestSqlStore_GetNameServerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2005,7 +2037,7 @@ func TestSqlStore_GetNameServerByID(t *testing.T) { } func TestSqlStore_SaveNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2037,7 +2069,7 @@ func TestSqlStore_SaveNameServerGroup(t *testing.T) { } func TestSqlStore_DeleteNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2051,3 +2083,481 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Error(t, err) require.Nil(t, nsGroup) } + +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[nbroute.ID]*nbroute.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + }, + } + + if err := addAllGroup(acc); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} + +// addAllGroup to account object if it doesn't exist +func addAllGroup(account *types.Account) error { + if len(account.Groups) == 0 { + allGroup := &types.Group{ + ID: xid.New().String(), + Name: "All", + Issued: types.GroupIssuedAPI, + } + for _, peer := range account.Peers { + allGroup.Peers = append(allGroup.Peers, peer.ID) + } + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} + + id := xid.New().String() + + defaultPolicy := &types.Policy{ + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Sources: []string{allGroup.ID}, + Destinations: []string{allGroup.ID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + + account.Policies = []*types.Policy{defaultPolicy} + } + return nil +} + +func TestSqlStore_GetAccountNetworks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve networks by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + + { + name: "retrieve networks by non-existing account ID", + accountID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networks, err := store.GetAccountNetworks(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, networks, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkID string + expectError bool + }{ + { + name: "retrieve existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectError: false, + }, + { + name: "retrieve non-existing network ID", + networkID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty ID", + networkID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, network) + } else { + require.NoError(t, err) + require.NotNil(t, network) + require.Equal(t, tt.networkID, network.ID) + } + }) + } +} + +func TestSqlStore_SaveNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + network := &networkTypes.Network{ + ID: "net-id", + AccountID: accountID, + Name: "net", + } + + err = store.SaveNetwork(context.Background(), LockingStrengthUpdate, network) + require.NoError(t, err) + + savedNet, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, network.ID) + require.NoError(t, err) + require.Equal(t, network, savedNet) +} + +func TestSqlStore_DeleteNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + err = store.DeleteNetwork(context.Background(), LockingStrengthUpdate, accountID, networkID) + require.NoError(t, err) + + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, networkID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, network) +} + +func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve routers by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve routers by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, routers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkRouterByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkRouterID string + expectError bool + }{ + { + name: "retrieve existing network router ID", + networkRouterID: "ctc20ji7qv9ck2sebc80", + expectError: false, + }, + { + name: "retrieve non-existing network router ID", + networkRouterID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty router ID", + networkRouterID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networkRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, tt.networkRouterID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, networkRouter) + } else { + require.NoError(t, err) + require.NotNil(t, networkRouter) + require.Equal(t, tt.networkRouterID, networkRouter.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0) + require.NoError(t, err) + + err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter) + require.NoError(t, err) + + savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, netRouter.ID) + require.NoError(t, err) + require.Equal(t, netRouter, savedNetRouter) +} + +func TestSqlStore_DeleteNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netRouterID := "ctc20ji7qv9ck2sebc80" + + err = store.DeleteNetworkRouter(context.Background(), LockingStrengthUpdate, accountID, netRouterID) + require.NoError(t, err) + + netRouter, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netRouterID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netRouter) +} + +func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve resources by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve resources by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, netResources, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkResourceByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + netResourceID string + expectError bool + }{ + { + name: "retrieve existing network resource ID", + netResourceID: "ctc4nci7qv9061u6ilfg", + expectError: false, + }, + { + name: "retrieve non-existing network resource ID", + netResourceID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty resource ID", + netResourceID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, tt.netResourceID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, netResource) + } else { + require.NoError(t, err) + require.NotNil(t, netResource) + require.Equal(t, tt.netResourceID, netResource.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{}) + require.NoError(t, err) + + err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) + require.NoError(t, err) + + savedNetResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, netResource.ID) + require.NoError(t, err) + require.Equal(t, netResource.ID, savedNetResource.ID) + require.Equal(t, netResource.Name, savedNetResource.Name) + require.Equal(t, netResource.NetworkID, savedNetResource.NetworkID) + require.Equal(t, netResource.Type, resourceTypes.NetworkResourceType("domain")) + require.Equal(t, netResource.Domain, "example.com") + require.Equal(t, netResource.AccountID, savedNetResource.AccountID) + require.Equal(t, netResource.Prefix, netip.Prefix{}) +} + +func TestSqlStore_DeleteNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netResourceID := "ctc4nci7qv9061u6ilfg" + + err = store.DeleteNetworkResource(context.Background(), LockingStrengthUpdate, accountID, netResourceID) + require.NoError(t, err) + + netResource, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netResourceID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netResource) +} + +func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + resourceId := "ctc4nci7qv9061u6ilfg" + groupID := "cs1tnh0hhcjnqoiuebeg" + + res := &types.Resource{ + ID: resourceId, + Type: "host", + } + err = store.AddResourceToGroup(context.Background(), accountID, groupID, res) + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.Contains(t, group.Resources, *res) + + groups, err := store.GetResourceGroups(context.Background(), LockingStrengthShare, accountID, resourceId) + require.NoError(t, err) + require.Len(t, groups, 1) + + err = store.RemoveResourceFromGroup(context.Background(), accountID, groupID, res.ID) + require.NoError(t, err) + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.NotContains(t, group.Resources, *res) +} diff --git a/management/server/store.go b/management/server/store/store.go similarity index 73% rename from management/server/store.go rename to management/server/store/store.go index b16ad8a1aa4..d9dc6b8f7b8 100644 --- a/management/server/store.go +++ b/management/server/store/store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -18,13 +18,15 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" - - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/migration" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/testutil" @@ -41,49 +43,50 @@ const ( ) type Store interface { - GetAllAccounts(ctx context.Context) []*Account - GetAccount(ctx context.Context, accountID string) (*Account, error) + GetAllAccounts(ctx context.Context) []*types.Account + GetAccount(ctx context.Context, accountID string) (*types.Account, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) - GetAccountByUser(ctx context.Context, userID string) (*Account, error) - GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) + GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) + GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) - GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) - GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later - GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later + GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) - SaveAccount(ctx context.Context, account *Account) error - DeleteAccount(ctx context.Context, account *Account) error + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) + SaveAccount(ctx context.Context, account *types.Account) error + DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error - SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error - GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) - GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) - SaveUsers(accountID string, users map[string]*User) error - SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error + GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) + GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) + SaveUsers(accountID string, users map[string]*types.User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) + GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error - GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) - CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error - SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -96,6 +99,8 @@ type Store interface { GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error + RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) @@ -105,11 +110,11 @@ type Store interface { SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error - GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) - SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) @@ -122,7 +127,7 @@ type Store interface { GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error - GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error @@ -136,30 +141,48 @@ type Store interface { // Close should close the store persisting all unsaved data. Close(ctx context.Context) error - // GetStoreEngine should return StoreEngine of the current store implementation. + // GetStoreEngine should return Engine of the current store implementation. // This is also a method of metrics.DataSource interface. - GetStoreEngine() StoreEngine + GetStoreEngine() Engine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + + GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) + GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) + SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error + DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error + + GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) + SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error + DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error + + GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) + GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) + SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error + DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error } -type StoreEngine string +type Engine string const ( - FileStoreEngine StoreEngine = "jsonfile" - SqliteStoreEngine StoreEngine = "sqlite" - PostgresStoreEngine StoreEngine = "postgres" + FileStoreEngine Engine = "jsonfile" + SqliteStoreEngine Engine = "sqlite" + PostgresStoreEngine Engine = "postgres" postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" ) -func getStoreEngineFromEnv() StoreEngine { +func getStoreEngineFromEnv() Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") if !ok { return "" } - value := StoreEngine(strings.ToLower(kind)) + value := Engine(strings.ToLower(kind)) if value == SqliteStoreEngine || value == PostgresStoreEngine { return value } @@ -171,7 +194,7 @@ func getStoreEngineFromEnv() StoreEngine { // If no engine is specified, it attempts to retrieve it from the environment. // If still not specified, it defaults to using SQLite. // Additionally, it handles the migration from a JSON store file to SQLite if applicable. -func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine { +func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { if kind == "" { kind = getStoreEngineFromEnv() if kind == "" { @@ -197,7 +220,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) Store } // NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics -func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { +func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { kind = getStoreEngine(ctx, dataDir, kind) if err := checkFileStoreEngine(kind, dataDir); err != nil { @@ -216,7 +239,7 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel } } -func checkFileStoreEngine(kind StoreEngine, dataDir string) error { +func checkFileStoreEngine(kind Engine, dataDir string) error { if kind == FileStoreEngine { storeFile := filepath.Join(dataDir, storeFileName) if util.FileExists(storeFile) { @@ -243,7 +266,7 @@ func migrate(ctx context.Context, db *gorm.DB) error { func getMigrations(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net") + return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net") }, func(db *gorm.DB) error { return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network") @@ -258,7 +281,7 @@ func getMigrations(ctx context.Context) []migrationFunc { return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, func(db *gorm.DB) error { - return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db) + return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db) }, } } diff --git a/management/server/store_test.go b/management/server/store/store_test.go similarity index 93% rename from management/server/store_test.go rename to management/server/store/store_test.go index fc821670d65..1d0026e3def 100644 --- a/management/server/store_test.go +++ b/management/server/store/store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -76,11 +76,3 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } - -func newStore(t *testing.T) Store { - t.Helper() - - store := newSqliteStore(t) - - return store -} diff --git a/management/server/testdata/networks.sql b/management/server/testdata/networks.sql new file mode 100644 index 00000000000..8138ce520d0 --- /dev/null +++ b/management/server/testdata/networks.sql @@ -0,0 +1,18 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); + +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); + +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO networks VALUES('testNetworkId','testAccountId','some-name','some-description'); + +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','','["csquuo4jcko732k1ag00"]',0,9999); + +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32'); +INSERT INTO network_resources VALUES('anotherTestResourceId','testNetworkId','testAccountId','used-name','some-description','host','3.3.3.3/32'); diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 168973cad91..7f0c7b5a4fd 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -12,6 +12,9 @@ CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`)); +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`type` text,`address` text,PRIMARY KEY (`id`)); +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); CREATE INDEX `idx_peers_key` ON `peers`(`key`); @@ -24,6 +27,14 @@ CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); +CREATE INDEX `idx_network_routers_id` ON `network_routers`(`id`); +CREATE INDEX `idx_network_routers_account_id` ON `network_routers`(`account_id`); +CREATE INDEX `idx_network_routers_network_id` ON `network_routers`(`network_id`); +CREATE INDEX `idx_network_resources_account_id` ON `network_resources`(`account_id`); +CREATE INDEX `idx_network_resources_network_id` ON `network_resources`(`network_id`); +CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`); +CREATE INDEX `idx_networks_id` ON `networks`(`id`); +CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); @@ -34,3 +45,6 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO installations VALUES(1,''); INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); +INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); +INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); +INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); diff --git a/management/server/types/account.go b/management/server/types/account.go new file mode 100644 index 00000000000..5d0f1201984 --- /dev/null +++ b/management/server/types/account.go @@ -0,0 +1,1475 @@ +package types + +import ( + "context" + "fmt" + "net" + "net/netip" + "slices" + "strconv" + "strings" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/route" +) + +const ( + defaultTTL = 300 + DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute + + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" +) + +type LookupMap map[string]struct{} + +// Account represents a unique account of the system +type Account struct { + // we have to name column to aid as it collides with Network.Id when work with associations + Id string `gorm:"primaryKey"` + + // User.Id it was created by + CreatedBy string + CreatedAt time.Time + Domain string `gorm:"index"` + DomainCategory string + IsDomainPrimaryAccount bool + SetupKeys map[string]*SetupKey `gorm:"-"` + SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` + Network *Network `gorm:"embedded;embeddedPrefix:network_"` + Peers map[string]*nbpeer.Peer `gorm:"-"` + PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + Users map[string]*User `gorm:"-"` + UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*Group `gorm:"-"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` + Routes map[route.ID]*route.Route `gorm:"-"` + RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` + NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` + NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` + PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` + // Settings is a dictionary of Account settings + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` + + Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` + NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` + NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` +} + +// Subclass used in gorm to only load network and not whole account +type AccountNetwork struct { + Network *Network `gorm:"embedded;embeddedPrefix:network_"` +} + +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + +// Subclass used in gorm to only load settings and not whole account +type AccountSettings struct { + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` +} + +// GetRoutesToSync returns the enabled routes for the peer ID and the routes +// from the ACL peers that have distribution groups associated with the peer ID. +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + groupListMap := a.GetPeerGroups(peerID) + for _, peer := range aclPeers { + activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) + groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) + filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership +func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map +func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +// If the given is not a routing peer, then the lists are empty. +func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + + peer := a.GetPeer(peerID) + if peer == nil { + log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + return enabledRoutes, disabledRoutes + } + + // currently we support only linux routing peers + if peer.Meta.GoOS != "linux" { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + for _, r := range a.Routes { + for _, groupID := range r.PeerGroups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := r.Copy() + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map + takeRoute(newPeerRoute, id) + break + } + } + if r.Peer == peerID { + takeRoute(r.Copy(), peerID) + } + } + + return enabledRoutes, disabledRoutes +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { + var routes []*route.Route + for _, r := range a.Routes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes +} + +// GetGroup returns a group by ID if exists, nil otherwise +func (a *Account) GetGroup(groupID string) *Group { + return a.Groups[groupID] +} + +// GetPeerNetworkMap returns the networkmap for the given peer ID. +func (a *Account) GetPeerNetworkMap( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + + peer := a.Peers[peerID] + if peer == nil { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) + // exclude expired peers + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + for _, p := range aclPeers { + expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) + if a.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) + isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) + } + peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + + nm := &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: a.Network.Copy(), + Routes: slices.Concat(networkResourcesRoutes, routesUpdate), + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), + } + + if metrics != nil { + objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ + "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", + a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) + } + } + + return nm +} + +func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { + missingPeers := map[string]struct{}{} + for _, r := range networkResourcesRoutes { + if r.Peer == peer.Key { + continue + } + + missing := true + for _, p := range slices.Concat(peersToConnect, expiredPeers) { + if r.Peer == p.Key { + missing = false + break + } + } + if missing { + missingPeers[r.Peer] = struct{}{} + } + } + + if isRouter { + for _, s := range sourcePeers { + if s == peer.ID { + continue + } + + missing := true + for _, p := range slices.Concat(peersToConnect, expiredPeers) { + if s == p.ID { + missing = false + break + } + } + if missing { + p, ok := a.Peers[s] + if ok { + missingPeers[p.Key] = struct{}{} + } + } + } + } + + for p := range missingPeers { + for _, p2 := range a.Peers { + if p2.Key == p { + peersToConnect = append(peersToConnect, p2) + break + } + } + } + return peersToConnect +} + +func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { + groupList := account.GetPeerGroups(peerID) + + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := groupList[gID] + if found { + if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +// peerIsNameserver returns true if the peer is a nameserver for a nsGroup +func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peer.IP.Equal(ns.IP.AsSlice()) { + return true + } + } + return false +} + +func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { + for _, peer := range account.Peers { + label, err := GetPeerHostLabel(peer.Name, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) + label, err = GetPeerHostLabel(peer.Meta.Hostname, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) + continue + } + } + peer.DNSLabel = label + peerLabels[label] = struct{}{} + } +} + +func GetPeerHostLabel(name string, peerLabels LookupMap) (string, error) { + label, err := nbdns.GetParsedDomainLabel(name) + if err != nil { + return "", err + } + + uniqueLabel := getUniqueHostLabel(label, peerLabels) + if uniqueLabel == "" { + return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) + } + return uniqueLabel, nil +} + +// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 +func getUniqueHostLabel(name string, peerLabels LookupMap) string { + _, found := peerLabels[name] + if !found { + return name + } + for i := 1; i < 1000; i++ { + nameWithSuffix := name + "-" + strconv.Itoa(i) + _, found = peerLabels[nameWithSuffix] + if !found { + return nameWithSuffix + } + } + return "" +} + +func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { + var merr *multierror.Error + + if dnsDomain == "" { + log.WithContext(ctx).Error("no dns domain is set, returning empty zone") + return nbdns.CustomZone{} + } + + customZone := nbdns.CustomZone{ + Domain: dns.Fqdn(dnsDomain), + Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), + } + + domainSuffix := "." + dnsDomain + + var sb strings.Builder + for _, peer := range a.Peers { + if peer.DNSLabel == "" { + merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) + continue + } + + sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) + sb.WriteString(peer.DNSLabel) + sb.WriteString(domainSuffix) + + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: sb.String(), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IP.String(), + }) + + sb.Reset() + } + + go func() { + if merr != nil { + log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) + } + }() + + return customZone +} + +// GetExpiredPeers returns peers that have been expired +func (a *Account) GetExpiredPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.GetPeersWithExpiration() { + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers +} + +// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithExpiration() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetPeers returns a list of all Account peers +func (a *Account) GetPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.Peers { + peers = append(peers, peer) + } + return peers +} + +// UpdateSettings saves new account settings +func (a *Account) UpdateSettings(update *Settings) *Account { + a.Settings = update.Copy() + return a +} + +// UpdatePeer saves new or replaces existing peer +func (a *Account) UpdatePeer(update *nbpeer.Peer) { + a.Peers[update.ID] = update +} + +// DeletePeer deletes peer from the account cleaning up all the references +func (a *Account) DeletePeer(peerID string) { + // delete peer from groups + for _, g := range a.Groups { + for i, pk := range g.Peers { + if pk == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + break + } + } + } + + for _, r := range a.Routes { + if r.Peer == peerID { + r.Enabled = false + r.Peer = "" + } + } + + for i, r := range a.NetworkRouters { + if r.Peer == peerID { + a.NetworkRouters = append(a.NetworkRouters[:i], a.NetworkRouters[i+1:]...) + break + } + } + + delete(a.Peers, peerID) + a.Network.IncSerial() +} + +func (a *Account) DeleteResource(resourceID string) { + // delete resource from groups + for _, g := range a.Groups { + for i, pk := range g.Resources { + if pk.ID == resourceID { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + break + } + } + } +} + +// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. +// It will return an object copy of the peer. +func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { + for _, peer := range a.Peers { + if peer.Key == peerPubKey { + return peer.Copy(), nil + } + } + + return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) +} + +// FindUserPeers returns a list of peers that user owns (created) +func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.UserID == userID { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// FindUser looks for a given user in the Account or returns error if user wasn't found. +func (a *Account) FindUser(userID string) (*User, error) { + user := a.Users[userID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user %s not found", userID) + } + + return user, nil +} + +// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. +func (a *Account) FindGroupByName(groupName string) (*Group, error) { + for _, group := range a.Groups { + if group.Name == groupName { + return group, nil + } + } + return nil, status.Errorf(status.NotFound, "group %s not found", groupName) +} + +// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. +func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { + key := a.SetupKeys[setupKey] + if key == nil { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + + return key, nil +} + +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + +func (a *Account) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := a.GetPeerGroups(peerID) + enabled := true + for _, groupID := range a.DNSSettings.DisabledManagementGroups { + _, found := peerGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (a *Account) GetPeerGroups(peerID string) LookupMap { + groupList := make(LookupMap) + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + groupList[groupID] = struct{}{} + break + } + } + } + return groupList +} + +func (a *Account) GetTakenIPs() []net.IP { + var takenIps []net.IP + for _, existingPeer := range a.Peers { + takenIps = append(takenIps, existingPeer.IP) + } + + return takenIps +} + +func (a *Account) GetPeerDNSLabels() LookupMap { + existingLabels := make(LookupMap) + for _, peer := range a.Peers { + if peer.DNSLabel != "" { + existingLabels[peer.DNSLabel] = struct{}{} + } + } + return existingLabels +} + +func (a *Account) Copy() *Account { + peers := map[string]*nbpeer.Peer{} + for id, peer := range a.Peers { + peers[id] = peer.Copy() + } + + users := map[string]*User{} + for id, user := range a.Users { + users[id] = user.Copy() + } + + setupKeys := map[string]*SetupKey{} + for id, key := range a.SetupKeys { + setupKeys[id] = key.Copy() + } + + groups := map[string]*Group{} + for id, group := range a.Groups { + groups[id] = group.Copy() + } + + policies := []*Policy{} + for _, policy := range a.Policies { + policies = append(policies, policy.Copy()) + } + + routes := map[route.ID]*route.Route{} + for id, r := range a.Routes { + routes[id] = r.Copy() + } + + nsGroups := map[string]*nbdns.NameServerGroup{} + for id, nsGroup := range a.NameServerGroups { + nsGroups[id] = nsGroup.Copy() + } + + dnsSettings := a.DNSSettings.Copy() + + var settings *Settings + if a.Settings != nil { + settings = a.Settings.Copy() + } + + postureChecks := []*posture.Checks{} + for _, postureCheck := range a.PostureChecks { + postureChecks = append(postureChecks, postureCheck.Copy()) + } + + nets := []*networkTypes.Network{} + for _, network := range a.Networks { + nets = append(nets, network.Copy()) + } + + networkRouters := []*routerTypes.NetworkRouter{} + for _, router := range a.NetworkRouters { + networkRouters = append(networkRouters, router.Copy()) + } + + networkResources := []*resourceTypes.NetworkResource{} + for _, resource := range a.NetworkResources { + networkResources = append(networkResources, resource.Copy()) + } + + return &Account{ + Id: a.Id, + CreatedBy: a.CreatedBy, + CreatedAt: a.CreatedAt, + Domain: a.Domain, + DomainCategory: a.DomainCategory, + IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, + SetupKeys: setupKeys, + Network: a.Network.Copy(), + Peers: peers, + Users: users, + Groups: groups, + Policies: policies, + Routes: routes, + NameServerGroups: nsGroups, + DNSSettings: dnsSettings, + PostureChecks: postureChecks, + Settings: settings, + Networks: nets, + NetworkRouters: networkRouters, + NetworkResources: networkResources, + } +} + +func (a *Account) GetGroupAll() (*Group, error) { + for _, g := range a.Groups { + if g.Name == "All" { + return g, nil + } + } + return nil, fmt.Errorf("no group ALL found") +} + +// GetPeer looks up a Peer by ID +func (a *Account) GetPeer(peerID string) *nbpeer.Peer { + return a.Peers[peerID] +} + +// UserGroupsAddToPeers adds groups to all peers of user +func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + userPeers := make(map[string]struct{}) + for pid, peer := range a.Peers { + if peer.UserID == userID { + userPeers[pid] = struct{}{} + } + } + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok { + continue + } + + oldPeers := group.Peers + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeers { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupUpdates[gid] = util.Difference(group.Peers, oldPeers) + } + + return groupUpdates +} + +// UserGroupsRemoveFromPeers removes groups from all peers of user +func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok || group.Name == "All" { + continue + } + + oldPeers := group.Peers + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + peer, ok := a.Peers[pid] + if !ok { + continue + } + if peer.UserID != userID { + update = append(update, pid) + } + } + group.Peers = update + groupUpdates[gid] = util.Difference(oldPeers, group.Peers) + } + + return groupUpdates +} + +// GetPeerConnectionResources for a given peer +// +// This function returns the list of peers and firewall rules that are applicable to a given peer. +func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + } + } + + return getAccumulatedResources() +} + +// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls +// +// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. +// It safe to call the generator function multiple times for same peer and different rules no duplicates will be +// generated. The accumulator function returns the result of all the generator calls. +func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + all, err := a.GetGroupAll() + if err != nil { + log.WithContext(ctx).Errorf("failed to get group all: %v", err) + all = &Group{} + } + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + isAll := (len(all.Peers) - 1) == len(groupPeers) + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = "0.0.0.0" + } + + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 { + rules = append(rules, &fr) + continue + } + + for _, port := range rule.Ports { + pr := fr // clone rule and add set new port + pr.Port = port + rules = append(rules, &pr) + } + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +// getAllPeersFromGroups for given peer ID and list of groups +// +// Returns a list of peers from specified groups that pass specified posture checks +// and a boolean indicating if the supplied peer ID exists within these groups. +// +// Important: Posture checks are applicable only to source group peers, +// for destination group peers, call this method with an empty list of sourcePostureChecksIDs +func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { + peerInGroups := false + filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) + for _, g := range groups { + group, ok := a.Groups[g] + if !ok { + continue + } + + for _, p := range group.Peers { + peer, ok := a.Peers[p] + if !ok || peer == nil { + continue + } + + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) + } + } + return filteredPeers, peerInGroups +} + +// validatePostureChecksOnPeer validates the posture checks on a peer +func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { + peer, ok := a.Peers[peerID] + if !ok && peer == nil { + return false + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := a.GetPostureChecks(postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.GetChecks() { + isValid, err := check.Check(ctx, *peer) + if err != nil { + log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) + } + if !isValid { + return false + } + } + } + return true +} + +func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { + for _, postureChecks := range a.PostureChecks { + if postureChecks.ID == postureChecksID { + return postureChecks + } + } + return nil +} + +// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := a.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + Domains: route.Domains, + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} + +// GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account. +func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + for _, route := range routes { + if route.Peer != peer.Key { + continue + } + resourceAppliedPolicies := resourcePolicies[route.GetResourceID()] + distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) + + rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + + return routesFirewallRules +} + +// getNetworkResourceGroups retrieves all groups associated with the given network resource. +func (a *Account) getNetworkResourceGroups(resourceID string) []*Group { + var networkResourceGroups []*Group + + for _, group := range a.Groups { + for _, resource := range group.Resources { + if resource.ID == resourceID { + networkResourceGroups = append(networkResourceGroups, group) + } + } + } + + return networkResourceGroups +} + +// GetResourcePoliciesMap returns a map of networks resource IDs and their associated policies. +func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { + resourcePolicies := make(map[string][]*Policy) + for _, resource := range a.NetworkResources { + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) + resourcePolicies[resource.ID] = resourceAppliedPolicies + } + return resourcePolicies +} + +// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. +func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { + var isRoutingPeer bool + var routes []*route.Route + var allSourcePeers []string + + for _, resource := range a.NetworkResources { + var addSourcePeers bool + + networkRoutingPeers, exists := routers[resource.NetworkID] + if exists { + if router, ok := networkRoutingPeers[peerID]; ok { + isRoutingPeer, addSourcePeers = true, true + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...) + } + } + + for _, policy := range resourcePolicies[resource.ID] { + for _, sourceGroup := range policy.SourceGroups() { + group := a.GetGroup(sourceGroup) + if group == nil { + log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id) + continue + } + + // routing peer should be able to connect with all source peers + if addSourcePeers { + allSourcePeers = append(allSourcePeers, group.Peers...) + } + + // add routes for the resource if the peer is in the distribution group + if slices.Contains(group.Peers, peerID) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) + } + } + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +// getNetworkResources filters and returns a list of network resources associated with the given network ID. +func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource { + var resources []*resourceTypes.NetworkResource + for _, resource := range a.NetworkResources { + if resource.NetworkID == networkID { + resources = append(resources, resource) + } + } + return resources +} + +// GetPoliciesForNetworkResource retrieves the list of policies that apply to a specific network resource. +// A policy is deemed applicable if its destination groups include any of the given network resource groups +// or if its destination resource explicitly matches the provided resource. +func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { + var resourceAppliedPolicies []*Policy + + networkResourceGroups := a.getNetworkResourceGroups(resourceId) + + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + if rule.DestinationResource.ID == resourceId { + resourceAppliedPolicies = append(resourceAppliedPolicies, policy) + break + } + + for _, group := range networkResourceGroups { + if slices.Contains(rule.Destinations, group.ID) { + resourceAppliedPolicies = append(resourceAppliedPolicies, policy) + break + } + } + } + } + + return resourceAppliedPolicies +} + +func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string { + networkResources := a.getNetworkResources(networkID) + + policiesIDs := map[string]struct{}{} + for _, resource := range networkResources { + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) + for _, policy := range resourceAppliedPolicies { + policiesIDs[policy.ID] = struct{}{} + } + } + + result := make([]string, 0, len(policiesIDs)) + for id := range policiesIDs { + result = append(result, id) + } + + return result +} + +// getNetworkResourcesRoutes convert the network resources list to routes list. +func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route { + resourceAppliedPolicies := resourcePolicies[resource.ID] + + var routes []*route.Route + // distribute the resource routes only if there is policy applied to it + if len(resourceAppliedPolicies) > 0 { + peer := a.GetPeer(peerId) + if peer != nil { + routes = append(routes, resource.ToRoute(peer, router)) + } + } + + return routes +} + +func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { + routers := make(map[string]map[string]*routerTypes.NetworkRouter) + + for _, router := range a.NetworkRouters { + if routers[router.NetworkID] == nil { + routers[router.NetworkID] = make(map[string]*routerTypes.NetworkRouter) + } + + if router.Peer != "" { + routers[router.NetworkID][router.Peer] = router + continue + } + + for _, peerGroup := range router.PeerGroups { + g := a.Groups[peerGroup] + if g != nil { + for _, peerID := range g.Peers { + routers[router.NetworkID][peerID] = router + } + } + } + } + + return routers +} + +// getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies. +func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} { + sourcePeers := make(map[string]struct{}) + + for _, policy := range policies { + for _, rule := range policy.Rules { + for _, sourceGroup := range rule.Sources { + group := groups[sourceGroup] + if group == nil { + continue + } + + for _, peer := range group.Peers { + sourcePeers[peer] = struct{}{} + } + } + } + } + + return sourcePeers +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go new file mode 100644 index 00000000000..c73421d1612 --- /dev/null +++ b/management/server/types/account_test.go @@ -0,0 +1,375 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/require" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" +) + +func setupTestAccount() *Account { + return &Account{ + Id: "accountID", + Peers: map[string]*nbpeer.Peer{ + "peer1": { + ID: "peer1", + AccountID: "accountID", + Key: "peer1Key", + }, + "peer2": { + ID: "peer2", + AccountID: "accountID", + Key: "peer2Key", + }, + "peer3": { + ID: "peer3", + AccountID: "accountID", + Key: "peer3Key", + }, + "peer11": { + ID: "peer11", + AccountID: "accountID", + Key: "peer11Key", + }, + "peer12": { + ID: "peer12", + AccountID: "accountID", + Key: "peer12Key", + }, + "peer21": { + ID: "peer21", + AccountID: "accountID", + Key: "peer21Key", + }, + "peer31": { + ID: "peer31", + AccountID: "accountID", + Key: "peer31Key", + }, + "peer32": { + ID: "peer32", + AccountID: "accountID", + Key: "peer32Key", + }, + "peer41": { + ID: "peer41", + AccountID: "accountID", + Key: "peer41Key", + }, + "peer51": { + ID: "peer51", + AccountID: "accountID", + Key: "peer51Key", + }, + "peer61": { + ID: "peer61", + AccountID: "accountID", + Key: "peer61Key", + }, + }, + Groups: map[string]*Group{ + "group1": { + ID: "group1", + Peers: []string{"peer11", "peer12"}, + Resources: []Resource{ + { + ID: "resource1ID", + Type: "Host", + }, + }, + }, + "group2": { + ID: "group2", + Peers: []string{"peer21"}, + Resources: []Resource{ + { + ID: "resource2ID", + Type: "Domain", + }, + }, + }, + "group3": { + ID: "group3", + Peers: []string{"peer31", "peer32"}, + Resources: []Resource{ + { + ID: "resource3ID", + Type: "Subnet", + }, + }, + }, + "group4": { + ID: "group4", + Peers: []string{"peer41"}, + Resources: []Resource{ + { + ID: "resource3ID", + Type: "Subnet", + }, + }, + }, + "group5": { + ID: "group5", + Peers: []string{"peer51"}, + }, + "group6": { + ID: "group6", + Peers: []string{"peer61"}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: "network1ID", + AccountID: "accountID", + Name: "network1", + }, + { + ID: "network2ID", + AccountID: "accountID", + Name: "network2", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "peer1", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router2ID", + NetworkID: "network2ID", + AccountID: "accountID", + Peer: "peer2", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router3ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "peer3", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router4ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group1"}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router5ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group2", "group3"}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router6ID", + NetworkID: "network2ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group4"}, + Masquerade: false, + Metric: 100, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1ID", + AccountID: "accountID", + NetworkID: "network1ID", + }, + { + ID: "resource2ID", + AccountID: "accountID", + NetworkID: "network2ID", + }, + { + ID: "resource3ID", + AccountID: "accountID", + NetworkID: "network1ID", + }, + { + ID: "resource4ID", + AccountID: "accountID", + NetworkID: "network1ID", + }, + }, + Policies: []*Policy{ + { + ID: "policy1ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule1ID", + Enabled: true, + Destinations: []string{"group1"}, + }, + }, + }, + { + ID: "policy2ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule2ID", + Enabled: true, + Destinations: []string{"group3"}, + }, + }, + }, + { + ID: "policy3ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule3ID", + Enabled: true, + Destinations: []string{"group2", "group4"}, + }, + }, + }, + { + ID: "policy4ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule4ID", + Enabled: true, + DestinationResource: Resource{ + ID: "resource4ID", + Type: "Host", + }, + }, + }, + }, + { + ID: "policy5ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule5ID", + Enabled: true, + }, + }, + }, + }, + } +} + +func Test_GetResourceRoutersMap(t *testing.T) { + account := setupTestAccount() + routers := account.GetResourceRoutersMap() + require.Equal(t, 2, len(routers)) + + require.Equal(t, 7, len(routers["network1ID"])) + require.NotNil(t, routers["network1ID"]["peer1"]) + require.NotNil(t, routers["network1ID"]["peer3"]) + require.NotNil(t, routers["network1ID"]["peer11"]) + require.NotNil(t, routers["network1ID"]["peer12"]) + require.NotNil(t, routers["network1ID"]["peer21"]) + require.NotNil(t, routers["network1ID"]["peer31"]) + require.NotNil(t, routers["network1ID"]["peer32"]) + + require.Equal(t, 2, len(routers["network2ID"])) + require.NotNil(t, routers["network2ID"]["peer2"]) + require.NotNil(t, routers["network2ID"]["peer41"]) +} + +func Test_GetResourcePoliciesMap(t *testing.T) { + account := setupTestAccount() + policies := account.GetResourcePoliciesMap() + require.Equal(t, 4, len(policies)) + require.Equal(t, 1, len(policies["resource1ID"])) + require.Equal(t, 1, len(policies["resource2ID"])) + require.Equal(t, 2, len(policies["resource3ID"])) + require.Equal(t, 1, len(policies["resource4ID"])) +} + +func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key"}, + {Peer: "peer3Key"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key"}, + } + expiredPeers := []*nbpeer.Peer{ + {Key: "peer4Key"}, + } + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 2) + require.Equal(t, "peer2Key", result[0].Key) + require.Equal(t, "peer3Key", result[1].Key) +} + +func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key"}, + } + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + +func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1Key"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key"}, + {Peer: "peer3Key"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key"}, + } + expiredPeers := []*nbpeer.Peer{ + {Key: "peer3Key"}, + } + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + +func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1"} + networkResourcesRoutes := []*route.Route{} + peersToConnect := []*nbpeer.Peer{} + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 0) +} diff --git a/management/server/types/dns_settings.go b/management/server/types/dns_settings.go new file mode 100644 index 00000000000..1d33bb9fbb2 --- /dev/null +++ b/management/server/types/dns_settings.go @@ -0,0 +1,16 @@ +package types + +// DNSSettings defines dns settings at the account level +type DNSSettings struct { + // DisabledManagementGroups groups whose DNS management is disabled + DisabledManagementGroups []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the DNS settings +func (d DNSSettings) Copy() DNSSettings { + settings := DNSSettings{ + DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), + } + copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) + return settings +} diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go new file mode 100644 index 00000000000..3d1b7e225ec --- /dev/null +++ b/management/server/types/firewall_rule.go @@ -0,0 +1,130 @@ +package types + +import ( + "context" + "fmt" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" +) + +const ( + FirewallRuleDirectionIN = 0 + FirewallRuleDirectionOUT = 1 +) + +// FirewallRule is a rule of the firewall. +type FirewallRule struct { + // PeerIP of the peer + PeerIP string + + // Direction of the traffic + Direction int + + // Action of the traffic + Action string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port string +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + Domains: route.Domains, + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(FirewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} diff --git a/management/server/group/group.go b/management/server/types/group.go similarity index 58% rename from management/server/group/group.go rename to management/server/types/group.go index 24c60d3ceef..00a28fa7702 100644 --- a/management/server/group/group.go +++ b/management/server/types/group.go @@ -1,6 +1,9 @@ -package group +package types -import "github.com/netbirdio/netbird/management/server/integration_reference" +import ( + "github.com/netbirdio/netbird/management/server/integration_reference" + "github.com/netbirdio/netbird/management/server/networks/resources/types" +) const ( GroupIssuedAPI = "api" @@ -25,6 +28,9 @@ type Group struct { // Peers list of the group Peers []string `gorm:"serializer:json"` + // Resources contains a list of resources in that group + Resources []Resource `gorm:"serializer:json"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } @@ -33,15 +39,21 @@ func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} } +func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]any { + return map[string]any{"name": g.Name, "id": g.ID, "resource_name": resource.Name, "resource_id": resource.ID, "resource_type": resource.Type} +} + func (g *Group) Copy() *Group { group := &Group{ ID: g.ID, Name: g.Name, Issued: g.Issued, Peers: make([]string, len(g.Peers)), + Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) + copy(group.Resources, g.Resources) return group } @@ -81,3 +93,31 @@ func (g *Group) RemovePeer(peerID string) bool { } return false } + +// AddResource adds resource to Resources if not present, returning true if added. +func (g *Group) AddResource(resource Resource) bool { + for _, item := range g.Resources { + if item == resource { + return false + } + } + + g.Resources = append(g.Resources, resource) + return true +} + +// RemoveResource removes resource from Resources if present, returning true if removed. +func (g *Group) RemoveResource(resource Resource) bool { + for i, item := range g.Resources { + if item == resource { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + return true + } + } + return false +} + +// HasResources checks if the group has any resources. +func (g *Group) HasResources() bool { + return len(g.Resources) > 0 +} diff --git a/management/server/group/group_test.go b/management/server/types/group_test.go similarity index 99% rename from management/server/group/group_test.go rename to management/server/types/group_test.go index cb002f8d9e1..12107c6030b 100644 --- a/management/server/group/group_test.go +++ b/management/server/types/group_test.go @@ -1,4 +1,4 @@ -package group +package types import ( "testing" diff --git a/management/server/network.go b/management/server/types/network.go similarity index 96% rename from management/server/network.go rename to management/server/types/network.go index a5b188b4610..d1fccd14906 100644 --- a/management/server/network.go +++ b/management/server/types/network.go @@ -1,4 +1,4 @@ -package server +package types import ( "math/rand" @@ -43,7 +43,7 @@ type Network struct { // Used to synchronize state to the client apps. Serial uint64 - mu sync.Mutex `json:"-" gorm:"-"` + Mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 @@ -66,15 +66,15 @@ func NewNetwork() *Network { // IncSerial increments Serial by 1 reflecting that the network state has been changed func (n *Network) IncSerial() { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() n.Serial++ } // CurrentSerial returns the Network.Serial of the network (latest state id) func (n *Network) CurrentSerial() uint64 { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() return n.Serial } diff --git a/management/server/network_test.go b/management/server/types/network_test.go similarity index 98% rename from management/server/network_test.go rename to management/server/types/network_test.go index b067c4991dc..d0b0894d42f 100644 --- a/management/server/network_test.go +++ b/management/server/types/network_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "net" diff --git a/management/server/personal_access_token.go b/management/server/types/personal_access_token.go similarity index 99% rename from management/server/personal_access_token.go rename to management/server/types/personal_access_token.go index f466661120f..1bf22585684 100644 --- a/management/server/personal_access_token.go +++ b/management/server/types/personal_access_token.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" diff --git a/management/server/personal_access_token_test.go b/management/server/types/personal_access_token_test.go similarity index 98% rename from management/server/personal_access_token_test.go rename to management/server/types/personal_access_token_test.go index 311ffd9cf05..ac337715178 100644 --- a/management/server/personal_access_token_test.go +++ b/management/server/types/personal_access_token_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" diff --git a/management/server/types/policy.go b/management/server/types/policy.go new file mode 100644 index 00000000000..c2b82d68a11 --- /dev/null +++ b/management/server/types/policy.go @@ -0,0 +1,125 @@ +package types + +const ( + // PolicyTrafficActionAccept indicates that the traffic is accepted + PolicyTrafficActionAccept = PolicyTrafficActionType("accept") + // PolicyTrafficActionDrop indicates that the traffic is dropped + PolicyTrafficActionDrop = PolicyTrafficActionType("drop") +) + +const ( + // PolicyRuleProtocolALL type of traffic + PolicyRuleProtocolALL = PolicyRuleProtocolType("all") + // PolicyRuleProtocolTCP type of traffic + PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") + // PolicyRuleProtocolUDP type of traffic + PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") + // PolicyRuleProtocolICMP type of traffic + PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") +) + +const ( + // PolicyRuleFlowDirect allows traffic from source to destination + PolicyRuleFlowDirect = PolicyRuleDirection("direct") + // PolicyRuleFlowBidirect allows traffic to both directions + PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") +) + +const ( + // DefaultRuleName is a name for the Default rule that is created for every account + DefaultRuleName = "Default" + // DefaultRuleDescription is a description for the Default rule that is created for every account + DefaultRuleDescription = "This is a default rule that allows connections between all the resources" + // DefaultPolicyName is a name for the Default policy that is created for every account + DefaultPolicyName = "Default" + // DefaultPolicyDescription is a description for the Default policy that is created for every account + DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" +) + +// PolicyUpdateOperation operation object with type and values to be applied +type PolicyUpdateOperation struct { + Type PolicyUpdateOperationType + Values []string +} + +// Policy of the Rego query +type Policy struct { + // ID of the policy' + ID string `gorm:"primaryKey"` + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name of the Policy + Name string + + // Description of the policy visible in the UI + Description string + + // Enabled status of the policy + Enabled bool + + // Rules of the policy + Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` + + // SourcePostureChecks are ID references to Posture checks for policy source groups + SourcePostureChecks []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the policy. +func (p *Policy) Copy() *Policy { + c := &Policy{ + ID: p.ID, + AccountID: p.AccountID, + Name: p.Name, + Description: p.Description, + Enabled: p.Enabled, + Rules: make([]*PolicyRule, len(p.Rules)), + SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), + } + for i, r := range p.Rules { + c.Rules[i] = r.Copy() + } + copy(c.SourcePostureChecks, p.SourcePostureChecks) + return c +} + +// EventMeta returns activity event meta related to this policy +func (p *Policy) EventMeta() map[string]any { + return map[string]any{"name": p.Name} +} + +// UpgradeAndFix different version of policies to latest version +func (p *Policy) UpgradeAndFix() { + for _, r := range p.Rules { + // start migrate from version v0.20.3 + if r.Protocol == "" { + r.Protocol = PolicyRuleProtocolALL + } + if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { + r.Bidirectional = true + } + // -- v0.20.4 + } +} + +// RuleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) RuleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + + return groups +} + +// SourceGroups returns a slice of all unique source groups referenced in the policy's rules. +func (p *Policy) SourceGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + } + return groups +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go new file mode 100644 index 00000000000..bd9a9929202 --- /dev/null +++ b/management/server/types/policyrule.go @@ -0,0 +1,87 @@ +package types + +// PolicyUpdateOperationType operation type +type PolicyUpdateOperationType int + +// PolicyTrafficActionType action type for the firewall +type PolicyTrafficActionType string + +// PolicyRuleProtocolType type of traffic +type PolicyRuleProtocolType string + +// PolicyRuleDirection direction of traffic +type PolicyRuleDirection string + +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + +// PolicyRule is the metadata of the policy +type PolicyRule struct { + // ID of the policy rule + ID string `gorm:"primaryKey"` + + // PolicyID is a reference to Policy that this object belongs + PolicyID string `json:"-" gorm:"index"` + + // Name of the rule visible in the UI + Name string + + // Description of the rule visible in the UI + Description string + + // Enabled status of rule in the system + Enabled bool + + // Action policy accept or drops packets + Action PolicyTrafficActionType + + // Destinations policy destination groups + Destinations []string `gorm:"serializer:json"` + + // DestinationResource policy destination resource that the rule is applied to + DestinationResource Resource `gorm:"serializer:json"` + + // Sources policy source groups + Sources []string `gorm:"serializer:json"` + + // SourceResource policy source resource that the rule is applied to + SourceResource Resource `gorm:"serializer:json"` + + // Bidirectional define if the rule is applicable in both directions, sources, and destinations + Bidirectional bool + + // Protocol type of the traffic + Protocol PolicyRuleProtocolType + + // Ports or it ranges list + Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` +} + +// Copy returns a copy of a policy rule +func (pm *PolicyRule) Copy() *PolicyRule { + rule := &PolicyRule{ + ID: pm.ID, + PolicyID: pm.PolicyID, + Name: pm.Name, + Description: pm.Description, + Enabled: pm.Enabled, + Action: pm.Action, + Destinations: make([]string, len(pm.Destinations)), + Sources: make([]string, len(pm.Sources)), + Bidirectional: pm.Bidirectional, + Protocol: pm.Protocol, + Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), + } + copy(rule.Destinations, pm.Destinations) + copy(rule.Sources, pm.Sources) + copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) + return rule +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go new file mode 100644 index 00000000000..820872f2004 --- /dev/null +++ b/management/server/types/resource.go @@ -0,0 +1,30 @@ +package types + +import ( + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Resource struct { + ID string + Type string +} + +func (r *Resource) ToAPIResponse() *api.Resource { + if r.ID == "" && r.Type == "" { + return nil + } + + return &api.Resource{ + Id: r.ID, + Type: api.ResourceType(r.Type), + } +} + +func (r *Resource) FromAPIRequest(req *api.Resource) { + if req == nil { + return + } + + r.ID = req.Id + r.Type = string(req.Type) +} diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go new file mode 100644 index 00000000000..64708d68ad6 --- /dev/null +++ b/management/server/types/route_firewall_rule.go @@ -0,0 +1,32 @@ +package types + +import ( + "github.com/netbirdio/netbird/management/domain" +) + +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // Domains list of network domains for the routed traffic + Domains domain.List + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} diff --git a/management/server/types/settings.go b/management/server/types/settings.go new file mode 100644 index 00000000000..0ce5a61333b --- /dev/null +++ b/management/server/types/settings.go @@ -0,0 +1,68 @@ +package types + +import ( + "time" + + "github.com/netbirdio/netbird/management/server/account" +) + +// Settings represents Account settings structure that can be modified via API and Dashboard +type Settings struct { + // PeerLoginExpirationEnabled globally enables or disables peer login expiration + PeerLoginExpirationEnabled bool + + // PeerLoginExpiration is a setting that indicates when peer login expires. + // Applies to all peers that have Peer.LoginExpirationEnabled set to true. + PeerLoginExpiration time.Duration + + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements + RegularUsersViewBlocked bool + + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer + GroupsPropagationEnabled bool + + // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName + // and add it to account groups. + JWTGroupsEnabled bool + + // JWTGroupsClaimName from which we extract groups name to add it to account groups + JWTGroupsClaimName string + + // JWTAllowGroups list of groups to which users are allowed access + JWTAllowGroups []string `gorm:"serializer:json"` + + // RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers + RoutingPeerDNSResolutionEnabled bool + + // Extra is a dictionary of Account settings + Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` +} + +// Copy copies the Settings struct +func (s *Settings) Copy() *Settings { + settings := &Settings{ + PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, + PeerLoginExpiration: s.PeerLoginExpiration, + JWTGroupsEnabled: s.JWTGroupsEnabled, + JWTGroupsClaimName: s.JWTGroupsClaimName, + GroupsPropagationEnabled: s.GroupsPropagationEnabled, + JWTAllowGroups: s.JWTAllowGroups, + RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, + + RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, + } + if s.Extra != nil { + settings.Extra = s.Extra.Copy() + } + return settings +} diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go new file mode 100644 index 00000000000..a5cf346a06a --- /dev/null +++ b/management/server/types/setupkey.go @@ -0,0 +1,181 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "hash/fnv" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" +) + +const ( + // SetupKeyReusable is a multi-use key (can be used for multiple machines) + SetupKeyReusable SetupKeyType = "reusable" + // SetupKeyOneOff is a single use key (can be used only once) + SetupKeyOneOff SetupKeyType = "one-off" + // DefaultSetupKeyDuration = 1 month + DefaultSetupKeyDuration = 24 * 30 * time.Hour + // DefaultSetupKeyName is a default name of the default setup key + DefaultSetupKeyName = "Default key" + // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key + SetupKeyUnlimitedUsage = 0 +) + +// SetupKeyType is the type of setup key +type SetupKeyType string + +// SetupKey represents a pre-authorized key used to register machines (peers) +type SetupKey struct { + Id string + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Key string + KeySecret string + Name string + Type SetupKeyType + CreatedAt time.Time + ExpiresAt time.Time + UpdatedAt time.Time `gorm:"autoUpdateTime:false"` + // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) + Revoked bool + // UsedTimes indicates how many times the key was used + UsedTimes int + // LastUsed last time the key was used for peer registration + LastUsed time.Time + // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register + AutoGroups []string `gorm:"serializer:json"` + // UsageLimit indicates the number of times this key can be used to enroll a machine. + // The value of 0 indicates the unlimited usage. + UsageLimit int + // Ephemeral indicate if the peers will be ephemeral or not + Ephemeral bool +} + +// Copy copies SetupKey to a new object +func (key *SetupKey) Copy() *SetupKey { + autoGroups := make([]string, len(key.AutoGroups)) + copy(autoGroups, key.AutoGroups) + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + return &SetupKey{ + Id: key.Id, + AccountID: key.AccountID, + Key: key.Key, + KeySecret: key.KeySecret, + Name: key.Name, + Type: key.Type, + CreatedAt: key.CreatedAt, + ExpiresAt: key.ExpiresAt, + UpdatedAt: key.UpdatedAt, + Revoked: key.Revoked, + UsedTimes: key.UsedTimes, + LastUsed: key.LastUsed, + AutoGroups: autoGroups, + UsageLimit: key.UsageLimit, + Ephemeral: key.Ephemeral, + } +} + +// EventMeta returns activity event meta related to the setup key +func (key *SetupKey) EventMeta() map[string]any { + return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} +} + +// HiddenKey returns the Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func HiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) +} + +// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now +func (key *SetupKey) IncrementUsage() *SetupKey { + c := key.Copy() + c.UsedTimes++ + c.LastUsed = time.Now().UTC() + return c +} + +// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to +func (key *SetupKey) IsValid() bool { + return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() +} + +// IsRevoked if key was revoked +func (key *SetupKey) IsRevoked() bool { + return key.Revoked +} + +// IsExpired if key was expired +func (key *SetupKey) IsExpired() bool { + if key.ExpiresAt.IsZero() { + return false + } + return time.Now().After(key.ExpiresAt) +} + +// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. +func (key *SetupKey) IsOverUsed() bool { + limit := key.UsageLimit + if key.Type == SetupKeyOneOff { + limit = 1 + } + return limit > 0 && key.UsedTimes >= limit +} + +// GenerateSetupKey generates a new setup key +func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, + usageLimit int, ephemeral bool) (*SetupKey, string) { + key := strings.ToUpper(uuid.New().String()) + limit := usageLimit + if t == SetupKeyOneOff { + limit = 1 + } + + expiresAt := time.Time{} + if validFor != 0 { + expiresAt = time.Now().UTC().Add(validFor) + } + + hashedKey := sha256.Sum256([]byte(key)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + + return &SetupKey{ + Id: strconv.Itoa(int(Hash(key))), + Key: encodedHashedKey, + KeySecret: HiddenKey(key, 4), + Name: name, + Type: t, + CreatedAt: time.Now().UTC(), + ExpiresAt: expiresAt, + UpdatedAt: time.Now().UTC(), + Revoked: false, + UsedTimes: 0, + AutoGroups: autoGroups, + UsageLimit: limit, + Ephemeral: ephemeral, + }, key +} + +// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration +func GenerateDefaultSetupKey() (*SetupKey, string) { + return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, + SetupKeyUnlimitedUsage, false) +} + +func Hash(s string) uint32 { + h := fnv.New32a() + _, err := h.Write([]byte(s)) + if err != nil { + panic(err) + } + return h.Sum32() +} diff --git a/management/server/types/user.go b/management/server/types/user.go new file mode 100644 index 00000000000..5f1b717922d --- /dev/null +++ b/management/server/types/user.go @@ -0,0 +1,231 @@ +package types + +import ( + "fmt" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" +) + +const ( + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" + + UserStatusActive UserStatus = "active" + UserStatusDisabled UserStatus = "disabled" + UserStatusInvited UserStatus = "invited" + + UserIssuedAPI = "api" + UserIssuedIntegration = "integration" +) + +// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown +func StrRoleToUserRole(strRole string) UserRole { + switch strings.ToLower(strRole) { + case "owner": + return UserRoleOwner + case "admin": + return UserRoleAdmin + case "user": + return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin + default: + return UserRoleUnknown + } +} + +// UserStatus is the status of a User +type UserStatus string + +// UserRole is the role of a User +type UserRole string + +type UserInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + NonDeletable bool `json:"non_deletable"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` + Permissions UserPermissions `json:"permissions"` +} + +type UserPermissions struct { + DashboardView string `json:"dashboard_view"` +} + +// User represents a user of the system +type User struct { + Id string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Role UserRole + IsServiceUser bool + // NonDeletable indicates whether the service user can be deleted + NonDeletable bool + // ServiceUserName is only set if IsServiceUser is true + ServiceUserName string + // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user + AutoGroups []string `gorm:"serializer:json"` + PATs map[string]*PersonalAccessToken `gorm:"-"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` + // Blocked indicates whether the user is blocked. Blocked users can't use the system. + Blocked bool + // LastLogin is the last time the user logged in to IdP + LastLogin time.Time + // CreatedAt records the time the user was created + CreatedAt time.Time + + // Issued of the user + Issued string `gorm:"default:api"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// IsBlocked returns true if the user is blocked, false otherwise +func (u *User) IsBlocked() bool { + return u.Blocked +} + +func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { + return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() +} + +// HasAdminPower returns true if the user has admin or owner roles, false otherwise +func (u *User) HasAdminPower() bool { + return u.Role == UserRoleAdmin || u.Role == UserRoleOwner +} + +// IsAdminOrServiceUser checks if the user has admin power or is a service user. +func (u *User) IsAdminOrServiceUser() bool { + return u.HasAdminPower() || u.IsServiceUser +} + +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + +// ToUserInfo converts a User object to a UserInfo object. +func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { + autoGroups := u.AutoGroups + if autoGroups == nil { + autoGroups = []string{} + } + + dashboardViewPermissions := "full" + if !u.HasAdminPower() { + dashboardViewPermissions = "limited" + if settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + + if userData == nil { + return &UserInfo{ + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil + } + if userData.ID != u.Id { + return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) + } + + userStatus := UserStatusActive + if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { + userStatus = UserStatusInvited + } + + return &UserInfo{ + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil +} + +// Copy the user +func (u *User) Copy() *User { + autoGroups := make([]string, len(u.AutoGroups)) + copy(autoGroups, u.AutoGroups) + pats := make(map[string]*PersonalAccessToken, len(u.PATs)) + for k, v := range u.PATs { + pats[k] = v.Copy() + } + return &User{ + Id: u.Id, + AccountID: u.AccountID, + Role: u.Role, + AutoGroups: autoGroups, + IsServiceUser: u.IsServiceUser, + NonDeletable: u.NonDeletable, + ServiceUserName: u.ServiceUserName, + PATs: pats, + Blocked: u.Blocked, + LastLogin: u.LastLogin, + CreatedAt: u.CreatedAt, + Issued: u.Issued, + IntegrationReference: u.IntegrationReference, + } +} + +// NewUser creates a new user +func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { + return &User{ + Id: id, + Role: role, + IsServiceUser: isServiceUser, + NonDeletable: nonDeletable, + ServiceUserName: serviceUserName, + AutoGroups: autoGroups, + Issued: issued, + CreatedAt: time.Now().UTC(), + } +} + +// NewRegularUser creates a new user with role UserRoleUser +func NewRegularUser(id string) *User { + return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) +} + +// NewAdminUser creates a new user with role UserRoleAdmin +func NewAdminUser(id string) *User { + return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) +} + +// NewOwnerUser creates a new user with role UserRoleOwner +func NewOwnerUser(id string) *User { + return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index d338b84b1bb..de7dd57df7f 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -9,13 +9,14 @@ import ( "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const channelBufferSize = 100 type UpdateMessage struct { Update *proto.SyncResponse - NetworkMap *NetworkMap + NetworkMap *types.NetworkMap } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index edb5e6fd374..457721917ac 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -13,217 +13,17 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) -const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" - UserRoleBillingAdmin UserRole = "billing_admin" - - UserStatusActive UserStatus = "active" - UserStatusDisabled UserStatus = "disabled" - UserStatusInvited UserStatus = "invited" - - UserIssuedAPI = "api" - UserIssuedIntegration = "integration" -) - -// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown -func StrRoleToUserRole(strRole string) UserRole { - switch strings.ToLower(strRole) { - case "owner": - return UserRoleOwner - case "admin": - return UserRoleAdmin - case "user": - return UserRoleUser - case "billing_admin": - return UserRoleBillingAdmin - default: - return UserRoleUnknown - } -} - -// UserStatus is the status of a User -type UserStatus string - -// UserRole is the role of a User -type UserRole string - -// User represents a user of the system -type User struct { - Id string `gorm:"primaryKey"` - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Role UserRole - IsServiceUser bool - // NonDeletable indicates whether the service user can be deleted - NonDeletable bool - // ServiceUserName is only set if IsServiceUser is true - ServiceUserName string - // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user - AutoGroups []string `gorm:"serializer:json"` - PATs map[string]*PersonalAccessToken `gorm:"-"` - PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` - // Blocked indicates whether the user is blocked. Blocked users can't use the system. - Blocked bool - // LastLogin is the last time the user logged in to IdP - LastLogin time.Time - // CreatedAt records the time the user was created - CreatedAt time.Time - - // Issued of the user - Issued string `gorm:"default:api"` - - IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// IsBlocked returns true if the user is blocked, false otherwise -func (u *User) IsBlocked() bool { - return u.Blocked -} - -func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { - return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() -} - -// HasAdminPower returns true if the user has admin or owner roles, false otherwise -func (u *User) HasAdminPower() bool { - return u.Role == UserRoleAdmin || u.Role == UserRoleOwner -} - -// IsAdminOrServiceUser checks if the user has admin power or is a service user. -func (u *User) IsAdminOrServiceUser() bool { - return u.HasAdminPower() || u.IsServiceUser -} - -// IsRegularUser checks if the user is a regular user. -func (u *User) IsRegularUser() bool { - return !u.HasAdminPower() && !u.IsServiceUser -} - -// ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { - autoGroups := u.AutoGroups - if autoGroups == nil { - autoGroups = []string{} - } - - dashboardViewPermissions := "full" - if !u.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - - if userData == nil { - return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil - } - if userData.ID != u.Id { - return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) - } - - userStatus := UserStatusActive - if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { - userStatus = UserStatusInvited - } - - return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil -} - -// Copy the user -func (u *User) Copy() *User { - autoGroups := make([]string, len(u.AutoGroups)) - copy(autoGroups, u.AutoGroups) - pats := make(map[string]*PersonalAccessToken, len(u.PATs)) - for k, v := range u.PATs { - pats[k] = v.Copy() - } - return &User{ - Id: u.Id, - AccountID: u.AccountID, - Role: u.Role, - AutoGroups: autoGroups, - IsServiceUser: u.IsServiceUser, - NonDeletable: u.NonDeletable, - ServiceUserName: u.ServiceUserName, - PATs: pats, - Blocked: u.Blocked, - LastLogin: u.LastLogin, - CreatedAt: u.CreatedAt, - Issued: u.Issued, - IntegrationReference: u.IntegrationReference, - } -} - -// NewUser creates a new user -func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { - return &User{ - Id: id, - Role: role, - IsServiceUser: isServiceUser, - NonDeletable: nonDeletable, - ServiceUserName: serviceUserName, - AutoGroups: autoGroups, - Issued: issued, - CreatedAt: time.Now().UTC(), - } -} - -// NewRegularUser creates a new user with role UserRoleUser -func NewRegularUser(id string) *User { - return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) -} - -// NewAdminUser creates a new user with role UserRoleAdmin -func NewAdminUser(id string) *User { - return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) -} - -// NewOwnerUser creates a new user with role UserRoleOwner -func NewOwnerUser(id string) *User { - return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) -} - // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { +func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -240,12 +40,12 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users") } - if role == UserRoleOwner { + if role == types.UserRoleOwner { return nil, status.Errorf(status.InvalidArgument, "can't create a service user with owner role") } newUserID := uuid.New().String() - newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) + newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI) log.WithContext(ctx).Debugf("New User: %v", newUser) account.Users[newUserID] = newUser @@ -257,29 +57,29 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI meta := map[string]any{"name": newUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) - return &UserInfo{ + return &types.UserInfo{ ID: newUser.Id, Email: "", Name: newUser.ServiceUserName, Role: string(newUser.Role), AutoGroups: newUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: true, LastLogin: time.Time{}, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nil } // CreateUser creates a new user under the given account. Effectively this is a user invite. -func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *types.UserInfo) (*types.UserInfo, error) { if user.IsServiceUser { - return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) + return am.createServiceUser(ctx, accountID, userID, types.StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) } return am.inviteNewUser(ctx, accountID, userID, user) } // inviteNewUser Invites a USer to a given account and creates reference in datastore -func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -291,14 +91,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, fmt.Errorf("provided user update is nil") } - invitedRole := StrRoleToUserRole(invite.Role) + invitedRole := types.StrRoleToUserRole(invite.Role) switch { case invite.Name == "": return nil, status.Errorf(status.InvalidArgument, "name can't be empty") case invite.Email == "": return nil, status.Errorf(status.InvalidArgument, "email can't be empty") - case invitedRole == UserRoleOwner: + case invitedRole == types.UserRoleOwner: return nil, status.Errorf(status.InvalidArgument, "can't invite a user with owner role") default: } @@ -348,7 +148,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - newUser := &User{ + newUser := &types.User{ Id: idpUser.ID, Role: invitedRole, AutoGroups: invite.AutoGroups, @@ -373,19 +173,19 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } -func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { - return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { + return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) } // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. -func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { +func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -409,7 +209,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. -func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { +func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -418,7 +218,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string return nil, err } - users := make([]*User, 0, len(account.Users)) + users := make([]*types.User, 0, len(account.Users)) for _, item := range account.Users { users = append(users, item) } @@ -426,7 +226,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string return users, nil } -func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) { +func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *types.Account, initiatorUserID string, targetUser *types.User) { meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) delete(account.Users, targetUser.Id) @@ -458,12 +258,12 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.Errorf(status.NotFound, "target user not found") } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "unable to delete a user with owner role") } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only integration service user can delete this user") } @@ -480,7 +280,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID) } -func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { +func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) error { meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err @@ -494,13 +294,13 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *types.Account) (bool, error) { peers, err := account.FindUserPeers(targetUserID) if err != nil { return false, status.Errorf(status.Internal, "failed to find user peers") @@ -560,7 +360,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin } // CreatePAT creates a new PAT for the given user -func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { +func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -591,7 +391,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") } - pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id) + pat, err := types.CreateNewPAT(tokenName, expiresIn, executingUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } @@ -660,13 +460,13 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string } // GetPAT returns a specific PAT from a user -func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return nil, err } @@ -685,13 +485,13 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i } // GetAllPATs returns all PATs for a user -func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return nil, err } @@ -700,7 +500,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) + pats := make([]*types.PersonalAccessToken, 0, len(targetUser.PATsG)) for _, pat := range targetUser.PATsG { pats = append(pats, pat.Copy()) } @@ -709,13 +509,13 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin } // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. -func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) { return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound } // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. -func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { if update == nil { return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } @@ -723,7 +523,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) + updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists) if err != nil { return nil, err } @@ -738,7 +538,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i // SaveOrAddUsers updates existing users or adds new users to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if len(updates) == 0 { return nil, nil //nolint:nilnil } @@ -757,7 +557,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") } - updatedUsers := make([]*UserInfo, 0, len(updates)) + updatedUsers := make([]*types.UserInfo, 0, len(updates)) var ( expiredPeers []*nbpeer.Peer userIDs []string @@ -808,7 +608,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, peerGroupsAdded := make(map[string][]string) peerGroupsRemoved := make(map[string][]string) if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) // need force update all auto groups in any case they will not be duplicated peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) @@ -840,7 +640,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -851,7 +651,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -880,11 +680,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in return eventsToStore } -func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { +func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { var eventsToStore []func() if newUser.AutoGroups != nil { - removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) - addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) + removedGroups := util.Difference(oldUser.AutoGroups, newUser.AutoGroups) + addedGroups := util.Difference(newUser.AutoGroups, oldUser.AutoGroups) removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved) eventsToStore = append(eventsToStore, removedEvents...) @@ -895,7 +695,7 @@ func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, in return eventsToStore } -func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { +func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { var eventsToStore []func() for _, g := range addedGroups { group := account.GetGroup(g) @@ -922,7 +722,7 @@ func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, ini return eventsToStore } -func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { +func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { var eventsToStore []func() for _, g := range removedGroups { group := account.GetGroup(g) @@ -952,10 +752,10 @@ func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, return eventsToStore } -func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool { - if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { +func handleOwnerRoleTransfer(account *types.Account, initiatorUser, update *types.User) bool { + if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner { newInitiatorUser := initiatorUser.Copy() - newInitiatorUser.Role = UserRoleAdmin + newInitiatorUser.Role = types.UserRoleAdmin account.Users[initiatorUser.Id] = newInitiatorUser return true } @@ -965,7 +765,7 @@ func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool // getUserInfo retrieves the UserInfo for a given User and Account. // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. -func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) { +func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *types.User, account *types.Account) (*types.UserInfo, error) { if !isNil(am.idpManager) && !user.IsServiceUser { userData, err := am.lookupUserInCache(ctx, user.Id, account) if err != nil { @@ -977,23 +777,23 @@ func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, acc } // validateUserUpdate validates the update operation for a user. -func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error { +func validateUserUpdate(account *types.Account, initiatorUser, oldUser, update *types.User) error { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role { return status.Errorf(status.PermissionDenied, "admins can't change their role") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { return status.Errorf(status.PermissionDenied, "unable to block owner user") } - if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") } - if oldUser.IsServiceUser && update.Role == UserRoleOwner { + if oldUser.IsServiceUser && update.Role == types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "can't update a service user with owner role") } @@ -1012,7 +812,7 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) { start := time.Now() unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() @@ -1039,7 +839,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u userObj := account.Users[userID] - if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner { + if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == types.UserRoleOwner { account.Domain = lowerDomain err = am.Store.SaveAccount(ctx, account) if err != nil { @@ -1052,7 +852,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. -func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) { +func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) { account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -1068,7 +868,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun users := make(map[string]userLoggedInOnce, len(account.Users)) usersFromIntegration := make([]*idp.UserData, 0) for _, user := range account.Users { - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { @@ -1092,7 +892,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun queriedUsers = append(queriedUsers, usersFromIntegration...) } - userInfos := make([]*UserInfo, 0) + userInfos := make([]*types.UserInfo, 0) // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { @@ -1116,7 +916,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun continue } - var info *UserInfo + var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { info, err = localUser.ToUserInfo(queriedUser, account.Settings) if err != nil { @@ -1136,16 +936,16 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } } - info = &UserInfo{ + info = &types.UserInfo{ ID: localUser.Id, Email: "", Name: name, Role: string(localUser.Role), AutoGroups: localUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, - Permissions: UserPermissions{DashboardView: dashboardViewPermissions}, + Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfos = append(userInfos, info) @@ -1155,7 +955,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *types.Account, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -1183,7 +983,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil } @@ -1260,13 +1060,13 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID)) continue } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user")) continue } @@ -1291,7 +1091,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta { @@ -1301,7 +1101,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) @@ -1342,8 +1142,8 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, - groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return @@ -1376,7 +1176,7 @@ func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[strin } // addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) { groupPeers := make(map[string]struct{}, len(group.Peers)) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -1393,7 +1193,7 @@ func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) } // removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) { // skip removing peers from group All if group.Name == "All" { return @@ -1419,7 +1219,7 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa } // areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. -func areUsersLinkedToPeers(account *Account, userIDs []string) bool { +func areUsersLinkedToPeers(account *types.Account, userIDs []string) bool { for _, peer := range account.Peers { if slices.Contains(userIDs, peer.UserID) { return true diff --git a/management/server/user_test.go b/management/server/user_test.go index 498017afa1d..75d88f9c864 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,8 +10,11 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" - nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -41,11 +44,15 @@ const ( ) func TestUser_CreatePAT_ForSameUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -82,14 +89,18 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { } func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: false, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -104,14 +115,18 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { } func TestUser_CreatePAT_ForServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -130,11 +145,15 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { } func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -149,11 +168,15 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { } func TestUser_CreatePAT_WithEmptyName(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -168,19 +191,23 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } func TestUser_DeletePAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -204,20 +231,24 @@ func TestUser_DeletePAT(t *testing.T) { } func TestUser_GetPAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -237,13 +268,17 @@ func TestUser_GetPAT(t *testing.T) { } func TestUser_GetAllPATs(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, @@ -254,7 +289,7 @@ func TestUser_GetAllPATs(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -274,14 +309,14 @@ func TestUser_GetAllPATs(t *testing.T) { func TestUser_Copy(t *testing.T) { // this is an imaginary case which will never be in DB this way - user := User{ + user := types.User{ Id: "userId", AccountID: "accountId", Role: "role", IsServiceUser: true, ServiceUserName: "servicename", AutoGroups: []string{"group1", "group2"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", @@ -340,11 +375,15 @@ func validateStruct(s interface{}) (err error) { } func TestUser_CreateServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -366,26 +405,30 @@ func TestUser_CreateServiceUser(t *testing.T) { assert.NotNil(t, account.Users[user.ID]) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) - assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs) + assert.Equal(t, map[string]*types.PersonalAccessToken{}, account.Users[user.ID].PATs) assert.Zero(t, user.Email) assert.True(t, user.IsServiceUser) assert.Equal(t, "active", user.Status) - _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) + _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, types.UserRoleOwner, mockServiceUserName, false, nil) if err == nil { t.Fatal("should return error when creating service user with owner role") } } func TestUser_CreateUser_ServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -395,7 +438,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: true, @@ -413,7 +456,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { assert.Equal(t, 2, len(account.Users)) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) assert.Equal(t, mockServiceUserName, user.Name) @@ -423,11 +466,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { } func TestUser_CreateUser_RegularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -437,7 +484,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: false, @@ -448,11 +495,15 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { } func TestUser_InviteNewUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -495,7 +546,7 @@ func TestUser_InviteNewUser(t *testing.T) { am.idpManager = &idpMock // test if new invite with regular role works - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, Email: "test@teste.com", @@ -506,9 +557,9 @@ func TestUser_InviteNewUser(t *testing.T) { assert.NoErrorf(t, err, "Invite user should not throw error") // test if new invite with owner role fails - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, - Role: string(UserRoleOwner), + Role: string(types.UserRoleOwner), Email: "test2@teste.com", IsServiceUser: false, AutoGroups: []string{"group1", "group2"}, @@ -520,13 +571,13 @@ func TestUser_InviteNewUser(t *testing.T) { func TestUser_DeleteUser_ServiceUser(t *testing.T) { tests := []struct { name string - serviceUser *User + serviceUser *types.User assertErrFunc assert.ErrorAssertionFunc assertErrMessage string }{ { name: "Can delete service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -535,7 +586,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { }, { name: "Cannot delete non-deletable service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -548,11 +599,16 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -580,11 +636,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { } func TestUser_DeleteUser_SelfDelete(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -601,38 +661,42 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { } func TestUser_DeleteUser_regularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -683,60 +747,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } func TestUser_DeleteUser_RegularUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - account.Users["user6"] = &User{ + account.Users["user6"] = &types.User{ Id: "user6", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user7"] = &User{ + account.Users["user7"] = &types.User{ Id: "user7", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user8"] = &User{ + account.Users["user8"] = &types.User{ Id: "user8", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - account.Users["user9"] = &User{ + account.Users["user9"] = &types.User{ Id: "user9", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -834,11 +902,15 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { } func TestDefaultAccountManager_GetUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -863,13 +935,17 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { } func TestDefaultAccountManager_ListUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") + account.Users["normal_user1"] = types.NewRegularUser("normal_user1") + account.Users["normal_user2"] = types.NewRegularUser("normal_user2") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -901,43 +977,43 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool expectedDashboardPermissions string }{ { name: "Regular user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, expectedDashboardPermissions: "limited", }, { name: "Admin user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, expectedDashboardPermissions: "blocked", }, { name: "Admin user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, expectedDashboardPermissions: "full", }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, expectedDashboardPermissions: "full", }, @@ -945,13 +1021,18 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI) account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings delete(account.Users, mockUserID) - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -976,13 +1057,17 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { } func TestDefaultAccountManager_ExternalCache(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - externalUser := &User{ + externalUser := &types.User{ Id: "externalUser", - Role: UserRoleUser, - Issued: UserIssuedIntegration, + Role: types.UserRoleUser, + Issued: types.UserIssuedIntegration, IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", @@ -990,7 +1075,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { } account.Users[externalUser.Id] = externalUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1020,7 +1105,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) assert.NoError(t, err) assert.Equal(t, 2, len(infos)) - var user *UserInfo + var user *types.UserInfo for _, info := range infos { if info.ID == externalUser.Id { user = info @@ -1032,24 +1117,28 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestUser_IsAdmin(t *testing.T) { - user := NewAdminUser(mockUserID) + user := types.NewAdminUser(mockUserID) assert.True(t, user.HasAdminPower()) - user = NewRegularUser(mockUserID) + user = types.NewRegularUser(mockUserID) assert.False(t, user.HasAdminPower()) } func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1068,17 +1157,20 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { } func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1112,25 +1204,25 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { tt := []struct { name string initiatorID string - update *User + update *types.User expectedErr bool }{ { name: "Should_Fail_To_Update_Admin_Role", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, Blocked: false, }, }, { name: "Should_Fail_When_Admin_Blocks_Themselves", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1138,9 +1230,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Non_Existing_User", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: userID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1148,9 +1240,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1158,9 +1250,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Update_User", expectedErr: false, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1168,9 +1260,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Transfer_Owner_Role_To_User", expectedErr: false, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1178,9 +1270,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Transfer_Owner_Role_To_Service_User", expectedErr: true, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: serviceUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1188,9 +1280,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1198,9 +1290,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_User", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1208,9 +1300,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Service_User", expectedErr: true, initiatorID: serviceUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1218,9 +1310,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1228,9 +1320,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Block_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: true, }, }, @@ -1246,9 +1338,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) - account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} + account.Users[regularUserID] = types.NewRegularUser(regularUserID) + account.Users[adminUserID] = types.NewAdminUser(adminUserID) + account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) @@ -1272,22 +1364,22 @@ func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) require.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1307,11 +1399,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1330,11 +1422,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) @@ -1364,11 +1456,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) // create a user and add new peer with the user - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1390,11 +1482,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) diff --git a/management/server/users/manager.go b/management/server/users/manager.go new file mode 100644 index 00000000000..718eb6190a1 --- /dev/null +++ b/management/server/users/manager.go @@ -0,0 +1,49 @@ +package users + +import ( + "context" + "errors" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetUser(ctx context.Context, userID string) (*types.User, error) +} + +type managerImpl struct { + store store.Store +} + +type managerMock struct { +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { + return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User, error) { + switch userID { + case "adminUser": + return &types.User{Id: userID, Role: types.UserRoleAdmin}, nil + case "regularUser": + return &types.User{Id: userID, Role: types.UserRoleUser}, nil + case "ownerUser": + return &types.User{Id: userID, Role: types.UserRoleOwner}, nil + case "billingUser": + return &types.User{Id: userID, Role: types.UserRoleBillingAdmin}, nil + default: + return nil, errors.New("user not found") + } +} diff --git a/management/server/util/util.go b/management/server/util/util.go new file mode 100644 index 00000000000..ff738781feb --- /dev/null +++ b/management/server/util/util.go @@ -0,0 +1,16 @@ +package util + +// Difference returns the elements in `a` that aren't in `b`. +func Difference(a, b []string) []string { + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + var diff []string + for _, x := range a { + if _, found := mb[x]; !found { + diff = append(diff, x) + } + } + return diff +} diff --git a/route/route.go b/route/route.go index e23801e6e9e..8f3c99b4c1d 100644 --- a/route/route.go +++ b/route/route.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" "slices" + "strings" log "github.com/sirupsen/logrus" @@ -88,18 +89,18 @@ type Route struct { // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` // Network and Domains are mutually exclusive - Network netip.Prefix `gorm:"serializer:json"` - Domains domain.List `gorm:"serializer:json"` - KeepRoute bool - NetID NetID - Description string - Peer string - PeerGroups []string `gorm:"serializer:json"` - NetworkType NetworkType - Masquerade bool - Metric int - Enabled bool - Groups []string `gorm:"serializer:json"` + Network netip.Prefix `gorm:"serializer:json"` + Domains domain.List `gorm:"serializer:json"` + KeepRoute bool + NetID NetID + Description string + Peer string + PeerGroups []string `gorm:"serializer:json"` + NetworkType NetworkType + Masquerade bool + Metric int + Enabled bool + Groups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"` } @@ -111,19 +112,19 @@ func (r *Route) EventMeta() map[string]any { // Copy copies a route object func (r *Route) Copy() *Route { route := &Route{ - ID: r.ID, - Description: r.Description, - NetID: r.NetID, - Network: r.Network, - Domains: slices.Clone(r.Domains), - KeepRoute: r.KeepRoute, - NetworkType: r.NetworkType, - Peer: r.Peer, - PeerGroups: slices.Clone(r.PeerGroups), - Metric: r.Metric, - Masquerade: r.Masquerade, - Enabled: r.Enabled, - Groups: slices.Clone(r.Groups), + ID: r.ID, + Description: r.Description, + NetID: r.NetID, + Network: r.Network, + Domains: slices.Clone(r.Domains), + KeepRoute: r.KeepRoute, + NetworkType: r.NetworkType, + Peer: r.Peer, + PeerGroups: slices.Clone(r.PeerGroups), + Metric: r.Metric, + Masquerade: r.Masquerade, + Enabled: r.Enabled, + Groups: slices.Clone(r.Groups), AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route @@ -149,7 +150,7 @@ func (r *Route) IsEqual(other *Route) bool { other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.PeerGroups, other.PeerGroups) && slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } @@ -170,6 +171,11 @@ func (r *Route) GetHAUniqueID() HAUniqueID { return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String())) } +// GetResourceID returns the Networks Resource ID from a route ID +func (r *Route) GetResourceID() string { + return strings.Split(string(r.ID), ":")[0] +} + // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { prefix, err := netip.ParsePrefix(networkString) From 82b4e58ad0b2c0c3efb310b32da46eff2fc92281 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 20 Dec 2024 16:20:50 +0100 Subject: [PATCH 14/19] Do not start DNS forwarder on client side (#3094) --- client/internal/engine.go | 44 ++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 9724e2a2257..042d384dc82 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -802,14 +802,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - var dnsRouteFeatureFlag bool - if networkMap.PeerConfig != nil { - dnsRouteFeatureFlag = networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled - } - routedDomains, routes := toRoutes(networkMap.GetRoutes()) - - e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains) + // DNS forwarder + dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) + dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) + e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains) + routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } @@ -874,12 +872,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } -func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { +func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool { + if networkMap.PeerConfig != nil { + return networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled + } + return false +} + +func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} } - var dnsRoutes []string routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { var prefix netip.Prefix @@ -890,7 +894,6 @@ func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { continue } } - dnsRoutes = append(dnsRoutes, protoRoute.Domains...) convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), @@ -905,7 +908,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { } routes = append(routes, convertedRoute) } - return dnsRoutes, routes + return routes +} + +func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + var dnsRoutes []string + for _, protoRoute := range protoRoutes { + if len(protoRoute.Domains) == 0 { + continue + } + if protoRoute.Peer == myPubKey { + dnsRoutes = append(dnsRoutes, protoRoute.Domains...) + } + } + return dnsRoutes } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { @@ -1243,7 +1263,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { if err != nil { return nil, nil, err } - _, routes := toRoutes(netMap.GetRoutes()) + routes := toRoutes(netMap.GetRoutes()) dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, &dnsCfg, nil } From 7ee7ada27347bec20d0650a9c0feeb991779b2d4 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 20 Dec 2024 21:10:53 +0300 Subject: [PATCH 15/19] [management] Fix duplicate resource routes when routing peer is part of the source group (#3095) * Remove duplicate resource routes when routing peer is part of the source group Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/route_test.go | 67 ++++++++++++++++++++++++++++++ management/server/types/account.go | 6 +-- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5390cb66b94..2ef2b01732c 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2175,6 +2175,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { peerJIp = "100.65.29.65" peerKIp = "100.65.29.66" peerMIp = "100.65.29.67" + peerOIp = "100.65.29.68" ) account := &types.Account{ @@ -2256,6 +2257,20 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { IP: net.ParseIP(peerMIp), Status: &nbpeer.PeerStatus{}, }, + "peerN": { + ID: "peerN", + IP: net.ParseIP("100.65.20.18"), + Key: "peerN", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerO": { + ID: "peerO", + IP: net.ParseIP(peerOIp), + Status: &nbpeer.PeerStatus{}, + }, }, Groups: map[string]*types.Group{ "router1": { @@ -2330,6 +2345,14 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Pipeline", Peers: []string{"peerM"}, }, + "metrics": { + ID: "metrics", + Name: "Metrics", + Peers: []string{"peerN", "peerO"}, + Resources: []types.Resource{ + {ID: "resource6"}, + }, + }, }, Networks: []*networkTypes.Network{ { @@ -2352,6 +2375,10 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { ID: "network5", Name: "Pipeline Network", }, + { + ID: "network6", + Name: "Metrics Network", + }, }, NetworkRouters: []*routerTypes.NetworkRouter{ { @@ -2389,6 +2416,13 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Masquerade: false, Metric: 9999, }, + { + ID: "router6", + NetworkID: "network6", + Peer: "peerN", + Masquerade: false, + Metric: 9999, + }, }, NetworkResources: []*resourceTypes.NetworkResource{ { @@ -2426,6 +2460,13 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Type: "host", Prefix: netip.MustParsePrefix("10.12.12.1/32"), }, + { + ID: "resource6", + NetworkID: "network6", + Name: "Resource 6", + Type: "domain", + Domain: "*.google.com", + }, }, Policies: []*types.Policy{ { @@ -2527,6 +2568,24 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { }, }, }, + { + ID: "policyResource6", + Name: "policyResource6", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource6", + Name: "ruleResource6", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"9090"}, + Sources: []string{"metrics"}, + Destinations: []string{"metrics"}, + }, + }, + }, }, } @@ -2553,6 +2612,10 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { // which is part of the destination in the policies (policyResource3 and policyResource4) policies = account.GetPoliciesForNetworkResource("resource4") assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups") + + // Test case: Resource6 is applied to the access control groups (metrics), + policies = account.GetPoliciesForNetworkResource("resource6") + assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") }) t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { @@ -2663,5 +2726,9 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) assert.Len(t, routes, 1) assert.Len(t, sourcePeers, 0) + + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + assert.Len(t, sourcePeers, 2) }) } diff --git a/management/server/types/account.go b/management/server/types/account.go index 5d0f1201984..b36b719e491 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1330,10 +1330,8 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st // routing peer should be able to connect with all source peers if addSourcePeers { allSourcePeers = append(allSourcePeers, group.Peers...) - } - - // add routes for the resource if the peer is in the distribution group - if slices.Contains(group.Peers, peerID) { + } else if slices.Contains(group.Peers, peerID) { + // add routes for the resource if the peer is in the distribution group for peerId, router := range networkRoutingPeers { routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) } From b48cf1bf65782c68600b00541113f90bb1aa88db Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 21 Dec 2024 15:56:52 +0100 Subject: [PATCH 16/19] [client] Reduce DNS handler chain lock contention (#3099) --- client/internal/dns/handler_chain.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index c6ac3ebd61e..9302d50b171 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,6 +1,7 @@ package dns import ( + "slices" "strings" "sync" @@ -161,16 +162,19 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { log.Tracef("handling DNS request for domain=%s", qname) c.mu.RLock() - defer c.mu.RUnlock() - - log.Tracef("current handlers (%d):", len(c.handlers)) - for _, h := range c.handlers { - log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", - h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + handlers := slices.Clone(c.handlers) + c.mu.RUnlock() + + if log.IsLevelEnabled(log.TraceLevel) { + log.Tracef("current handlers (%d):", len(handlers)) + for _, h := range handlers { + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + } } // Try handlers in priority order - for _, entry := range c.handlers { + for _, entry := range handlers { var matched bool switch { case entry.Pattern == ".": From e670068cabfa4288455d34d3b917eaa5a5ae2f7a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 23 Dec 2024 14:37:09 +0100 Subject: [PATCH 17/19] [management] Run test sequential (#3101) --- .github/workflows/golang-test-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 50c6cba848c..15519039844 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -183,7 +183,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) benchmark: needs: [ build-cache ] From 05930ee6b192fa9928e709a8fc74ef6e2cb779a1 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:57:15 +0100 Subject: [PATCH 18/19] [client] Add firewall rules to the debug bundle (#3089) Adds the following to the debug bundle: - iptables: `iptables-save`, `iptables -v -n -L` - nftables: `nft list ruleset` or if not available formatted output from netlink (WIP) --- client/server/debug.go | 24 ++ client/server/debug_linux.go | 693 ++++++++++++++++++++++++++++++++ client/server/debug_nonlinux.go | 15 + client/server/debug_test.go | 113 ++++++ 4 files changed, 845 insertions(+) create mode 100644 client/server/debug_linux.go create mode 100644 client/server/debug_nonlinux.go diff --git a/client/server/debug.go b/client/server/debug.go index c12fd99dbf2..9dfde0367f3 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -40,6 +40,8 @@ netbird.err: Most recent, anonymized stderr log file of the NetBird client. netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. +iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. +nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states. @@ -106,6 +108,24 @@ The config.txt file contains anonymized configuration information of the NetBird - CustomDNSAddress Other non-sensitive configuration options are included without anonymization. + +Firewall Rules (Linux only) +The bundle includes two separate firewall rule files: + +iptables.txt: +- Complete iptables ruleset with packet counters using 'iptables -v -n -L' +- Includes all tables (filter, nat, mangle, raw, security) +- Shows packet and byte counters for each rule +- All IP addresses are anonymized +- Chain names, table names, and other non-sensitive information remain unchanged + +nftables.txt: +- Complete nftables ruleset obtained via 'nft -a list ruleset' +- Includes rule handle numbers and packet counters +- All tables, chains, and rules are included +- Shows packet and byte counters for each rule +- All IP addresses are anonymized +- Chain names, table names, and other non-sensitive information remain unchanged ` const ( @@ -172,6 +192,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques if err := s.addInterfaces(req, anonymizer, archive); err != nil { log.Errorf("Failed to add interfaces to debug bundle: %v", err) } + + if err := s.addFirewallRules(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add firewall rules to debug bundle: %v", err) + } } if err := s.addNetworkMap(req, anonymizer, archive); err != nil { diff --git a/client/server/debug_linux.go b/client/server/debug_linux.go new file mode 100644 index 00000000000..60bc4056167 --- /dev/null +++ b/client/server/debug_linux.go @@ -0,0 +1,693 @@ +//go:build linux && !android + +package server + +import ( + "archive/zip" + "bytes" + "encoding/binary" + "fmt" + "os/exec" + "sort" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/proto" +) + +// addFirewallRules collects and adds firewall rules to the archive +func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + log.Info("Collecting firewall rules") + // Collect and add iptables rules + iptablesRules, err := collectIPTablesRules() + if err != nil { + log.Warnf("Failed to collect iptables rules: %v", err) + } else { + if req.GetAnonymize() { + iptablesRules = anonymizer.AnonymizeString(iptablesRules) + } + if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil { + log.Warnf("Failed to add iptables rules to bundle: %v", err) + } + } + + // Collect and add nftables rules + nftablesRules, err := collectNFTablesRules() + if err != nil { + log.Warnf("Failed to collect nftables rules: %v", err) + } else { + if req.GetAnonymize() { + nftablesRules = anonymizer.AnonymizeString(nftablesRules) + } + if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil { + log.Warnf("Failed to add nftables rules to bundle: %v", err) + } + } + + return nil +} + +// collectIPTablesRules collects rules using both iptables-save and verbose listing +func collectIPTablesRules() (string, error) { + var builder strings.Builder + + // First try using iptables-save + saveOutput, err := collectIPTablesSave() + if err != nil { + log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) + } else { + builder.WriteString("=== iptables-save output ===\n") + builder.WriteString(saveOutput) + builder.WriteString("\n") + } + + // Then get verbose statistics for each table + builder.WriteString("=== iptables -v -n -L output ===\n") + + // Get list of tables + tables := []string{"filter", "nat", "mangle", "raw", "security"} + + for _, table := range tables { + builder.WriteString(fmt.Sprintf("*%s\n", table)) + + // Get verbose statistics for the entire table + stats, err := getTableStatistics(table) + if err != nil { + log.Warnf("Failed to get statistics for table %s: %v", table, err) + continue + } + builder.WriteString(stats) + builder.WriteString("\n") + } + + return builder.String(), nil +} + +// collectIPTablesSave uses iptables-save to get rule definitions +func collectIPTablesSave() (string, error) { + cmd := exec.Command("iptables-save") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String()) + } + + rules := stdout.String() + if strings.TrimSpace(rules) == "" { + return "", fmt.Errorf("no iptables rules found") + } + + return rules, nil +} + +// getTableStatistics gets verbose statistics for an entire table using iptables command +func getTableStatistics(table string) (string, error) { + cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String()) + } + + return stdout.String(), nil +} + +// collectNFTablesRules attempts to collect nftables rules using either nft command or netlink +func collectNFTablesRules() (string, error) { + // First try using nft command + rules, err := collectNFTablesFromCommand() + if err != nil { + log.Debugf("Failed to collect nftables rules using nft command: %v, falling back to netlink", err) + // Fall back to netlink + rules, err = collectNFTablesFromNetlink() + if err != nil { + return "", fmt.Errorf("collect nftables rules using both nft and netlink failed: %w", err) + } + } + return rules, nil +} + +// collectNFTablesFromCommand attempts to collect rules using nft command +func collectNFTablesFromCommand() (string, error) { + cmd := exec.Command("nft", "-a", "list", "ruleset") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute nft list ruleset: %w (stderr: %s)", err, stderr.String()) + } + + rules := stdout.String() + if strings.TrimSpace(rules) == "" { + return "", fmt.Errorf("no nftables rules found") + } + + return rules, nil +} + +// collectNFTablesFromNetlink collects rules using netlink library +func collectNFTablesFromNetlink() (string, error) { + conn, err := nftables.New() + if err != nil { + return "", fmt.Errorf("create nftables connection: %w", err) + } + + tables, err := conn.ListTables() + if err != nil { + return "", fmt.Errorf("list tables: %w", err) + } + + sortTables(tables) + return formatTables(conn, tables), nil +} + +func formatTables(conn *nftables.Conn, tables []*nftables.Table) string { + var builder strings.Builder + + for _, table := range tables { + builder.WriteString(fmt.Sprintf("table %s %s {\n", formatFamily(table.Family), table.Name)) + + chains, err := getAndSortTableChains(conn, table) + if err != nil { + log.Warnf("Failed to list chains for table %s: %v", table.Name, err) + continue + } + + // Format chains + for _, chain := range chains { + formatChain(conn, table, chain, &builder) + } + + // Format sets + if sets, err := conn.GetSets(table); err != nil { + log.Warnf("Failed to get sets for table %s: %v", table.Name, err) + } else if len(sets) > 0 { + builder.WriteString("\n") + for _, set := range sets { + builder.WriteString(formatSet(conn, set)) + } + } + + builder.WriteString("}\n") + } + + return builder.String() +} + +func getAndSortTableChains(conn *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { + chains, err := conn.ListChains() + if err != nil { + return nil, err + } + + var tableChains []*nftables.Chain + for _, chain := range chains { + if chain.Table.Name == table.Name && chain.Table.Family == table.Family { + tableChains = append(tableChains, chain) + } + } + + sort.Slice(tableChains, func(i, j int) bool { + return tableChains[i].Name < tableChains[j].Name + }) + + return tableChains, nil +} + +func formatChain(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, builder *strings.Builder) { + builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name)) + + if chain.Type != "" { + var policy string + if chain.Policy != nil { + policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy)) + } + builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n", + formatChainType(chain.Type), + formatChainHook(chain.Hooknum), + chain.Priority, + policy)) + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err) + } else { + sort.Slice(rules, func(i, j int) bool { + return rules[i].Position < rules[j].Position + }) + for _, rule := range rules { + builder.WriteString(formatRule(rule)) + } + } + + builder.WriteString("\t}\n") +} + +func sortTables(tables []*nftables.Table) { + sort.Slice(tables, func(i, j int) bool { + if tables[i].Family != tables[j].Family { + return tables[i].Family < tables[j].Family + } + return tables[i].Name < tables[j].Name + }) +} + +func formatFamily(family nftables.TableFamily) string { + switch family { + case nftables.TableFamilyIPv4: + return "ip" + case nftables.TableFamilyIPv6: + return "ip6" + case nftables.TableFamilyINet: + return "inet" + case nftables.TableFamilyARP: + return "arp" + case nftables.TableFamilyBridge: + return "bridge" + case nftables.TableFamilyNetdev: + return "netdev" + default: + return fmt.Sprintf("family-%d", family) + } +} + +func formatChainType(typ nftables.ChainType) string { + switch typ { + case nftables.ChainTypeFilter: + return "filter" + case nftables.ChainTypeNAT: + return "nat" + case nftables.ChainTypeRoute: + return "route" + default: + return fmt.Sprintf("type-%s", typ) + } +} + +func formatChainHook(hook *nftables.ChainHook) string { + if hook == nil { + return "none" + } + switch *hook { + case *nftables.ChainHookPrerouting: + return "prerouting" + case *nftables.ChainHookInput: + return "input" + case *nftables.ChainHookForward: + return "forward" + case *nftables.ChainHookOutput: + return "output" + case *nftables.ChainHookPostrouting: + return "postrouting" + default: + return fmt.Sprintf("hook-%d", *hook) + } +} + +func formatPolicy(policy nftables.ChainPolicy) string { + switch policy { + case nftables.ChainPolicyDrop: + return "drop" + case nftables.ChainPolicyAccept: + return "accept" + default: + return fmt.Sprintf("policy-%d", policy) + } +} + +func formatRule(rule *nftables.Rule) string { + var builder strings.Builder + builder.WriteString("\t\t") + + for i := 0; i < len(rule.Exprs); i++ { + if i > 0 { + builder.WriteString(" ") + } + i = formatExprSequence(&builder, rule.Exprs, i) + } + + builder.WriteString("\n") + return builder.String() +} + +func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int { + curr := exprs[i] + + // Handle Meta + Cmp sequence + if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + if formatted := formatMetaWithCmp(meta, cmp); formatted != "" { + builder.WriteString(formatted) + return i + 1 + } + } + } + + // Handle Payload + Cmp sequence + if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + builder.WriteString(formatPayloadWithCmp(payload, cmp)) + return i + 1 + } + } + + builder.WriteString(formatExpr(curr)) + return i +} + +func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string { + switch meta.Key { + case expr.MetaKeyIIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyOIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyMARK: + if len(cmp.Data) == 4 { + val := binary.BigEndian.Uint32(cmp.Data) + return fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val) + } + } + return "" +} + +func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string { + if p.Base == expr.PayloadBaseNetworkHeader { + switch p.Offset { + case 12: // Source IP + if p.Len == 4 { + return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } else if p.Len == 2 { + return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } + case 16: // Destination IP + if p.Len == 4 { + return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } else if p.Len == 2 { + return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } + } + } + return fmt.Sprintf("%d reg%d [%d:%d] %s %v", + p.Base, p.DestRegister, p.Offset, p.Len, + formatCmpOp(cmp.Op), cmp.Data) +} + +func formatIPBytes(data []byte) string { + if len(data) == 4 { + return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) + } else if len(data) == 2 { + return fmt.Sprintf("%d.%d.0.0/16", data[0], data[1]) + } + return fmt.Sprintf("%v", data) +} + +func formatCmpOp(op expr.CmpOp) string { + switch op { + case expr.CmpOpEq: + return "==" + case expr.CmpOpNeq: + return "!=" + case expr.CmpOpLt: + return "<" + case expr.CmpOpLte: + return "<=" + case expr.CmpOpGt: + return ">" + case expr.CmpOpGte: + return ">=" + default: + return fmt.Sprintf("op-%d", op) + } +} + +// formatExpr formats an expression in nft-like syntax +func formatExpr(exp expr.Any) string { + switch e := exp.(type) { + case *expr.Meta: + return formatMeta(e) + case *expr.Cmp: + return formatCmp(e) + case *expr.Payload: + return formatPayload(e) + case *expr.Verdict: + return formatVerdict(e) + case *expr.Counter: + return fmt.Sprintf("counter packets %d bytes %d", e.Packets, e.Bytes) + case *expr.Masq: + return "masquerade" + case *expr.NAT: + return formatNat(e) + case *expr.Match: + return formatMatch(e) + case *expr.Queue: + return fmt.Sprintf("queue num %d", e.Num) + case *expr.Lookup: + return fmt.Sprintf("@%s", e.SetName) + case *expr.Bitwise: + return formatBitwise(e) + case *expr.Fib: + return formatFib(e) + case *expr.Target: + return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets + case *expr.Immediate: + if e.Register == 1 { + return formatImmediateData(e.Data) + } + return fmt.Sprintf("immediate %v", e.Data) + default: + return fmt.Sprintf("<%T>", exp) + } +} + +func formatImmediateData(data []byte) string { + // For IP addresses (4 bytes) + if len(data) == 4 { + return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) + } + return fmt.Sprintf("%v", data) +} + +func formatMeta(e *expr.Meta) string { + // Handle source register case first (meta mark set) + if e.SourceRegister { + return fmt.Sprintf("meta %s set reg %d", formatMetaKey(e.Key), e.Register) + } + + // For interface names, handle register load operation + switch e.Key { + case expr.MetaKeyIIFNAME, + expr.MetaKeyOIFNAME, + expr.MetaKeyBRIIIFNAME, + expr.MetaKeyBRIOIFNAME: + // Simply the key name with no register reference + return formatMetaKey(e.Key) + + case expr.MetaKeyMARK: + // For mark operations, we want just "mark" + return "mark" + } + + // For other meta keys, show as loading into register + return fmt.Sprintf("meta %s => reg %d", formatMetaKey(e.Key), e.Register) +} + +func formatMetaKey(key expr.MetaKey) string { + switch key { + case expr.MetaKeyLEN: + return "length" + case expr.MetaKeyPROTOCOL: + return "protocol" + case expr.MetaKeyPRIORITY: + return "priority" + case expr.MetaKeyMARK: + return "mark" + case expr.MetaKeyIIF: + return "iif" + case expr.MetaKeyOIF: + return "oif" + case expr.MetaKeyIIFNAME: + return "iifname" + case expr.MetaKeyOIFNAME: + return "oifname" + case expr.MetaKeyIIFTYPE: + return "iiftype" + case expr.MetaKeyOIFTYPE: + return "oiftype" + case expr.MetaKeySKUID: + return "skuid" + case expr.MetaKeySKGID: + return "skgid" + case expr.MetaKeyNFTRACE: + return "nftrace" + case expr.MetaKeyRTCLASSID: + return "rtclassid" + case expr.MetaKeySECMARK: + return "secmark" + case expr.MetaKeyNFPROTO: + return "nfproto" + case expr.MetaKeyL4PROTO: + return "l4proto" + case expr.MetaKeyBRIIIFNAME: + return "briifname" + case expr.MetaKeyBRIOIFNAME: + return "broifname" + case expr.MetaKeyPKTTYPE: + return "pkttype" + case expr.MetaKeyCPU: + return "cpu" + case expr.MetaKeyIIFGROUP: + return "iifgroup" + case expr.MetaKeyOIFGROUP: + return "oifgroup" + case expr.MetaKeyCGROUP: + return "cgroup" + case expr.MetaKeyPRANDOM: + return "prandom" + default: + return fmt.Sprintf("meta-%d", key) + } +} + +func formatCmp(e *expr.Cmp) string { + ops := map[expr.CmpOp]string{ + expr.CmpOpEq: "==", + expr.CmpOpNeq: "!=", + expr.CmpOpLt: "<", + expr.CmpOpLte: "<=", + expr.CmpOpGt: ">", + expr.CmpOpGte: ">=", + } + return fmt.Sprintf("%s %v", ops[e.Op], e.Data) +} + +func formatPayload(e *expr.Payload) string { + var proto string + switch e.Base { + case expr.PayloadBaseNetworkHeader: + proto = "ip" + case expr.PayloadBaseTransportHeader: + proto = "tcp" + default: + proto = fmt.Sprintf("payload-%d", e.Base) + } + return fmt.Sprintf("%s reg%d [%d:%d]", proto, e.DestRegister, e.Offset, e.Len) +} + +func formatVerdict(e *expr.Verdict) string { + switch e.Kind { + case expr.VerdictAccept: + return "accept" + case expr.VerdictDrop: + return "drop" + case expr.VerdictJump: + return fmt.Sprintf("jump %s", e.Chain) + case expr.VerdictGoto: + return fmt.Sprintf("goto %s", e.Chain) + case expr.VerdictReturn: + return "return" + default: + return fmt.Sprintf("verdict-%d", e.Kind) + } +} + +func formatNat(e *expr.NAT) string { + switch e.Type { + case expr.NATTypeSourceNAT: + return "snat" + case expr.NATTypeDestNAT: + return "dnat" + default: + return fmt.Sprintf("nat-%d", e.Type) + } +} + +func formatMatch(e *expr.Match) string { + return fmt.Sprintf("match %s rev %d", e.Name, e.Rev) +} + +func formatBitwise(e *expr.Bitwise) string { + return fmt.Sprintf("bitwise reg%d = reg%d & %v ^ %v", + e.DestRegister, e.SourceRegister, e.Mask, e.Xor) +} + +func formatFib(e *expr.Fib) string { + var flags []string + if e.FlagSADDR { + flags = append(flags, "saddr") + } + if e.FlagDADDR { + flags = append(flags, "daddr") + } + if e.FlagMARK { + flags = append(flags, "mark") + } + if e.FlagIIF { + flags = append(flags, "iif") + } + if e.FlagOIF { + flags = append(flags, "oif") + } + if e.ResultADDRTYPE { + flags = append(flags, "type") + } + return fmt.Sprintf("fib reg%d %s", e.Register, strings.Join(flags, ",")) +} + +func formatSet(conn *nftables.Conn, set *nftables.Set) string { + var builder strings.Builder + builder.WriteString(fmt.Sprintf("\tset %s {\n", set.Name)) + builder.WriteString(fmt.Sprintf("\t\ttype %s\n", formatSetKeyType(set.KeyType))) + if set.ID > 0 { + builder.WriteString(fmt.Sprintf("\t\t# handle %d\n", set.ID)) + } + + elements, err := conn.GetSetElements(set) + if err != nil { + log.Warnf("Failed to get elements for set %s: %v", set.Name, err) + } else if len(elements) > 0 { + builder.WriteString("\t\telements = {") + for i, elem := range elements { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(fmt.Sprintf("%v", elem.Key)) + } + builder.WriteString("}\n") + } + + builder.WriteString("\t}\n") + return builder.String() +} + +func formatSetKeyType(keyType nftables.SetDatatype) string { + switch keyType { + case nftables.TypeInvalid: + return "invalid" + case nftables.TypeIPAddr: + return "ipv4_addr" + case nftables.TypeIP6Addr: + return "ipv6_addr" + case nftables.TypeEtherAddr: + return "ether_addr" + case nftables.TypeInetProto: + return "inet_proto" + case nftables.TypeInetService: + return "inet_service" + case nftables.TypeMark: + return "mark" + default: + return fmt.Sprintf("type-%v", keyType) + } +} diff --git a/client/server/debug_nonlinux.go b/client/server/debug_nonlinux.go new file mode 100644 index 00000000000..c54ac9b6ebb --- /dev/null +++ b/client/server/debug_nonlinux.go @@ -0,0 +1,15 @@ +//go:build !linux || android + +package server + +import ( + "archive/zip" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/proto" +) + +// collectFirewallRules returns nothing on non-linux systems +func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + return nil +} diff --git a/client/server/debug_test.go b/client/server/debug_test.go index c8f7bae5d38..ebd0bffbc68 100644 --- a/client/server/debug_test.go +++ b/client/server/debug_test.go @@ -428,3 +428,116 @@ func isInCGNATRange(ip net.IP) bool { } return cgnat.Contains(ip) } + +func TestAnonymizeFirewallRules(t *testing.T) { + // TODO: Add ipv6 + + // Example iptables-save output + iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s 192.168.1.0/24 -j ACCEPT +-A INPUT -s 44.192.140.1/32 -j DROP +-A FORWARD -s 10.0.0.0/8 -j DROP +-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT +COMMIT + +*nat +:PREROUTING ACCEPT [0:0] +:INPUT ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +:POSTROUTING ACCEPT [0:0] +-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE +-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80 +COMMIT` + + // Example iptables -v -n -L output + iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination + 0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0 + 100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0 + +Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination + 0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0 + 25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24 + +Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination` + + // Example nftables output + nftablesRules := `table inet filter { + chain input { + type filter hook input priority filter; policy accept; + ip saddr 192.168.1.1 accept + ip saddr 44.192.140.1 drop + } + chain forward { + type filter hook forward priority filter; policy accept; + ip saddr 10.0.0.0/8 drop + ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + } + }` + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Test iptables-save anonymization + anonIptablesSave := anonymizer.AnonymizeString(iptablesSave) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonIptablesSave, "192.168.1.0/24") + assert.Contains(t, anonIptablesSave, "10.0.0.0/8") + assert.Contains(t, anonIptablesSave, "192.168.100.0/24") + assert.Contains(t, anonIptablesSave, "192.168.1.10") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonIptablesSave, "44.192.140.1") + assert.NotContains(t, anonIptablesSave, "44.192.140.0/24") + assert.NotContains(t, anonIptablesSave, "52.84.12.34") + assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range + + // Structure should be preserved + assert.Contains(t, anonIptablesSave, "*filter") + assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]") + assert.Contains(t, anonIptablesSave, "COMMIT") + assert.Contains(t, anonIptablesSave, "-j MASQUERADE") + assert.Contains(t, anonIptablesSave, "--dport 80") + + // Test iptables verbose output anonymization + anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24") + assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonIptablesVerbose, "44.192.140.1") + assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24") + assert.NotContains(t, anonIptablesVerbose, "52.84.12.34") + assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range + + // Structure and counters should be preserved + assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)") + assert.Contains(t, anonIptablesVerbose, "100 1024 DROP") + assert.Contains(t, anonIptablesVerbose, "pkts bytes target") + + // Test nftables anonymization + anonNftables := anonymizer.AnonymizeString(nftablesRules) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonNftables, "192.168.1.1") + assert.Contains(t, anonNftables, "10.0.0.0/8") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonNftables, "44.192.140.1") + assert.NotContains(t, anonNftables, "44.192.140.0/24") + assert.NotContains(t, anonNftables, "52.84.12.34") + assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range + + // Structure should be preserved + assert.Contains(t, anonNftables, "table inet filter {") + assert.Contains(t, anonNftables, "chain input {") + assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") +} From ad9f044aadadf6a77facd418373e9aad32095443 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 23 Dec 2024 18:22:17 +0100 Subject: [PATCH 19/19] [client] Add stateful userspace firewall and remove egress filters (#3093) - Add stateful firewall functionality for UDP/TCP/ICMP in userspace firewalll - Removes all egress drop rules/filters, still needs refactoring so we don't add output rules to any chains/filters. - on Linux, if the OUTPUT policy is DROP then we don't do anything about it (no extra allow rules). This is up to the user, if they don't want anything leaving their machine they'll have to manage these rules explicitly. --- client/firewall/iptables/acl_linux.go | 6 - client/firewall/iptables/manager_linux.go | 14 +- client/firewall/nftables/acl_linux.go | 53 - client/firewall/uspfilter/allow_netbird.go | 20 +- .../uspfilter/allow_netbird_windows.go | 16 + client/firewall/uspfilter/conntrack/common.go | 138 +++ .../uspfilter/conntrack/common_test.go | 115 ++ client/firewall/uspfilter/conntrack/icmp.go | 170 +++ .../firewall/uspfilter/conntrack/icmp_test.go | 39 + client/firewall/uspfilter/conntrack/tcp.go | 376 +++++++ .../firewall/uspfilter/conntrack/tcp_test.go | 311 ++++++ client/firewall/uspfilter/conntrack/udp.go | 158 +++ .../firewall/uspfilter/conntrack/udp_test.go | 243 +++++ client/firewall/uspfilter/uspfilter.go | 248 ++++- .../uspfilter/uspfilter_bench_test.go | 998 ++++++++++++++++++ client/firewall/uspfilter/uspfilter_test.go | 309 +++++- 16 files changed, 3099 insertions(+), 115 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/common.go create mode 100644 client/firewall/uspfilter/conntrack/common_test.go create mode 100644 client/firewall/uspfilter/conntrack/icmp.go create mode 100644 client/firewall/uspfilter/conntrack/icmp_test.go create mode 100644 client/firewall/uspfilter/conntrack/tcp.go create mode 100644 client/firewall/uspfilter/conntrack/tcp_test.go create mode 100644 client/firewall/uspfilter/conntrack/udp.go create mode 100644 client/firewall/uspfilter/conntrack/udp_test.go create mode 100644 client/firewall/uspfilter/uspfilter_bench_test.go diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 1c0527ebc78..d774f45381b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error { // The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index adb8f20ef5c..0e1e5836f39 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -207,19 +207,9 @@ func (m *Manager) AllowNetbird() error { "", ) if err != nil { - return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + return fmt.Errorf("allow netbird interface traffic: %w", err) } - _, err = m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), - "all", - nil, - nil, - firewall.RuleDirectionOUT, - firewall.ActionAccept, - "", - "", - ) - return err + return nil } // Flush doesn't need to be implemented for this manager diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index abe890fb9a1..852cfec8de6 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "net" - "net/netip" "strconv" "strings" "time" @@ -28,7 +27,6 @@ const ( // filter chains contains the rules that jump to the rules chains chainNameInputFilter = "netbird-acl-input-filter" - chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" chainNamePrerouting = "netbird-rt-prerouting" @@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) { return err } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addFwdAllow(chain, expr.MetaKeyOIFNAME) - m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules - m.addDropExpressions(chain, expr.MetaKeyOIFNAME) - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err) - return err - } - // netbird-acl-forward-filter chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd @@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - dstOp := expr.CmpOpNeq - expressions := []expr.Any{ - &expr.Meta{Key: iifname, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cefc81a3ce6..cc07922559d 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,7 +2,10 @@ package uspfilter -import "github.com/netbirdio/netbird/client/internal/statemanager" +import ( + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/internal/statemanager" +) // Reset firewall to the default state func (m *Manager) Reset(stateManager *statemanager.Manager) error { @@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index d3732301ed5..0d55d62689c 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go new file mode 100644 index 00000000000..a4b1971bf6e --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common.go @@ -0,0 +1,138 @@ +// common.go +package conntrack + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// BaseConnTrack provides common fields and locking for all connection types +type BaseConnTrack struct { + sync.RWMutex + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access + established atomic.Bool +} + +// these small methods will be inlined by the compiler + +// UpdateLastSeen safely updates the last seen timestamp +func (b *BaseConnTrack) UpdateLastSeen() { + b.lastSeen.Store(time.Now().UnixNano()) +} + +// IsEstablished safely checks if connection is established +func (b *BaseConnTrack) IsEstablished() bool { + return b.established.Load() +} + +// SetEstablished safely sets the established state +func (b *BaseConnTrack) SetEstablished(state bool) { + b.established.Store(state) +} + +// GetLastSeen safely gets the last seen timestamp +func (b *BaseConnTrack) GetLastSeen() time.Time { + return time.Unix(0, b.lastSeen.Load()) +} + +// timeoutExceeded checks if the connection has exceeded the given timeout +func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { + lastSeen := time.Unix(0, b.lastSeen.Load()) + return time.Since(lastSeen) > timeout +} + +// IPAddr is a fixed-size IP address to avoid allocations +type IPAddr [16]byte + +// MakeIPAddr creates an IPAddr from net.IP +func MakeIPAddr(ip net.IP) (addr IPAddr) { + // Optimization: check for v4 first as it's more common + if ip4 := ip.To4(); ip4 != nil { + copy(addr[12:], ip4) + } else { + copy(addr[:], ip.To16()) + } + return addr +} + +// ConnKey uniquely identifies a connection +type ConnKey struct { + SrcIP IPAddr + DstIP IPAddr + SrcPort uint16 + DstPort uint16 +} + +// makeConnKey creates a connection key +func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { + return ConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + } +} + +// ValidateIPs checks if IPs match without allocation +func ValidateIPs(connIP IPAddr, pktIP net.IP) bool { + if ip4 := pktIP.To4(); ip4 != nil { + // Compare IPv4 addresses (last 4 bytes) + for i := 0; i < 4; i++ { + if connIP[12+i] != ip4[i] { + return false + } + } + return true + } + // Compare full IPv6 addresses + ip6 := pktIP.To16() + for i := 0; i < 16; i++ { + if connIP[i] != ip6[i] { + return false + } + } + return true +} + +// PreallocatedIPs is a pool of IP byte slices to reduce allocations +type PreallocatedIPs struct { + sync.Pool +} + +// NewPreallocatedIPs creates a new IP pool +func NewPreallocatedIPs() *PreallocatedIPs { + return &PreallocatedIPs{ + Pool: sync.Pool{ + New: func() interface{} { + ip := make(net.IP, 16) + return &ip + }, + }, + } +} + +// Get retrieves an IP from the pool +func (p *PreallocatedIPs) Get() net.IP { + return *p.Pool.Get().(*net.IP) +} + +// Put returns an IP to the pool +func (p *PreallocatedIPs) Put(ip net.IP) { + p.Pool.Put(&ip) +} + +// copyIP copies an IP address efficiently +func copyIP(dst, src net.IP) { + if len(src) == 16 { + copy(dst, src) + } else { + // Handle IPv4 + copy(dst[12:], src.To4()) + } +} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go new file mode 100644 index 00000000000..72d006def57 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -0,0 +1,115 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkIPOperations(b *testing.B) { + b.Run("MakeIPAddr", func(b *testing.B) { + ip := net.ParseIP("192.168.1.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = MakeIPAddr(ip) + } + }) + + b.Run("ValidateIPs", func(b *testing.B) { + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.1") + addr := MakeIPAddr(ip1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ValidateIPs(addr, ip2) + } + }) + + b.Run("IPPool", func(b *testing.B) { + pool := NewPreallocatedIPs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ip := pool.Get() + pool.Put(ip) + } + }) + +} +func BenchmarkAtomicOperations(b *testing.B) { + conn := &BaseConnTrack{} + b.Run("UpdateLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.UpdateLastSeen() + } + }) + + b.Run("IsEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.IsEstablished() + } + }) + + b.Run("SetEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.SetEstablished(i%2 == 0) + } + }) + + b.Run("GetLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.GetLastSeen() + } + }) +} + +// Memory pressure tests +func BenchmarkMemoryPressure(b *testing.B) { + b.Run("TCPHighLoad", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + } + } + }) + + b.Run("UDPHighLoad", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + } + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go new file mode 100644 index 00000000000..e0a971678f1 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -0,0 +1,170 @@ +package conntrack + +import ( + "net" + "sync" + "time" + + "github.com/google/gopacket/layers" +) + +const ( + // DefaultICMPTimeout is the default timeout for ICMP connections + DefaultICMPTimeout = 30 * time.Second + // ICMPCleanupInterval is how often we check for stale ICMP connections + ICMPCleanupInterval = 15 * time.Second +) + +// ICMPConnKey uniquely identifies an ICMP connection +type ICMPConnKey struct { + // Supports both IPv4 and IPv6 + SrcIP [16]byte + DstIP [16]byte + Sequence uint16 // ICMP sequence number + ID uint16 // ICMP identifier +} + +// ICMPConnTrack represents an ICMP connection state +type ICMPConnTrack struct { + BaseConnTrack + Sequence uint16 + ID uint16 +} + +// ICMPTracker manages ICMP connection states +type ICMPTracker struct { + connections map[ICMPConnKey]*ICMPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewICMPTracker creates a new ICMP connection tracker +func NewICMPTracker(timeout time.Duration) *ICMPTracker { + if timeout == 0 { + timeout = DefaultICMPTimeout + } + + tracker := &ICMPTracker{ + connections: make(map[ICMPConnKey]*ICMPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(ICMPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound ICMP Echo Request +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + key := makeICMPKey(srcIP, dstIP, id, seq) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &ICMPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + }, + ID: id, + Sequence: seq, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { + switch icmpType { + case uint8(layers.ICMPv4TypeDestinationUnreachable), + uint8(layers.ICMPv4TypeTimeExceeded): + return true + case uint8(layers.ICMPv4TypeEchoReply): + // continue processing + default: + return false + } + + key := makeICMPKey(dstIP, srcIP, id, seq) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.ID == id && + conn.Sequence == seq +} + +func (t *ICMPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} +func (t *ICMPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *ICMPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// makeICMPKey creates an ICMP connection key +func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { + return ICMPConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + ID: id, + Sequence: seq, + } +} diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go new file mode 100644 index 00000000000..21176e719d4 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -0,0 +1,39 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkICMPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go new file mode 100644 index 00000000000..e8d20f41c67 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -0,0 +1,376 @@ +package conntrack + +// TODO: Send RST packets for invalid/timed-out connections + +import ( + "net" + "sync" + "time" +) + +const ( + // MSL (Maximum Segment Lifetime) is typically 2 minutes + MSL = 2 * time.Minute + // TimeWaitTimeout (TIME-WAIT) should last 2*MSL + TimeWaitTimeout = 2 * MSL +) + +const ( + TCPSyn uint8 = 0x02 + TCPAck uint8 = 0x10 + TCPFin uint8 = 0x01 + TCPRst uint8 = 0x04 + TCPPush uint8 = 0x08 + TCPUrg uint8 = 0x20 +) + +const ( + // DefaultTCPTimeout is the default timeout for established TCP connections + DefaultTCPTimeout = 3 * time.Hour + // TCPHandshakeTimeout is timeout for TCP handshake completion + TCPHandshakeTimeout = 60 * time.Second + // TCPCleanupInterval is how often we check for stale connections + TCPCleanupInterval = 5 * time.Minute +) + +// TCPState represents the state of a TCP connection +type TCPState int + +const ( + TCPStateNew TCPState = iota + TCPStateSynSent + TCPStateSynReceived + TCPStateEstablished + TCPStateFinWait1 + TCPStateFinWait2 + TCPStateClosing + TCPStateTimeWait + TCPStateCloseWait + TCPStateLastAck + TCPStateClosed +) + +// TCPConnKey uniquely identifies a TCP connection +type TCPConnKey struct { + SrcIP [16]byte + DstIP [16]byte + SrcPort uint16 + DstPort uint16 +} + +// TCPConnTrack represents a TCP connection state +type TCPConnTrack struct { + BaseConnTrack + State TCPState +} + +// TCPTracker manages TCP connection states +type TCPTracker struct { + connections map[ConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + done chan struct{} + timeout time.Duration + ipPool *PreallocatedIPs +} + +// NewTCPTracker creates a new TCP connection tracker +func NewTCPTracker(timeout time.Duration) *TCPTracker { + tracker := &TCPTracker{ + connections: make(map[ConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + done: make(chan struct{}), + timeout: timeout, + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound processes an outbound TCP packet and updates connection state +func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { + // Create key before lock + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + State: TCPStateNew, + } + conn.lastSeen.Store(now) + conn.established.Store(false) + t.connections[key] = conn + } + t.mutex.Unlock() + + // Lock individual connection for state update + conn.Lock() + t.updateState(conn, flags, true) + conn.Unlock() + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound TCP packet matches a tracked connection +func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + // Handle new SYN packets + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.Lock() + if _, exists := t.connections[key]; !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, dstIP) + copyIP(dstIPCopy, srcIP) + + conn := &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: dstPort, + DestPort: srcPort, + }, + State: TCPStateSynReceived, + } + conn.lastSeen.Store(time.Now().UnixNano()) + conn.established.Store(false) + t.connections[key] = conn + } + t.mutex.Unlock() + return true + } + + // Look up existing connection + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + // Handle RST packets + if flags&TCPRst != 0 { + conn.Lock() + isEstablished := conn.IsEstablished() + if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + conn.State = TCPStateClosed + conn.SetEstablished(false) + conn.Unlock() + return true + } + conn.Unlock() + return false + } + + // Update state + conn.Lock() + t.updateState(conn, flags, false) + conn.UpdateLastSeen() + isEstablished := conn.IsEstablished() + isValidState := t.isValidStateForFlags(conn.State, flags) + conn.Unlock() + + return isEstablished || isValidState +} + +// updateState updates the TCP connection state based on flags +func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { + // Handle RST flag specially - it always causes transition to closed + if flags&TCPRst != 0 { + conn.State = TCPStateClosed + conn.SetEstablished(false) + return + } + + switch conn.State { + case TCPStateNew: + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + conn.State = TCPStateSynSent + } + + case TCPStateSynSent: + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if isOutbound { + conn.State = TCPStateSynReceived + } else { + // Simultaneous open + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + } + + case TCPStateSynReceived: + if flags&TCPAck != 0 && flags&TCPSyn == 0 { + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + + case TCPStateEstablished: + if flags&TCPFin != 0 { + if isOutbound { + conn.State = TCPStateFinWait1 + } else { + conn.State = TCPStateCloseWait + } + conn.SetEstablished(false) + } + + case TCPStateFinWait1: + switch { + case flags&TCPFin != 0 && flags&TCPAck != 0: + // Simultaneous close - both sides sent FIN + conn.State = TCPStateClosing + case flags&TCPFin != 0: + conn.State = TCPStateFinWait2 + case flags&TCPAck != 0: + conn.State = TCPStateFinWait2 + } + + case TCPStateFinWait2: + if flags&TCPFin != 0 { + conn.State = TCPStateTimeWait + } + + case TCPStateClosing: + if flags&TCPAck != 0 { + conn.State = TCPStateTimeWait + // Keep established = false from previous state + } + + case TCPStateCloseWait: + if flags&TCPFin != 0 { + conn.State = TCPStateLastAck + } + + case TCPStateLastAck: + if flags&TCPAck != 0 { + conn.State = TCPStateClosed + } + + case TCPStateTimeWait: + // Stay in TIME-WAIT for 2MSL before transitioning to closed + // This is handled by the cleanup routine + } +} + +// isValidStateForFlags checks if the TCP flags are valid for the current connection state +func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + switch state { + case TCPStateNew: + return flags&TCPSyn != 0 && flags&TCPAck == 0 + case TCPStateSynSent: + return flags&TCPSyn != 0 && flags&TCPAck != 0 + case TCPStateSynReceived: + return flags&TCPAck != 0 + case TCPStateEstablished: + if flags&TCPRst != 0 { + return true + } + return flags&TCPAck != 0 + case TCPStateFinWait1: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateFinWait2: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateClosing: + // In CLOSING state, we should accept the final ACK + return flags&TCPAck != 0 + case TCPStateTimeWait: + // In TIME_WAIT, we might see retransmissions + return flags&TCPAck != 0 + case TCPStateCloseWait: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateLastAck: + return flags&TCPAck != 0 + } + return false +} + +func (t *TCPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *TCPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + var timeout time.Duration + switch { + case conn.State == TCPStateTimeWait: + timeout = TimeWaitTimeout + case conn.IsEstablished(): + timeout = t.timeout + default: + timeout = TCPHandshakeTimeout + } + + lastSeen := conn.GetLastSeen() + if time.Since(lastSeen) > timeout { + // Return IPs to pool + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *TCPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + // Clean up all remaining IPs + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +func isValidFlagCombination(flags uint8) bool { + // Invalid: SYN+FIN + if flags&TCPSyn != 0 && flags&TCPFin != 0 { + return false + } + + // Invalid: RST with SYN or FIN + if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) { + return false + } + + return true +} diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go new file mode 100644 index 00000000000..3933c888943 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -0,0 +1,311 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTCPStateMachine(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Security Tests", func(t *testing.T) { + tests := []struct { + name string + flags uint8 + wantDrop bool + desc string + }{ + { + name: "Block unsolicited SYN-ACK", + flags: TCPSyn | TCPAck, + wantDrop: true, + desc: "Should block SYN-ACK without prior SYN", + }, + { + name: "Block invalid SYN-FIN", + flags: TCPSyn | TCPFin, + wantDrop: true, + desc: "Should block invalid SYN-FIN combination", + }, + { + name: "Block unsolicited RST", + flags: TCPRst, + wantDrop: true, + desc: "Should block RST without connection", + }, + { + name: "Block unsolicited ACK", + flags: TCPAck, + wantDrop: true, + desc: "Should block ACK without connection", + }, + { + name: "Block data without connection", + flags: TCPAck | TCPPush, + wantDrop: true, + desc: "Should block data without established connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) + require.Equal(t, !tt.wantDrop, isValid, tt.desc) + }) + } + }) + + t.Run("Connection Flow Tests", func(t *testing.T) { + tests := []struct { + name string + test func(*testing.T) + desc string + }{ + { + name: "Normal Handshake", + test: func(t *testing.T) { + t.Helper() + + // Send initial SYN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + // Receive SYN-ACK + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + // Send ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + + // Test data transfer + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + require.True(t, valid, "Data should be allowed after handshake") + }, + }, + { + name: "Normal Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Send FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + + // Receive ACK for FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "ACK for FIN should be allowed") + + // Receive FIN from other side + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "FIN should be allowed") + + // Send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + }, + { + name: "RST During Connection", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Receive RST + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + require.True(t, valid, "RST should be allowed for established connection") + + // Verify connection is closed + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + t.Helper() + + require.False(t, valid, "Data should be blocked after RST") + }, + }, + { + name: "Simultaneous Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Both sides send FIN+ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "Simultaneous FIN should be allowed") + + // Both sides send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "Final ACKs should be allowed") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + + tracker = NewTCPTracker(DefaultTCPTimeout) + tt.test(t) + }) + } + }) +} + +func TestRSTHandling(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + tests := []struct { + name string + setupState func() + sendRST func() + wantValid bool + desc string + }{ + { + name: "RST in established", + setupState: func() { + // Establish connection first + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: true, + desc: "Should accept RST for established connection", + }, + { + name: "RST without connection", + setupState: func() {}, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: false, + desc: "Should reject RST without connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + tt.sendRST() + + // Verify connection state is as expected + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + if tt.wantValid { + require.NotNil(t, conn) + require.Equal(t, TCPStateClosed, conn.State) + require.False(t, conn.IsEstablished()) + } + }) + } +} + +// Helper to establish a TCP connection +func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { + t.Helper() + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) +} + +func BenchmarkTCPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } + }) + + b.Run("ConcurrentAccess", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } else { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + } + i++ + } + }) + }) +} + +// Benchmark connection cleanup +func BenchmarkCleanup(b *testing.B) { + b.Run("TCPCleanup", func(b *testing.B) { + tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + defer tracker.Close() + + // Pre-populate with expired connections + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + for i := 0; i < 10000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + // Wait for connections to expire + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.cleanup() + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go new file mode 100644 index 00000000000..a969a4e8425 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -0,0 +1,158 @@ +package conntrack + +import ( + "net" + "sync" + "time" +) + +const ( + // DefaultUDPTimeout is the default timeout for UDP connections + DefaultUDPTimeout = 30 * time.Second + // UDPCleanupInterval is how often we check for stale connections + UDPCleanupInterval = 15 * time.Second +) + +// UDPConnTrack represents a UDP connection state +type UDPConnTrack struct { + BaseConnTrack +} + +// UDPTracker manages UDP connection states +type UDPTracker struct { + connections map[ConnKey]*UDPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewUDPTracker creates a new UDP connection tracker +func NewUDPTracker(timeout time.Duration) *UDPTracker { + if timeout == 0 { + timeout = DefaultUDPTimeout + } + + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(UDPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound UDP connection +func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &UDPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound packet matches a tracked connection +func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.DestPort == srcPort && + conn.SourcePort == dstPort +} + +// cleanupRoutine periodically removes stale connections +func (t *UDPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *UDPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *UDPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// GetConnection safely retrieves a connection state +func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := t.connections[key] + if !exists { + return nil, false + } + + return conn, true +} + +// Timeout returns the configured timeout duration for the tracker +func (t *UDPTracker) Timeout() time.Duration { + return t.timeout +} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go new file mode 100644 index 00000000000..67172189069 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -0,0 +1,243 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewUDPTracker(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + wantTimeout time.Duration + }{ + { + name: "with custom timeout", + timeout: 1 * time.Minute, + wantTimeout: 1 * time.Minute, + }, + { + name: "with zero timeout uses default", + timeout: 0, + wantTimeout: DefaultUDPTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker := NewUDPTracker(tt.timeout) + assert.NotNil(t, tracker) + assert.Equal(t, tt.wantTimeout, tracker.timeout) + assert.NotNil(t, tracker.connections) + assert.NotNil(t, tracker.cleanupTicker) + assert.NotNil(t, tracker.done) + }) + } +} + +func TestUDPTracker_TrackOutbound(t *testing.T) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + // Verify connection was tracked + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := tracker.connections[key] + require.True(t, exists) + assert.True(t, conn.SourceIP.Equal(srcIP)) + assert.True(t, conn.DestIP.Equal(dstIP)) + assert.Equal(t, srcPort, conn.SourcePort) + assert.Equal(t, dstPort, conn.DestPort) + assert.True(t, conn.IsEstablished()) + assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) +} + +func TestUDPTracker_IsValidInbound(t *testing.T) { + tracker := NewUDPTracker(1 * time.Second) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + // Track outbound connection + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + tests := []struct { + name string + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + sleep time.Duration + want bool + }{ + { + name: "valid inbound response", + srcIP: dstIP, // Original destination is now source + dstIP: srcIP, // Original source is now destination + srcPort: dstPort, // Original destination port is now source + dstPort: srcPort, // Original source port is now destination + sleep: 0, + want: true, + }, + { + name: "invalid source IP", + srcIP: net.ParseIP("192.168.1.4"), + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination IP", + srcIP: dstIP, + dstIP: net.ParseIP("192.168.1.4"), + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid source port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: 54321, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: 54321, + sleep: 0, + want: false, + }, + { + name: "expired connection", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 2 * time.Second, // Longer than tracker timeout + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sleep > 0 { + time.Sleep(tt.sleep) + } + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUDPTracker_Cleanup(t *testing.T) { + // Use shorter intervals for testing + timeout := 50 * time.Millisecond + cleanupInterval := 25 * time.Millisecond + + // Create tracker with custom cleanup interval + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(cleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + // Start cleanup routine + go tracker.cleanupRoutine() + + // Add some connections + connections := []struct { + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + }{ + { + srcIP: net.ParseIP("192.168.1.2"), + dstIP: net.ParseIP("192.168.1.3"), + srcPort: 12345, + dstPort: 53, + }, + { + srcIP: net.ParseIP("192.168.1.4"), + dstIP: net.ParseIP("192.168.1.5"), + srcPort: 12346, + dstPort: 53, + }, + } + + for _, conn := range connections { + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + } + + // Verify initial connections + assert.Len(t, tracker.connections, 2) + + // Wait for connection timeout and cleanup interval + time.Sleep(timeout + 2*cleanupInterval) + + tracker.mutex.RLock() + connCount := len(tracker.connections) + tracker.mutex.RUnlock() + + // Verify connections were cleaned up + assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up") + + // Properly close the tracker + tracker.Close() +} + +func BenchmarkUDPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + } + }) +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index fb726395bef..24cfd6e9691 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "net/netip" + "os" + "strconv" "sync" "github.com/google/gopacket" @@ -12,6 +14,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -19,6 +22,8 @@ import ( const layerTypeAll = 0 +const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) @@ -42,6 +47,11 @@ type Manager struct { nativeFirewall firewall.Manager mutex sync.RWMutex + + stateful bool + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker } // decoder for packages @@ -73,6 +83,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager } func create(iface IFaceMapper) (*Manager, error) { + disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + m := &Manager{ decoders: sync.Pool{ New: func() any { @@ -90,6 +102,16 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, + stateful: !disableConntrack, + } + + // Only initialize trackers if stateful mode is enabled + if disableConntrack { + log.Info("conntrack is disabled") + } else { + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } if err := iface.SetFilter(m); err != nil { @@ -249,16 +271,16 @@ func (m *Manager) Flush() error { return nil } // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte) bool { - return m.dropFilter(packetData, m.outgoingRules, false) + return m.processOutgoingHooks(packetData) } // DropIncoming filter incoming packets func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData, m.incomingRules, true) + return m.dropFilter(packetData, m.incomingRules) } -// dropFilter implements same logic for booth direction of the traffic -func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool { +// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP +func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -266,61 +288,213 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco defer m.decoders.Put(d) if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - log.Tracef("couldn't decode layer, err: %s", err) - return true + return false } if len(d.decoded) < 2 { - log.Tracef("not enough levels in network packet") - return true + return false } - ipLayer := d.decoded[0] + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { + return false + } - switch ipLayer { - case layers.LayerTypeIPv4: - if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) { - return false + // Always process UDP hooks + if d.decoded[1] == layers.LayerTypeUDP { + // Track UDP state only if enabled + if m.stateful { + m.trackUDPOutbound(d, srcIP, dstIP) } - case layers.LayerTypeIPv6: - if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) { - return false + return m.checkUDPHooks(d, dstIP, packetData) + } + + // Track other protocols only if stateful mode is enabled + if m.stateful { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.trackTCPOutbound(d, srcIP, dstIP) + case layers.LayerTypeICMPv4: + m.trackICMPOutbound(d, srcIP, dstIP) } - default: - log.Errorf("unknown layer: %v", d.decoded[0]) - return true } - var ip net.IP - switch ipLayer { + return false +} + +func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { + switch d.decoded[0] { case layers.LayerTypeIPv4: - if isIncomingPacket { - ip = d.ip4.SrcIP - } else { - ip = d.ip4.DstIP - } + return d.ip4.SrcIP, d.ip4.DstIP case layers.LayerTypeIPv6: - if isIncomingPacket { - ip = d.ip6.SrcIP - } else { - ip = d.ip6.DstIP + return d.ip6.SrcIP, d.ip6.DstIP + default: + return nil, nil + } +} + +func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) { + flags := getTCPFlags(&d.tcp) + m.tcpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + flags, + ) +} + +func getTCPFlags(tcp *layers.TCP) uint8 { + var flags uint8 + if tcp.SYN { + flags |= conntrack.TCPSyn + } + if tcp.ACK { + flags |= conntrack.TCPAck + } + if tcp.FIN { + flags |= conntrack.TCPFin + } + if tcp.RST { + flags |= conntrack.TCPRst + } + if tcp.PSH { + flags |= conntrack.TCPPush + } + if tcp.URG { + flags |= conntrack.TCPUrg + } + return flags +} + +func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) +} + +func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) + } + } } } + return false +} + +func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + ) + } +} + +// dropFilter implements filtering logic for incoming packets +func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if !m.isValidPacket(d, packetData) { + return true + } + + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { + log.Errorf("unknown layer: %v", d.decoded[0]) + return true + } + + if !m.isWireguardTraffic(srcIP, dstIP) { + return false + } + + // Check connection state only if enabled + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + return false + } - filter, ok := validateRule(ip, packetData, rules[ip.String()], d) - if ok { + return m.applyRules(srcIP, packetData, rules, d) +} + +func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + log.Tracef("couldn't decode layer, err: %s", err) + return false + } + + if len(d.decoded) < 2 { + log.Tracef("not enough levels in network packet") + return false + } + return true +} + +func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { + return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) +} + +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return m.tcpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + getTCPFlags(&d.tcp), + ) + + case layers.LayerTypeUDP: + return m.udpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + case layers.LayerTypeICMPv4: + return m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + d.icmp4.TypeCode.Type(), + ) + + // TODO: ICMPv6 + } + + return false +} + +func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { + if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { return filter } - filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d) - if ok { + + if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { return filter } - filter, ok = validateRule(ip, packetData, rules["::"], d) - if ok { + + if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { return filter } - // default policy is DROP ALL + // Default policy: DROP ALL return true } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go new file mode 100644 index 00000000000..3c661e71c70 --- /dev/null +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -0,0 +1,998 @@ +package uspfilter + +import ( + "fmt" + "math/rand" + "net" + "os" + "strings" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface/device" +) + +// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range +func generateRandomIPs(n int) []net.IP { + ips := make([]net.IP, n) + seen := make(map[string]bool) + + for i := 0; i < n; { + ip := make(net.IP, 4) + ip[0] = 100 + ip[1] = byte(64 + rand.Intn(63)) // 64-126 + ip[2] = byte(rand.Intn(256)) + ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255 + + key := ip.String() + if !seen[key] { + ips[i] = ip + seen[key] = true + i++ + } + } + return ips +} + +func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: protocol, + } + + var transportLayer gopacket.SerializableLayer + switch protocol { + case layers.IPProtocolTCP: + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + } + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = tcp + case layers.IPProtocolUDP: + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = udp + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test")) + require.NoError(b, err) + return buf.Bytes() +} + +// BenchmarkCoreFiltering focuses on the essential performance comparisons between +// stateful and stateless filtering approaches +func BenchmarkCoreFiltering(b *testing.B) { + scenarios := []struct { + name string + stateful bool + setupFunc func(*Manager) + desc string + }{ + { + name: "stateless_single_allow_all", + stateful: false, + setupFunc: func(m *Manager) { + // Single rule allowing all traffic + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") + require.NoError(b, err) + }, + desc: "Baseline: Single 'allow all' rule without connection tracking", + }, + { + name: "stateful_no_rules", + stateful: true, + setupFunc: func(m *Manager) { + // No explicit rules - rely purely on connection tracking + }, + desc: "Pure connection tracking without any rules", + }, + { + name: "stateless_explicit_return", + stateful: false, + setupFunc: func(m *Manager) { + // Add explicit rules matching return traffic pattern + for i := 0; i < 1000; i++ { // Simulate realistic ruleset size + ip := generateRandomIPs(1)[0] + _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, + &fw.Port{Values: []int{1024 + i}}, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") + require.NoError(b, err) + } + }, + desc: "Explicit rules matching return traffic patterns without state", + }, + { + name: "stateful_with_established", + stateful: true, + setupFunc: func(m *Manager) { + // Add some basic rules but rely on state for established connections + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, + fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") + require.NoError(b, err) + }, + desc: "Connection tracking with established connections", + }, + } + + // Test both TCP and UDP + protocols := []struct { + name string + proto layers.IPProtocol + }{ + {"TCP", layers.IPProtocolTCP}, + {"UDP", layers.IPProtocolUDP}, + } + + for _, sc := range scenarios { + for _, proto := range protocols { + b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + } else { + require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m")) + } + + // Create manager and basic setup + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Apply scenario-specific setup + sc.setupFunc(manager) + + // Generate test packets + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + srcPort := uint16(1024 + b.N%60000) + dstPort := uint16(80) + + outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto) + inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto) + + // For stateful scenarios, establish the connection + if sc.stateful { + manager.processOutgoingHooks(outbound) + } + + // Measure inbound packet processing + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } + } +} + +// BenchmarkStateScaling measures how performance scales with connection table size +func BenchmarkStateScaling(b *testing.B) { + connCounts := []int{100, 1000, 10000, 100000} + + for _, count := range connCounts { + b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Pre-populate connection table + srcIPs := generateRandomIPs(count) + dstIPs := generateRandomIPs(count) + for i := 0; i < count; i++ { + outbound := generatePacket(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, layers.IPProtocolTCP) + manager.processOutgoingHooks(outbound) + } + + // Test packet + testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP) + testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) + + // First establish our test connection + manager.processOutgoingHooks(testOut) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(testIn, manager.incomingRules) + } + }) + } +} + +// BenchmarkEstablishmentOverhead measures the overhead of connection establishment +func BenchmarkEstablishmentOverhead(b *testing.B) { + scenarios := []struct { + name string + established bool + }{ + {"established", true}, + {"new", false}, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) + inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + + if sc.established { + manager.processOutgoingHooks(outbound) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic +func BenchmarkRoutedNetworkReturn(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + state string // "new", "established", "post_handshake" (TCP only) + setupFunc func(*Manager) + genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario + desc string + }{ + { + name: "allow_non_wg_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Allow non-WG: TCP new connection", + }, + { + name: "allow_non_wg_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with ACK flag for established connection + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Allow non-WG: TCP established connection", + }, + { + name: "allow_non_wg_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP new connection", + }, + { + name: "allow_non_wg_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP established connection", + }, + { + name: "stateful_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Stateful: TCP new connection", + }, + { + name: "stateful_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate established TCP packets (ACK flag) + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Stateful: TCP established connection", + }, + { + name: "stateful_tcp_post_handshake", + proto: layers.IPProtocolTCP, + state: "post_handshake", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with PSH+ACK flags for data transfer + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + desc: "Stateful: TCP post-handshake data transfer", + }, + { + name: "stateful_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP new connection", + }, + { + name: "stateful_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP established connection", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + // Setup scenario + sc.setupFunc(manager) + + // Use IPs outside WG range for routed network simulation + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("8.8.8.8") + outbound, inbound := sc.genPackets(srcIP, dstIP) + + // For stateful cases and established connections + if !strings.Contains(sc.name, "allow_non_wg") || + (strings.Contains(sc.state, "established") || sc.state == "post_handshake") { + manager.processOutgoingHooks(outbound) + + // For TCP post-handshake, simulate full handshake + if sc.state == "post_handshake" { + // SYN + syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + // ACK + ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +var scenarios = []struct { + name string + stateful bool // Whether conntrack is enabled + rules bool // Whether to add return traffic rules + routed bool // Whether to test routed network traffic + connCount int // Number of concurrent connections + desc string +}{ + { + name: "stateless_with_rules_100conns", + stateful: false, + rules: true, + routed: false, + connCount: 100, + desc: "Pure stateless with return traffic rules, 100 conns", + }, + { + name: "stateless_with_rules_1000conns", + stateful: false, + rules: true, + routed: false, + connCount: 1000, + desc: "Pure stateless with return traffic rules, 1000 conns", + }, + { + name: "stateful_no_rules_100conns", + stateful: true, + rules: false, + routed: false, + connCount: 100, + desc: "Pure stateful tracking without rules, 100 conns", + }, + { + name: "stateful_no_rules_1000conns", + stateful: true, + rules: false, + routed: false, + connCount: 1000, + desc: "Pure stateful tracking without rules, 1000 conns", + }, + { + name: "stateful_with_rules_100conns", + stateful: true, + rules: true, + routed: false, + connCount: 100, + desc: "Combined stateful + rules (current implementation), 100 conns", + }, + { + name: "stateful_with_rules_1000conns", + stateful: true, + rules: true, + routed: false, + connCount: 1000, + desc: "Combined stateful + rules (current implementation), 1000 conns", + }, + { + name: "routed_network_100conns", + stateful: true, + rules: false, + routed: true, + connCount: 100, + desc: "Routed network traffic (non-WG), 100 conns", + }, + { + name: "routed_network_1000conns", + stateful: true, + rules: false, + routed: true, + connCount: 1000, + desc: "Routed network traffic (non-WG), 1000 conns", + }, +} + +// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns +func BenchmarkLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + // Initial SYN + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + // ACK + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Prepare test packets simulating bidirectional traffic + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + // Server -> Client (inbound) + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + // Client -> Server (outbound) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + connIdx := i % sc.connCount + + // Simulate bidirectional traffic + // First outbound data + manager.processOutgoingHooks(outPackets[connIdx]) + // Then inbound response - this is what we're actually measuring + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + } +} + +// BenchmarkShortLivedConnections tests performance with many short-lived connections +func BenchmarkShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create packet patterns for a complete HTTP-like short connection: + // 1. Initial handshake (SYN, SYN-ACK, ACK) + // 2. HTTP Request (PSH+ACK from client) + // 3. HTTP Response (PSH+ACK from server) + // 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK) + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + // Generate all possible connection patterns + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + // Handshake + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + + // Data transfer + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + + // Connection teardown + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Each iteration creates a new short-lived connection + connIdx := i % sc.connCount + p := patterns[connIdx] + + // Connection establishment + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + // Data transfer + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + // Connection teardown + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + } +} + +// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel +func BenchmarkParallelLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Pre-generate test packets + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + // Each goroutine gets its own counter to distribute load + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + + // Simulate bidirectional traffic + manager.processOutgoingHooks(outPackets[connIdx]) + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + }) + } +} + +// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel +func BenchmarkParallelShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs and pre-generate all packet patterns + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + p := patterns[connIdx] + + // Full connection lifecycle + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + }) + } +} + +// generateTCPPacketWithFlags creates a TCP packet with specific flags +func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolTCP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + } + + // Set TCP flags + tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 + tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 + tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 + tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 + tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d7c93cb7f99..d3563e6f251 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "sync" "testing" "time" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -185,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager := &Manager{ - incomingRules: map[string]RuleSet{}, - outgoingRules: map[string]RuleSet{}, - } + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -313,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) { t.Errorf("failed to set network layer for checksum: %v", err) return } - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -325,7 +327,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), m.outgoingRules, false) { + if m.dropFilter(buf.Bytes(), m.outgoingRules) { t.Errorf("expected packet to be accepted") return } @@ -348,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) { if err != nil { t.Fatalf("Failed to create Manager: %s", err) } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Add a UDP packet hook hookFunc := func(data []byte) bool { return true } @@ -384,6 +389,88 @@ func TestRemovePacketHook(t *testing.T) { } } +func TestProcessOutgoingHooks(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + manager.udpTracker.Close() + manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + + hookCalled := false + hookID := manager.AddUDPPacketHook( + false, + net.ParseIP("100.10.0.100"), + 53, + func([]byte) bool { + hookCalled = true + return true + }, + ) + require.NotEmpty(t, hookID) + + // Create test UDP packet + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: net.ParseIP("100.10.0.1"), + DstIP: net.ParseIP("100.10.0.100"), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: 51334, + DstPort: 53, + } + + err = udp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + payload := gopacket.Payload("test") + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload) + require.NoError(t, err) + + // Test hook gets called + result := manager.processOutgoingHooks(buf.Bytes()) + require.True(t, result) + require.True(t, hookCalled) + + // Test non-UDP packet is ignored + ipv4.Protocol = layers.IPProtocolTCP + buf = gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(buf, opts, ipv4) + require.NoError(t, err) + + result = manager.processOutgoingHooks(buf.Bytes()) + require.False(t, result) +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { @@ -418,3 +505,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) { }) } } + +func TestStatefulFirewall_UDPTracking(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + manager.udpTracker.Close() // Close the existing tracker + manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + // Set up packet parameters + srcIP := net.ParseIP("100.10.0.1") + dstIP := net.ParseIP("100.10.0.100") + srcPort := uint16(51334) + dstPort := uint16(53) + + // Create outbound packet + outboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolUDP, + } + outboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + + err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + require.NoError(t, err) + + outboundBuf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err = gopacket.SerializeLayers(outboundBuf, opts, + outboundIPv4, + outboundUDP, + gopacket.Payload("test"), + ) + require.NoError(t, err) + + // Process outbound packet and verify connection tracking + drop := manager.DropOutgoing(outboundBuf.Bytes()) + require.False(t, drop, "Initial outbound packet should not be dropped") + + // Verify connection was tracked + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + + require.True(t, exists, "Connection should be tracked after outbound packet") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") + require.Equal(t, srcPort, conn.SourcePort, "Source port should match") + require.Equal(t, dstPort, conn.DestPort, "Destination port should match") + + // Create valid inbound response packet + inboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: dstIP, // Original destination is now source + DstIP: srcIP, // Original source is now destination + Protocol: layers.IPProtocolUDP, + } + inboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(dstPort), // Original destination port is now source + DstPort: layers.UDPPort(srcPort), // Original source port is now destination + } + + err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4) + require.NoError(t, err) + + inboundBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(inboundBuf, opts, + inboundIPv4, + inboundUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + // Test roundtrip response handling over time + checkPoints := []struct { + sleep time.Duration + shouldAllow bool + description string + }{ + { + sleep: 0, + shouldAllow: true, + description: "Immediate response should be allowed", + }, + { + sleep: 50 * time.Millisecond, + shouldAllow: true, + description: "Response within timeout should be allowed", + }, + { + sleep: 100 * time.Millisecond, + shouldAllow: true, + description: "Response at half timeout should be allowed", + }, + { + // tracker hasn't updated conn for 250ms -> greater than 200ms timeout + sleep: 250 * time.Millisecond, + shouldAllow: false, + description: "Response after timeout should be dropped", + }, + } + + for _, cp := range checkPoints { + time.Sleep(cp.sleep) + + drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + require.Equal(t, cp.shouldAllow, !drop, cp.description) + + // If the connection should still be valid, verify it exists + if cp.shouldAllow { + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should still exist during valid window") + require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), + "LastSeen should be updated for valid responses") + } + } + + // Test invalid response packets (while connection is expired) + invalidCases := []struct { + name string + modifyFunc func(*layers.IPv4, *layers.UDP) + description string + }{ + { + name: "wrong source IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.SrcIP = net.ParseIP("100.10.0.101") + }, + description: "Response from wrong IP should be dropped", + }, + { + name: "wrong destination IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.DstIP = net.ParseIP("100.10.0.2") + }, + description: "Response to wrong IP should be dropped", + }, + { + name: "wrong source port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.SrcPort = 54 + }, + description: "Response from wrong port should be dropped", + }, + { + name: "wrong destination port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.DstPort = 51335 + }, + description: "Response to wrong port should be dropped", + }, + } + + // Create a new outbound connection for invalid tests + drop = manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Second outbound packet should not be dropped") + + for _, tc := range invalidCases { + t.Run(tc.name, func(t *testing.T) { + testIPv4 := *inboundIPv4 + testUDP := *inboundUDP + + tc.modifyFunc(&testIPv4, &testUDP) + + err = testUDP.SetNetworkLayerForChecksum(&testIPv4) + require.NoError(t, err) + + testBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(testBuf, opts, + &testIPv4, + &testUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + + // Verify the invalid packet is dropped + drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + require.True(t, drop, tc.description) + }) + } +}