diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 527a6badbd6..e3e644357e7 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -10,6 +10,7 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -94,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9305c0b5a15..b81d8bd3f5e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -39,6 +39,7 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" relayClient "github.com/netbirdio/netbird/relay/client" @@ -1219,7 +1220,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 8df033d91f2..128de8e020f 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -20,6 +20,7 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/signal/proto" @@ -133,7 +134,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } diff --git a/management/client/client_test.go b/management/client/client_test.go index 08300244293..8bd8af8d2aa 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -77,7 +78,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 2248b52d95a..3eb52eb9012 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -41,13 +41,20 @@ import ( "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" httpapi "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" ) @@ -266,7 +273,15 @@ var ( KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + userManager := users.NewManager(store) + settingsManager := settings.NewManager(store) + permissionsManager := permissions.NewManager(userManager, settingsManager) + resourcesManager := resources.NewManager(store, permissionsManager, accountManager) + routersManager := routers.NewManager(store, permissionsManager, accountManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager) + groupsManager := groups.NewManager(store, permissionsManager) + + httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -275,7 +290,7 @@ var ( ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 070cead4df3..e60b41b4ec1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -29,20 +29,15 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/netbirdio/netbird/management/server/networks" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -151,10 +146,7 @@ type AccountManager interface { GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error - GetNetworksManager() networks.Manager - GetUserManager() users.Manager - GetSettingsManager() settings.Manager - GetGroupsManager() groups.Manager + UpdateAccountPeers(ctx context.Context, accountID string) } type DefaultAccountManager struct { @@ -191,12 +183,6 @@ type DefaultAccountManager struct { integratedPeerValidator integrated_validator.IntegratedValidator metrics telemetry.AppMetrics - - groupsManager groups.Manager - networksManager networks.Manager - userManager users.Manager - settingsManager settings.Manager - permissionsManager permissions.Manager } // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. @@ -263,19 +249,11 @@ func BuildManager( integratedPeerValidator integrated_validator.IntegratedValidator, metrics telemetry.AppMetrics, ) (*DefaultAccountManager, error) { - userManager := users.NewManager(store) - settingsManager := settings.NewManager(store) - permissionsManager := permissions.NewManager(userManager, settingsManager) am := &DefaultAccountManager{ Store: store, geo: geo, peersUpdateManager: peersUpdateManager, idpManager: idpManager, - networksManager: networks.NewManager(store, permissionsManager), - groupsManager: groups.NewManager(store, permissionsManager), - userManager: userManager, - settingsManager: settingsManager, - permissionsManager: permissionsManager, ctx: context.Background(), cacheMux: sync.Mutex{}, cacheLoading: map[string]chan struct{}{}, @@ -440,7 +418,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers { - go am.updateAccountPeers(ctx, accountID) + go am.UpdateAccountPeers(ctx, accountID) } return updatedAccount, nil @@ -1417,7 +1395,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } } @@ -1684,7 +1662,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -1749,22 +1727,6 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) GetNetworksManager() networks.Manager { - return am.networksManager -} - -func (am *DefaultAccountManager) GetUserManager() users.Manager { - return am.userManager -} - -func (am *DefaultAccountManager) GetSettingsManager() settings.Manager { - return am.settingsManager -} - -func (am *DefaultAccountManager) GetGroupsManager() groups.Manager { - return am.groupsManager -} - // addAllGroup to account object if it doesn't exist func addAllGroup(account *types.Account) error { if len(account.Groups) == 0 { diff --git a/management/server/dns.go b/management/server/dns.go index 27c27dd4708..39dc11eb247 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -136,7 +136,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -185,7 +185,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t // areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups) if err != nil { return false, err } @@ -194,7 +194,7 @@ func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Stor return true, nil } - return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) + return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups) } // validateDNSSettings validates the DNS settings. diff --git a/management/server/group.go b/management/server/group.go index cd228af6528..d433a348551 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -130,7 +130,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -296,7 +296,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -337,7 +337,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -378,7 +378,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -419,7 +419,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -644,15 +644,15 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupI return false } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { +// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. +func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) if err != nil { return false, err } for _, group := range groups { - if group.HasPeers() { + if group.HasPeers() || group.HasResources() { return true, nil } } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index cd781a35b9f..d870935896f 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" internalStatus "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" @@ -30,8 +31,9 @@ import ( // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { - accountManager AccountManager - wgKey wgtypes.Key + accountManager AccountManager + settingsManager settings.Manager + wgKey wgtypes.Key proto.UnimplementedManagementServiceServer peersUpdateManager *PeersUpdateManager config *Config @@ -48,6 +50,7 @@ func NewServer( ctx context.Context, config *Config, accountManager AccountManager, + settingsManager settings.Manager, peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, @@ -100,6 +103,7 @@ func NewServer( // peerKey -> event channel peersUpdateManager: peersUpdateManager, accountManager: accountManager, + settingsManager: settingsManager, config: config, secretsManager: secretsManager, jwtValidator: jwtValidator, @@ -481,16 +485,20 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p } } - settings, err := s.accountManager.GetSettingsManager().GetSettings(ctx, peer.AccountID, userID) + settings, err := s.settingsManager.GetSettings(ctx, accountID, userID) if err != nil { log.WithContext(ctx).Errorf("failed to get settings for account %s and user %s: %v", accountID, userID, err) - return nil, mapError(ctx, err) + } + + routingPeerDNSResolutionEnabled := false + if settings != nil { + routingPeerDNSResolutionEnabled = settings.RoutingPeerDNSResolutionEnabled } // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, relayToken), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), settings.RoutingPeerDNSResolutionEnabled), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), routingPeerDNSResolutionEnabled), Checks: toProtocolChecks(ctx, postureChecks), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) @@ -688,7 +696,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p } } - settings, err := s.accountManager.GetSettingsManager().GetSettings(ctx, peer.AccountID, peer.UserID) + settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, peer.UserID) if err != nil { return status.Errorf(codes.Internal, "error handling request") } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 824a1cd8e05..3a169da9d9d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,6 +12,7 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroups "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/handlers/accounts" "github.com/netbirdio/netbird/management/server/http/handlers/dns" @@ -26,6 +27,9 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" + nbnetworks "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -39,7 +43,7 @@ type apiHandler struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -94,7 +98,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa routes.AddEndpoints(api.AccountManager, authCfg, router) dns.AddEndpoints(api.AccountManager, authCfg, router) events.AddEndpoints(api.AccountManager, authCfg, router) - networks.AddEndpoints(api.AccountManager.GetNetworksManager(), api.AccountManager.GetGroupsManager(), api.AccountManager.GetAccountIDFromToken, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index fff7a5f6da7..9875b139ca5 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -14,6 +14,8 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/status" @@ -22,17 +24,20 @@ import ( // handler is a handler that returns networks of the account type handler struct { - networksManager networks.Manager + networksManager networks.Manager + resourceManager resources.Manager + routerManager routers.Manager + groupsManager groups.Manager extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) claimsExtractor *jwtclaims.ClaimsExtractor } -func AddEndpoints(networksManager networks.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { - addRouterEndpoints(networksManager.GetRouterManager(), extractFromToken, authCfg, router) - addResourceEndpoints(networksManager.GetResourceManager(), groupsManager, extractFromToken, authCfg, router) +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + addRouterEndpoints(routerManager, extractFromToken, authCfg, router) + addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) - networksHandler := newHandler(networksManager, groupsManager, extractFromToken, authCfg) + networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, extractFromToken, authCfg) router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") @@ -40,9 +45,11 @@ func AddEndpoints(networksManager networks.Manager, groupsManager groups.Manager router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") } -func newHandler(networksManager networks.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { +func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { return &handler{ networksManager: networksManager, + resourceManager: resourceManager, + routerManager: routerManager, groupsManager: groupsManager, extractFromToken: extractFromToken, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -66,7 +73,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { return } - resourceIDs, err := h.networksManager.GetResourceManager().GetAllResourceIDsInAccount(r.Context(), accountID, userID) + resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -78,7 +85,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { return } - routers, err := h.networksManager.GetRouterManager().GetAllRoutersInAccount(r.Context(), accountID, userID) + routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -212,7 +219,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { } func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, networkID string) ([]string, []string, int, error) { - resources, err := h.networksManager.GetResourceManager().GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + resources, err := h.resourceManager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get resources in network: %w", err) } @@ -222,7 +229,7 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne resourceIDs = append(resourceIDs, resource.ID) } - routers, err := h.networksManager.GetRouterManager().GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + routers, err := h.routerManager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index f1d6de36198..c664237366a 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" @@ -439,7 +440,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index f0f83a237b9..40514ae14db 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" @@ -552,7 +553,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 55787868e36..042137b1b02 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,15 +13,11 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/netbirdio/netbird/management/server/networks" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) @@ -115,24 +111,8 @@ type MockAccountManager struct { DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error } -func (am *MockAccountManager) GetUserManager() users.Manager { - // TODO implement me - panic("implement me") -} - -func (am *MockAccountManager) GetNetworksManager() networks.Manager { - // TODO implement me - panic("implement me") -} - -func (am *MockAccountManager) GetSettingsManager() settings.Manager { - // TODO implement me - panic("implement me") -} - -func (am *MockAccountManager) GetGroupsManager() groups.Manager { - // TODO implement me - panic("implement me") +func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { + // do nothing } func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 19acdf1bae3..1a01c7a89ca 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups) if err != nil { return err } @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return newNSGroup.Copy(), nil @@ -143,7 +143,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -172,7 +172,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups) if err != nil { return err } @@ -190,7 +190,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -249,7 +249,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) if err != nil { return false, err } @@ -258,7 +258,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store return true, nil } - return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) + return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index ddc88b05f00..d5291d9dafb 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -7,7 +7,6 @@ import ( "github.com/rs/xid" "github.com/netbirdio/netbird/management/server/networks/resources" - "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" @@ -20,23 +19,19 @@ type Manager interface { GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error - GetResourceManager() resources.Manager - GetRouterManager() routers.Manager } type managerImpl struct { store store.Store permissionsManager permissions.Manager - routersManager routers.Manager resourcesManager resources.Manager } -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { +func NewManager(store store.Store, permissionsManager permissions.Manager, manager resources.Manager) Manager { return &managerImpl{ store: store, permissionsManager: permissionsManager, - routersManager: routers.NewManager(store, permissionsManager), - resourcesManager: resources.NewManager(store, permissionsManager), + resourcesManager: manager, } } @@ -130,11 +125,3 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) }) } - -func (m *managerImpl) GetResourceManager() resources.Manager { - return m.resourcesManager -} - -func (m *managerImpl) GetRouterManager() routers.Manager { - return m.routersManager -} diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go index 5fa2b17f188..af1ce1caeea 100644 --- a/management/server/networks/manager_test.go +++ b/management/server/networks/manager_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" @@ -21,8 +23,10 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) networks, err := manager.GetAllNetworks(ctx, accountID, userID) require.NoError(t, err) @@ -40,8 +44,10 @@ func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) networks, err := manager.GetAllNetworks(ctx, accountID, userID) require.Error(t, err) @@ -59,8 +65,10 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) networks, err := manager.GetNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -78,8 +86,10 @@ func Test_GetNetworkReturnsPermissionDenied(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) network, err := manager.GetNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -99,8 +109,10 @@ func Test_CreateNetworkSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) createdNetwork, err := manager.CreateNetwork(ctx, userID, network) require.NoError(t, err) @@ -120,8 +132,10 @@ func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) createdNetwork, err := manager.CreateNetwork(ctx, userID, network) require.Error(t, err) @@ -139,8 +153,10 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) err = manager.DeleteNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -157,8 +173,10 @@ func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) err = manager.DeleteNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -178,8 +196,10 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) require.NoError(t, err) @@ -200,8 +220,11 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) { t.Fatal(err) } t.Cleanup(cleanUp) + + am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + resourcesManager := resources.NewManager(s, permissionsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager) updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) require.Error(t, err) diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 907d926cd2e..e1f15c2c374 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" @@ -25,12 +26,14 @@ type Manager interface { type managerImpl struct { store store.Store permissionsManager permissions.Manager + accountManager s.AccountManager } -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, permissionsManager: permissionsManager, + accountManager: accountManager, } } @@ -94,12 +97,19 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return nil, fmt.Errorf("failed to create new network resource: %w", err) } - _, err = m.store.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + _, err = m.store.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) if err == nil { return nil, errors.New("resource already exists") } + + err = m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return nil, fmt.Errorf("failed to create network resource: %w", err) + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) - return resource, m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + return resource, nil } func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { @@ -150,7 +160,14 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc return nil, errors.New("new resource name already exists") } - return resource, m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + err = m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return nil, fmt.Errorf("failed to update network resource: %w", err) + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + + return resource, nil } func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { @@ -165,9 +182,16 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net unlock := m.store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resourceID) }) + if err != nil { + return fmt.Errorf("failed to delete network resource: %w", err) + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil } func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) error { diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index f053c56f055..e9ce8d28019 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" @@ -25,7 +26,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -44,7 +46,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -62,7 +65,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.NoError(t, err) @@ -80,7 +84,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.Error(t, err) @@ -101,7 +106,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -121,7 +127,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) @@ -146,7 +153,8 @@ func Test_CreateResourceSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(store, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(store, permissionsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.NoError(t, err) @@ -170,7 +178,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(store, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(store, permissionsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -195,7 +204,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(store, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(store, permissionsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -219,7 +229,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(store, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(store, permissionsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -247,7 +258,8 @@ func Test_UpdateResourceSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.NoError(t, err) @@ -277,7 +289,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -305,7 +318,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -333,7 +347,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -353,7 +368,8 @@ func Test_DeleteResourceSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -372,7 +388,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(store, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(store, permissionsManager, &am) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 8e6d0304362..2103beb06f6 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -7,6 +7,7 @@ import ( "github.com/rs/xid" + s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" @@ -25,12 +26,14 @@ type Manager interface { type managerImpl struct { store store.Store permissionsManager permissions.Manager + accountManager s.AccountManager } -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, permissionsManager: permissionsManager, + accountManager: accountManager, } } @@ -79,7 +82,14 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t router.ID = xid.New().String() - return router, m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + err = m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return nil, fmt.Errorf("failed to create network router: %w", err) + } + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil } func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { @@ -112,7 +122,14 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return nil, status.NewPermissionDeniedError() } - return router, m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + err = m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return nil, fmt.Errorf("failed to update network router: %w", err) + } + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil } func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { @@ -124,5 +141,12 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo return status.NewPermissionDeniedError() } - return m.store.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + err = m.store.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return fmt.Errorf("failed to delete network router: %w", err) + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil } diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 18b876b1c57..e650074cc17 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" @@ -24,7 +25,8 @@ func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -44,7 +46,8 @@ func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -65,7 +68,8 @@ func Test_GetRouterReturnsRouter(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -85,7 +89,8 @@ func Test_GetRouterReturnsPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) @@ -107,7 +112,8 @@ func Test_CreateRouterSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) createdRouter, err := manager.CreateRouter(ctx, userID, router) require.NoError(t, err) @@ -132,7 +138,8 @@ func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) createdRouter, err := manager.CreateRouter(ctx, userID, router) require.Error(t, err) @@ -153,7 +160,8 @@ func Test_DeleteRouterSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) require.NoError(t, err) @@ -172,7 +180,8 @@ func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) require.Error(t, err) @@ -193,7 +202,8 @@ func Test_UpdateRouterSuccessfully(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) updatedRouter, err := manager.UpdateRouter(ctx, userID, router) require.NoError(t, err) @@ -214,7 +224,8 @@ func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) { } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() - manager := NewManager(s, permissionsManager) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) updatedRouter, err := manager.UpdateRouter(ctx, userID, router) require.Error(t, err) diff --git a/management/server/peer.go b/management/server/peer.go index 80f4e08cc79..c2efc5edc01 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil @@ -273,7 +273,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return peer, nil @@ -353,7 +353,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -611,7 +611,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } if newGroupsAffectsPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -693,7 +693,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } if peerNotValid { @@ -839,7 +839,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged || (updated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -1004,9 +1004,9 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) } -// updateAccountPeers updates all peers that belong to an account. +// UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { +func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 133c9909093..29461ecd829 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -882,7 +882,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id) } duration := time.Since(start) diff --git a/management/server/policy.go b/management/server/policy.go index 8ae2f96d03d..45b3e93e697 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -84,7 +84,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return policy, nil @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -171,7 +171,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err } @@ -181,7 +181,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a } } - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.RuleGroups()) + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } // validatePolicy validates the policy and its rules. diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index c9329766bac..1690f8e339a 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -83,7 +83,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return postureChecks, nil @@ -182,7 +182,7 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.St for _, policy := range policies { if slices.Contains(policy.SourcePostureChecks, postureCheckID) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.RuleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups()) if err != nil { return false, err } diff --git a/management/server/route.go b/management/server/route.go index 49d76bc43f3..1eb51aea751 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -211,7 +211,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if am.isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -297,7 +297,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -329,7 +329,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if am.isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/types/group.go b/management/server/types/group.go index 7ba4b8656bc..462f702a98d 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -111,3 +111,8 @@ func (g *Group) RemoveResource(resource Resource) bool { } return false } + +// HasResources checks if the group has any resources. +func (g *Group) HasResources() bool { + return len(g.Resources) > 0 +} diff --git a/management/server/user.go b/management/server/user.go index 9fc2464de90..457721917ac 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -294,7 +294,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil @@ -640,7 +640,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -983,7 +983,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil } @@ -1091,7 +1091,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta {