diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000000..d207b1802b2 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.go text eol=lf diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index cc634610b4d..bd42a03f643 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -1,20 +1,36 @@ name: golangci-lint on: [pull_request] + +permissions: + contents: read + pull-requests: read + concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} cancel-in-progress: true + jobs: golangci: + strategy: + fail-fast: false + matrix: + os: [macos-latest, windows-latest, ubuntu-latest] name: lint - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + timeout-minutes: 15 steps: - name: Checkout code uses: actions/checkout@v3 - name: Install Go uses: actions/setup-go@v4 with: - go-version: "1.21.x" + go-version: "1.20.x" + cache: false - name: Install dependencies + if: matrix.os == 'ubuntu-latest' run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v3 \ No newline at end of file + uses: golangci/golangci-lint-action@v3 + with: + version: latest + args: --timeout=12m \ No newline at end of file diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index c56d655fbd3..c24ef6933ed 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -112,6 +112,27 @@ jobs: grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" + - name: Install modules + run: go mod tidy + + - name: Build management binary + working-directory: management + run: CGO_ENABLED=1 go build -o netbird-mgmt main.go + + - name: Build management docker image + working-directory: management + run: | + docker build -t netbirdio/management:latest . + + - name: Build signal binary + working-directory: signal + run: CGO_ENABLED=0 go build -o netbird-signal main.go + + - name: Build signal docker image + working-directory: signal + run: | + docker build -t netbirdio/signal:latest . + - name: run docker compose up working-directory: infrastructure_files run: | diff --git a/client/android/login.go b/client/android/login.go index ad334541ce6..afd61055f07 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { + _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + s, ok := gstatus.FromError(err) + if !ok { + return err + } + if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { supportsSSO = false err = nil } @@ -189,7 +193,7 @@ func (a *Auth) login(urlOpener URLOpener) error { } func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config) + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) if err != nil { return nil, err } diff --git a/client/cmd/login.go b/client/cmd/login.go index a5cc3215cfd..2ddab46f35c 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "runtime" "strings" "time" @@ -82,9 +81,10 @@ var loginCmd = &cobra.Command{ client := proto.NewDaemonServiceClient(conn) loginRequest := proto.LoginRequest{ - SetupKey: setupKey, - PreSharedKey: preSharedKey, - ManagementUrl: managementURL, + SetupKey: setupKey, + PreSharedKey: preSharedKey, + ManagementUrl: managementURL, + IsLinuxDesktopClient: isLinuxRunningDesktop(), } var loginErr error @@ -165,7 +165,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C } func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop()) if err != nil { return nil, err } @@ -195,51 +195,17 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) } - browserAuthMsg := "Please do the SSO login in your browser. \n" + + cmd.Println("Please do the SSO login in your browser. \n" + "If your browser didn't open automatically, use this URL to log in:\n\n" + - verificationURIComplete + " " + codeMsg - - setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys" - - authenticateUsingBrowser := func() { - cmd.Println(browserAuthMsg) - cmd.Println("") - if err := open.Run(verificationURIComplete); err != nil { - cmd.Println(setupKeyAuthMsg) - } - } - - switch runtime.GOOS { - case "windows", "darwin": - authenticateUsingBrowser() - case "linux": - if isLinuxRunningDesktop() { - authenticateUsingBrowser() - } else { - // If current flow is PKCE, it implies the server is anticipating the redirect to localhost. - // Devices lacking browser support are incompatible with this flow.Therefore, - // these devices will need to resort to setup keys instead. - if isPKCEFlow(verificationURIComplete) { - cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys") - } else { - cmd.Println(browserAuthMsg) - } - } + verificationURIComplete + " " + codeMsg) + cmd.Println("") + if err := open.Run(verificationURIComplete); err != nil { + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } -// isLinuxRunningDesktop checks if a Linux OS is running desktop environment. +// isLinuxRunningDesktop checks if a Linux OS is running desktop environment func isLinuxRunningDesktop() bool { return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" } - -// isPKCEFlow determines if the PKCE flow is active or not, -// by checking the existence of redirect_uri inside the verification URL. -func isPKCEFlow(verificationURL string) bool { - if verificationURL == "" { - return false - } - return strings.Contains(verificationURL, "redirect_uri") -} diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 678072f0bbe..6d47021dd09 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste return nil, nil } accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index d81b671d980..e17cbb18f7b 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -150,6 +150,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CustomDNSAddress: customDNSAddressConverted, RosenpassEnabled: rosenpassEnabled, + IsLinuxDesktopClient: isLinuxRunningDesktop(), } var loginErr error diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 753282d879b..048c0fd5042 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -93,7 +93,7 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) { // AddFiltering rule to the firewall // -// If comment is empty rule ID is used as comment +// Comment will be ignored because some system this feature is not supported func (m *Manager) AddFiltering( ip net.IP, protocol fw.Protocol, @@ -123,9 +123,6 @@ func (m *Manager) AddFiltering( ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal) ruleID := uuid.New().String() - if comment == "" { - comment = ruleID - } if ipsetName != "" { rs, rsExists := m.rulesets[ipsetName] @@ -157,8 +154,7 @@ func (m *Manager) AddFiltering( // this is new ipset so we need to create firewall rule for it } - specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal, - direction, action, comment, ipsetName) + specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) if direction == fw.RuleDirectionOUT { ok, err := client.Exists("filter", ChainOutputFilterName, specs...) @@ -283,7 +279,7 @@ func (m *Manager) AllowNetbird() error { fw.RuleDirectionIN, fw.ActionAccept, "", - "allow netbird interface traffic", + "", ) if err != nil { return fmt.Errorf("failed to allow netbird interface traffic: %w", err) @@ -296,7 +292,7 @@ func (m *Manager) AllowNetbird() error { fw.RuleDirectionOUT, fw.ActionAccept, "", - "allow netbird interface traffic", + "", ) return err } @@ -362,9 +358,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error { // filterRuleSpecs returns the specs of a filtering rule func (m *Manager) filterRuleSpecs( - table string, ip net.IP, protocol string, sPort, dPort string, - direction fw.RuleDirection, action fw.Action, comment string, - ipsetName string, + ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string, ) (specs []string) { matchByIP := true // don't use IP matching if IP is ip 0.0.0.0 @@ -398,8 +392,7 @@ func (m *Manager) filterRuleSpecs( if dPort != "" { specs = append(specs, "--dport", dPort) } - specs = append(specs, "-j", m.actionToStr(action)) - return append(specs, "-m", "comment", "--comment", comment) + return append(specs, "-j", m.actionToStr(action)) } // rawClient returns corresponding iptables client for the given ip diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 794fe0958e2..82adf91b9ec 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "runtime" log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" @@ -57,25 +58,45 @@ func (t TokenInfo) GetTokenToUse() string { return t.AccessToken } -// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration. -func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { - log.Debug("loading pkce authorization flow info") +// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration +// +// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow, +// and if that also fails, the authentication process is deemed unsuccessful +// +// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow +func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) { + if runtime.GOOS == "linux" && !isLinuxDesktopClient { + return authenticateWithDeviceCodeFlow(ctx, config) + } - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) - if err == nil { - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + if err != nil { + // fallback to device code flow + log.Debugf("failed to initialize pkce authentication with error: %v\n", err) + log.Debug("falling back to device code flow") + return authenticateWithDeviceCodeFlow(ctx, config) } + return pkceFlow, nil +} - log.Debugf("loading pkce authorization flow info failed with error: %v", err) - log.Debugf("falling back to device authorization flow info") +// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow +func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { + pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + if err != nil { + return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) + } + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) +} +// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow +func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { s, ok := gstatus.FromError(err) if ok && s.Code() == codes.NotFound { return nil, fmt.Errorf("no SSO provider returned from management. " + - "If you are using hosting Netbird see documentation at " + - "https://github.com/netbirdio/netbird/tree/main/management for details") + "Please proceed with setting up this device using setup keys " + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } else if ok && s.Code() == codes.Unimplemented { return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ "please update your server or use Setup Keys to login", config.ManagementURL) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index d15d493738a..32f5383d36e 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -12,7 +12,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" log "github.com/sirupsen/logrus" @@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string { } // RequestAuthInfo requests a authorization code login flow information. -func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) { +func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) { state, err := randomBytesInHex(24) if err != nil { return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err) @@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( tokenChan := make(chan *oauth2.Token, 1) errChan := make(chan error, 1) - go p.startServer(tokenChan, errChan) + parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) + if err != nil { + return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err) + } + + server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} + defer func() { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + log.Errorf("failed to close the server: %v", err) + } + }() + + go p.startServer(server, tokenChan, errChan) select { case <-ctx.Done(): return TokenInfo{}, ctx.Err() case token := <-tokenChan: - return p.handleOAuthToken(token) + return p.parseOAuthToken(token) case err := <-errChan: return TokenInfo{}, err } } -func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) { - var wg sync.WaitGroup - - parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) - if err != nil { - errChan <- fmt.Errorf("failed to parse redirect URL: %v", err) - return - } - - server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} - go func() { - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- err - } - }() - - wg.Add(1) - http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - defer wg.Done() - - tokenValidatorFunc := func() (*oauth2.Token, error) { - query := req.URL.Query() - - if authError := query.Get(queryError); authError != "" { - authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) - } - - // Prevent timing attacks on state - if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") - } - - code := query.Get(queryCode) - if code == "" { - return nil, fmt.Errorf("missing code") - } - - return p.oAuthConfig.Exchange( - req.Context(), - code, - oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), - ) - } - - token, err := tokenValidatorFunc() +func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + token, err := p.handleRequest(req) if err != nil { renderPKCEFlowTmpl(w, err) errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err) @@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC tokenChan <- token }) - wg.Wait() - if err := server.Shutdown(context.Background()); err != nil { - log.Errorf("error while shutting down pkce flow server: %v", err) + server.Handler = mux + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err + } +} + +func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) { + query := req.URL.Query() + + if authError := query.Get(queryError); authError != "" { + authErrorDesc := query.Get(queryErrorDesc) + return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) } + + // Prevent timing attacks on the state + if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { + return nil, fmt.Errorf("invalid state") + } + + code := query.Get(queryCode) + if code == "" { + return nil, fmt.Errorf("missing code") + } + + return p.oAuthConfig.Exchange( + req.Context(), + code, + oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), + ) } -func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) { +func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) { tokenInfo := TokenInfo{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, @@ -200,7 +197,13 @@ func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo tokenInfo.IDToken = idToken } - if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil { + // if a provider doesn't support an audience, use the Client ID for token verification + audience := p.providerConfig.Audience + if audience == "" { + audience = p.providerConfig.ClientID + } + + if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil { return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) } diff --git a/client/internal/auth/util.go b/client/internal/auth/util.go index 33a0e6e35ad..31c81d7019c 100644 --- a/client/internal/auth/util.go +++ b/client/internal/auth/util.go @@ -7,7 +7,6 @@ import ( "encoding/json" "fmt" "io" - "reflect" "strings" ) @@ -44,15 +43,14 @@ func isValidAccessToken(token string, audience string) error { } // Audience claim of JWT can be a string or an array of strings - typ := reflect.TypeOf(claims.Audience) - switch typ.Kind() { - case reflect.String: - if claims.Audience == audience { + switch aud := claims.Audience.(type) { + case string: + if aud == audience { return nil } - case reflect.Slice: - for _, aud := range claims.Audience.([]interface{}) { - if audience == aud { + case []interface{}: + for _, audItem := range aud { + if audStr, ok := audItem.(string); ok && audStr == audience { return nil } } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9f17ff36ba3..ea4a23a8de6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1049,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { return nil, "", err } accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { return nil, "", err } diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go index 2efbae97b5f..a35dacc77be 100644 --- a/client/internal/pkce_auth.go +++ b/client/internal/pkce_auth.go @@ -106,9 +106,6 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error { errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" - if config.Audience == "" { - return fmt.Errorf(errorMSGFormat, "Audience") - } if config.ClientID == "" { return fmt.Errorf(errorMSGFormat, "Client ID") } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 87ce12052ea..b1638a74c74 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v4.23.4 // source: daemon.proto package proto @@ -40,8 +40,9 @@ type LoginRequest struct { // cleanNATExternalIPs clean map list of external IPs. // This is needed because the generated code // omits initialized empty slices due to omitempty tags - CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` - CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` + CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` + CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` + IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` RosenpassEnabled bool `protobuf:"varint,8,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` } @@ -126,6 +127,13 @@ func (x *LoginRequest) GetCustomDNSAddress() []byte { return nil } +func (x *LoginRequest) GetIsLinuxDesktopClient() bool { + if x != nil { + return x.IsLinuxDesktopClient + } + return false +} + func (x *LoginRequest) GetRosenpassEnabled() bool { if x != nil { return x.RosenpassEnabled @@ -1052,6 +1060,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xc2, 0x02, 0x0a, 0x0c, 0x4c, 0x6f, + 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xca, 0x02, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 0fed0c238d4..ae85a21c04f 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -51,8 +51,9 @@ message LoginRequest { bytes customDNSAddress = 7; - bool rosenpassEnabled = 8; + bool isLinuxDesktopClient = 8; + bool rosenpassEnabled = 9; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index d5b3ebb3b85..c3090768585 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -211,7 +211,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro state.Set(internal.StatusConnecting) if msg.SetupKey == "" { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsLinuxDesktopClient) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err diff --git a/client/ui/build-ui-linux.sh b/client/ui/build-ui-linux.sh new file mode 100644 index 00000000000..eab08214dbd --- /dev/null +++ b/client/ui/build-ui-linux.sh @@ -0,0 +1,5 @@ +#!/bin/bash +sudo apt update +sudo apt remove gir1.2-appindicator3-0.1 +sudo apt install -y libayatana-appindicator3-dev +go build \ No newline at end of file diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index e6b4394e8d2..9c7685db03d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -202,9 +202,10 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } _, err = client.Login(s.ctx, &proto.LoginRequest{ - ManagementUrl: s.iMngURL.Text, - AdminURL: s.iAdminURL.Text, - PreSharedKey: s.iPreSharedKey.Text, + ManagementUrl: s.iMngURL.Text, + AdminURL: s.iAdminURL.Text, + PreSharedKey: s.iPreSharedKey.Text, + IsLinuxDesktopClient: runtime.GOOS == "linux", }) if err != nil { log.Errorf("login to management URL: %v", err) @@ -233,7 +234,9 @@ func (s *serviceClient) login() error { return err } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{}) + loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + IsLinuxDesktopClient: runtime.GOOS == "linux", + }) if err != nil { log.Errorf("login to management URL with: %v", err) return err diff --git a/go.mod b/go.mod index b01140af844..19ae32e56a4 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( require ( cunicu.li/go-rosenpass v0.4.0 fyne.io/fyne/v2 v2.1.4 + github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/c-robinson/iplib v1.0.3 github.com/cilium/ebpf v0.11.0 github.com/coreos/go-iptables v0.7.0 diff --git a/go.sum b/go.sum index 5c69370123a..46daf967781 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbt github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo= +github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 h1:pami0oPhVosjOu/qRHepRmdjD6hGILF7DBr+qQZeP10= github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2/go.mod h1:jNIx5ykW1MroBuaTja9+VpglmaJOUzezumfhLlER3oY= github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index 4bcec128df9..f610a9691bc 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -46,6 +46,14 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken} # PKCE authorization flow NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"} NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false} +NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE + +# Dashboard + +# The default setting is to transmit the audience to the IDP during authorization. However, +# if your IDP does not have this capability, you can turn this off by setting it to false. +NETBIRD_DASH_AUTH_USE_AUDIENCE=${NETBIRD_DASH_AUTH_USE_AUDIENCE:-true} +NETBIRD_DASH_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE # exports export NETBIRD_DOMAIN @@ -86,4 +94,7 @@ export NETBIRD_TOKEN_SOURCE export NETBIRD_AUTH_DEVICE_AUTH_SCOPE export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT -export NETBIRD_AUTH_PKCE_USE_ID_TOKEN \ No newline at end of file +export NETBIRD_AUTH_PKCE_USE_ID_TOKEN +export NETBIRD_AUTH_PKCE_AUDIENCE +export NETBIRD_DASH_AUTH_USE_AUDIENCE +export NETBIRD_DASH_AUTH_AUDIENCE \ No newline at end of file diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index 4e568b2fef4..3db79906827 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -164,6 +164,12 @@ done export NETBIRD_AUTH_PKCE_REDIRECT_URLS=${REDIRECT_URLS%,} +# Remove audience for providers that do not support it +if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then + export NETBIRD_DASH_AUTH_AUDIENCE=none + export NETBIRD_AUTH_PKCE_AUDIENCE= +fi + env | grep NETBIRD envsubst docker-compose.yml diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b70e4cb6e0b..c5ea3ae56b0 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -12,7 +12,7 @@ services: - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT # OIDC - - AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE + - AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID - AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 6d371081693..cab471df646 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -12,7 +12,7 @@ services: - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT # OIDC - - AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE + - AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID - AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY @@ -20,6 +20,7 @@ services: - AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES - AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI - AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI + - NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE # SSL - NGINX_SSL_PORT=443 # Letsencrypt diff --git a/infrastructure_files/management.json.tmpl b/infrastructure_files/management.json.tmpl index e74b93b32ac..e185faa6ebd 100644 --- a/infrastructure_files/management.json.tmpl +++ b/infrastructure_files/management.json.tmpl @@ -62,7 +62,7 @@ }, "PKCEAuthorizationFlow": { "ProviderConfig": { - "Audience": "$NETBIRD_AUTH_AUDIENCE", + "Audience": "$NETBIRD_AUTH_PKCE_AUDIENCE", "ClientID": "$NETBIRD_AUTH_CLIENT_ID", "ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET", "AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT", diff --git a/infrastructure_files/setup.env.example b/infrastructure_files/setup.env.example index 9b03ccd2d77..f9ad638465f 100644 --- a/infrastructure_files/setup.env.example +++ b/infrastructure_files/setup.env.example @@ -8,6 +8,9 @@ NETBIRD_DOMAIN="" # e.g., https://example.eu.auth0.com/.well-known/openid-configuration # ------------------------------------------- NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="" +# The default setting is to transmit the audience to the IDP during authorization. However, +# if your IDP does not have this capability, you can turn this off by setting it to false. +#NETBIRD_DASH_AUTH_USE_AUDIENCE=false NETBIRD_AUTH_AUDIENCE="" # e.g. netbird-client NETBIRD_AUTH_CLIENT_ID="" diff --git a/management/client/client_test.go b/management/client/client_test.go index deef573296e..86c598adbd9 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -61,7 +61,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 5c38167151f..f85cf225eaf 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity/sqlite" httpapi "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/idp" @@ -142,12 +143,22 @@ var ( if disableSingleAccMode { mgmtSingleAccModeDomain = "" } - eventStore, err := sqlite.NewSQLiteStore(config.Datadir) + eventStore, key, err := initEventStore(config.Datadir, config.DataStoreEncryptionKey) if err != nil { - return err + return fmt.Errorf("failed to initialize database: %s", err) + } + + if config.DataStoreEncryptionKey != key { + log.Infof("update config with activity store key") + config.DataStoreEncryptionKey = key + err := updateMgmtConfig(mgmtConfig, config) + if err != nil { + return fmt.Errorf("failed to write out store encryption key: %s", err) + } } + accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore) + dnsDomain, eventStore, userDeleteFromIDPEnabled) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } @@ -287,6 +298,20 @@ var ( } ) +func initEventStore(dataDir string, key string) (activity.Store, string, error) { + var err error + if key == "" { + log.Debugf("generate new activity store encryption key") + key, err = sqlite.GenerateKey() + if err != nil { + return nil, "", err + } + } + store, err := sqlite.NewSQLiteStore(dataDir, key) + return store, key, err + +} + func notifyStop(msg string) { select { case stopCh <- 1: @@ -440,6 +465,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { return loadedConfig, err } +func updateMgmtConfig(path string, config *server.Config) error { + return util.DirectWriteJson(path, config) +} + // OIDCConfigResponse used for parsing OIDC config response type OIDCConfigResponse struct { Issuer string `json:"issuer"` diff --git a/management/cmd/root.go b/management/cmd/root.go index a149841c50b..2080a6b29f2 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -24,6 +24,7 @@ var ( disableMetrics bool disableSingleAccMode bool idpSignKeyRefreshEnabled bool + userDeleteFromIDPEnabled bool rootCmd = &cobra.Command{ Use: "netbird-mgmt", @@ -56,6 +57,7 @@ func init() { mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") + mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account") rootCmd.MarkFlagRequired("config") //nolint rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") diff --git a/management/server/account.go b/management/server/account.go index 4c707af3aff..1f8fc497be5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -62,12 +62,9 @@ type AccountManager interface { GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) MarkPATUsed(tokenID string) error GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) - AccountExists(accountId string) (*bool, error) - GetPeerByKey(peerKey string) (*Peer, error) GetPeers(accountID, userID string) ([]*Peer, error) MarkPeerConnected(peerKey string, connected bool) error - DeletePeer(accountID, peerID, userID string) (*Peer, error) - GetPeerByIP(accountId string, peerIP string) (*Peer, error) + DeletePeer(accountID, peerID, userID string) error UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error) GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) @@ -80,26 +77,22 @@ type AccountManager interface { GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) SaveGroup(accountID, userID string, group *Group) error - UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) DeleteGroup(accountId, userId, groupID string) error ListGroups(accountId string) ([]*Group, error) GroupAddPeer(accountId, groupID, peerID string) error - GroupDeletePeer(accountId, groupID, peerKey string) error - GroupListPeers(accountId, groupID string) ([]*Peer, error) + GroupDeletePeer(accountId, groupID, peerID string) error GetPolicy(accountID, policyID, userID string) (*Policy, error) SavePolicy(accountID, userID string, policy *Policy) error DeletePolicy(accountID, policyID, userID string) error ListPolicies(accountID, userID string) ([]*Policy, error) GetRoute(accountID, routeID, userID string) (*route.Route, error) - CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) + CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) SaveRoute(accountID, userID string, route *route.Route) error - UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) DeleteRoute(accountID, routeID, userID string) error ListRoutes(accountID, userID string) ([]*route.Route, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroup(accountID, nsGroupID, userID string) error ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string @@ -133,6 +126,9 @@ type DefaultAccountManager struct { // dnsDomain is used for peer resolution. This is appended to the peer's name dnsDomain string peerLoginExpiry Scheduler + + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account + userDeleteFromIDPEnabled bool } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -204,7 +200,7 @@ type UserInfo struct { // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Route { - routes, peerDisabledRoutes := a.getEnabledAndDisabledRoutesByPeer(peerID) + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) peerRoutesMembership := make(lookupMap) for _, r := range append(routes, peerDisabledRoutes...) { peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{} @@ -212,7 +208,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Rout groupListMap := a.getPeerGroups(peerID) for _, peer := range aclPeers { - activeRoutes, _ := a.getEnabledAndDisabledRoutesByPeer(peer.ID) + activeRoutes, _ := a.getRoutingPeerRoutes(peer.ID) groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) routes = append(routes, filteredRoutes...) @@ -248,29 +244,63 @@ func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap looku return filteredRoutes } -// getEnabledAndDisabledRoutesByPeer returns the enabled and disabled lists of routes that belong to a peer. +// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Route, []*route.Route) { - var enabledRoutes []*route.Route - var disabledRoutes []*route.Route +// If the given is not a routing peer, then the lists are empty. +func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + + peer := a.GetPeer(peerID) + if peer == nil { + log.Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + return enabledRoutes, disabledRoutes + } + + // currently we support only linux routing peers + if peer.Meta.GoOS != "linux" { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[string]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + for _, r := range a.Routes { - if r.Peer == peerID { - // We need to set Peer.Key instead of Peer.ID because this object will be sent to agents as part of a network map. - // Ideally we should have a separate field for that, but fine for now. - peer := a.GetPeer(peerID) - if peer == nil { - log.Errorf("route %s has peer %s that doesn't exist under account %s", r.ID, peerID, a.Id) + for _, groupID := range r.PeerGroups { + group := a.GetGroup(groupID) + if group == nil { + log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) continue } - raut := r.Copy() - raut.Peer = peer.Key - if r.Enabled { - enabledRoutes = append(enabledRoutes, raut) - continue + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := r.Copy() + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map + takeRoute(newPeerRoute, id) + break } - disabledRoutes = append(disabledRoutes, raut) + } + if r.Peer == peerID { + takeRoute(r.Copy(), peerID) } } + return enabledRoutes, disabledRoutes } @@ -286,17 +316,6 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { return routes } -// GetPeerByIP returns peer by it's IP if exists under account or nil otherwise -func (a *Account) GetPeerByIP(peerIP string) *Peer { - for _, peer := range a.Peers { - if peerIP == peer.IP.String() { - return peer - } - } - - return nil -} - // GetGroup returns a group by ID if exists, nil otherwise func (a *Account) GetGroup(groupID string) *Group { return a.Groups[groupID] @@ -316,7 +335,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { } peersToConnect = append(peersToConnect, p) } - // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. + routesUpdate := a.getRoutesToSync(peerID, peersToConnect) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) @@ -577,8 +596,8 @@ func (a *Account) Copy() *Account { } routes := map[string]*route.Route{} - for id, route := range a.Routes { - routes[id] = route.Copy() + for id, r := range a.Routes { + routes[id] = r.Copy() } nsGroups := map[string]*nbdns.NameServerGroup{} @@ -738,18 +757,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, - singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, + singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ - Store: store, - peersUpdateManager: peersUpdateManager, - idpManager: idpManager, - ctx: context.Background(), - cacheMux: sync.Mutex{}, - cacheLoading: map[string]chan struct{}{}, - dnsDomain: dnsDomain, - eventStore: eventStore, - peerLoginExpiry: NewDefaultScheduler(), + Store: store, + peersUpdateManager: peersUpdateManager, + idpManager: idpManager, + ctx: context.Background(), + cacheMux: sync.Mutex{}, + cacheLoading: map[string]chan struct{}{}, + dnsDomain: dnsDomain, + eventStore: eventStore, + peerLoginExpiry: NewDefaultScheduler(), + userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -874,33 +894,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() return account.GetNextPeerExpiration() } + expiredPeers := account.GetExpiredPeers() var peerIDs []string - for _, peer := range account.GetExpiredPeers() { - if peer.Status.LoginExpired { - continue - } + for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) - peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) - if err != nil { - log.Errorf("failed saving peer status while expiring peer %s", peer.ID) - return account.GetNextPeerExpiration() - } - am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain())) } log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) - if len(peerIDs) != 0 { - // this will trigger peer disconnect from the management service - am.peersUpdateManager.CloseChannels(peerIDs) - err = am.updateAccountPeers(account) - if err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", accountID) - return account.GetNextPeerExpiration() - } + if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextPeerExpiration() } + return account.GetNextPeerExpiration() } } @@ -941,6 +947,27 @@ func (am *DefaultAccountManager) warmupIDPCache() error { return err } + // If the Identity Provider does not support writing AppMetadata, + // in cases like this, we expect it to return all users in an "unset" field. + // We iterate over the users in the "unset" field, look up their AccountID in our store, and + // update their AppMetadata with the AccountID. + if unsetData, ok := userData[idp.UnsetAccountID]; ok { + for _, user := range unsetData { + accountID, err := am.Store.GetAccountByUser(user.ID) + if err == nil { + data := userData[accountID.Id] + if data == nil { + data = make([]*idp.UserData, 0, 1) + } + + user.AppMetadata.WTAccountID = accountID.Id + + userData[accountID.Id] = append(data, user) + } + } + } + delete(userData, idp.UnsetAccountID) + for accountID, users := range userData { err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) if err != nil { @@ -1007,7 +1034,36 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) { log.Debugf("account %s not found in cache, reloading", accountID) - return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID)) + accountIDString := fmt.Sprintf("%v", accountID) + + account, err := am.Store.GetAccount(accountIDString) + if err != nil { + return nil, err + } + + userData, err := am.idpManager.GetAccount(accountIDString) + if err != nil { + return nil, err + } + + dataMap := make(map[string]*idp.UserData, len(userData)) + for _, datum := range userData { + dataMap[datum.ID] = datum + } + + matchedUserData := make([]*idp.UserData, 0) + for _, user := range account.Users { + if user.IsServiceUser { + continue + } + datum, ok := dataMap[user.Id] + if !ok { + log.Warnf("user %s not found in IDP", user.Id) + continue + } + matchedUserData = append(matchedUserData, datum) + } + return matchedUserData, nil } func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) { @@ -1256,7 +1312,6 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { - unlock := am.Store.AcquireGlobalLock() user, err := am.Store.GetUserByTokenID(tokenID) if err != nil { @@ -1268,8 +1323,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { return err } - unlock() - unlock = am.Store.AcquireAccountLock(account.Id) + unlock := am.Store.AcquireAccountLock(account.Id) defer unlock() account, err = am.Store.GetAccountByUser(user.Id) @@ -1396,9 +1450,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat if err := am.Store.SaveAccount(account); err != nil { log.Errorf("failed to save account: %v", err) } else { - if err := am.updateAccountPeers(account); err != nil { - log.Errorf("failed updating account peers while updating user %s", account.Id) - } + am.updateAccountPeers(account) for _, g := range addNewGroups { if group := account.GetGroup(g); group != nil { am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser, @@ -1509,26 +1561,6 @@ func isDomainValid(domain string) bool { return re.Match([]byte(domain)) } -// AccountExists checks whether account exists (returns true) or not (returns false) -func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - var res bool - _, err := am.Store.GetAccount(accountID) - if err != nil { - if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - res = false - return &res, nil - } else { - return nil, err - } - } - - res = true - return &res, nil -} - // GetDNSDomain returns the configured dnsDomain func (am *DefaultAccountManager) GetDNSDomain() string { return am.dnsDomain @@ -1605,19 +1637,3 @@ func newAccountWithId(accountID, userID, domain string) *Account { } return acc } - -func removeFromList(inputList []string, toRemove []string) []string { - toRemoveMap := make(map[string]struct{}) - for _, item := range toRemove { - toRemoveMap[item] = struct{}{} - } - - var resultList []string - for _, item := range inputList { - _, ok := toRemoveMap[item] - if !ok { - resultList = append(resultList, item) - } - } - return resultList -} diff --git a/management/server/account_test.go b/management/server/account_test.go index 64fd90524e6..331df2017dd 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -706,30 +706,6 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string) return account, nil } -func TestAccountManager_AccountExists(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - return - } - - expectedId := "test_account" - userId := "account_creator" - _, err = createAccount(manager, expectedId, userId, "") - if err != nil { - t.Fatal(err) - } - - exists, err := manager.AccountExists(expectedId) - if err != nil { - t.Fatal(err) - } - - if !*exists { - t.Errorf("expected account to exist after creation, got false") - } -} - func TestAccountManager_GetAccount(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -1062,7 +1038,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if _, err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil { + if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil { t.Errorf("delete peer: %v", err) return } @@ -1129,7 +1105,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - _, err = manager.DeletePeer(account.Id, peerKey, userID) + err = manager.DeletePeer(account.Id, peerKey, userID) if err != nil { return } @@ -1261,7 +1237,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { } account := &Account{ Peers: map[string]*Peer{ - "peer-1": {Key: "peer-1"}, "peer-2": {Key: "peer-2"}, "peer-3": {Key: "peer-1"}, + "peer-1": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}}, }, Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[string]*route.Route{ @@ -1385,8 +1361,9 @@ func TestAccount_Copy(t *testing.T) { }, Routes: map[string]*route.Route{ "route1": { - ID: "route1", - Groups: []string{"group1"}, + ID: "route1", + PeerGroups: []string{}, + Groups: []string{"group1"}, }, }, NameServerGroups: map[string]*nbdns.NameServerGroup{ @@ -2063,7 +2040,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4de667dedab..ce36f520fca 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -104,6 +104,8 @@ const ( UserBlocked // UserUnblocked indicates that a user unblocked another user UserUnblocked + // UserDeleted indicates that a user deleted another user + UserDeleted // GroupDeleted indicates that a user deleted group GroupDeleted // UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login @@ -162,6 +164,7 @@ var activityMap = map[Activity]Code{ ServiceUserDeleted: {"Service user deleted", "service.user.delete"}, UserBlocked: {"User blocked", "user.block"}, UserUnblocked: {"User unblocked", "user.unblock"}, + UserDeleted: {"User deleted", "user.delete"}, GroupDeleted: {"Group deleted", "group.delete"}, UserLoggedInPeer: {"User logged in peer", "user.peer.login"}, PeerLoginExpired: {"Peer login expired", "peer.login.expire"}, diff --git a/management/server/activity/event.go b/management/server/activity/event.go index 17ec4a0b0b5..f212f5b21b3 100644 --- a/management/server/activity/event.go +++ b/management/server/activity/event.go @@ -18,10 +18,15 @@ type Event struct { ID uint64 // InitiatorID is the ID of an object that initiated the event (e.g., a user) InitiatorID string + // InitiatorName is the name of an object that initiated the event. + InitiatorName string + // InitiatorEmail is the email address of an object that initiated the event. + InitiatorEmail string // TargetID is the ID of an object that was effected by the event (e.g., a peer) TargetID string // AccountID is the ID of an account where the event happened AccountID string + // Meta of the event, e.g. deleted peer information like name, IP, etc Meta map[string]any } @@ -35,12 +40,14 @@ func (e *Event) Copy() *Event { } return &Event{ - Timestamp: e.Timestamp, - Activity: e.Activity, - ID: e.ID, - InitiatorID: e.InitiatorID, - TargetID: e.TargetID, - AccountID: e.AccountID, - Meta: meta, + Timestamp: e.Timestamp, + Activity: e.Activity, + ID: e.ID, + InitiatorID: e.InitiatorID, + InitiatorName: e.InitiatorName, + InitiatorEmail: e.InitiatorEmail, + TargetID: e.TargetID, + AccountID: e.AccountID, + Meta: meta, } } diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go new file mode 100644 index 00000000000..cf4dda7463c --- /dev/null +++ b/management/server/activity/sqlite/crypt.go @@ -0,0 +1,81 @@ +package sqlite + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" +) + +var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} + +type FieldEncrypt struct { + block cipher.Block +} + +func GenerateKey() (string, error) { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + return "", err + } + readableKey := base64.StdEncoding.EncodeToString(key) + return readableKey, nil +} + +func NewFieldEncrypt(key string) (*FieldEncrypt, error) { + binKey, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(binKey) + if err != nil { + return nil, err + } + ec := &FieldEncrypt{ + block: block, + } + + return ec, nil +} + +func (ec *FieldEncrypt) Encrypt(payload string) string { + plainText := pkcs5Padding([]byte(payload)) + cipherText := make([]byte, len(plainText)) + cbc := cipher.NewCBCEncrypter(ec.block, iv) + cbc.CryptBlocks(cipherText, plainText) + return base64.StdEncoding.EncodeToString(cipherText) +} + +func (ec *FieldEncrypt) Decrypt(data string) (string, error) { + cipherText, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return "", err + } + cbc := cipher.NewCBCDecrypter(ec.block, iv) + cbc.CryptBlocks(cipherText, cipherText) + payload, err := pkcs5UnPadding(cipherText) + if err != nil { + return "", err + } + + return string(payload), nil +} + +func pkcs5Padding(ciphertext []byte) []byte { + padding := aes.BlockSize - len(ciphertext)%aes.BlockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padText...) +} + +func pkcs5UnPadding(src []byte) ([]byte, error) { + srcLen := len(src) + paddingLen := int(src[srcLen-1]) + if paddingLen >= srcLen || paddingLen > aes.BlockSize { + return nil, fmt.Errorf("padding size error") + } + return src[:srcLen-paddingLen], nil +} diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go new file mode 100644 index 00000000000..efa740921d3 --- /dev/null +++ b/management/server/activity/sqlite/crypt_test.go @@ -0,0 +1,63 @@ +package sqlite + +import ( + "testing" +) + +func TestGenerateKey(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + ee, err := NewFieldEncrypt(key) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + encrypted := ee.Encrypt(testData) + if encrypted == "" { + t.Fatalf("invalid encrypted text") + } + + decrypted, err := ee.Decrypt(encrypted) + if err != nil { + t.Fatalf("failed to decrypt data: %s", err) + } + + if decrypted != testData { + t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) + } +} + +func TestCorruptKey(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + ee, err := NewFieldEncrypt(key) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + encrypted := ee.Encrypt(testData) + if encrypted == "" { + t.Fatalf("invalid encrypted text") + } + + newKey, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + + ee, err = NewFieldEncrypt(newKey) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + res, _ := ee.Decrypt(encrypted) + if res == testData { + t.Fatalf("incorrect decryption, the result is: %s", res) + } +} diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index a4c85cf6057..6af4d4d8dbb 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -3,14 +3,14 @@ package sqlite import ( "database/sql" "encoding/json" - - "github.com/netbirdio/netbird/management/server/activity" - - // sqlite driver + "fmt" "path/filepath" "time" _ "github.com/mattn/go-sqlite3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/activity" ) const ( @@ -25,69 +25,122 @@ const ( "meta TEXT," + " target_id TEXT);" - selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + - " FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;" - selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" + - " FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;" + creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);` + + selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta + FROM events + LEFT JOIN deleted_users i ON events.initiator_id = i.id + LEFT JOIN deleted_users t ON events.target_id = t.id + WHERE account_id = ? + ORDER BY timestamp DESC LIMIT ? OFFSET ?;` + + selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta + FROM events + LEFT JOIN deleted_users i ON events.initiator_id = i.id + LEFT JOIN deleted_users t ON events.target_id = t.id + WHERE account_id = ? + ORDER BY timestamp ASC LIMIT ? OFFSET ?;` + insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + "VALUES(?, ?, ?, ?, ?, ?)" + + insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)` ) // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { - db *sql.DB + db *sql.DB + fieldEncrypt *FieldEncrypt + insertStatement *sql.Stmt selectAscStatement *sql.Stmt selectDescStatement *sql.Stmt + deleteUserStmt *sql.Stmt } // NewSQLiteStore creates a new Store with an event table if not exists. -func NewSQLiteStore(dataDir string) (*Store, error) { +func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { dbFile := filepath.Join(dataDir, eventSinkDB) db, err := sql.Open("sqlite3", dbFile) if err != nil { return nil, err } + crypt, err := NewFieldEncrypt(encryptionKey) + if err != nil { + _ = db.Close() + return nil, err + } + _, err = db.Exec(createTableQuery) if err != nil { + _ = db.Close() + return nil, err + } + + _, err = db.Exec(creatTableDeletedUsersQuery) + if err != nil { + _ = db.Close() + return nil, err + } + + err = updateDeletedUsersTable(db) + if err != nil { + _ = db.Close() return nil, err } insertStmt, err := db.Prepare(insertQuery) if err != nil { + _ = db.Close() return nil, err } selectDescStmt, err := db.Prepare(selectDescQuery) if err != nil { + _ = db.Close() return nil, err } selectAscStmt, err := db.Prepare(selectAscQuery) if err != nil { + _ = db.Close() + return nil, err + } + + deleteUserStmt, err := db.Prepare(insertDeleteUserQuery) + if err != nil { + _ = db.Close() return nil, err } - return &Store{ + s := &Store{ db: db, + fieldEncrypt: crypt, insertStatement: insertStmt, selectDescStatement: selectDescStmt, selectAscStatement: selectAscStmt, - }, nil + deleteUserStmt: deleteUserStmt, + } + + return s, nil } -func processResult(result *sql.Rows) ([]*activity.Event, error) { +func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { events := make([]*activity.Event, 0) for result.Next() { var id int64 var operation activity.Activity var timestamp time.Time var initiator string + var initiatorName *string + var initiatorEmail *string var target string + var targetUserName *string + var targetEmail *string var account string var jsonMeta string - err := result.Scan(&id, &operation, ×tamp, &initiator, &target, &account, &jsonMeta) + err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorName, &initiatorEmail, &target, &targetUserName, &targetEmail, &account, &jsonMeta) if err != nil { return nil, err } @@ -100,7 +153,27 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) { } } - events = append(events, &activity.Event{ + if targetUserName != nil { + name, err := store.fieldEncrypt.Decrypt(*targetUserName) + if err != nil { + log.Errorf("failed to decrypt username for target id: %s", target) + meta["username"] = "" + } else { + meta["username"] = name + } + } + + if targetEmail != nil { + email, err := store.fieldEncrypt.Decrypt(*targetEmail) + if err != nil { + log.Errorf("failed to decrypt email address for target id: %s", target) + meta["email"] = "" + } else { + meta["email"] = email + } + } + + event := &activity.Event{ Timestamp: timestamp, Activity: operation, ID: uint64(id), @@ -108,7 +181,27 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) { TargetID: target, AccountID: account, Meta: meta, - }) + } + + if initiatorName != nil { + name, err := store.fieldEncrypt.Decrypt(*initiatorName) + if err != nil { + log.Errorf("failed to decrypt username of initiator: %s", initiator) + } else { + event.InitiatorName = name + } + } + + if initiatorEmail != nil { + email, err := store.fieldEncrypt.Decrypt(*initiatorEmail) + if err != nil { + log.Errorf("failed to decrypt email address of initiator: %s", initiator) + } else { + event.InitiatorEmail = email + } + } + + events = append(events, event) } return events, nil @@ -127,13 +220,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([ } defer result.Close() //nolint - return processResult(result) + return store.processResult(result) } -// Save an event in the SQLite events table +// Save an event in the SQLite events table end encrypt the "email" element in meta map func (store *Store) Save(event *activity.Event) (*activity.Event, error) { var jsonMeta string - if event.Meta != nil { + meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event) + if err != nil { + return nil, err + } + + if meta != nil { metaBytes, err := json.Marshal(event.Meta) if err != nil { return nil, err @@ -156,6 +254,34 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) { return eventCopy, nil } +// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete +// this item from meta map +func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) { + email, ok := event.Meta["email"] + if !ok { + return event.Meta, nil + } + + name, ok := event.Meta["name"] + if !ok { + return event.Meta, nil + } + + encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email)) + encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name)) + _, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName) + if err != nil { + return nil, err + } + + if len(event.Meta) == 2 { + return nil, nil // nolint + } + delete(event.Meta, "email") + delete(event.Meta, "name") + return event.Meta, nil +} + // Close the Store func (store *Store) Close() error { if store.db != nil { @@ -163,3 +289,44 @@ func (store *Store) Close() error { } return nil } + +func updateDeletedUsersTable(db *sql.DB) error { + log.Debugf("check deleted_users table version") + rows, err := db.Query(`PRAGMA table_info(deleted_users);`) + if err != nil { + return err + } + defer rows.Close() + found := false + for rows.Next() { + var ( + cid int + name string + dataType string + notNull int + dfltVal sql.NullString + pk int + ) + err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk) + if err != nil { + return err + } + if name == "name" { + found = true + break + } + } + + err = rows.Err() + if err != nil { + return err + } + + if found { + return nil + } + + log.Debugf("update delted_users table") + _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`) + return err +} diff --git a/management/server/activity/sqlite/sqlite_test.go b/management/server/activity/sqlite/sqlite_test.go index 2ca9a1e648e..f6a6f94678a 100644 --- a/management/server/activity/sqlite/sqlite_test.go +++ b/management/server/activity/sqlite/sqlite_test.go @@ -12,7 +12,8 @@ import ( func TestNewSQLiteStore(t *testing.T) { dataDir := t.TempDir() - store, err := NewSQLiteStore(dataDir) + key, _ := GenerateKey() + store, err := NewSQLiteStore(dataDir, key) if err != nil { t.Fatal(err) return diff --git a/management/server/config.go b/management/server/config.go index ea014398821..31c1cf45c5d 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -35,7 +35,8 @@ type Config struct { TURNConfig *TURNConfig Signal *Host - Datadir string + Datadir string + DataStoreEncryptionKey string HttpConfig *HttpServerConfig diff --git a/management/server/dns.go b/management/server/dns.go index 427ba40d120..252782aeaac 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -122,7 +122,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string am.storeEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - return am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return nil } func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 092c52afabb..b089949b282 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -191,7 +191,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore, false) } func createDNSStore(t *testing.T) (Store, error) { diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index a7b4239837c..0e76e58acc9 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -162,7 +162,7 @@ func (e *EphemeralManager) cleanup() { for id, p := range deletePeers { log.Debugf("delete ephemeral peer: %s", id) - _, err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) + err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) if err != nil { log.Tracef("failed to delete ephemeral peer: %s", err) } diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index a763f4cef6d..d271e5fcafe 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -29,9 +29,9 @@ type MocAccountManager struct { store *MockStore } -func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) (*Peer, error) { +func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error { delete(a.store.account.Peers, peerID) - return nil, nil //nolint:nilnil + return nil //nolint:nil } func TestNewManager(t *testing.T) { diff --git a/management/server/group.go b/management/server/group.go index 5b1d2ac9fd0..a7502134aa3 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -33,26 +33,6 @@ type Group struct { Peers []string } -const ( - // UpdateGroupName indicates a name update operation - UpdateGroupName GroupUpdateOperationType = iota - // InsertPeersToGroup indicates insert peers to group operation - InsertPeersToGroup - // RemovePeersFromGroup indicates a remove peers from group operation - RemovePeersFromGroup - // UpdateGroupPeers indicates a replacement of group peers list - UpdateGroupPeers -) - -// GroupUpdateOperationType operation type -type GroupUpdateOperationType int - -// GroupUpdateOperation operation object with type and values to be applied -type GroupUpdateOperation struct { - Type GroupUpdateOperationType - Values []string -} - // EventMeta returns activity event meta related to the group func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} @@ -104,10 +84,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G return err } - err = am.updateAccountPeers(account) - if err != nil { - return err - } + am.updateAccountPeers(account) // the following snippet tracks the activity and stores the group events in the event store. // It has to happen after all the operations have been successfully performed. @@ -165,57 +142,6 @@ func difference(a, b []string) []string { return diff } -// UpdateGroup updates a group using a list of operations -func (am *DefaultAccountManager) UpdateGroup(accountID string, - groupID string, operations []GroupUpdateOperation, -) (*Group, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - groupToUpdate, ok := account.Groups[groupID] - if !ok { - return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID) - } - - group := groupToUpdate.Copy() - - for _, operation := range operations { - switch operation.Type { - case UpdateGroupName: - group.Name = operation.Values[0] - case UpdateGroupPeers: - group.Peers = operation.Values - case InsertPeersToGroup: - sourceList := group.Peers - resultList := removeFromList(sourceList, operation.Values) - group.Peers = append(resultList, operation.Values...) - case RemovePeersFromGroup: - sourceList := group.Peers - resultList := removeFromList(sourceList, operation.Values) - group.Peers = resultList - } - } - - account.Groups[groupID] = group - - account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - return nil, err - } - - return group, nil -} - // DeleteGroup object of the peers func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { unlock := am.Store.AcquireAccountLock(accountId) @@ -300,7 +226,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) - return am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return nil } // ListGroups objects of the peers @@ -352,11 +280,13 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) return err } - return am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return nil } // GroupDeletePeer removes peer from the group -func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { +func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -372,7 +302,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str account.Network.IncSerial() for i, itemID := range group.Peers { - if itemID == peerKey { + if itemID == peerID { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) if err := am.Store.SaveAccount(account); err != nil { return err @@ -380,31 +310,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str } } - return am.updateAccountPeers(account) -} + am.updateAccountPeers(account) -// GroupListPeers returns list of the peers from the group -func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, status.Errorf(status.NotFound, "account not found") - } - - group, ok := account.Groups[groupID] - if !ok { - return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - peers := make([]*Peer, 0, len(account.Groups)) - for _, peerID := range group.Peers { - p, ok := account.Peers[peerID] - if ok { - peers = append(peers, p) - } - } - - return peers, nil + return nil } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 32b553f9be4..383cb0d1ffd 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -159,6 +159,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi select { // condition when there are some updates case update, open := <-updates: + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) + } + if !open { log.Debugf("updates channel for peer %s was closed", peerKey.String()) s.cancelPeerRoutines(peer) diff --git a/management/server/http/api/generate.sh b/management/server/http/api/generate.sh old mode 100644 new mode 100755 diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 06da0ede315..658d389f668 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -745,9 +745,15 @@ components: type: boolean example: true peer: - description: Peer Identifier associated with route + description: Peer Identifier associated with route. This property can not be set together with `peer_groups` type: string example: chacbco6lnnbn6cg5s91 + peer_groups: + description: Peers Group Identifier associated with route. This property can not be set together with `peer` + type: array + items: + type: string + example: chacbco6lnnbn6cg5s91 network: description: Network range in CIDR format type: string @@ -773,7 +779,9 @@ components: - description - network_id - enabled - - peer + # Only one property has to be set + #- peer + #- peer_groups - network - metric - masquerade @@ -922,6 +930,14 @@ components: description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event. type: string example: google-oauth2|123456789012345678901 + initiator_name: + description: The name of the initiator of the event. + type: string + example: John Doe + initiator_email: + description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. + type: string + example: demo@netbird.io target_id: description: The ID of the target of the event. E.g., an ID of the peer that a user removed. type: string @@ -938,6 +954,8 @@ components: - activity - activity_code - initiator_id + - initiator_name + - initiator_email - target_id - meta responses: @@ -1134,8 +1152,8 @@ paths: '500': "$ref": "#/components/responses/internal_error" delete: - summary: Block a User - description: This method blocks a user from accessing the system, but leaves the IDP user intact. + summary: Delete a User + description: This method removes a user from accessing the system. For this leaves the IDP user intact unless the `--user-delete-from-idp` is passed to management startup. tags: [ Users ] security: - BearerAuth: [ ] diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 402aae63572..fd3eedde30c 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/deepmap/oapi-codegen version v1.11.1-0.20220912230023-4a1477f6a8ba DO NOT EDIT. +// Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT. package api import ( @@ -164,9 +164,15 @@ type Event struct { // Id Event unique identifier Id string `json:"id"` + // InitiatorEmail The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. + InitiatorEmail string `json:"initiator_email"` + // InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event. InitiatorId string `json:"initiator_id"` + // InitiatorName The name of the initiator of the event. + InitiatorName string `json:"initiator_name"` + // Meta The metadata of the event Meta map[string]string `json:"meta"` @@ -593,8 +599,11 @@ type Route struct { // NetworkType Network type indicating if it is IPv4 or IPv6 NetworkType string `json:"network_type"` - // Peer Peer Identifier associated with route - Peer string `json:"peer"` + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` } // RouteRequest defines model for RouteRequest. @@ -620,8 +629,11 @@ type RouteRequest struct { // NetworkId Route network identifier, to group HA routes NetworkId string `json:"network_id"` - // Peer Peer Identifier associated with route - Peer string `json:"peer"` + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` } // Rule defines model for Rule. diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 1d1c176e51b..a89c206a368 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -45,14 +45,66 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteError(err, w) return } - events := make([]*api.Event, 0) - for _, e := range accountEvents { - events = append(events, toEventResponse(e)) + events := make([]*api.Event, len(accountEvents)) + for i, e := range accountEvents { + events[i] = toEventResponse(e) + } + + err = h.fillEventsWithUserInfo(events, account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return } util.WriteJSONObject(w, events) } +func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error { + // build email, name maps based on users + userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) + if err != nil { + log.Errorf("failed to get users from account: %s", err) + return err + } + + emails := make(map[string]string) + names := make(map[string]string) + for _, ui := range userInfos { + emails[ui.ID] = ui.Email + names[ui.ID] = ui.Name + } + + var ok bool + for _, event := range events { + // fill initiator + if event.InitiatorEmail == "" { + event.InitiatorEmail, ok = emails[event.InitiatorId] + if !ok { + log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) + } + } + + if event.InitiatorName == "" { + // here to allowed to be empty because in the first release we did not store the name + event.InitiatorName = names[event.InitiatorId] + } + + // fill target meta + email, ok := emails[event.TargetId] + if !ok { + continue + } + event.Meta["email"] = email + + username, ok := names[event.TargetId] + if !ok { + continue + } + event.Meta["username"] = username + } + return nil +} + func toEventResponse(event *activity.Event) *api.Event { meta := make(map[string]string) if event.Meta != nil { @@ -60,13 +112,16 @@ func toEventResponse(event *activity.Event) *api.Event { meta[s] = fmt.Sprintf("%v", a) } } - return &api.Event{ - Id: fmt.Sprint(event.ID), - InitiatorId: event.InitiatorID, - Activity: event.Activity.Message(), - ActivityCode: api.EventActivityCode(event.Activity.StringCode()), - TargetId: event.TargetID, - Timestamp: event.Timestamp, - Meta: meta, + e := &api.Event{ + Id: fmt.Sprint(event.ID), + InitiatorId: event.InitiatorID, + InitiatorName: event.InitiatorName, + InitiatorEmail: event.InitiatorEmail, + Activity: event.Activity.Message(), + ActivityCode: api.EventActivityCode(event.Activity.StringCode()), + TargetId: event.TargetID, + Timestamp: event.Timestamp, + Meta: meta, } + return e } diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index a77e44f454f..4cfad922be5 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -37,6 +37,9 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E }, }, user, nil }, + GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { + return make([]*server.UserInfo, 0), nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 44603059a2f..aad03d50bed 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -53,30 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle Issued: server.GroupIssuedAPI, }, nil }, - UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { - var group server.Group - group.ID = groupID - for _, operation := range operations { - switch operation.Type { - case server.UpdateGroupName: - group.Name = operation.Values[0] - case server.UpdateGroupPeers, server.InsertPeersToGroup: - group.Peers = operation.Values - case server.RemovePeersFromGroup: - default: - return nil, fmt.Errorf("no operation") - } - } - return &group, nil - }, - GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { - for _, peer := range TestPeers { - if peer.IP.String() == peerIP { - return peer, nil - } - } - return nil, fmt.Errorf("peer not found") - }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 75fcb4c1c89..100f4b87a7f 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -88,31 +88,6 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - UpdateNameServerGroupFunc: func(accountID, nsGroupID, _ string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - nsGroupToUpdate := baseExistingNSGroup.Copy() - if nsGroupID != nsGroupToUpdate.ID { - return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) - } - for _, operation := range operations { - switch operation.Type { - case server.UpdateNameServerGroupName: - nsGroupToUpdate.Name = operation.Values[0] - case server.UpdateNameServerGroupDescription: - nsGroupToUpdate.Description = operation.Values[0] - case server.UpdateNameServerGroupNameServers: - var parsedNSList []nbdns.NameServer - for _, nsURL := range operation.Values { - parsed, err := nbdns.ParseNameServerURL(nsURL) - if err != nil { - return nil, err - } - parsedNSList = append(parsedNSList, parsed) - } - nsGroupToUpdate.NameServers = parsedNSList - } - } - return nsGroupToUpdate, nil - }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingNSAccount, testingAccount.Users["test_user"], nil }, diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 100549aadf9..adf4a972102 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -61,7 +61,7 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe } func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { - _, err := h.accountManager.DeletePeer(accountID, peerID, userID) + err := h.accountManager.DeletePeer(accountID, peerID, userID) if err != nil { util.WriteError(err, w) return diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index a6dfa9c74df..348bdbfd688 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -82,7 +82,33 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { return } - newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), req.Peer, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id) + peerId := "" + if req.Peer != nil { + peerId = *req.Peer + } + + peerGroupIds := []string{} + if req.PeerGroups != nil { + peerGroupIds = *req.PeerGroups + } + + if (peerId != "" && len(peerGroupIds) > 0) || (peerId == "" && len(peerGroupIds) == 0) { + util.WriteError(status.Errorf(status.InvalidArgument, "only one peer or peer_groups should be provided"), w) + return + } + + // do not allow non Linux peers + if peer := account.GetPeer(peerId); peer != nil { + if peer.Meta.GoOS != "linux" { + util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) + return + } + } + + newRoute, err := h.accountManager.CreateRoute( + account.Id, newPrefix.String(), peerId, peerGroupIds, + req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, + ) if err != nil { util.WriteError(err, w) return @@ -135,19 +161,49 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } + if req.Peer != nil && req.PeerGroups != nil { + util.WriteError(status.Errorf(status.InvalidArgument, "only peer or peers_group should be provided"), w) + return + } + + if req.Peer == nil && req.PeerGroups == nil { + util.WriteError(status.Errorf(status.InvalidArgument, "either peer or peers_group should be provided"), w) + return + } + + peerID := "" + if req.Peer != nil { + peerID = *req.Peer + } + + // do not allow non Linux peers + if peer := account.GetPeer(peerID); peer != nil { + if peer.Meta.GoOS != "linux" { + util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) + return + } + } + newRoute := &route.Route{ ID: routeID, Network: newPrefix, NetID: req.NetworkId, NetworkType: prefixType, Masquerade: req.Masquerade, - Peer: req.Peer, Metric: req.Metric, Description: req.Description, Enabled: req.Enabled, Groups: req.Groups, } + if req.Peer != nil { + newRoute.Peer = peerID + } + + if req.PeerGroups != nil { + newRoute.PeerGroups = *req.PeerGroups + } + err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute) if err != nil { util.WriteError(err, w) @@ -208,16 +264,21 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { } func toRouteResponse(serverRoute *route.Route) *api.Route { - return &api.Route{ + route := &api.Route{ Id: serverRoute.ID, Description: serverRoute.Description, NetworkId: serverRoute.NetID, Enabled: serverRoute.Enabled, - Peer: serverRoute.Peer, + Peer: &serverRoute.Peer, Network: serverRoute.Network.String(), NetworkType: serverRoute.NetworkType.String(), Masquerade: serverRoute.Masquerade, Metric: serverRoute.Metric, Groups: serverRoute.Groups, } + + if len(serverRoute.PeerGroups) > 0 { + route.PeerGroups = &serverRoute.PeerGroups + } + return route } diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index c4270284c77..0bb4587e4f4 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/netip" - "strconv" "testing" "github.com/netbirdio/netbird/management/server/http/api" @@ -24,16 +23,23 @@ import ( ) const ( - existingRouteID = "existingRouteID" - notFoundRouteID = "notFoundRouteID" - existingPeerIP = "100.64.0.100" - existingPeerID = "peer-id" - notFoundPeerID = "nonExistingPeer" - existingPeerKey = "existingPeerKey" - testAccountID = "test_id" - existingGroupID = "testGroup" + existingRouteID = "existingRouteID" + existingRouteID2 = "existingRouteID2" // for peer_groups test + notFoundRouteID = "notFoundRouteID" + existingPeerIP1 = "100.64.0.100" + existingPeerIP2 = "100.64.0.101" + notFoundPeerID = "nonExistingPeer" + existingPeerKey = "existingPeerKey" + nonLinuxExistingPeerKey = "darwinExistingPeerKey" + testAccountID = "test_id" + existingGroupID = "testGroup" + notFoundGroupID = "nonExistingGroup" ) +var emptyString = "" +var existingPeerID = "peer-id" +var nonLinuxExistingPeerID = "darwin-peer-id" + var baseExistingRoute = &route.Route{ ID: existingRouteID, Description: "base route", @@ -52,8 +58,19 @@ var testingAccount = &server.Account{ Peers: map[string]*server.Peer{ existingPeerID: { Key: existingPeerKey, - IP: netip.MustParseAddr(existingPeerIP).AsSlice(), + IP: netip.MustParseAddr(existingPeerIP1).AsSlice(), ID: existingPeerID, + Meta: server.PeerSystemMeta{ + GoOS: "linux", + }, + }, + nonLinuxExistingPeerID: { + Key: nonLinuxExistingPeerID, + IP: netip.MustParseAddr(existingPeerIP2).AsSlice(), + ID: nonLinuxExistingPeerID, + Meta: server.PeerSystemMeta{ + GoOS: "darwin", + }, }, }, Users: map[string]*server.User{ @@ -68,17 +85,26 @@ func initRoutesTestData() *RoutesHandler { if routeID == existingRouteID { return baseExistingRoute, nil } + if routeID == existingRouteID2 { + route := baseExistingRoute.Copy() + route.PeerGroups = []string{existingGroupID} + return route, nil + } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(accountID string, network, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { + CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } + if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { + return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) + } networkType, p, _ := route.ParseNetwork(network) return &route.Route{ ID: existingRouteID, NetID: netID, Peer: peerID, + PeerGroups: peerGroups, Network: p, NetworkType: networkType, Description: description, @@ -99,47 +125,6 @@ func initRoutesTestData() *RoutesHandler { } return nil }, - GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { - if peerIP != existingPeerID { - return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP) - } - return &server.Peer{ - Key: existingPeerKey, - IP: netip.MustParseAddr(existingPeerID).AsSlice(), - }, nil - }, - UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) { - routeToUpdate := baseExistingRoute - if routeID != routeToUpdate.ID { - return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) - } - for _, operation := range operations { - switch operation.Type { - case server.UpdateRouteNetwork: - routeToUpdate.NetworkType, routeToUpdate.Network, _ = route.ParseNetwork(operation.Values[0]) - case server.UpdateRouteDescription: - routeToUpdate.Description = operation.Values[0] - case server.UpdateRouteNetworkIdentifier: - routeToUpdate.NetID = operation.Values[0] - case server.UpdateRoutePeer: - routeToUpdate.Peer = operation.Values[0] - if routeToUpdate.Peer == notFoundPeerID { - return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToUpdate.Peer) - } - case server.UpdateRouteMetric: - routeToUpdate.Metric, _ = strconv.Atoi(operation.Values[0]) - case server.UpdateRouteMasquerade: - routeToUpdate.Masquerade, _ = strconv.ParseBool(operation.Values[0]) - case server.UpdateRouteEnabled: - routeToUpdate.Enabled, _ = strconv.ParseBool(operation.Values[0]) - case server.UpdateRouteGroups: - routeToUpdate.Groups = operation.Values - default: - return nil, fmt.Errorf("no operation") - } - } - return routeToUpdate, nil - }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingAccount, testingAccount.Users["test_user"], nil }, @@ -157,6 +142,9 @@ func initRoutesTestData() *RoutesHandler { } func TestRoutesHandlers(t *testing.T) { + baseExistingRouteWithPeerGroups := baseExistingRoute.Copy() + baseExistingRouteWithPeerGroups.PeerGroups = []string{existingGroupID} + tt := []struct { name string expectedStatus int @@ -180,6 +168,14 @@ func TestRoutesHandlers(t *testing.T) { requestPath: "/api/routes/" + notFoundRouteID, expectedStatus: http.StatusNotFound, }, + { + name: "Get Existing Route with Peer Groups", + requestType: http.MethodGet, + requestPath: "/api/routes/" + existingRouteID2, + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: toRouteResponse(baseExistingRouteWithPeerGroups), + }, { name: "Delete Existing Route", requestType: http.MethodDelete, @@ -206,13 +202,21 @@ func TestRoutesHandlers(t *testing.T) { Description: "Post", NetworkId: "awesomeNet", Network: "192.168.0.0/16", - Peer: existingPeerID, + Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, Enabled: false, Groups: []string{existingGroupID}, }, }, + { + name: "POST Non Linux Peer", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", nonLinuxExistingPeerID, existingGroupID)), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, { name: "POST Not Found Peer", requestType: http.MethodPost, @@ -237,6 +241,24 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "POST UnprocessableEntity when both peer and peer_groups are provided", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer\":\"%s\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, + { + name: "POST UnprocessableEntity when no peer and peer_groups are provided", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"groups\":[\"%s\"]}", existingPeerID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, { name: "PUT OK", requestType: http.MethodPut, @@ -249,7 +271,27 @@ func TestRoutesHandlers(t *testing.T) { Description: "Post", NetworkId: "awesomeNet", Network: "192.168.0.0/16", - Peer: existingPeerID, + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + }, + }, + { + name: "PUT OK when peer_groups provided", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingGroupID, existingGroupID)), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: &api.Route{ + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: "192.168.0.0/16", + Peer: &emptyString, + PeerGroups: &[]string{existingGroupID}, NetworkType: route.IPv4NetworkString, Masquerade: false, Enabled: false, @@ -272,6 +314,14 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "PUT Non Linux Peer", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", nonLinuxExistingPeerID, existingGroupID)), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, { name: "PUT Invalid Network Identifier", requestType: http.MethodPut, @@ -288,6 +338,24 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "PUT UnprocessableEntity when both peer and peer_groups are provided", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer\":\"%s\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, + { + name: "PUT UnprocessableEntity when no peer and peer_groups are provided", + requestType: http.MethodPut, + requestPath: "/api/routes/" + existingRouteID, + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"groups\":[\"%s\"]}", existingPeerID))), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: false, + }, } p := initRoutesTestData() diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 44f4919f50d..2776273101c 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -77,6 +77,7 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { // WriteError converts an error to an JSON error response. // If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise func WriteError(err error, w http.ResponseWriter) { + log.Errorf("got a handler error: %s", err.Error()) errStatus, ok := status.FromError(err) httpStatus := http.StatusInternalServerError msg := "internal server error" diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 64ec88e9f82..d3802d8ad96 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -513,7 +513,9 @@ func buildUserExportRequest() (string, error) { return string(str), nil } -func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { +func (am *Auth0Manager) createRequest( + method string, endpoint string, body io.Reader, +) (*http.Request, error) { jwtToken, err := am.credentials.Authenticate() if err != nil { return nil, err @@ -521,17 +523,23 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (* reqURL := am.authIssuer + endpoint - payload := strings.NewReader(payloadStr) - - req, err := http.NewRequest("POST", reqURL, payload) + req, err := http.NewRequest(method, reqURL, body) if err != nil { return nil, err } req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) - req.Header.Add("content-type", "application/json") return req, nil +} +func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { + req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr)) + if err != nil { + return nil, err + } + req.Header.Add("content-type", "application/json") + + return req, nil } // GetAllAccounts gets all registered accounts with corresponding user data. @@ -737,6 +745,38 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { return nil } +// DeleteUser from Auth0 +func (am *Auth0Manager) DeleteUser(userID string) error { + req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) + if err != nil { + return err + } + + resp, err := am.httpClient.Do(req) + if err != nil { + log.Debugf("execute delete request: %v", err) + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + + defer func() { + err = resp.Body.Close() + if err != nil { + log.Errorf("close delete request body: %v", err) + } + }() + if resp.StatusCode != 204 { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + return nil +} + // checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // If the status is "completed", then return the downloadLink func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 0898f1c9422..ca995b2996e 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -12,9 +12,10 @@ import ( "time" "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/telemetry" log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // AuthentikManager authentik manager client instance. @@ -209,47 +210,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (am *AuthentikManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { - ctx, err := am.authenticationContext() - if err != nil { - return err - } - - userPk, err := strconv.ParseInt(userID, 10, 32) - if err != nil { - return err - } - - var pendingInvite bool - if appMetadata.WTPendingInvite != nil { - pendingInvite = *appMetadata.WTPendingInvite - } - - patchedUserReq := api.PatchedUserRequest{ - Attributes: map[string]interface{}{ - wtAccountID: appMetadata.WTAccountID, - wtPendingInvite: pendingInvite, - }, - } - _, resp, err := am.apiClient.CoreApi.CoreUsersPartialUpdate(ctx, int32(userPk)). - PatchedUserRequest(patchedUserReq). - Execute() - if err != nil { - return err - } - defer resp.Body.Close() - - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() - } - - if resp.StatusCode != http.StatusOK { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestStatusError() - } - return fmt.Errorf("unable to update user %s, statusCode %d", userID, resp.StatusCode) - } - +func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { return nil } @@ -282,7 +243,10 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) } - return parseAuthentikUser(*user) + userData := parseAuthentikUser(*user) + userData.AppMetadata = appMetadata + + return userData, nil } // GetAccount returns all the users for a given profile. @@ -292,8 +256,7 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { return nil, err } - accountFilter := fmt.Sprintf("{%q:%q}", wtAccountID, accountID) - userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Attributes(accountFilter).Execute() + userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute() if err != nil { return nil, err } @@ -312,10 +275,9 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { users := make([]*UserData, 0) for _, user := range userList.Results { - userData, err := parseAuthentikUser(user) - if err != nil { - return nil, err - } + userData := parseAuthentikUser(user) + userData.AppMetadata.WTAccountID = accountID + users = append(users, userData) } @@ -349,65 +311,16 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { indexedUsers := make(map[string][]*UserData) for _, user := range userList.Results { - userData, err := parseAuthentikUser(user) - if err != nil { - return nil, err - } - - accountID := userData.AppMetadata.WTAccountID - if accountID != "" { - if _, ok := indexedUsers[accountID]; !ok { - indexedUsers[accountID] = make([]*UserData, 0) - } - indexedUsers[accountID] = append(indexedUsers[accountID], userData) - } + userData := parseAuthentikUser(user) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } return indexedUsers, nil } // CreateUser creates a new user in authentik Idp and sends an invitation. -func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { - ctx, err := am.authenticationContext() - if err != nil { - return nil, err - } - - groupID, err := am.getUserGroupByName("netbird") - if err != nil { - return nil, err - } - - defaultBoolValue := true - createUserRequest := api.UserRequest{ - Email: &email, - Name: name, - IsActive: &defaultBoolValue, - Groups: []string{groupID}, - Username: email, - Attributes: map[string]interface{}{ - wtAccountID: accountID, - wtPendingInvite: &defaultBoolValue, - }, - } - user, resp, err := am.apiClient.CoreApi.CoreUsersCreate(ctx).UserRequest(createUserRequest).Execute() - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountCreateUser() - } - - if resp.StatusCode != http.StatusCreated { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) - } - - return parseAuthentikUser(*user) +func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) { + return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. @@ -437,11 +350,7 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) { users := make([]*UserData, 0) for _, user := range userList.Results { - userData, err := parseAuthentikUser(user) - if err != nil { - return nil, err - } - users = append(users, userData) + users = append(users, parseAuthentikUser(user)) } return users, nil @@ -453,79 +362,57 @@ func (am *AuthentikManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } -func (am *AuthentikManager) authenticationContext() (context.Context, error) { - jwtToken, err := am.credentials.Authenticate() - if err != nil { - return nil, err - } - - value := map[string]api.APIKey{ - "authentik": { - Key: jwtToken.AccessToken, - Prefix: jwtToken.TokenType, - }, - } - return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil -} - -// getUserGroupByName retrieves the user group for assigning new users. -// If the group is not found, a new group with the specified name will be created. -func (am *AuthentikManager) getUserGroupByName(name string) (string, error) { +// DeleteUser from Authentik +func (am *AuthentikManager) DeleteUser(userID string) error { ctx, err := am.authenticationContext() if err != nil { - return "", err + return err } - groupList, resp, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute() + userPk, err := strconv.ParseInt(userID, 10, 32) if err != nil { - return "", err + return err } - defer resp.Body.Close() - if groupList != nil { - if len(groupList.Results) > 0 { - return groupList.Results[0].Pk, nil - } + resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute() + if err != nil { + return err } + defer resp.Body.Close() // nolint - createGroupRequest := api.GroupRequest{Name: name} - group, resp, err := am.apiClient.CoreApi.CoreGroupsCreate(ctx).GroupRequest(createGroupRequest).Execute() - if err != nil { - return "", err + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountDeleteUser() } - defer resp.Body.Close() - if resp.StatusCode != http.StatusCreated { - return "", fmt.Errorf("unable to create user group, statusCode: %d", resp.StatusCode) + if resp.StatusCode != http.StatusNoContent { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode) } - return group.Pk, nil + return nil } -func parseAuthentikUser(user api.User) (*UserData, error) { - var attributes struct { - AccountID string `json:"wt_account_id"` - PendingInvite bool `json:"wt_pending_invite"` - } - - helper := JsonParser{} - buf, err := helper.Marshal(user.Attributes) +func (am *AuthentikManager) authenticationContext() (context.Context, error) { + jwtToken, err := am.credentials.Authenticate() if err != nil { return nil, err } - err = helper.Unmarshal(buf, &attributes) - if err != nil { - return nil, err + value := map[string]api.APIKey{ + "authentik": { + Key: jwtToken.AccessToken, + Prefix: jwtToken.TokenType, + }, } + return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil +} +func parseAuthentikUser(user api.User) *UserData { return &UserData{ Email: *user.Email, Name: user.Name, ID: strconv.FormatInt(int64(user.Pk), 10), - AppMetadata: AppMetadata{ - WTAccountID: attributes.AccountID, - WTPendingInvite: &attributes.PendingInvite, - }, - }, nil + } } diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 7cff7d8fc2e..e4224c26d96 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -1,7 +1,6 @@ package idp import ( - "encoding/json" "fmt" "io" "net/http" @@ -11,19 +10,13 @@ import ( "time" "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/telemetry" log "github.com/sirupsen/logrus" -) - -const ( - // azure extension properties template - wtAccountIDTpl = "extension_%s_wt_account_id" - wtPendingInviteTpl = "extension_%s_wt_pending_invite" - profileFields = "id,displayName,mail,userPrincipalName" - extensionFields = "id,name,targetObjects" + "github.com/netbirdio/netbird/management/server/telemetry" ) +const profileFields = "id,displayName,mail,userPrincipalName" + // AzureManager azure manager client instance. type AzureManager struct { ClientID string @@ -58,21 +51,6 @@ type AzureCredentials struct { // azureProfile represents an azure user profile. type azureProfile map[string]any -// passwordProfile represent authentication method for, -// newly created user profile. -type passwordProfile struct { - ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"` - Password string `json:"password"` -} - -// azureExtension represent custom attribute, -// that can be added to user objects in Azure Active Directory (AD). -type azureExtension struct { - Name string `json:"name"` - DataType string `json:"dataType"` - TargetObjects []string `json:"targetObjects"` -} - // NewAzureManager creates a new instance of the AzureManager. func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() @@ -115,7 +93,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) appMetrics: appMetrics, } - manager := &AzureManager{ + return &AzureManager{ ObjectID: config.ObjectID, ClientID: config.ClientID, GraphAPIEndpoint: config.GraphAPIEndpoint, @@ -123,14 +101,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) credentials: credentials, helper: helper, appMetrics: appMetrics, - } - - err := manager.configureAppMetadata() - if err != nil { - return nil, err - } - - return manager, nil + }, nil } // jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure. @@ -236,44 +207,14 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in azure AD Idp. -func (am *AzureManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { - payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID) - if err != nil { - return nil, err - } - - body, err := am.post("users", payload) - if err != nil { - return nil, err - } - - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountCreateUser() - } - - var profile azureProfile - err = am.helper.Unmarshal(body, &profile) - if err != nil { - return nil, err - } - - wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID) - profile[wtAccountIDField] = accountID - - wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID) - profile[wtPendingInviteField] = true - - return profile.userData(am.ClientID), nil +func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) { + return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserDataByID requests user data from keycloak via ID. func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID) - wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID) - selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",") - q := url.Values{} - q.Add("$select", selectFields) + q.Add("$select", profileFields) body, err := am.get("users/"+userID, q) if err != nil { @@ -290,18 +231,17 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) return nil, err } - return profile.userData(am.ClientID), nil + userData := profile.userData() + userData.AppMetadata = appMetadata + + return userData, nil } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { - wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID) - wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID) - selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",") - q := url.Values{} - q.Add("$select", selectFields) + q.Add("$select", profileFields) body, err := am.get("users/"+email, q) if err != nil { @@ -319,20 +259,15 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { } users := make([]*UserData, 0) - users = append(users, profile.userData(am.ClientID)) + users = append(users, profile.userData()) return users, nil } // GetAccount returns all the users for a given profile. func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { - wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID) - wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID) - selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",") - q := url.Values{} - q.Add("$select", selectFields) - q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID)) + q.Add("$select", profileFields) body, err := am.get("users", q) if err != nil { @@ -351,7 +286,10 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { users := make([]*UserData, 0) for _, profile := range profiles.Value { - users = append(users, profile.userData(am.ClientID)) + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountID + + users = append(users, userData) } return users, nil @@ -360,12 +298,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { - wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID) - wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID) - selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",") - q := url.Values{} - q.Add("$select", selectFields) + q.Add("$select", profileFields) body, err := am.get("users", q) if err != nil { @@ -384,49 +318,40 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { indexedUsers := make(map[string][]*UserData) for _, profile := range profiles.Value { - userData := profile.userData(am.ClientID) - - accountID := userData.AppMetadata.WTAccountID - if accountID != "" { - if _, ok := indexedUsers[accountID]; !ok { - indexedUsers[accountID] = make([]*UserData, 0) - } - indexedUsers[accountID] = append(indexedUsers[accountID], userData) - } - + userData := profile.userData() + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } return indexedUsers, nil } // UpdateUserAppMetadata updates user app metadata based on userID. -func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { - jwtToken, err := am.credentials.Authenticate() - if err != nil { - return err - } +func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { + return nil +} - wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID) - wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID) +// InviteUserByID resend invitations to users who haven't activated, +// their accounts prior to the expiration period. +func (am *AzureManager) InviteUserByID(_ string) error { + return fmt.Errorf("method InviteUserByID not implemented") +} - data, err := am.helper.Marshal(map[string]any{ - wtAccountIDField: appMetadata.WTAccountID, - wtPendingInviteField: appMetadata.WTPendingInvite, - }) +// DeleteUser from Azure. +func (am *AzureManager) DeleteUser(userID string) error { + jwtToken, err := am.credentials.Authenticate() if err != nil { return err } - payload := strings.NewReader(string(data)) - reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID) - req, err := http.NewRequest(http.MethodPatch, reqURL, payload) + reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID)) + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) if err != nil { return err } req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("content-type", "application/json") - log.Debugf("updating idp metadata for user %s", userID) + log.Debugf("delete idp user %s", userID) resp, err := am.httpClient.Do(req) if err != nil { @@ -438,92 +363,11 @@ func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMeta defer resp.Body.Close() if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() + am.appMetrics.IDPMetrics().CountDeleteUser() } if resp.StatusCode != http.StatusNoContent { - return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode) - } - - return nil -} - -// InviteUserByID resend invitations to users who haven't activated, -// their accounts prior to the expiration period. -func (am *AzureManager) InviteUserByID(_ string) error { - return fmt.Errorf("method InviteUserByID not implemented") -} - -func (am *AzureManager) getUserExtensions() ([]azureExtension, error) { - q := url.Values{} - q.Add("$select", extensionFields) - - resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID) - body, err := am.get(resource, q) - if err != nil { - return nil, err - } - - var extensions struct{ Value []azureExtension } - err = am.helper.Unmarshal(body, &extensions) - if err != nil { - return nil, err - } - - return extensions.Value, nil -} - -func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) { - extension := azureExtension{ - Name: name, - DataType: "string", - TargetObjects: []string{"User"}, - } - - payload, err := am.helper.Marshal(extension) - if err != nil { - return nil, err - } - - resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID) - body, err := am.post(resource, string(payload)) - if err != nil { - return nil, err - } - - var userExtension azureExtension - err = am.helper.Unmarshal(body, &userExtension) - if err != nil { - return nil, err - } - - return &userExtension, nil -} - -// configureAppMetadata sets up app metadata extensions if they do not exists. -func (am *AzureManager) configureAppMetadata() error { - wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID) - wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID) - - extensions, err := am.getUserExtensions() - if err != nil { - return err - } - - // If the wt_account_id extension does not already exist, create it. - if !hasExtension(extensions, wtAccountIDField) { - _, err = am.createUserExtension(wtAccountID) - if err != nil { - return err - } - } - - // If the wt_pending_invite extension does not already exist, create it. - if !hasExtension(extensions, wtPendingInviteField) { - _, err = am.createUserExtension(wtPendingInvite) - if err != nil { - return err - } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) } return nil @@ -565,44 +409,8 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { return io.ReadAll(resp.Body) } -// post perform Post requests. -func (am *AzureManager) post(resource string, body string) ([]byte, error) { - jwtToken, err := am.credentials.Authenticate() - if err != nil { - return nil, err - } - - reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource) - req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) - req.Header.Add("content-type", "application/json") - - resp, err := am.httpClient.Do(req) - if err != nil { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestError() - } - - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusCreated { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestStatusError() - } - - return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode) - } - - return io.ReadAll(resp.Body) -} - // userData construct user data from keycloak profile. -func (ap azureProfile) userData(clientID string) *UserData { +func (ap azureProfile) userData() *UserData { id, ok := ap["id"].(string) if !ok { id = "" @@ -618,66 +426,9 @@ func (ap azureProfile) userData(clientID string) *UserData { name = "" } - accountIDField := extensionName(wtAccountIDTpl, clientID) - accountID, ok := ap[accountIDField].(string) - if !ok { - accountID = "" - } - - pendingInviteField := extensionName(wtPendingInviteTpl, clientID) - pendingInvite, ok := ap[pendingInviteField].(bool) - if !ok { - pendingInvite = false - } - return &UserData{ Email: email, Name: name, ID: id, - AppMetadata: AppMetadata{ - WTAccountID: accountID, - WTPendingInvite: &pendingInvite, - }, - } -} - -func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) { - wtAccountIDField := extensionName(wtAccountIDTpl, clientID) - wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID) - - req := &azureProfile{ - "accountEnabled": true, - "displayName": name, - "mailNickName": strings.Join(strings.Split(name, " "), ""), - "userPrincipalName": email, - "passwordProfile": passwordProfile{ - ForceChangePasswordNextSignIn: true, - Password: GeneratePassword(8, 1, 1, 1), - }, - wtAccountIDField: accountID, - wtPendingInviteField: true, - } - - str, err := json.Marshal(req) - if err != nil { - return "", err - } - - return string(str), nil -} - -func extensionName(extensionTpl, clientID string) string { - clientID = strings.ReplaceAll(clientID, "-", "") - return fmt.Sprintf(extensionTpl, clientID) -} - -// hasExtension checks whether a given extension by name, -// exists in an list of extensions. -func hasExtension(extensions []azureExtension, name string) bool { - for _, ext := range extensions { - if ext.Name == name { - return true - } } - return false } diff --git a/management/server/idp/azure_test.go b/management/server/idp/azure_test.go index 9d845ffbeef..b4dc96b23cd 100644 --- a/management/server/idp/azure_test.go +++ b/management/server/idp/azure_test.go @@ -8,15 +8,6 @@ import ( "github.com/stretchr/testify/assert" ) -type mockAzureCredentials struct { - jwtToken JWTToken - err error -} - -func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) { - return mc.jwtToken, mc.err -} - func TestAzureJwtStillValid(t *testing.T) { type jwtStillValidTest struct { name string @@ -124,206 +115,63 @@ func TestAzureAuthenticate(t *testing.T) { } } -func TestAzureUpdateUserAppMetadata(t *testing.T) { - type updateUserAppMetadataTest struct { - name string - inputReqBody string - expectedReqBody string - appMetadata AppMetadata - statusCode int - helper ManagerHelper - managerCreds ManagerCredentials - assertErrFunc assert.ErrorAssertionFunc - assertErrFuncMessage string - } - - appMetadata := AppMetadata{WTAccountID: "ok"} - - updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{ - name: "Bad Authentication", - expectedReqBody: "", - appMetadata: appMetadata, - statusCode: 400, - helper: JsonParser{}, - managerCreds: &mockAzureCredentials{ - jwtToken: JWTToken{}, - err: fmt.Errorf("error"), - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ - name: "Bad Status Code", - expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID), - appMetadata: appMetadata, - statusCode: 400, - helper: JsonParser{}, - managerCreds: &mockAzureCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{ - name: "Bad Response Parsing", - statusCode: 400, - helper: &mockJsonParser{marshalErrorString: "error"}, - managerCreds: &mockAzureCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ - name: "Good request", - expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID), - appMetadata: appMetadata, - statusCode: 204, - helper: JsonParser{}, - managerCreds: &mockAzureCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.NoError, - assertErrFuncMessage: "shouldn't return error", - } - - invite := true - updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{ - name: "Update Pending Invite", - expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID), - appMetadata: AppMetadata{ - WTAccountID: "ok", - WTPendingInvite: &invite, - }, - statusCode: 204, - helper: JsonParser{}, - managerCreds: &mockAzureCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.NoError, - assertErrFuncMessage: "shouldn't return error", - } - - for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, - updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} { - t.Run(testCase.name, func(t *testing.T) { - reqClient := mockHTTPClient{ - resBody: testCase.inputReqBody, - code: testCase.statusCode, - } - - manager := &AzureManager{ - httpClient: &reqClient, - credentials: testCase.managerCreds, - helper: testCase.helper, - } - - err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) - testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) - - assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match") - }) - } -} - func TestAzureProfile(t *testing.T) { type azureProfileTest struct { name string - clientID string invite bool inputProfile azureProfile expectedUserData UserData } azureProfileTestCase1 := azureProfileTest{ - name: "Good Request", - clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c", - invite: false, + name: "Good Request", + invite: false, inputProfile: azureProfile{ "id": "test1", "displayName": "John Doe", "userPrincipalName": "test1@test.com", - "extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1", - "extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false, }, expectedUserData: UserData{ Email: "test1@test.com", Name: "John Doe", ID: "test1", - AppMetadata: AppMetadata{ - WTAccountID: "1", - }, }, } azureProfileTestCase2 := azureProfileTest{ - name: "Missing User ID", - clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c", - invite: true, + name: "Missing User ID", + invite: true, inputProfile: azureProfile{ "displayName": "John Doe", "userPrincipalName": "test2@test.com", - "extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1", - "extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true, }, expectedUserData: UserData{ Email: "test2@test.com", Name: "John Doe", - AppMetadata: AppMetadata{ - WTAccountID: "1", - }, }, } azureProfileTestCase3 := azureProfileTest{ - name: "Missing User Name", - clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c", - invite: false, + name: "Missing User Name", + invite: false, inputProfile: azureProfile{ "id": "test3", "userPrincipalName": "test3@test.com", - "extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1", - "extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false, }, expectedUserData: UserData{ ID: "test3", Email: "test3@test.com", - AppMetadata: AppMetadata{ - WTAccountID: "1", - }, - }, - } - - azureProfileTestCase4 := azureProfileTest{ - name: "Missing Extension Fields", - clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c", - invite: false, - inputProfile: azureProfile{ - "id": "test4", - "displayName": "John Doe", - "userPrincipalName": "test4@test.com", - }, - expectedUserData: UserData{ - ID: "test4", - Name: "John Doe", - Email: "test4@test.com", - AppMetadata: AppMetadata{}, }, } - for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} { + for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} { t.Run(testCase.name, func(t *testing.T) { testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite - userData := testCase.inputProfile.userData(testCase.clientID) + userData := testCase.inputProfile.userData() assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match") assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match") assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match") - assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match") - assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match") }) } } diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 2e65497dc41..ed2de9a4225 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -5,15 +5,14 @@ import ( "encoding/base64" "fmt" "net/http" - "strings" "time" - "github.com/netbirdio/netbird/management/server/telemetry" log "github.com/sirupsen/logrus" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" - "google.golang.org/api/googleapi" "google.golang.org/api/option" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // GoogleWorkspaceManager Google Workspace manager client instance. @@ -73,17 +72,13 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te } service, err := admin.NewService(context.Background(), - option.WithScopes(admin.AdminDirectoryUserScope, admin.AdminDirectoryUserschemaScope), + option.WithScopes(admin.AdminDirectoryUserReadonlyScope), option.WithCredentials(adminCredentials), ) if err != nil { return nil, err } - if err = configureAppMetadataSchema(service, config.CustomerID); err != nil { - return nil, err - } - return &GoogleWorkspaceManager{ usersService: service.Users, CustomerID: config.CustomerID, @@ -95,27 +90,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { - metadata, err := gm.helper.Marshal(appMetadata) - if err != nil { - return err - } - - user := &admin.User{ - CustomSchemas: map[string]googleapi.RawMessage{ - "app_metadata": metadata, - }, - } - - _, err = gm.usersService.Update(userID, user).Do() - if err != nil { - return err - } - - if gm.appMetrics != nil { - gm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() - } - +func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { return nil } @@ -130,23 +105,23 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App gm.appMetrics.IDPMetrics().CountGetUserDataByID() } - return parseGoogleWorkspaceUser(user) + userData := parseGoogleWorkspaceUser(user) + userData.AppMetadata = appMetadata + + return userData, nil } // GetAccount returns all the users for a given profile. func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { - query := fmt.Sprintf("app_metadata.wt_account_id=\"%s\"", accountID) - usersList, err := gm.usersService.List().Customer(gm.CustomerID).Query(query).Projection("full").Do() + usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do() if err != nil { return nil, err } usersData := make([]*UserData, 0) for _, user := range usersList.Users { - userData, err := parseGoogleWorkspaceUser(user) - if err != nil { - return nil, err - } + userData := parseGoogleWorkspaceUser(user) + userData.AppMetadata.WTAccountID = accountID usersData = append(usersData, userData) } @@ -168,61 +143,16 @@ func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, erro indexedUsers := make(map[string][]*UserData) for _, user := range usersList.Users { - userData, err := parseGoogleWorkspaceUser(user) - if err != nil { - return nil, err - } - - accountID := userData.AppMetadata.WTAccountID - if accountID != "" { - if _, ok := indexedUsers[accountID]; !ok { - indexedUsers[accountID] = make([]*UserData, 0) - } - indexedUsers[accountID] = append(indexedUsers[accountID], userData) - } + userData := parseGoogleWorkspaceUser(user) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } return indexedUsers, nil } // CreateUser creates a new user in Google Workspace and sends an invitation. -func (gm *GoogleWorkspaceManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { - invite := true - metadata := AppMetadata{ - WTAccountID: accountID, - WTPendingInvite: &invite, - } - - username := &admin.UserName{} - fields := strings.Fields(name) - if n := len(fields); n > 0 { - username.GivenName = strings.Join(fields[:n-1], " ") - username.FamilyName = fields[n-1] - } - - payload, err := gm.helper.Marshal(metadata) - if err != nil { - return nil, err - } - - user := &admin.User{ - Name: username, - PrimaryEmail: email, - CustomSchemas: map[string]googleapi.RawMessage{ - "app_metadata": payload, - }, - Password: GeneratePassword(8, 1, 1, 1), - } - user, err = gm.usersService.Insert(user).Do() - if err != nil { - return nil, err - } - - if gm.appMetrics != nil { - gm.appMetrics.IDPMetrics().CountCreateUser() - } - - return parseGoogleWorkspaceUser(user) +func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) { + return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. @@ -237,13 +167,8 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err gm.appMetrics.IDPMetrics().CountGetUserByEmail() } - userData, err := parseGoogleWorkspaceUser(user) - if err != nil { - return nil, err - } - users := make([]*UserData, 0) - users = append(users, userData) + users = append(users, parseGoogleWorkspaceUser(user)) return users, nil } @@ -254,6 +179,19 @@ func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } +// DeleteUser from GoogleWorkspace. +func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { + if err := gm.usersService.Delete(userID).Do(); err != nil { + return err + } + + if gm.appMetrics != nil { + gm.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + // getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. // It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // If that fails, it falls back to using the default Google credentials path. @@ -268,8 +206,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) creds, err := google.CredentialsFromJSON( context.Background(), decodeKey, - admin.AdminDirectoryUserschemaScope, - admin.AdminDirectoryUserScope, + admin.AdminDirectoryUserReadonlyScope, ) if err == nil { // No need to fallback to the default Google credentials path @@ -281,8 +218,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) creds, err = google.FindDefaultCredentials( context.Background(), - admin.AdminDirectoryUserschemaScope, - admin.AdminDirectoryUserScope, + admin.AdminDirectoryUserReadonlyScope, ) if err != nil { return nil, err @@ -291,62 +227,11 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) return creds, nil } -// configureAppMetadataSchema create a custom schema for managing app metadata fields in Google Workspace. -func configureAppMetadataSchema(service *admin.Service, customerID string) error { - schemaList, err := service.Schemas.List(customerID).Do() - if err != nil { - return err - } - - // checks if app_metadata schema is already created - for _, schema := range schemaList.Schemas { - if schema.SchemaName == "app_metadata" { - return nil - } - } - - // create new app_metadata schema - appMetadataSchema := &admin.Schema{ - SchemaName: "app_metadata", - Fields: []*admin.SchemaFieldSpec{ - { - FieldName: "wt_account_id", - FieldType: "STRING", - MultiValued: false, - }, - { - FieldName: "wt_pending_invite", - FieldType: "BOOL", - MultiValued: false, - }, - }, - } - _, err = service.Schemas.Insert(customerID, appMetadataSchema).Do() - if err != nil { - return err - } - - return nil -} - // parseGoogleWorkspaceUser parse google user to UserData. -func parseGoogleWorkspaceUser(user *admin.User) (*UserData, error) { - var appMetadata AppMetadata - - // Get app metadata from custom schemas - if user.CustomSchemas != nil { - rawMessage := user.CustomSchemas["app_metadata"] - helper := JsonParser{} - - if err := helper.Unmarshal(rawMessage, &appMetadata); err != nil { - return nil, err - } - } - +func parseGoogleWorkspaceUser(user *admin.User) *UserData { return &UserData{ - ID: user.Id, - Email: user.PrimaryEmail, - Name: user.Name.FullName, - AppMetadata: appMetadata, - }, nil + ID: user.Id, + Email: user.PrimaryEmail, + Name: user.Name.FullName, + } } diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 3c1f4c327a4..7adb76f4044 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -9,6 +9,11 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) +const ( + // UnsetAccountID is a special key to map users without an account ID + UnsetAccountID = "unset" +) + // Manager idp manager interface type Manager interface { UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error @@ -18,6 +23,7 @@ type Manager interface { CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) GetUserByEmail(email string) ([]*UserData, error) InviteUserByID(userID string) error + DeleteUser(userID string) error } // ClientConfig defines common client configuration for all IdP manager @@ -37,10 +43,10 @@ type Config struct { ManagerType string ClientConfig *ClientConfig ExtraConfig ExtraConfig - Auth0ClientCredentials Auth0ClientConfig - AzureClientCredentials AzureClientConfig - KeycloakClientCredentials KeycloakClientConfig - ZitadelClientCredentials ZitadelClientConfig + Auth0ClientCredentials *Auth0ClientConfig + AzureClientCredentials *AzureClientConfig + KeycloakClientCredentials *KeycloakClientConfig + ZitadelClientCredentials *ZitadelClientConfig } // ManagerCredentials interface that authenticates using the credential of each type of idp @@ -96,7 +102,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) case "auth0": auth0ClientConfig := config.Auth0ClientCredentials if config.ClientConfig != nil { - auth0ClientConfig = Auth0ClientConfig{ + auth0ClientConfig = &Auth0ClientConfig{ Audience: config.ExtraConfig["Audience"], AuthIssuer: config.ClientConfig.Issuer, ClientID: config.ClientConfig.ClientID, @@ -105,11 +111,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewAuth0Manager(auth0ClientConfig, appMetrics) + return NewAuth0Manager(*auth0ClientConfig, appMetrics) case "azure": azureClientConfig := config.AzureClientCredentials if config.ClientConfig != nil { - azureClientConfig = AzureClientConfig{ + azureClientConfig = &AzureClientConfig{ ClientID: config.ClientConfig.ClientID, ClientSecret: config.ClientConfig.ClientSecret, GrantType: config.ClientConfig.GrantType, @@ -119,11 +125,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewAzureManager(azureClientConfig, appMetrics) + return NewAzureManager(*azureClientConfig, appMetrics) case "keycloak": keycloakClientConfig := config.KeycloakClientCredentials if config.ClientConfig != nil { - keycloakClientConfig = KeycloakClientConfig{ + keycloakClientConfig = &KeycloakClientConfig{ ClientID: config.ClientConfig.ClientID, ClientSecret: config.ClientConfig.ClientSecret, GrantType: config.ClientConfig.GrantType, @@ -132,11 +138,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewKeycloakManager(keycloakClientConfig, appMetrics) + return NewKeycloakManager(*keycloakClientConfig, appMetrics) case "zitadel": zitadelClientConfig := config.ZitadelClientCredentials if config.ClientConfig != nil { - zitadelClientConfig = ZitadelClientConfig{ + zitadelClientConfig = &ZitadelClientConfig{ ClientID: config.ClientConfig.ClientID, ClientSecret: config.ClientConfig.ClientSecret, GrantType: config.ClientConfig.GrantType, @@ -145,7 +151,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewZitadelManager(zitadelClientConfig, appMetrics) + return NewZitadelManager(*zitadelClientConfig, appMetrics) case "authentik": authentikConfig := AuthentikClientConfig{ Issuer: config.ClientConfig.Issuer, @@ -170,7 +176,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) CustomerID: config.ExtraConfig["CustomerId"], } return NewGoogleWorkspaceManager(googleClientConfig, appMetrics) - + case "jumpcloud": + jumpcloudConfig := JumpCloudClientConfig{ + APIToken: config.ExtraConfig["ApiToken"], + } + return NewJumpCloudManager(jumpcloudConfig, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/jumpcloud.go b/management/server/idp/jumpcloud.go new file mode 100644 index 00000000000..0115b404982 --- /dev/null +++ b/management/server/idp/jumpcloud.go @@ -0,0 +1,257 @@ +package idp + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + v1 "github.com/TheJumpCloud/jcapi-go/v1" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +const ( + contentType = "application/json" + accept = "application/json" +) + +// JumpCloudManager JumpCloud manager client instance. +type JumpCloudManager struct { + client *v1.APIClient + apiToken string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +// JumpCloudClientConfig JumpCloud manager client configurations. +type JumpCloudClientConfig struct { + APIToken string +} + +// JumpCloudCredentials JumpCloud authentication information. +type JumpCloudCredentials struct { + clientConfig JumpCloudClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + appMetrics telemetry.AppMetrics +} + +// NewJumpCloudManager creates a new instance of the JumpCloudManager. +func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppMetrics) (*JumpCloudManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + if config.APIToken == "" { + return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing") + } + + client := v1.NewAPIClient(v1.NewConfiguration()) + credentials := &JumpCloudCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &JumpCloudManager{ + client: client, + apiToken: config.APIToken, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +// Authenticate retrieves access token to use the JumpCloud user API. +func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) { + return JWTToken{}, nil +} + +func (jm *JumpCloudManager) authenticationContext() context.Context { + return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{ + Key: jm.apiToken, + }) +} + +// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. +func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { + return nil +} + +// GetUserDataByID requests user data from JumpCloud via ID. +func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { + authCtx := jm.authenticationContext() + user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) + } + + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + userData := parseJumpCloudUser(user) + userData.AppMetadata = appMetadata + + return userData, nil +} + +// GetAccount returns all the users for a given profile. +func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) { + authCtx := jm.authenticationContext() + userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode) + } + + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountGetAccount() + } + + users := make([]*UserData, 0) + for _, user := range userList.Results { + userData := parseJumpCloudUser(user) + userData.AppMetadata.WTAccountID = accountID + + users = append(users, userData) + } + + return users, nil +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// It returns a list of users indexed by accountID. +func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) { + authCtx := jm.authenticationContext() + userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) + } + + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + indexedUsers := make(map[string][]*UserData) + for _, user := range userList.Results { + userData := parseJumpCloudUser(user) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + } + + return indexedUsers, nil +} + +// CreateUser creates a new user in JumpCloud Idp and sends an invitation. +func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) { + return nil, fmt.Errorf("method CreateUser not implemented") +} + +// GetUserByEmail searches users with a given email. +// If no users have been found, this function returns an empty list. +func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) { + searchFilter := map[string]interface{}{ + "searchFilter": map[string]interface{}{ + "filter": []string{email}, + "fields": []string{"email"}, + }, + } + + authCtx := jm.authenticationContext() + userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode) + } + + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + usersData := make([]*UserData, 0) + for _, user := range userList.Results { + usersData = append(usersData, parseJumpCloudUser(user)) + } + + return usersData, nil +} + +// InviteUserByID resend invitations to users who haven't activated, +// their accounts prior to the expiration period. +func (jm *JumpCloudManager) InviteUserByID(_ string) error { + return fmt.Errorf("method InviteUserByID not implemented") +} + +// DeleteUser from jumpCloud directory +func (jm *JumpCloudManager) DeleteUser(userID string) error { + authCtx := jm.authenticationContext() + _, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) + } + + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + +// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData. +func parseJumpCloudUser(user v1.Systemuserreturn) *UserData { + names := []string{user.Firstname, user.Middlename, user.Lastname} + return &UserData{ + Email: user.Email, + Name: strings.Join(names, " "), + ID: user.Id, + } +} diff --git a/management/server/idp/jumpcloud_test.go b/management/server/idp/jumpcloud_test.go new file mode 100644 index 00000000000..1bfdcefcc70 --- /dev/null +++ b/management/server/idp/jumpcloud_test.go @@ -0,0 +1,46 @@ +package idp + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func TestNewJumpCloudManager(t *testing.T) { + type test struct { + name string + inputConfig JumpCloudClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := JumpCloudClientConfig{ + APIToken: "test123", + } + + testCase1 := test{ + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + testCase2Config := defaultTestConfig + testCase2Config.APIToken = "" + + testCase2 := test{ + name: "Missing APIToken Configuration", + inputConfig: testCase2Config, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + } + + for _, testCase := range []test{testCase1, testCase2} { + t.Run(testCase.name, func(t *testing.T) { + _, err := NewJumpCloudManager(testCase.inputConfig, &telemetry.MockAppMetrics{}) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + }) + } +} diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index 12ed8738955..3a6f80d0359 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -1,12 +1,10 @@ package idp import ( - "encoding/json" "fmt" "io" "net/http" "net/url" - "path" "strconv" "strings" "sync" @@ -18,11 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const ( - wtAccountID = "wt_account_id" - wtPendingInvite = "wt_pending_invite" -) - // KeycloakManager keycloak manager client instance. type KeycloakManager struct { adminEndpoint string @@ -51,28 +44,10 @@ type KeycloakCredentials struct { appMetrics telemetry.AppMetrics } -// keycloakUserCredential describe the authentication method for, -// newly created user profile. -type keycloakUserCredential struct { - Type string `json:"type"` - Value string `json:"value"` - Temporary bool `json:"temporary"` -} - // keycloakUserAttributes holds additional user data fields. type keycloakUserAttributes map[string][]string -// createUserRequest is a user create request. -type keycloakCreateUserRequest struct { - Email string `json:"email"` - Username string `json:"username"` - Enabled bool `json:"enabled"` - EmailVerified bool `json:"emailVerified"` - Credentials []keycloakUserCredential `json:"credentials"` - Attributes keycloakUserAttributes `json:"attributes"` -} - -// keycloakProfile represents an keycloak user profile response. +// keycloakProfile represents a keycloak user profile response. type keycloakProfile struct { ID string `json:"id"` CreatedTimestamp int64 `json:"createdTimestamp"` @@ -230,62 +205,8 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in keycloak Idp and sends an invite. -func (km *KeycloakManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { - jwtToken, err := km.credentials.Authenticate() - if err != nil { - return nil, err - } - - invite := true - appMetadata := AppMetadata{ - WTAccountID: accountID, - WTPendingInvite: &invite, - } - - payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata) - if err != nil { - return nil, err - } - - reqURL := fmt.Sprintf("%s/users", km.adminEndpoint) - payload := strings.NewReader(payloadString) - - req, err := http.NewRequest(http.MethodPost, reqURL, payload) - if err != nil { - return nil, err - } - req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) - req.Header.Add("content-type", "application/json") - - if km.appMetrics != nil { - km.appMetrics.IDPMetrics().CountCreateUser() - } - - resp, err := km.httpClient.Do(req) - if err != nil { - if km.appMetrics != nil { - km.appMetrics.IDPMetrics().CountRequestError() - } - - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusCreated { - if km.appMetrics != nil { - km.appMetrics.IDPMetrics().CountRequestStatusError() - } - - return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) - } - - locationHeader := resp.Header.Get("location") - userID, err := extractUserIDFromLocationHeader(locationHeader) - if err != nil { - return nil, err - } - - return km.GetUserDataByID(userID, appMetadata) +func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) { + return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. @@ -319,7 +240,7 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetUserDataByID requests user data from keycloak via ID. -func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) { body, err := km.get("users/"+userID, nil) if err != nil { return nil, err @@ -338,12 +259,9 @@ func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadat return profile.userData(), nil } -// GetAccount returns all the users for a given profile. +// GetAccount returns all the users for a given account profile. func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { - q := url.Values{} - q.Add("q", wtAccountID+":"+accountID) - - body, err := km.get("users", q) + profiles, err := km.fetchAllUserProfiles() if err != nil { return nil, err } @@ -352,15 +270,12 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { km.appMetrics.IDPMetrics().CountGetAccount() } - profiles := make([]keycloakProfile, 0) - err = km.helper.Unmarshal(body, &profiles) - if err != nil { - return nil, err - } - users := make([]*UserData, 0) for _, profile := range profiles { - users = append(users, profile.userData()) + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountID + + users = append(users, userData) } return users, nil @@ -369,15 +284,7 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { - totalUsers, err := km.totalUsersCount() - if err != nil { - return nil, err - } - - q := url.Values{} - q.Add("max", fmt.Sprint(*totalUsers)) - - body, err := km.get("users", q) + profiles, err := km.fetchAllUserProfiles() if err != nil { return nil, err } @@ -386,60 +293,44 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { km.appMetrics.IDPMetrics().CountGetAllAccounts() } - profiles := make([]keycloakProfile, 0) - err = km.helper.Unmarshal(body, &profiles) - if err != nil { - return nil, err - } - indexedUsers := make(map[string][]*UserData) for _, profile := range profiles { userData := profile.userData() - - accountID := userData.AppMetadata.WTAccountID - if accountID != "" { - if _, ok := indexedUsers[accountID]; !ok { - indexedUsers[accountID] = make([]*UserData, 0) - } - indexedUsers[accountID] = append(indexedUsers[accountID], userData) - } + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } return indexedUsers, nil } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { - jwtToken, err := km.credentials.Authenticate() - if err != nil { - return err - } +func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { + return nil +} - attrs := keycloakUserAttributes{} - attrs.Set(wtAccountID, appMetadata.WTAccountID) - if appMetadata.WTPendingInvite != nil { - attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite)) - } else { - attrs.Set(wtPendingInvite, "false") - } +// InviteUserByID resend invitations to users who haven't activated, +// their accounts prior to the expiration period. +func (km *KeycloakManager) InviteUserByID(_ string) error { + return fmt.Errorf("method InviteUserByID not implemented") +} - reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID) - data, err := km.helper.Marshal(map[string]any{ - "attributes": attrs, - }) +// DeleteUser from Keycloak by user ID. +func (km *KeycloakManager) DeleteUser(userID string) error { + jwtToken, err := km.credentials.Authenticate() if err != nil { return err } - payload := strings.NewReader(string(data)) - req, err := http.NewRequest(http.MethodPut, reqURL, payload) + reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID)) + req, err := http.NewRequest(http.MethodDelete, reqURL, nil) if err != nil { return err } req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("content-type", "application/json") - log.Debugf("updating IdP metadata for user %s", userID) + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountDeleteUser() + } resp, err := km.httpClient.Do(req) if err != nil { @@ -448,51 +339,41 @@ func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppM } return err } - defer resp.Body.Close() + defer resp.Body.Close() // nolint - if km.appMetrics != nil { - km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() - } + // In the docs, they specified 200, but in the endpoints, they return 204 + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestStatusError() + } - if resp.StatusCode != http.StatusNoContent { - return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode) + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) } return nil } -// InviteUserByID resend invitations to users who haven't activated, -// their accounts prior to the expiration period. -func (km *KeycloakManager) InviteUserByID(_ string) error { - return fmt.Errorf("method InviteUserByID not implemented") -} +func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { + totalUsers, err := km.totalUsersCount() + if err != nil { + return nil, err + } -func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { - attrs := keycloakUserAttributes{} - attrs.Set(wtAccountID, appMetadata.WTAccountID) - attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite)) - - req := &keycloakCreateUserRequest{ - Email: email, - Username: name, - Enabled: true, - EmailVerified: true, - Credentials: []keycloakUserCredential{ - { - Type: "password", - Value: GeneratePassword(8, 1, 1, 1), - Temporary: false, - }, - }, - Attributes: attrs, - } - - str, err := json.Marshal(req) + q := url.Values{} + q.Add("max", fmt.Sprint(*totalUsers)) + + body, err := km.get("users", q) if err != nil { - return "", err + return nil, err } - return string(str), nil + profiles := make([]keycloakProfile, 0) + err = km.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + return profiles, nil } // get perform Get requests. @@ -547,53 +428,11 @@ func (km *KeycloakManager) totalUsersCount() (*int, error) { return &count, nil } -// extractUserIDFromLocationHeader extracts the user ID from the location, -// header once the user is created successfully -func extractUserIDFromLocationHeader(locationHeader string) (string, error) { - userURL, err := url.Parse(locationHeader) - if err != nil { - return "", err - } - - return path.Base(userURL.Path), nil -} - // userData construct user data from keycloak profile. func (kp keycloakProfile) userData() *UserData { - accountID := kp.Attributes.Get(wtAccountID) - pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite)) - if err != nil { - pendingInvite = false - } - return &UserData{ Email: kp.Email, Name: kp.Username, ID: kp.ID, - AppMetadata: AppMetadata{ - WTAccountID: accountID, - WTPendingInvite: &pendingInvite, - }, - } -} - -// Set sets the key to value. It replaces any existing -// values. -func (ka keycloakUserAttributes) Set(key, value string) { - ka[key] = []string{value} -} - -// Get returns the first value associated with the given key. -// If there are no values associated with the key, Get returns -// the empty string. -func (ka keycloakUserAttributes) Get(key string) string { - if ka == nil { - return "" - } - - values := ka[key] - if len(values) == 0 { - return "" } - return values[0] } diff --git a/management/server/idp/keycloak_test.go b/management/server/idp/keycloak_test.go index 0c33fc13754..9b6c1d3c63c 100644 --- a/management/server/idp/keycloak_test.go +++ b/management/server/idp/keycloak_test.go @@ -84,15 +84,6 @@ func TestNewKeycloakManager(t *testing.T) { } } -type mockKeycloakCredentials struct { - jwtToken JWTToken - err error -} - -func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) { - return mc.jwtToken, mc.err -} - func TestKeycloakRequestJWTToken(t *testing.T) { type requestJWTTokenTest struct { @@ -316,108 +307,3 @@ func TestKeycloakAuthenticate(t *testing.T) { }) } } - -func TestKeycloakUpdateUserAppMetadata(t *testing.T) { - type updateUserAppMetadataTest struct { - name string - inputReqBody string - expectedReqBody string - appMetadata AppMetadata - statusCode int - helper ManagerHelper - managerCreds ManagerCredentials - assertErrFunc assert.ErrorAssertionFunc - assertErrFuncMessage string - } - - appMetadata := AppMetadata{WTAccountID: "ok"} - - updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{ - name: "Bad Authentication", - expectedReqBody: "", - appMetadata: appMetadata, - statusCode: 400, - helper: JsonParser{}, - managerCreds: &mockKeycloakCredentials{ - jwtToken: JWTToken{}, - err: fmt.Errorf("error"), - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ - name: "Bad Status Code", - expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID), - appMetadata: appMetadata, - statusCode: 400, - helper: JsonParser{}, - managerCreds: &mockKeycloakCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{ - name: "Bad Response Parsing", - statusCode: 400, - helper: &mockJsonParser{marshalErrorString: "error"}, - managerCreds: &mockKeycloakCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ - name: "Good request", - expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID), - appMetadata: appMetadata, - statusCode: 204, - helper: JsonParser{}, - managerCreds: &mockKeycloakCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.NoError, - assertErrFuncMessage: "shouldn't return error", - } - - invite := true - updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{ - name: "Update Pending Invite", - expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID), - appMetadata: AppMetadata{ - WTAccountID: "ok", - WTPendingInvite: &invite, - }, - statusCode: 204, - helper: JsonParser{}, - managerCreds: &mockKeycloakCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.NoError, - assertErrFuncMessage: "shouldn't return error", - } - - for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, - updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} { - t.Run(testCase.name, func(t *testing.T) { - reqClient := mockHTTPClient{ - resBody: testCase.inputReqBody, - code: testCase.statusCode, - } - - manager := &KeycloakManager{ - httpClient: &reqClient, - credentials: testCase.managerCreds, - helper: testCase.helper, - } - - err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) - testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) - - assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match") - }) - } -} diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index c6b5055d42f..3e7b9357ec6 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -8,9 +8,9 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/okta/okta-sdk-golang/v2/okta" - "github.com/okta/okta-sdk-golang/v2/okta/query" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // OktaManager okta manager client instance. @@ -76,11 +76,6 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (* return nil, err } - err = updateUserProfileSchema(client) - if err != nil { - return nil, err - } - credentials := &OktaCredentials{ clientConfig: config, httpClient: httpClient, @@ -103,49 +98,8 @@ func (oc *OktaCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in okta Idp and sends an invitation. -func (om *OktaManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { - var ( - sendEmail = true - activate = true - userProfile = okta.UserProfile{ - "email": email, - "login": email, - wtAccountID: accountID, - wtPendingInvite: true, - } - ) - - fields := strings.Fields(name) - if n := len(fields); n > 0 { - userProfile["firstName"] = strings.Join(fields[:n-1], " ") - userProfile["lastName"] = fields[n-1] - } - - user, resp, err := om.client.User.CreateUser(context.Background(), - okta.CreateUserRequest{ - Profile: &userProfile, - }, - &query.Params{ - Activate: &activate, - SendEmail: &sendEmail, - }, - ) - if err != nil { - return nil, err - } - - if om.appMetrics != nil { - om.appMetrics.IDPMetrics().CountCreateUser() - } - - if resp.StatusCode != http.StatusOK { - if om.appMetrics != nil { - om.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) - } - - return parseOktaUser(user) +func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) { + return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserDataByID requests user data from keycloak via ID. @@ -166,7 +120,13 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) ( return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) } - return parseOktaUser(user) + userData, err := parseOktaUser(user) + if err != nil { + return nil, err + } + userData.AppMetadata = appMetadata + + return userData, nil } // GetUserByEmail searches users with a given email. @@ -200,8 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { // GetAccount returns all the users for a given profile. func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { - search := fmt.Sprintf("profile.wt_account_id eq %q", accountID) - users, resp, err := om.client.User.ListUsers(context.Background(), &query.Params{Search: search}) + users, resp, err := om.client.User.ListUsers(context.Background(), nil) if err != nil { return nil, err } @@ -223,6 +182,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { if err != nil { return nil, err } + userData.AppMetadata.WTAccountID = accountID list = append(list, userData) } @@ -256,13 +216,7 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { return nil, err } - accountID := userData.AppMetadata.WTAccountID - if accountID != "" { - if _, ok := indexedUsers[accountID]; !ok { - indexedUsers[accountID] = make([]*UserData, 0) - } - indexedUsers[accountID] = append(indexedUsers[accountID], userData) - } + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } return indexedUsers, nil @@ -270,46 +224,6 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { - user, resp, err := om.client.User.GetUser(context.Background(), userID) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - if om.appMetrics != nil { - om.appMetrics.IDPMetrics().CountRequestStatusError() - } - return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode) - } - - profile := *user.Profile - - if appMetadata.WTPendingInvite != nil { - profile[wtPendingInvite] = *appMetadata.WTPendingInvite - } - - if appMetadata.WTAccountID != "" { - profile[wtAccountID] = appMetadata.WTAccountID - } - - user.Profile = &profile - _, resp, err = om.client.User.UpdateUser(context.Background(), userID, *user, nil) - if err != nil { - fmt.Println(err.Error()) - return err - } - - if om.appMetrics != nil { - om.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() - } - - if resp.StatusCode != http.StatusOK { - if om.appMetrics != nil { - om.appMetrics.IDPMetrics().CountRequestStatusError() - } - return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode) - } - return nil } @@ -319,47 +233,23 @@ func (om *OktaManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } -// updateUserProfileSchema updates the Okta user schema to include custom fields, -// wt_account_id and wt_pending_invite. -func updateUserProfileSchema(client *okta.Client) error { - // Ensure Okta doesn't enforce user input for these fields, as they are solely used by Netbird - userPermissions := []*okta.UserSchemaAttributePermission{{Action: "HIDE", Principal: "SELF"}} - - _, resp, err := client.UserSchema.UpdateUserProfile( - context.Background(), - "default", - okta.UserSchema{ - Definitions: &okta.UserSchemaDefinitions{ - Custom: &okta.UserSchemaPublic{ - Id: "#custom", - Type: "object", - Properties: map[string]*okta.UserSchemaAttribute{ - wtAccountID: { - MaxLength: 100, - MinLength: 1, - Required: new(bool), - Scope: "NONE", - Title: "Wt Account Id", - Type: "string", - Permissions: userPermissions, - }, - wtPendingInvite: { - Required: new(bool), - Scope: "NONE", - Title: "Wt Pending Invite", - Type: "boolean", - Permissions: userPermissions, - }, - }, - }, - }, - }) +// DeleteUser from Okta +func (om *OktaManager) DeleteUser(userID string) error { + resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil) if err != nil { + fmt.Println(err.Error()) return err } + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountDeleteUser() + } + if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unable to update user profile schema, statusCode %d", resp.StatusCode) + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountRequestStatusError() + } + return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) } return nil @@ -368,11 +258,9 @@ func updateUserProfileSchema(client *okta.Client) error { // parseOktaUserToUserData parse okta user to UserData. func parseOktaUser(user *okta.User) (*UserData, error) { var oktaUser struct { - Email string `json:"email"` - FirstName string `json:"firstName"` - LastName string `json:"lastName"` - AccountID string `json:"wt_account_id"` - PendingInvite bool `json:"wt_pending_invite"` + Email string `json:"email"` + FirstName string `json:"firstName"` + LastName string `json:"lastName"` } if user == nil { @@ -396,9 +284,5 @@ func parseOktaUser(user *okta.User) (*UserData, error) { Email: oktaUser.Email, Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "), ID: user.Id, - AppMetadata: AppMetadata{ - WTAccountID: oktaUser.AccountID, - WTPendingInvite: &oktaUser.PendingInvite, - }, }, nil } diff --git a/management/server/idp/okta_test.go b/management/server/idp/okta_test.go index 02c28b3aefb..20df246f82a 100644 --- a/management/server/idp/okta_test.go +++ b/management/server/idp/okta_test.go @@ -1,31 +1,28 @@ package idp import ( + "testing" + "github.com/okta/okta-sdk-golang/v2/okta" "github.com/stretchr/testify/assert" - "testing" ) func TestParseOktaUser(t *testing.T) { type parseOktaUserTest struct { name string - invite bool inputProfile *okta.User expectedUserData *UserData assertErrFunc assert.ErrorAssertionFunc } parseOktaTestCase1 := parseOktaUserTest{ - name: "Good Request", - invite: true, + name: "Good Request", inputProfile: &okta.User{ Id: "123", Profile: &okta.UserProfile{ - "email": "test@example.com", - "firstName": "John", - "lastName": "Doe", - "wt_account_id": "456", - "wt_pending_invite": true, + "email": "test@example.com", + "firstName": "John", + "lastName": "Doe", }, }, expectedUserData: &UserData{ @@ -41,36 +38,17 @@ func TestParseOktaUser(t *testing.T) { parseOktaTestCase2 := parseOktaUserTest{ name: "Invalid okta user", - invite: true, inputProfile: nil, expectedUserData: nil, assertErrFunc: assert.Error, } - parseOktaTestCase3 := parseOktaUserTest{ - name: "Invalid pending invite type", - invite: false, - inputProfile: &okta.User{ - Id: "123", - Profile: &okta.UserProfile{ - "email": "test@example.com", - "firstName": "John", - "lastName": "Doe", - "wt_account_id": "456", - "wt_pending_invite": "true", - }, - }, - expectedUserData: nil, - assertErrFunc: assert.Error, - } - - for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2, parseOktaTestCase3} { + for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} { t.Run(testCase.name, func(t *testing.T) { userData, err := parseOktaUser(testCase.inputProfile) testCase.assertErrFunc(t, err, testCase.assertErrFunc) if err == nil { - testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match") } }) @@ -83,13 +61,5 @@ func userDataEqual(a, b *UserData) bool { if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID { return false } - if a.AppMetadata.WTAccountID != b.AppMetadata.WTAccountID { - return false - } - - if a.AppMetadata.WTPendingInvite != nil && b.AppMetadata.WTPendingInvite != nil && - *a.AppMetadata.WTPendingInvite != *b.AppMetadata.WTPendingInvite { - return false - } return true } diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index fce2c7b379c..5325e51bebe 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -1,20 +1,18 @@ package idp import ( - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" "net/url" - "strconv" "strings" "sync" "time" "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/telemetry" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // ZitadelManager zitadel manager client instance. @@ -67,12 +65,6 @@ type zitadelUser struct { type zitadelAttributes map[string][]map[string]any -// zitadelMetadata holds additional user data. -type zitadelMetadata struct { - Key string `json:"key"` - Value string `json:"value"` -} - // zitadelProfile represents an zitadel user profile response. type zitadelProfile struct { ID string `json:"id"` @@ -81,7 +73,6 @@ type zitadelProfile struct { PreferredLoginName string `json:"preferredLoginName"` LoginNames []string `json:"loginNames"` Human *zitadelUser `json:"human"` - Metadata []zitadelMetadata } // NewZitadelManager creates a new instance of the ZitadelManager. @@ -234,42 +225,8 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in zitadel Idp and sends an invite. -func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { - payload, err := buildZitadelCreateUserRequestPayload(email, name) - if err != nil { - return nil, err - } - - body, err := zm.post("users/human/_import", payload) - if err != nil { - return nil, err - } - - if zm.appMetrics != nil { - zm.appMetrics.IDPMetrics().CountCreateUser() - } - - var result struct { - UserID string `json:"userId"` - } - err = zm.helper.Unmarshal(body, &result) - if err != nil { - return nil, err - } - - invite := true - appMetadata := AppMetadata{ - WTAccountID: accountID, - WTPendingInvite: &invite, - } - - // Add metadata to new user - err = zm.UpdateUserAppMetadata(result.UserID, appMetadata) - if err != nil { - return nil, err - } - - return zm.GetUserDataByID(result.UserID, appMetadata) +func (zm *ZitadelManager) CreateUser(_, _, _, _ string) (*UserData, error) { + return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. @@ -307,12 +264,6 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { users := make([]*UserData, 0) for _, profile := range profiles.Result { - metadata, err := zm.getUserMetadata(profile.ID) - if err != nil { - return nil, err - } - profile.Metadata = metadata - users = append(users, profile.userData()) } @@ -336,18 +287,15 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata return nil, err } - metadata, err := zm.getUserMetadata(userID) - if err != nil { - return nil, err - } - profile.User.Metadata = metadata + userData := profile.User.userData() + userData.AppMetadata = appMetadata - return profile.User.userData(), nil + return userData, nil } // GetAccount returns all the users for a given profile. func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { - accounts, err := zm.GetAllAccounts() + body, err := zm.post("users/_search", "") if err != nil { return nil, err } @@ -356,7 +304,21 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { zm.appMetrics.IDPMetrics().CountGetAccount() } - return accounts[accountID], nil + var profiles struct{ Result []zitadelProfile } + err = zm.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles.Result { + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountID + + users = append(users, userData) + } + + return users, nil } // GetAllAccounts gets all registered accounts with corresponding user data. @@ -379,22 +341,8 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) { indexedUsers := make(map[string][]*UserData) for _, profile := range profiles.Result { - // fetch user metadata - metadata, err := zm.getUserMetadata(profile.ID) - if err != nil { - return nil, err - } - profile.Metadata = metadata - userData := profile.userData() - accountID := userData.AppMetadata.WTAccountID - - if accountID != "" { - if _, ok := indexedUsers[accountID]; !ok { - indexedUsers[accountID] = make([]*UserData, 0) - } - indexedUsers[accountID] = append(indexedUsers[accountID], userData) - } + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } return indexedUsers, nil @@ -402,42 +350,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) { // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // Metadata values are base64 encoded. -func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { - if appMetadata.WTPendingInvite == nil { - appMetadata.WTPendingInvite = new(bool) - } - pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite) - - wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID)) - wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf) - - metadata := zitadelAttributes{ - "metadata": { - { - "key": wtAccountID, - "value": wtAccountIDValue, - }, - { - "key": wtPendingInvite, - "value": wtPendingInviteValue, - }, - }, - } - payload, err := zm.helper.Marshal(metadata) - if err != nil { - return err - } - - resource := fmt.Sprintf("users/%s/metadata/_bulk", userID) - _, err = zm.post(resource, string(payload)) - if err != nil { - return err - } - - if zm.appMetrics != nil { - zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() - } - +func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { return nil } @@ -447,21 +360,18 @@ func (zm *ZitadelManager) InviteUserByID(_ string) error { return fmt.Errorf("method InviteUserByID not implemented") } -// getUserMetadata requests user metadata from zitadel via ID. -func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) { - resource := fmt.Sprintf("users/%s/metadata/_search", userID) - body, err := zm.post(resource, "") - if err != nil { - return nil, err +// DeleteUser from Zitadel +func (zm *ZitadelManager) DeleteUser(userID string) error { + resource := fmt.Sprintf("users/%s", userID) + if err := zm.delete(resource); err != nil { + return err } - var metadata struct{ Result []zitadelMetadata } - err = zm.helper.Unmarshal(body, &metadata) - if err != nil { - return nil, err + if zm.appMetrics != nil { + zm.appMetrics.IDPMetrics().CountDeleteUser() } - return metadata.Result, nil + return nil } // post perform Post requests. @@ -500,6 +410,11 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { return io.ReadAll(resp.Body) } +// delete perform Delete requests. +func (zm *ZitadelManager) delete(_ string) error { + return nil +} + // get perform Get requests. func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { jwtToken, err := zm.credentials.Authenticate() @@ -536,38 +451,13 @@ func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { return io.ReadAll(resp.Body) } -// value returns string represented by the base64 string value. -func (zm zitadelMetadata) value() string { - value, err := base64.StdEncoding.DecodeString(zm.Value) - if err != nil { - return "" - } - - return string(value) -} - // userData construct user data from zitadel profile. func (zp zitadelProfile) userData() *UserData { var ( - email string - name string - wtAccountIDValue string - wtPendingInviteValue bool + email string + name string ) - for _, metadata := range zp.Metadata { - if metadata.Key == wtAccountID { - wtAccountIDValue = metadata.value() - } - - if metadata.Key == wtPendingInvite { - value, err := strconv.ParseBool(metadata.value()) - if err == nil { - wtPendingInviteValue = value - } - } - } - // Obtain the email for the human account and the login name, // for the machine account. if zp.Human != nil { @@ -584,39 +474,5 @@ func (zp zitadelProfile) userData() *UserData { Email: email, Name: name, ID: zp.ID, - AppMetadata: AppMetadata{ - WTAccountID: wtAccountIDValue, - WTPendingInvite: &wtPendingInviteValue, - }, } } - -func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) { - var firstName, lastName string - - words := strings.Fields(name) - if n := len(words); n > 0 { - firstName = strings.Join(words[:n-1], " ") - lastName = words[n-1] - } - - req := &zitadelUser{ - UserName: name, - Profile: zitadelUserInfo{ - FirstName: strings.TrimSpace(firstName), - LastName: strings.TrimSpace(lastName), - DisplayName: name, - }, - Email: zitadelEmail{ - Email: email, - IsEmailVerified: false, - }, - } - - str, err := json.Marshal(req) - if err != nil { - return "", err - } - - return string(str), nil -} diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index b558bba733d..9a771b36a7c 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -7,9 +7,10 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" ) func TestNewZitadelManager(t *testing.T) { @@ -63,15 +64,6 @@ func TestNewZitadelManager(t *testing.T) { } } -type mockZitadelCredentials struct { - jwtToken JWTToken - err error -} - -func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) { - return mc.jwtToken, mc.err -} - func TestZitadelRequestJWTToken(t *testing.T) { type requestJWTTokenTest struct { @@ -296,98 +288,6 @@ func TestZitadelAuthenticate(t *testing.T) { } } -func TestZitadelUpdateUserAppMetadata(t *testing.T) { - type updateUserAppMetadataTest struct { - name string - inputReqBody string - expectedReqBody string - appMetadata AppMetadata - statusCode int - helper ManagerHelper - managerCreds ManagerCredentials - assertErrFunc assert.ErrorAssertionFunc - assertErrFuncMessage string - } - - appMetadata := AppMetadata{WTAccountID: "ok"} - - updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{ - name: "Bad Authentication", - expectedReqBody: "", - appMetadata: appMetadata, - statusCode: 400, - helper: JsonParser{}, - managerCreds: &mockZitadelCredentials{ - jwtToken: JWTToken{}, - err: fmt.Errorf("error"), - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ - name: "Bad Response Parsing", - statusCode: 400, - helper: &mockJsonParser{marshalErrorString: "error"}, - managerCreds: &mockZitadelCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.Error, - assertErrFuncMessage: "should return error", - } - - updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{ - name: "Good request", - expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}", - appMetadata: appMetadata, - statusCode: 200, - helper: JsonParser{}, - managerCreds: &mockZitadelCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.NoError, - assertErrFuncMessage: "shouldn't return error", - } - - invite := true - updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ - name: "Update Pending Invite", - expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}", - appMetadata: AppMetadata{ - WTAccountID: "ok", - WTPendingInvite: &invite, - }, - statusCode: 200, - helper: JsonParser{}, - managerCreds: &mockZitadelCredentials{ - jwtToken: JWTToken{}, - }, - assertErrFunc: assert.NoError, - assertErrFuncMessage: "shouldn't return error", - } - - for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, - updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} { - t.Run(testCase.name, func(t *testing.T) { - reqClient := mockHTTPClient{ - resBody: testCase.inputReqBody, - code: testCase.statusCode, - } - - manager := &ZitadelManager{ - httpClient: &reqClient, - credentials: testCase.managerCreds, - helper: testCase.helper, - } - - err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) - testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) - - assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match") - }) - } -} - func TestZitadelProfile(t *testing.T) { type azureProfileTest struct { name string @@ -418,16 +318,6 @@ func TestZitadelProfile(t *testing.T) { IsEmailVerified: true, }, }, - Metadata: []zitadelMetadata{ - { - Key: "wt_account_id", - Value: "MQ==", - }, - { - Key: "wt_pending_invite", - Value: "ZmFsc2U=", - }, - }, }, expectedUserData: UserData{ ID: "test1", @@ -451,16 +341,6 @@ func TestZitadelProfile(t *testing.T) { "machine", }, Human: nil, - Metadata: []zitadelMetadata{ - { - Key: "wt_account_id", - Value: "MQ==", - }, - { - Key: "wt_pending_invite", - Value: "dHJ1ZQ==", - }, - }, }, expectedUserData: UserData{ ID: "test2", @@ -480,8 +360,6 @@ func TestZitadelProfile(t *testing.T) { assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match") assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match") assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match") - assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match") - assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match") }) } } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 66661dbf873..b4a527e463d 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 6c93765f438..fa35cfdef4e 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager() eventStore := &activity.InMemoryEventStore{} accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", - eventStore) + eventStore, false) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 696df5f3cc1..3b3db0baa87 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -176,6 +176,7 @@ func (w *Worker) generateProperties() properties { rulesDirection map[string]int groups int routes int + routesWithRGGroups int nameservers int uiClient int version string @@ -201,6 +202,11 @@ func (w *Worker) generateProperties() properties { groups = groups + len(account.Groups) routes = routes + len(account.Routes) + for _, route := range account.Routes { + if len(route.PeerGroups) > 0 { + routesWithRGGroups++ + } + } nameservers = nameservers + len(account.NameServerGroups) for _, policy := range account.Policies { @@ -282,6 +288,7 @@ func (w *Worker) generateProperties() properties { metricsProperties["rules"] = rules metricsProperties["groups"] = groups metricsProperties["routes"] = routes + metricsProperties["routes_with_routing_groups"] = routesWithRGGroups metricsProperties["nameservers"] = nameservers metricsProperties["version"] = version metricsProperties["min_active_peer_version"] = minActivePeerVersion diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go new file mode 100644 index 00000000000..c61613fd26f --- /dev/null +++ b/management/server/metrics/selfhosted_test.go @@ -0,0 +1,239 @@ +package metrics + +import ( + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/route" +) + +type mockDatasource struct{} + +// GetAllConnectedPeers returns a map of connected peer IDs for use in tests with predefined information +func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { + return map[string]struct{}{ + "1": {}, + } +} + +// GetAllAccounts returns a list of *server.Account for use in tests with predefined information +func (mockDatasource) GetAllAccounts() []*server.Account { + return []*server.Account{ + { + Id: "1", + Settings: &server.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*server.SetupKey{ + "1": { + Id: "1", + Ephemeral: true, + UsedTimes: 1, + }, + }, + Groups: map[string]*server.Group{ + "1": {}, + "2": {}, + }, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "1": {}, + }, + Peers: map[string]*server.Peer{ + "1": { + ID: "1", + UserID: "test", + SSHEnabled: true, + Meta: server.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, + }, + }, + Policies: []*server.Policy{ + { + Rules: []*server.PolicyRule{ + { + Bidirectional: true, + Protocol: server.PolicyRuleProtocolTCP, + }, + }, + }, + { + Rules: []*server.PolicyRule{ + { + Bidirectional: false, + Protocol: server.PolicyRuleProtocolTCP, + }, + }, + }, + }, + Routes: map[string]*route.Route{ + "1": { + ID: "1", + PeerGroups: make([]string, 1), + }, + }, + Users: map[string]*server.User{ + "1": { + IsServiceUser: true, + PATs: map[string]*server.PersonalAccessToken{ + "1": {}, + }, + }, + "2": { + IsServiceUser: false, + PATs: map[string]*server.PersonalAccessToken{ + "1": {}, + }, + }, + }, + }, + { + Id: "2", + Settings: &server.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*server.SetupKey{ + "1": { + Id: "1", + Ephemeral: true, + UsedTimes: 1, + }, + }, + Groups: map[string]*server.Group{ + "1": {}, + "2": {}, + }, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "1": {}, + }, + Peers: map[string]*server.Peer{ + "1": { + ID: "1", + UserID: "test", + SSHEnabled: true, + Meta: server.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, + }, + }, + Policies: []*server.Policy{ + { + Rules: []*server.PolicyRule{ + { + Bidirectional: true, + Protocol: server.PolicyRuleProtocolTCP, + }, + }, + }, + { + Rules: []*server.PolicyRule{ + { + Bidirectional: false, + Protocol: server.PolicyRuleProtocolTCP, + }, + }, + }, + }, + Routes: map[string]*route.Route{ + "1": { + ID: "1", + PeerGroups: make([]string, 1), + }, + }, + Users: map[string]*server.User{ + "1": { + IsServiceUser: true, + PATs: map[string]*server.PersonalAccessToken{ + "1": {}, + }, + }, + "2": { + IsServiceUser: false, + PATs: map[string]*server.PersonalAccessToken{ + "1": {}, + }, + }, + }, + }, + } +} + +// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties +func TestGenerateProperties(t *testing.T) { + ds := mockDatasource{} + worker := Worker{ + dataSource: ds, + connManager: ds, + } + + properties := worker.generateProperties() + + if properties["accounts"] != 2 { + t.Errorf("expected 2 accounts, got %d", properties["accounts"]) + } + if properties["peers"] != 2 { + t.Errorf("expected 2 peers, got %d", properties["peers"]) + } + if properties["routes"] != 2 { + t.Errorf("expected 2 routes, got %d", properties["routes"]) + } + if properties["rules"] != 4 { + t.Errorf("expected 4 rules, got %d", properties["rules"]) + } + if properties["users"] != 2 { + t.Errorf("expected 1 users, got %d", properties["users"]) + } + if properties["setup_keys_usage"] != 2 { + t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"]) + } + if properties["pats"] != 4 { + t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"]) + } + if properties["peers_ssh_enabled"] != 2 { + t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"]) + } + if properties["routes_with_routing_groups"] != 2 { + t.Errorf("expected 2 routes_with_routing_groups, got %d", properties["routes_with_routing_groups"]) + } + if properties["rules_protocol_tcp"] != 4 { + t.Errorf("expected 4 rules_protocol_tcp, got %d", properties["rules_protocol_tcp"]) + } + if properties["rules_direction_oneway"] != 2 { + t.Errorf("expected 2 rules_direction_oneway, got %d", properties["rules_direction_oneway"]) + } + + if properties["active_peers_last_day"] != 2 { + t.Errorf("expected 2 active_peers_last_day, got %d", properties["active_peers_last_day"]) + } + if properties["min_active_peer_version"] != "0.0.1" { + t.Errorf("expected 0.0.1 min_active_peer_version, got %s", properties["min_active_peer_version"]) + } + if properties["max_active_peer_version"] != "0.0.1" { + t.Errorf("expected 0.0.1 max_active_peer_version, got %s", properties["max_active_peer_version"]) + } + + if properties["peers_login_expiration_enabled"] != 2 { + t.Errorf("expected 2 peers_login_expiration_enabled, got %d", properties["peers_login_expiration_enabled"]) + } + + if properties["service_users"] != 2 { + t.Errorf("expected 2 service_users, got %d", properties["service_users"]) + } + + if properties["peer_os_linux"] != 2 { + t.Errorf("expected 2 peer_os_linux, got %d", properties["peer_os_linux"]) + } + + if properties["ephemeral_peers_setup_keys"] != 2 { + t.Errorf("expected 2 ephemeral_peers_setup_keys, got %d", properties["ephemeral_peers_setup_keys_usage"]) + } + + if properties["ephemeral_peers_setup_keys_usage"] != 2 { + t.Errorf("expected 2 ephemeral_peers_setup_keys_usage, got %d", properties["ephemeral_peers_setup_keys_usage"]) + } + + if properties["nameservers"] != 2 { + t.Errorf("expected 2 nameservers, got %d", properties["nameservers"]) + } + + if properties["groups"] != 4 { + t.Errorf("expected 4 groups, got %d", properties["groups"]) + } + + if properties["user_peers"] != 2 { + t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4bfa922c70f..5432b201bfb 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -20,23 +20,18 @@ type MockAccountManager struct { GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) - AccountExistsFunc func(accountId string) (*bool, error) - GetPeerByKeyFunc func(peerKey string) (*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) MarkPeerConnectedFunc func(peerKey string, connected bool) error - DeletePeerFunc func(accountID, peerKey, userID string) (*server.Peer, error) - GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) + DeletePeerFunc func(accountID, peerKey, userID string) error GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error) GetGroupFunc func(accountID, groupID string) (*server.Group, error) SaveGroupFunc func(accountID, userID string, group *server.Group) error - UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) DeleteGroupFunc func(accountID, userId, groupID string) error ListGroupsFunc func(accountID string) ([]*server.Group, error) - GroupAddPeerFunc func(accountID, groupID, peerKey string) error - GroupDeletePeerFunc func(accountID, groupID, peerKey string) error - GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) + GroupAddPeerFunc func(accountID, groupID, peerID string) error + GroupDeletePeerFunc func(accountID, groupID, peerID string) error GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error) SaveRuleFunc func(accountID, userID string, rule *server.Rule) error DeleteRuleFunc func(accountID, ruleID, userID string) error @@ -51,10 +46,9 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) - CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) + CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) SaveRouteFunc func(accountID, userID string, route *route.Route) error - UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) DeleteRouteFunc func(accountID, routeID, userID string) error ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) @@ -68,7 +62,6 @@ type MockAccountManager struct { GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - UpdateNameServerGroupFunc func(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) @@ -93,11 +86,11 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID strin } // DeletePeer mock implementation of DeletePeer from server.AccountManager interface -func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) (*server.Peer, error) { +func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) error { if am.DeletePeerFunc != nil { return am.DeletePeerFunc(accountID, peerID, userID) } - return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") + return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") } // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface @@ -143,22 +136,6 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID( ) } -// AccountExists mock implementation of AccountExists from server.AccountManager interface -func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { - if am.AccountExistsFunc != nil { - return am.AccountExistsFunc(accountId) - } - return nil, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") -} - -// GetPeerByKey mocks implementation of GetPeerByKey from server.AccountManager interface -func (am *MockAccountManager) GetPeerByKey(peerKey string) (*server.Peer, error) { - if am.GetPeerByKeyFunc != nil { - return am.GetPeerByKeyFunc(peerKey) - } - return nil, status.Errorf(codes.Unimplemented, "method GetPeerByKey is not implemented") -} - // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error { if am.MarkPeerConnectedFunc != nil { @@ -167,14 +144,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface -func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) { - if am.GetPeerByIPFunc != nil { - return am.GetPeerByIPFunc(accountId, peerIP) - } - return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP is not implemented") -} - // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { @@ -267,14 +236,6 @@ func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server. return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") } -// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface -func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { - if am.UpdateGroupFunc != nil { - return am.UpdateGroupFunc(accountID, groupID, operations) - } - return nil, status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented") -} - // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error { if am.DeleteGroupFunc != nil { @@ -292,29 +253,21 @@ func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, err } // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface -func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { +func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { - return am.GroupAddPeerFunc(accountID, groupID, peerKey) + return am.GroupAddPeerFunc(accountID, groupID, peerID) } return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented") } // GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface -func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { +func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { if am.GroupDeletePeerFunc != nil { - return am.GroupDeletePeerFunc(accountID, groupID, peerKey) + return am.GroupDeletePeerFunc(accountID, groupID, peerID) } return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented") } -// GroupListPeers mock implementation of GroupListPeers from server.AccountManager interface -func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*server.Peer, error) { - if am.GroupListPeersFunc != nil { - return am.GroupListPeersFunc(accountID, groupID) - } - return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers is not implemented") -} - // GetRule mock implementation of GetRule from server.AccountManager interface func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) { if am.GetRuleFunc != nil { @@ -412,9 +365,9 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *server. } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(accountID string, network, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(accountID, network, peerID, description, netID, masquerade, metric, groups, enabled, userID) + return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } @@ -435,14 +388,6 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented") } -// UpdateRoute mock implementation of UpdateRoute from server.AccountManager interface -func (am *MockAccountManager) UpdateRoute(accountID, ruleID string, operations []server.RouteUpdateOperation) (*route.Route, error) { - if am.UpdateRouteFunc != nil { - return am.UpdateRouteFunc(accountID, ruleID, operations) - } - return nil, status.Errorf(codes.Unimplemented, "method UpdateRoute not implemented") -} - // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error { if am.DeleteRouteFunc != nil { @@ -533,14 +478,6 @@ func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGr return nil } -// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface -func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - if am.UpdateNameServerGroupFunc != nil { - return am.UpdateNameServerGroupFunc(accountID, nsGroupID, userID, operations) - } - return nil, nil -} - // DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { if am.DeleteNameServerGroupFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index eb21279451a..9af5b49adc8 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -3,66 +3,17 @@ package server import ( "errors" "regexp" - "strconv" "unicode/utf8" "github.com/miekg/dns" "github.com/rs/xid" - log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" ) -const ( - // UpdateNameServerGroupName indicates a nameserver group name update operation - UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota - // UpdateNameServerGroupDescription indicates a nameserver group description update operation - UpdateNameServerGroupDescription - // UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation - UpdateNameServerGroupNameServers - // UpdateNameServerGroupGroups indicates a nameserver group' groups update operation - UpdateNameServerGroupGroups - // UpdateNameServerGroupEnabled indicates a nameserver group status update operation - UpdateNameServerGroupEnabled - // UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation - UpdateNameServerGroupPrimary - // UpdateNameServerGroupDomains indicates a nameserver group' domains update operation - UpdateNameServerGroupDomains - - domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` -) - -// NameServerGroupUpdateOperationType operation type -type NameServerGroupUpdateOperationType int - -func (t NameServerGroupUpdateOperationType) String() string { - switch t { - case UpdateNameServerGroupDescription: - return "UpdateNameServerGroupDescription" - case UpdateNameServerGroupName: - return "UpdateNameServerGroupName" - case UpdateNameServerGroupNameServers: - return "UpdateNameServerGroupNameServers" - case UpdateNameServerGroupGroups: - return "UpdateNameServerGroupGroups" - case UpdateNameServerGroupEnabled: - return "UpdateNameServerGroupEnabled" - case UpdateNameServerGroupPrimary: - return "UpdateNameServerGroupPrimary" - case UpdateNameServerGroupDomains: - return "UpdateNameServerGroupDomains" - default: - return "InvalidOperation" - } -} - -// NameServerGroupUpdateOperation operation object with type and values to be applied -type NameServerGroupUpdateOperation struct { - Type NameServerGroupUpdateOperationType - Values []string -} +const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { @@ -122,11 +73,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d return nil, err } - err = am.updateAccountPeers(account) - if err != nil { - log.Error(err) - return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name) - } + am.updateAccountPeers(account) am.storeEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -161,120 +108,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n return err } - err = am.updateAccountPeers(account) - if err != nil { - log.Error(err) - return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name) - } + am.updateAccountPeers(account) am.storeEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil } -// UpdateNameServerGroup updates existing nameserver group with set of operations -func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - if len(operations) == 0 { - return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty") - } - - nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID] - if !ok { - return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) - } - - newNSGroup := nsGroupToUpdate.Copy() - - for _, operation := range operations { - valuesCount := len(operation.Values) - if valuesCount < 1 { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String()) - } - - for _, value := range operation.Values { - if value == "" { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String()) - } - } - switch operation.Type { - case UpdateNameServerGroupDescription: - newNSGroup.Description = operation.Values[0] - case UpdateNameServerGroupName: - if valuesCount > 1 { - return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount) - } - err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups) - if err != nil { - return nil, err - } - newNSGroup.Name = operation.Values[0] - case UpdateNameServerGroupNameServers: - var nsList []nbdns.NameServer - for _, url := range operation.Values { - ns, err := nbdns.ParseNameServerURL(url) - if err != nil { - return nil, err - } - nsList = append(nsList, ns) - } - err = validateNSList(nsList) - if err != nil { - return nil, err - } - newNSGroup.NameServers = nsList - case UpdateNameServerGroupGroups: - err = validateGroups(operation.Values, account.Groups) - if err != nil { - return nil, err - } - newNSGroup.Groups = operation.Values - case UpdateNameServerGroupEnabled: - enabled, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) - } - newNSGroup.Enabled = enabled - case UpdateNameServerGroupPrimary: - primary, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0]) - } - newNSGroup.Primary = primary - case UpdateNameServerGroupDomains: - err = validateDomainInput(false, operation.Values) - if err != nil { - return nil, err - } - newNSGroup.Domains = operation.Values - } - } - - account.NameServerGroups[nsGroupID] = newNSGroup - - account.Network.IncSerial() - err = am.Store.SaveAccount(account) - if err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - log.Error(err) - return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name) - } - - return newNSGroup.Copy(), nil -} - // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { @@ -298,10 +138,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use return err } - err = am.updateAccountPeers(account) - if err != nil { - return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID) - } + am.updateAccountPeers(account) am.storeEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 9d44250562f..26977116b86 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -655,323 +655,6 @@ func TestSaveNameServerGroup(t *testing.T) { } } -func TestUpdateNameServerGroup(t *testing.T) { - nsGroupID := "testingNSGroup" - - existingNSGroup := &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "super", - Description: "super", - Primary: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("1.1.2.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID}, - Enabled: true, - } - - testCases := []struct { - name string - existingNSGroup *nbdns.NameServerGroup - nsGroupID string - operations []NameServerGroupUpdateOperation - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedNSGroup *nbdns.NameServerGroup - }{ - { - name: "Should Config Single Property", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"superNew"}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedNSGroup: &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "superNew", - Description: "super", - Primary: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("1.1.2.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID}, - Enabled: true, - }, - }, - { - name: "Should Config Multiple Properties", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"superNew"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDescription, - Values: []string{"superDescription"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupGroups, - Values: []string{group1ID, group2ID}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupEnabled, - Values: []string{"false"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupPrimary, - Values: []string{"false"}, - }, - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDomains, - Values: []string{validDomain}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedNSGroup: &nbdns.NameServerGroup{ - ID: nsGroupID, - Name: "superNew", - Description: "superDescription", - Primary: false, - Domains: []string{validDomain}, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("127.0.0.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - }, - Groups: []string{group1ID, group2ID}, - Enabled: false, - }, - }, - { - name: "Should Not Config On Invalid ID", - existingNSGroup: existingNSGroup, - nsGroupID: "nonExistingNSGroup", - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty Operations", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{}, - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty Values", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Empty String", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{""}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Name Large String", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid On Existing Name", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{existingNSGroupName}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid On Multiple Name Values", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupName, - Values: []string{"nameOne", "nameTwo"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Boolean", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupEnabled, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Nameservers Wrong Schema", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"https://127.0.0.1:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Nameservers Wrong IP", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://8.8.8.300:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Large Number Of Nameservers", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupNameServers, - Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid GroupID", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupGroups, - Values: []string{"nonExistingGroupID"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Domains", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupDomains, - Values: []string{invalidDomain}, - }, - }, - errFunc: require.Error, - }, - { - name: "Should Not Config On Invalid Primary Status", - existingNSGroup: existingNSGroup, - nsGroupID: existingNSGroup.ID, - operations: []NameServerGroupUpdateOperation{ - NameServerGroupUpdateOperation{ - Type: UpdateNameServerGroupPrimary, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - am, err := createNSManager(t) - if err != nil { - t.Error("failed to create account manager") - } - - account, err := initTestNSAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } - - account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - - err = am.Store.SaveAccount(account) - if err != nil { - t.Error("account should be saved") - } - - updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, userID, testCase.operations) - testCase.errFunc(t, err) - - if !testCase.shouldCreate { - return - } - - testCase.expectedNSGroup.ID = updatedRoute.ID - - if !testCase.expectedNSGroup.IsEqual(updatedRoute) { - t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup) - } - - }) - } -} - func TestDeleteNameServerGroup(t *testing.T) { nsGroupID := "testingNSGroup" @@ -1061,7 +744,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false) } func createNSStore(t *testing.T) (Store, error) { diff --git a/management/server/peer.go b/management/server/peer.go index f9631719f2a..e5c6e39d65d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -195,16 +195,6 @@ func (p *PeerStatus) Copy() *PeerStatus { } } -// GetPeerByKey looks up peer by its public WireGuard key -func (am *DefaultAccountManager) GetPeerByKey(peerPubKey string) (*Peer, error) { - account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) - if err != nil { - return nil, err - } - - return account.FindPeerByPubKey(peerPubKey) -} - // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) { @@ -290,10 +280,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected if oldStatus.LoginExpired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - err = am.updateAccountPeers(account) - if err != nil { - return err - } + am.updateAccountPeers(account) } return nil @@ -364,82 +351,75 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe return nil, err } - err = am.updateAccountPeers(account) - if err != nil { - return nil, err - } + am.updateAccountPeers(account) return peer, nil } -// DeletePeer removes peer from the account by its IP -func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) (*Peer, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() +// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock +func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, userID string) error { - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer %s not found", peerID) - } + // the first loop is needed to ensure all peers present under the account before modifying, otherwise + // we might have some inconsistencies + peers := make([]*Peer, 0, len(peerIDs)) + for _, peerID := range peerIDs { - account.DeletePeer(peerID) - - err = am.Store.SaveAccount(account) - if err != nil { - return nil, err + peer := account.GetPeer(peerID) + if peer == nil { + return status.Errorf(status.NotFound, "peer %s not found", peerID) + } + peers = append(peers, peer) } - err = am.peersUpdateManager.SendUpdate(peer.ID, - &UpdateMessage{ - Update: &proto.SyncResponse{ - // fill those field for backward compatibility - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - // new field - NetworkMap: &proto.NetworkMap{ - Serial: account.Network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, + // the 2nd loop performs the actual modification + for _, peer := range peers { + account.DeletePeer(peer.ID) + am.peersUpdateManager.SendUpdate(peer.ID, + &UpdateMessage{ + Update: &proto.SyncResponse{ + // fill those field for backward compatibility + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + // new field + NetworkMap: &proto.NetworkMap{ + Serial: account.Network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + }, }, - }, - }) - if err != nil { - return nil, err - } - - if err := am.updateAccountPeers(account); err != nil { - return nil, err + }) + am.peersUpdateManager.CloseChannel(peer.ID) + am.storeEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) } - am.peersUpdateManager.CloseChannel(peerID) - am.storeEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) - return peer, nil + return nil } -// GetPeerByIP returns peer by its IP -func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) { +// DeletePeer removes peer from the account by its IP +func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, err + return err } - for _, peer := range account.Peers { - if peerIP == peer.IP.String() { - return peer, nil - } + err = am.deletePeers(account, []string{peerID}, userID) + if err != nil { + return err } - return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP) + err = am.Store.SaveAccount(account) + if err != nil { + return err + } + + am.updateAccountPeers(account) + + return nil } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) @@ -609,10 +589,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) am.storeEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - err = am.updateAccountPeers(account) - if err != nil { - return nil, nil, err - } + am.updateAccountPeers(account) networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain) return newPeer, networkMap, nil @@ -727,10 +704,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, } if updateRemotePeers { - err = am.updateAccountPeers(account) - if err != nil { - return nil, nil, err - } + am.updateAccountPeers(account) } return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil } @@ -804,10 +778,7 @@ func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *Peer, account *A } // trigger network map update - err = am.updateAccountPeers(account) - if err != nil { - return nil, err - } + am.updateAccountPeers(account) return peer, nil } @@ -852,7 +823,9 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) } // trigger network map update - return am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return nil } // GetPeer for a given accountID, peerID and userID error if not found. @@ -909,21 +882,12 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) (*Peer, b // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { +func (am *DefaultAccountManager) updateAccountPeers(account *Account) { peers := account.GetPeers() for _, peer := range peers { - remotePeerNetworkMap, err := am.GetNetworkMap(peer.ID) - if err != nil { - return err - } - + remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) - err = am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) - if err != nil { - return err - } + am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) } - - return nil } diff --git a/management/server/policy.go b/management/server/policy.go index dde0b46d8b5..308a5c3c0db 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -350,7 +350,9 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po } am.storeEvent(userID, policy.ID, accountID, action, policy.EventMeta()) - return am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return nil } // DeletePolicy from the store @@ -375,7 +377,9 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string am.storeEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - return am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return nil } // ListPolicies from the store diff --git a/management/server/route.go b/management/server/route.go index f51b7c2dbde..6b5aa982d64 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -2,7 +2,6 @@ package server import ( "net/netip" - "strconv" "unicode/utf8" "github.com/netbirdio/netbird/management/proto" @@ -10,60 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" "github.com/rs/xid" - log "github.com/sirupsen/logrus" ) -const ( - // UpdateRouteDescription indicates a route description update operation - UpdateRouteDescription RouteUpdateOperationType = iota - // UpdateRouteNetwork indicates a route IP update operation - UpdateRouteNetwork - // UpdateRoutePeer indicates a route peer update operation - UpdateRoutePeer - // UpdateRouteMetric indicates a route metric update operation - UpdateRouteMetric - // UpdateRouteMasquerade indicates a route masquerade update operation - UpdateRouteMasquerade - // UpdateRouteEnabled indicates a route enabled update operation - UpdateRouteEnabled - // UpdateRouteNetworkIdentifier indicates a route net ID update operation - UpdateRouteNetworkIdentifier - // UpdateRouteGroups indicates a group list update operation - UpdateRouteGroups -) - -// RouteUpdateOperationType operation type -type RouteUpdateOperationType int - -func (t RouteUpdateOperationType) String() string { - switch t { - case UpdateRouteDescription: - return "UpdateRouteDescription" - case UpdateRouteNetwork: - return "UpdateRouteNetwork" - case UpdateRoutePeer: - return "UpdateRoutePeer" - case UpdateRouteMetric: - return "UpdateRouteMetric" - case UpdateRouteMasquerade: - return "UpdateRouteMasquerade" - case UpdateRouteEnabled: - return "UpdateRouteEnabled" - case UpdateRouteNetworkIdentifier: - return "UpdateRouteNetworkIdentifier" - case UpdateRouteGroups: - return "UpdateRouteGroups" - default: - return "InvalidOperation" - } -} - -// RouteUpdateOperation operation object with type and values to be applied -type RouteUpdateOperation struct { - Type RouteUpdateOperationType - Values []string -} - // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { unlock := am.Store.AcquireAccountLock(accountID) @@ -91,30 +38,82 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) } -// checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil -func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peerID string, prefix netip.Prefix) error { +// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. +func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error { + // routes can have both peer and peer_groups + routesWithPrefix := account.GetRoutesByPrefix(prefix) + + // lets remember all the peers and the peer groups from routesWithPrefix + seenPeers := make(map[string]bool) + seenPeerGroups := make(map[string]bool) + + for _, prefixRoute := range routesWithPrefix { + // we skip route(s) with the same network ID as we want to allow updating of the existing route + // when create a new route routeID is newly generated so nothing will be skipped + if routeID == prefixRoute.ID { + continue + } - if peerID == "" { - return nil + if prefixRoute.Peer != "" { + seenPeers[prefixRoute.ID] = true + } + for _, groupID := range prefixRoute.PeerGroups { + seenPeerGroups[groupID] = true + + group := account.GetGroup(groupID) + if group == nil { + return status.Errorf( + status.InvalidArgument, "failed to add route with prefix %s - peer group %s doesn't exist", + prefix.String(), groupID) + } + + for _, pID := range group.Peers { + seenPeers[pID] = true + } + } } - account, err := am.Store.GetAccount(accountID) - if err != nil { - return err + if peerID != "" { + // check that peerID exists and is not in any route as single peer or part of the group + peer := account.GetPeer(peerID) + if peer == nil { + return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) + } + if _, ok := seenPeers[peerID]; ok { + return status.Errorf(status.AlreadyExists, + "failed to add route with prefix %s - peer %s already has this route", prefix.String(), peerID) + } } - routesWithPrefix := account.GetRoutesByPrefix(prefix) + // check that peerGroupIDs are not in any route peerGroups list + for _, groupID := range peerGroupIDs { + group := account.GetGroup(groupID) // we validated the group existent before entering this function, o need to check again. - for _, prefixRoute := range routesWithPrefix { - if prefixRoute.Peer == peerID { - return status.Errorf(status.AlreadyExists, "failed to add route with prefix %s - peer already has this route", prefix.String()) + if _, ok := seenPeerGroups[groupID]; ok { + return status.Errorf( + status.AlreadyExists, "failed to add route with prefix %s - peer group %s already has this route", + prefix.String(), group.Name) + } + + // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix + for _, id := range group.Peers { + if _, ok := seenPeers[id]; ok { + peer := account.GetPeer(id) + if peer == nil { + return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) + } + return status.Errorf(status.AlreadyExists, + "failed to add route with prefix %s - peer %s from the group %s already has this route", + prefix.String(), peer.Name, group.Name) + } } } + return nil } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { +func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -123,19 +122,29 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID, return nil, err } - if peerID != "" { - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) - } + if peerID != "" && len(peerGroupIDs) != 0 { + return nil, status.Errorf( + status.InvalidArgument, + "peer with ID %s and peers group %s should not be provided at the same time", + peerID, peerGroupIDs) } var newRoute route.Route + newRoute.ID = xid.New().String() + prefixType, newPrefix, err := route.ParseNetwork(network) if err != nil { return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network) } - err = am.checkPrefixPeerExists(accountID, peerID, newPrefix) + + if len(peerGroupIDs) > 0 { + err = validateGroups(peerGroupIDs, account.Groups) + if err != nil { + return nil, err + } + } + + err = am.checkRoutePrefixExistsForPeers(account, peerID, newRoute.ID, peerGroupIDs, newPrefix) if err != nil { return nil, err } @@ -154,7 +163,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID, } newRoute.Peer = peerID - newRoute.ID = xid.New().String() + newRoute.PeerGroups = peerGroupIDs newRoute.Network = newPrefix newRoute.NetworkType = prefixType newRoute.Description = description @@ -175,11 +184,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID, return nil, err } - err = am.updateAccountPeers(account) - if err != nil { - log.Error(err) - return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix) - } + am.updateAccountPeers(account) am.storeEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -212,13 +217,22 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave return err } - if routeToSave.Peer != "" { - peer := account.GetPeer(routeToSave.Peer) - if peer == nil { - return status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToSave.Peer) + if routeToSave.Peer != "" && len(routeToSave.PeerGroups) != 0 { + return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") + } + + if len(routeToSave.PeerGroups) > 0 { + err = validateGroups(routeToSave.PeerGroups, account.Groups) + if err != nil { + return err } } + err = am.checkRoutePrefixExistsForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network) + if err != nil { + return err + } + err = validateGroups(routeToSave.Groups, account.Groups) if err != nil { return err @@ -231,119 +245,13 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave return err } - err = am.updateAccountPeers(account) - if err != nil { - return err - } + am.updateAccountPeers(account) am.storeEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta()) return nil } -// UpdateRoute updates existing route with set of operations -func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) { - unlock := am.Store.AcquireAccountLock(accountID) - defer unlock() - - account, err := am.Store.GetAccount(accountID) - if err != nil { - return nil, err - } - - routeToUpdate, ok := account.Routes[routeID] - if !ok { - return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) - } - - newRoute := routeToUpdate.Copy() - - for _, operation := range operations { - - if len(operation.Values) != 1 { - return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String()) - } - - switch operation.Type { - case UpdateRouteDescription: - newRoute.Description = operation.Values[0] - case UpdateRouteNetworkIdentifier: - if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" { - return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) - } - newRoute.NetID = operation.Values[0] - case UpdateRouteNetwork: - prefixType, prefix, err := route.ParseNetwork(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0]) - } - err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix) - if err != nil { - return nil, err - } - newRoute.Network = prefix - newRoute.NetworkType = prefixType - case UpdateRoutePeer: - if operation.Values[0] != "" { - peer := account.GetPeer(operation.Values[0]) - if peer == nil { - return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", operation.Values[0]) - } - } - - err = am.checkPrefixPeerExists(accountID, operation.Values[0], routeToUpdate.Network) - if err != nil { - return nil, err - } - newRoute.Peer = operation.Values[0] - case UpdateRouteMetric: - metric, err := strconv.Atoi(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0]) - } - if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d", - operation.Values[0], - route.MinMetric, - route.MaxMetric, - ) - } - newRoute.Metric = metric - case UpdateRouteMasquerade: - masquerade, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0]) - } - newRoute.Masquerade = masquerade - case UpdateRouteEnabled: - enabled, err := strconv.ParseBool(operation.Values[0]) - if err != nil { - return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) - } - newRoute.Enabled = enabled - case UpdateRouteGroups: - err = validateGroups(operation.Values, account.Groups) - if err != nil { - return nil, err - } - newRoute.Groups = operation.Values - } - } - - account.Routes[routeID] = newRoute - - account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { - return nil, err - } - - err = am.updateAccountPeers(account) - if err != nil { - return nil, status.Errorf(status.Internal, "failed to update account peers") - } - return newRoute, nil -} - // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { unlock := am.Store.AcquireAccountLock(accountID) @@ -367,7 +275,9 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) am.storeEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta()) - return am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return nil } // ListRoutes returns a list of routes from account diff --git a/management/server/route_test.go b/management/server/route_test.go index c943aee0bfa..00ef3e93a4d 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/rs/xid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" @@ -14,32 +15,46 @@ import ( const ( peer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" peer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI=" + peer3Key = "ayF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NaF=" + peer4Key = "ayF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5acc=" + peer5Key = "ayF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5a55=" peer1ID = "peer-1-id" peer2ID = "peer-2-id" + peer3ID = "peer-3-id" + peer4ID = "peer-4-id" + peer5ID = "peer-5-id" routeGroup1 = "routeGroup1" routeGroup2 = "routeGroup2" + routeGroup3 = "routeGroup3" // for existing route + routeGroup4 = "routeGroup4" // for existing route + routeGroupHA1 = "routeGroupHA1" + routeGroupHA2 = "routeGroupHA2" routeInvalidGroup1 = "routeInvalidGroup1" userID = "testingUser" + existingNetwork = "10.10.10.0/24" + existingRouteID = "random-id" ) func TestCreateRoute(t *testing.T) { type input struct { - network string - netID string - peerKey string - description string - masquerade bool - metric int - enabled bool - groups []string + network string + netID string + peerKey string + peerGroupIDs []string + description string + masquerade bool + metric int + enabled bool + groups []string } testCases := []struct { - name string - inputArgs input - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedRoute *route.Route + name string + inputArgs input + createInitRoute bool + shouldCreate bool + errFunc require.ErrorAssertionFunc + expectedRoute *route.Route }{ { name: "Happy Path", @@ -67,6 +82,48 @@ func TestCreateRoute(t *testing.T) { Groups: []string{routeGroup1}, }, }, + { + name: "Happy Path Peer Groups", + inputArgs: input{ + network: "192.168.0.0/16", + netID: "happy", + peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1, routeGroup2}, + }, + errFunc: require.NoError, + shouldCreate: true, + expectedRoute: &route.Route{ + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + }, + }, + { + name: "Both peer and peer_groups Provided Should Fail", + inputArgs: input{ + network: "192.168.0.0/16", + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + }, + errFunc: require.Error, + shouldCreate: false, + }, { name: "Bad Prefix Should Fail", inputArgs: input{ @@ -97,6 +154,38 @@ func TestCreateRoute(t *testing.T) { errFunc: require.Error, shouldCreate: false, }, + { + name: "Bad Peer already has this route", + inputArgs: input{ + network: existingNetwork, + netID: "bad", + peerKey: peer5ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + }, + createInitRoute: true, + errFunc: require.Error, + shouldCreate: false, + }, + { + name: "Bad Peers Group already has this route", + inputArgs: input{ + network: existingNetwork, + netID: "bad", + peerGroupIDs: []string{routeGroup1, routeGroup3}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + }, + createInitRoute: true, + errFunc: require.Error, + shouldCreate: false, + }, { name: "Empty Peer Should Create", inputArgs: input{ @@ -238,13 +327,26 @@ func TestCreateRoute(t *testing.T) { account, err := initTestRouteAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Errorf("failed to init testing account: %s", err) + } + + if testCase.createInitRoute { + groupAll, errInit := account.GetGroupAll() + if errInit != nil { + t.Errorf("failed to get group all: %s", errInit) + } + _, errInit = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4}, + "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID) + if errInit != nil { + t.Errorf("failed to create init route: %s", errInit) + } } outRoute, err := am.CreateRoute( account.Id, testCase.inputArgs.network, testCase.inputArgs.peerKey, + testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, @@ -272,6 +374,7 @@ func TestCreateRoute(t *testing.T) { func TestSaveRoute(t *testing.T) { validPeer := peer2ID + validUsedPeer := peer5ID invalidPeer := "nonExisting" validPrefix := netip.MustParsePrefix("192.168.0.0/24") invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34") @@ -279,18 +382,22 @@ func TestSaveRoute(t *testing.T) { invalidMetric := 99999 validNetID := "12345678901234567890qw" invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1" + validGroupHA1 := routeGroupHA1 + validGroupHA2 := routeGroupHA2 testCases := []struct { - name string - existingRoute *route.Route - newPeer *string - newMetric *int - newPrefix *netip.Prefix - newGroups []string - skipCopying bool - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedRoute *route.Route + name string + existingRoute *route.Route + createInitRoute bool + newPeer *string + newPeerGroups []string + newMetric *int + newPrefix *netip.Prefix + newGroups []string + skipCopying bool + shouldCreate bool + errFunc require.ErrorAssertionFunc + expectedRoute *route.Route }{ { name: "Happy Path", @@ -325,6 +432,55 @@ func TestSaveRoute(t *testing.T) { Groups: []string{routeGroup2}, }, }, + { + name: "Happy Path Peer Groups", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: validNetID, + NetworkType: route.IPv4Network, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + newPeerGroups: []string{validGroupHA1, validGroupHA2}, + newMetric: &validMetric, + newPrefix: &validPrefix, + newGroups: []string{routeGroup2}, + errFunc: require.NoError, + shouldCreate: true, + expectedRoute: &route.Route{ + ID: "testingRoute", + Network: validPrefix, + NetID: validNetID, + NetworkType: route.IPv4Network, + PeerGroups: []string{validGroupHA1, validGroupHA2}, + Description: "super", + Masquerade: false, + Metric: validMetric, + Enabled: true, + Groups: []string{routeGroup2}, + }, + }, + { + name: "Both peer and peers_roup Provided Should Fail", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: validNetID, + NetworkType: route.IPv4Network, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + newPeer: &validPeer, + newPeerGroups: []string{validGroupHA1}, + errFunc: require.Error, + }, { name: "Bad Prefix Should Fail", existingRoute: &route.Route{ @@ -461,6 +617,73 @@ func TestSaveRoute(t *testing.T) { newGroups: []string{routeInvalidGroup1}, errFunc: require.Error, }, + { + name: "Allow to modify existing route with new peer", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: netip.MustParsePrefix(existingNetwork), + NetID: validNetID, + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + newPeer: &validPeer, + errFunc: require.NoError, + shouldCreate: true, + expectedRoute: &route.Route{ + ID: "testingRoute", + Network: netip.MustParsePrefix(existingNetwork), + NetID: validNetID, + NetworkType: route.IPv4Network, + Peer: validPeer, + PeerGroups: []string{}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + }, + { + name: "Do not allow to modify existing route with a peer from another route", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: netip.MustParsePrefix(existingNetwork), + NetID: validNetID, + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + createInitRoute: true, + newPeer: &validUsedPeer, + errFunc: require.Error, + }, + { + name: "Do not allow to modify existing route with a peers group from another route", + existingRoute: &route.Route{ + ID: "testingRoute", + Network: netip.MustParsePrefix(existingNetwork), + NetID: validNetID, + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroup3}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + }, + createInitRoute: true, + newPeerGroups: []string{routeGroup4}, + errFunc: require.Error, + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { @@ -474,6 +697,21 @@ func TestSaveRoute(t *testing.T) { t.Error("failed to init testing account") } + if testCase.createInitRoute { + account.Routes["initRoute"] = &route.Route{ + ID: "initRoute", + Network: netip.MustParsePrefix(existingNetwork), + NetID: existingRouteID, + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroup4}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + } + account.Routes[testCase.existingRoute.ID] = testCase.existingRoute err = am.Store.SaveAccount(account) @@ -488,6 +726,9 @@ func TestSaveRoute(t *testing.T) { if testCase.newPeer != nil { routeToSave.Peer = *testCase.newPeer } + if len(testCase.newPeerGroups) != 0 { + routeToSave.PeerGroups = testCase.newPeerGroups + } if testCase.newMetric != nil { routeToSave.Metric = *testCase.newMetric } @@ -524,265 +765,6 @@ func TestSaveRoute(t *testing.T) { } } -func TestUpdateRoute(t *testing.T) { - routeID := "testingRouteID" - - existingRoute := &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - } - - testCases := []struct { - name string - existingRoute *route.Route - operations []RouteUpdateOperation - shouldCreate bool - errFunc require.ErrorAssertionFunc - expectedRoute *route.Route - }{ - { - name: "Happy Path Single OPS", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{peer2ID}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: peer2ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - }, - }, - { - name: "Happy Path Multiple OPS", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteDescription, - Values: []string{"great"}, - }, - { - Type: UpdateRouteNetwork, - Values: []string{"192.168.0.0/24"}, - }, - { - Type: UpdateRoutePeer, - Values: []string{peer2ID}, - }, - { - Type: UpdateRouteMetric, - Values: []string{"3030"}, - }, - { - Type: UpdateRouteMasquerade, - Values: []string{"true"}, - }, - { - Type: UpdateRouteEnabled, - Values: []string{"false"}, - }, - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{"megaRoute"}, - }, - { - Type: UpdateRouteGroups, - Values: []string{routeGroup2}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/24"), - NetID: "megaRoute", - NetworkType: route.IPv4Network, - Peer: peer2ID, - Description: "great", - Masquerade: true, - Metric: 3030, - Enabled: false, - Groups: []string{routeGroup2}, - }, - }, - { - name: "Empty Values Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - }, - }, - errFunc: require.Error, - }, - { - name: "Multiple Values Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{peer2ID, peer1ID}, - }, - }, - errFunc: require.Error, - }, - { - name: "Bad Prefix Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetwork, - Values: []string{"192.168.0.0/34"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Bad Peer Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{"non existing Peer"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Empty Peer", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRoutePeer, - Values: []string{""}, - }, - }, - errFunc: require.NoError, - shouldCreate: true, - expectedRoute: &route.Route{ - ID: routeID, - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superRoute", - NetworkType: route.IPv4Network, - Peer: "", - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - }, - }, - { - name: "Large Network ID Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Empty Network ID Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteNetworkIdentifier, - Values: []string{""}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Metric Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteMetric, - Values: []string{"999999"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Boolean Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteMasquerade, - Values: []string{"yes"}, - }, - }, - errFunc: require.Error, - }, - { - name: "Invalid Group Should Fail", - existingRoute: existingRoute, - operations: []RouteUpdateOperation{ - { - Type: UpdateRouteGroups, - Values: []string{routeInvalidGroup1}, - }, - }, - errFunc: require.Error, - }, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) - if err != nil { - t.Error("failed to create account manager") - } - - account, err := initTestRouteAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } - - account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - - err = am.Store.SaveAccount(account) - if err != nil { - t.Error("account should be saved") - } - - updatedRoute, err := am.UpdateRoute(account.Id, testCase.existingRoute.ID, testCase.operations) - - testCase.errFunc(t, err) - - if !testCase.shouldCreate { - return - } - - testCase.expectedRoute.ID = updatedRoute.ID - - if !testCase.expectedRoute.IsEqual(updatedRoute) { - t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedRoute) - } - }) - } -} - func TestDeleteRoute(t *testing.T) { testingRoute := &route.Route{ ID: "testingRoute", @@ -828,6 +810,96 @@ func TestDeleteRoute(t *testing.T) { } } +func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { + baseRoute := &route.Route{ + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + } + + am, err := createRouterManager(t) + if err != nil { + t.Error("failed to create account manager") + } + + account, err := initTestRouteAccount(t, am) + if err != nil { + t.Error("failed to init testing account") + } + + newAccountRoutes, err := am.GetNetworkMap(peer1ID) + require.NoError(t, err) + require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") + + newRoute, err := am.CreateRoute( + account.Id, baseRoute.Network.String(), baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, + baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID) + require.NoError(t, err) + require.Equal(t, newRoute.Enabled, true) + + peer1Routes, err := am.GetNetworkMap(peer1ID) + require.NoError(t, err) + assert.Len(t, peer1Routes.Routes, 1, "HA route should have 1 server route") + + peer2Routes, err := am.GetNetworkMap(peer2ID) + require.NoError(t, err) + assert.Len(t, peer2Routes.Routes, 1, "HA route should have 1 server route") + + peer4Routes, err := am.GetNetworkMap(peer4ID) + require.NoError(t, err) + assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") + + groups, err := am.ListGroups(account.Id) + require.NoError(t, err) + var groupHA1, groupHA2 *Group + for _, group := range groups { + switch group.Name { + case routeGroupHA1: + groupHA1 = group + case routeGroupHA2: + groupHA2 = group + } + } + + err = am.GroupDeletePeer(account.Id, groupHA1.ID, peer2ID) + require.NoError(t, err) + + peer2RoutesAfterDelete, err := am.GetNetworkMap(peer2ID) + require.NoError(t, err) + assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") + + err = am.GroupDeletePeer(account.Id, groupHA2.ID, peer4ID) + require.NoError(t, err) + + peer2RoutesAfterDelete, err = am.GetNetworkMap(peer2ID) + require.NoError(t, err) + assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") + + err = am.GroupAddPeer(account.Id, groupHA2.ID, peer4ID) + require.NoError(t, err) + + peer1RoutesAfterAdd, err := am.GetNetworkMap(peer1ID) + require.NoError(t, err) + assert.Len(t, peer1RoutesAfterAdd.Routes, 1, "HA route should have more than 1 route") + + peer2RoutesAfterAdd, err := am.GetNetworkMap(peer2ID) + require.NoError(t, err) + assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") + + err = am.DeleteRoute(account.Id, newRoute.ID, userID) + require.NoError(t, err) + + peer1DeletedRoute, err := am.GetNetworkMap(peer1ID) + require.NoError(t, err) + assert.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") +} + func TestGetNetworkMap_RouteSync(t *testing.T) { // no routes for peer in different groups // no routes when route is deleted @@ -858,7 +930,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network.String(), peer1ID, + createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network.String(), peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID) require.NoError(t, err) @@ -940,7 +1012,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore) + return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false) } func createRouterStore(t *testing.T) (Store, error) { @@ -954,6 +1026,8 @@ func createRouterStore(t *testing.T) (Store, error) { } func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { + t.Helper() + accountID := "testingAcc" domain := "example.com" @@ -1013,6 +1087,81 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer2.ID] = peer2 + ips = account.getTakenIPs() + peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + if err != nil { + return nil, err + } + + peer3 := &Peer{ + IP: peer3IP, + ID: peer3ID, + Key: peer3Key, + Name: "test-host3@netbird.io", + UserID: userID, + Meta: PeerSystemMeta{ + Hostname: "test-host3@netbird.io", + GoOS: "darwin", + Kernel: "Darwin", + Core: "13.4.1", + Platform: "arm64", + OS: "darwin", + WtVersion: "development", + UIVersion: "development", + }, + } + account.Peers[peer3.ID] = peer3 + + ips = account.getTakenIPs() + peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + if err != nil { + return nil, err + } + + peer4 := &Peer{ + IP: peer4IP, + ID: peer4ID, + Key: peer4Key, + Name: "test-host4@netbird.io", + UserID: userID, + Meta: PeerSystemMeta{ + Hostname: "test-host4@netbird.io", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + } + account.Peers[peer4.ID] = peer4 + + ips = account.getTakenIPs() + peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + if err != nil { + return nil, err + } + + peer5 := &Peer{ + IP: peer5IP, + ID: peer5ID, + Key: peer5Key, + Name: "test-host4@netbird.io", + UserID: userID, + Meta: PeerSystemMeta{ + Hostname: "test-host4@netbird.io", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + } + account.Peers[peer5.ID] = peer5 + err = am.Store.SaveAccount(account) if err != nil { return nil, err @@ -1029,26 +1178,53 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er if err != nil { return nil, err } - - newGroup := &Group{ - ID: routeGroup1, - Name: routeGroup1, - Peers: []string{peer1.ID}, + err = am.GroupAddPeer(accountID, groupAll.ID, peer3ID) + if err != nil { + return nil, err } - err = am.SaveGroup(accountID, userID, newGroup) + err = am.GroupAddPeer(accountID, groupAll.ID, peer4ID) if err != nil { return nil, err } - newGroup = &Group{ - ID: routeGroup2, - Name: routeGroup2, - Peers: []string{peer2.ID}, + newGroup := []*Group{ + { + ID: routeGroup1, + Name: routeGroup1, + Peers: []string{peer1.ID}, + }, + { + ID: routeGroup2, + Name: routeGroup2, + Peers: []string{peer2.ID}, + }, + { + ID: routeGroup3, + Name: routeGroup3, + Peers: []string{peer5.ID}, + }, + { + ID: routeGroup4, + Name: routeGroup4, + Peers: []string{peer5.ID}, + }, + { + ID: routeGroupHA1, + Name: routeGroupHA1, + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, // we have one non Linux peer, see peer3 + }, + { + ID: routeGroupHA2, + Name: routeGroupHA2, + Peers: []string{peer1.ID, peer4.ID}, + }, } - err = am.SaveGroup(accountID, userID, newGroup) - if err != nil { - return nil, err + for _, group := range newGroup { + err = am.SaveGroup(accountID, userID, group) + if err != nil { + return nil, err + } } return am.Store.GetAccount(account.Id) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index e857230a577..6e626d08411 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -317,7 +317,9 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup } }() - return newKey, am.updateAccountPeers(account) + am.updateAccountPeers(account) + + return newKey, nil } // ListSetupKeys returns a list of all setup keys of the account diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index 4ca592179e3..25789f5c752 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -19,6 +19,7 @@ type GRPCMetrics struct { activeStreamsGauge asyncint64.Gauge syncRequestDuration syncint64.Histogram loginRequestDuration syncint64.Histogram + channelQueueLength syncint64.Histogram ctx context.Context } @@ -52,6 +53,18 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + // We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time + // Then we should be able to extract min, manx, mean and the percentiles. + // TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100) + channelQueue, err := meter.SyncInt64().Histogram( + "management.grpc.updatechannel.queue", + instrument.WithDescription("Number of update messages in the channel queue"), + instrument.WithUnit("length"), + ) + if err != nil { + return nil, err + } + return &GRPCMetrics{ meter: meter, syncRequestsCounter: syncRequestsCounter, @@ -60,6 +73,7 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro activeStreamsGauge: activeStreamsGauge, syncRequestDuration: syncRequestDuration, loginRequestDuration: loginRequestDuration, + channelQueueLength: channelQueue, ctx: ctx, }, err } @@ -100,3 +114,8 @@ func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) }, ) } + +// UpdateChannelQueueLength update the histogram that keep distribution of the update messages channel queue +func (metrics *GRPCMetrics) UpdateChannelQueueLength(len int) { + metrics.channelQueueLength.Record(metrics.ctx, int64(len)) +} diff --git a/management/server/telemetry/idp_metrics.go b/management/server/telemetry/idp_metrics.go index 67a1d9e859a..e9eee17bd3e 100644 --- a/management/server/telemetry/idp_metrics.go +++ b/management/server/telemetry/idp_metrics.go @@ -2,6 +2,7 @@ package telemetry import ( "context" + "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/metric/instrument/syncint64" @@ -13,6 +14,7 @@ type IDPMetrics struct { getUserByEmailCounter syncint64.Counter getAllAccountsCounter syncint64.Counter createUserCounter syncint64.Counter + deleteUserCounter syncint64.Counter getAccountCounter syncint64.Counter getUserByIDCounter syncint64.Counter authenticateRequestCounter syncint64.Counter @@ -39,6 +41,10 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) if err != nil { return nil, err } + deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1")) + if err != nil { + return nil, err + } getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1")) if err != nil { return nil, err @@ -65,6 +71,7 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) getUserByEmailCounter: getUserByEmailCounter, getAllAccountsCounter: getAllAccountsCounter, createUserCounter: createUserCounter, + deleteUserCounter: deleteUserCounter, getAccountCounter: getAccountCounter, getUserByIDCounter: getUserByIDCounter, authenticateRequestCounter: authenticateRequestCounter, @@ -88,6 +95,11 @@ func (idpMetrics *IDPMetrics) CountCreateUser() { idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1) } +// CountDeleteUser ... +func (idpMetrics *IDPMetrics) CountDeleteUser() { + idpMetrics.deleteUserCounter.Add(idpMetrics.ctx, 1) +} + // CountGetAllAccounts ... func (idpMetrics *IDPMetrics) CountGetAllAccounts() { idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1) diff --git a/management/server/telemetry/store_metrics.go b/management/server/telemetry/store_metrics.go index 704ef65d47f..98c13f12a99 100644 --- a/management/server/telemetry/store_metrics.go +++ b/management/server/telemetry/store_metrics.go @@ -11,37 +11,54 @@ import ( // StoreMetrics represents all metrics related to the FileStore type StoreMetrics struct { - globalLockAcquisitionDuration syncint64.Histogram - persistenceDuration syncint64.Histogram - ctx context.Context + globalLockAcquisitionDurationMicro syncint64.Histogram + globalLockAcquisitionDurationMs syncint64.Histogram + persistenceDurationMicro syncint64.Histogram + persistenceDurationMs syncint64.Histogram + ctx context.Context } // NewStoreMetrics creates an instance of StoreMetrics func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, error) { - globalLockAcquisitionDuration, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.micro", + globalLockAcquisitionDurationMicro, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.micro", instrument.WithUnit("microseconds")) if err != nil { return nil, err } - persistenceDuration, err := meter.SyncInt64().Histogram("management.store.persistence.duration.micro", + + globalLockAcquisitionDurationMs, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.ms") + if err != nil { + return nil, err + } + + persistenceDurationMicro, err := meter.SyncInt64().Histogram("management.store.persistence.duration.micro", instrument.WithUnit("microseconds")) if err != nil { return nil, err } + persistenceDurationMs, err := meter.SyncInt64().Histogram("management.store.persistence.duration.ms") + if err != nil { + return nil, err + } + return &StoreMetrics{ - globalLockAcquisitionDuration: globalLockAcquisitionDuration, - persistenceDuration: persistenceDuration, - ctx: ctx, + globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro, + globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs, + persistenceDurationMicro: persistenceDurationMicro, + persistenceDurationMs: persistenceDurationMs, + ctx: ctx, }, nil } // CountGlobalLockAcquisitionDuration counts the duration of the global lock acquisition func (metrics *StoreMetrics) CountGlobalLockAcquisitionDuration(duration time.Duration) { - metrics.globalLockAcquisitionDuration.Record(metrics.ctx, duration.Microseconds()) + metrics.globalLockAcquisitionDurationMicro.Record(metrics.ctx, duration.Microseconds()) + metrics.globalLockAcquisitionDurationMs.Record(metrics.ctx, duration.Milliseconds()) } // CountPersistenceDuration counts the duration of a store persistence operation func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) { - metrics.persistenceDuration.Record(metrics.ctx, duration.Microseconds()) + metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds()) + metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds()) } diff --git a/management/server/turncredentials.go b/management/server/turncredentials.go index 1114aeeabac..aedcf2ee159 100644 --- a/management/server/turncredentials.go +++ b/management/server/turncredentials.go @@ -118,11 +118,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { }, } log.Debugf("sending new TURN credentials to peer %s", peerID) - err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) - if err != nil { - log.Errorf("error while sending TURN update to peer %s %v", peerID, err) - // todo maybe continue trying? - } + m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) } } }() diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 74438654735..5e6bcbb1cf0 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -29,7 +29,7 @@ func NewPeersUpdateManager() *PeersUpdateManager { } // SendUpdate sends update message to the peer's channel -func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) error { +func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) { p.channelsMux.Lock() defer p.channelsMux.Unlock() if channel, ok := p.peerChannels[peerID]; ok { @@ -39,10 +39,9 @@ func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) er default: log.Warnf("channel for peer %s is %d full", peerID, len(channel)) } - return nil + } else { + log.Debugf("peer %s has no channel", peerID) } - log.Debugf("peer %s has no channel", peerID) - return nil } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index c37cd422870..6cfb4d52fc9 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -31,10 +31,7 @@ func TestSendUpdate(t *testing.T) { if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } - err := peersUpdater.SendUpdate(peer, update1) - if err != nil { - t.Error("Error sending update: ", err) - } + peersUpdater.SendUpdate(peer, update1) select { case <-peersUpdater.peerChannels[peer]: default: @@ -42,10 +39,7 @@ func TestSendUpdate(t *testing.T) { } for range [channelBufferSize]int{} { - err = peersUpdater.SendUpdate(peer, update1) - if err != nil { - t.Errorf("got an early error sending update: %v ", err) - } + peersUpdater.SendUpdate(peer, update1) } update2 := &UpdateMessage{Update: &proto.SyncResponse{ @@ -54,10 +48,7 @@ func TestSendUpdate(t *testing.T) { }, }} - err = peersUpdater.SendUpdate(peer, update2) - if err != nil { - t.Error("update shouldn't return an error when channel buffer is full") - } + peersUpdater.SendUpdate(peer, update2) timeout := time.After(5 * time.Second) for range [channelBufferSize]int{} { select { diff --git a/management/server/user.go b/management/server/user.go index 8ee036df732..3169c784f14 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -307,8 +307,17 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( return user, nil } +func (am *DefaultAccountManager) deleteServiceUser(account *Account, initiatorUserID string, targetUser *User) { + meta := map[string]any{"name": targetUser.ServiceUserName} + am.storeEvent(initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) + delete(account.Users, targetUser.Id) +} + // DeleteUser deletes a user from the given account. func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error { + if initiatorUserID == targetUserID { + return status.Errorf(status.InvalidArgument, "self deletion is not allowed") + } unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -317,36 +326,76 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t return err } - targetUser := account.Users[targetUserID] - if targetUser == nil { - return status.Errorf(status.NotFound, "user not found") - } - executingUser := account.Users[initiatorUserID] if executingUser == nil { return status.Errorf(status.NotFound, "user not found") } if executingUser.Role != UserRoleAdmin { - return status.Errorf(status.PermissionDenied, "only admins can delete service users") + return status.Errorf(status.PermissionDenied, "only admins can delete users") } - if !targetUser.IsServiceUser { - return status.Errorf(status.PermissionDenied, "regular users can not be deleted") + targetUser := account.Users[targetUserID] + if targetUser == nil { + return status.Errorf(status.NotFound, "target user not found") } - meta := map[string]any{"name": targetUser.ServiceUserName} - am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta) + // handle service user first and exit, no need to fetch extra data from IDP, etc + if targetUser.IsServiceUser { + am.deleteServiceUser(account, initiatorUserID, targetUser) + return am.Store.SaveAccount(account) + } - delete(account.Users, targetUserID) + return am.deleteRegularUser(account, initiatorUserID, targetUserID) +} + +func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUserID, targetUserID string) error { + tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(account.Id, initiatorUserID, targetUserID) + if err != nil { + log.Errorf("failed to resolve email address: %s", err) + return err + } + if !isNil(am.idpManager) { + err = am.deleteUserFromIDP(targetUserID, account.Id) + if err != nil { + log.Debugf("failed to delete user from IDP: %s", targetUserID) + return err + } + } + + err = am.deleteUserPeers(initiatorUserID, targetUserID, account) + if err != nil { + return err + } + + delete(account.Users, targetUserID) err = am.Store.SaveAccount(account) if err != nil { return err } + meta := map[string]any{"name": tuName, "email": tuEmail} + am.storeEvent(initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) + + am.updateAccountPeers(account) + return nil } +func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetUserID string, account *Account) error { + peers, err := account.FindUserPeers(targetUserID) + if err != nil { + return status.Errorf(status.Internal, "failed to find user peers") + } + + peerIDs := make([]string, 0, len(peers)) + for _, peer := range peers { + peerIDs = append(peerIDs, peer.ID) + } + + return am.deletePeers(account, peerIDs, initiatorUserID) +} + // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { unlock := am.Store.AcquireAccountLock(accountID) @@ -609,23 +658,10 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd if err != nil { return nil, err } - var peerIDs []string - for _, peer := range blockedPeers { - peerIDs = append(peerIDs, peer.ID) - peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) - if err != nil { - log.Errorf("failed saving peer status while expiring peer %s", peer.ID) - return nil, err - } - } - am.peersUpdateManager.CloseChannels(peerIDs) - err = am.updateAccountPeers(account) - if err != nil { - log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID) - return nil, err + if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil { + log.Errorf("failed update expired peers: %s", err) + return nil, err } } @@ -640,9 +676,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd return nil, err } - if err := am.updateAccountPeers(account); err != nil { - log.Errorf("failed updating account peers while updating user %s", accountID) - } + am.updateAccountPeers(account) } else { if err = am.Store.SaveAccount(account); err != nil { return nil, err @@ -814,6 +848,68 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( return userInfos, nil } +// expireAndUpdatePeers expires all peers of the given user and updates them in the account +func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error { + var peerIDs []string + for _, peer := range peers { + if peer.Status.LoginExpired { + continue + } + peerIDs = append(peerIDs, peer.ID) + peer.MarkLoginExpired(true) + account.UpdatePeer(peer) + if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { + return err + } + am.storeEvent( + peer.UserID, peer.ID, account.Id, + activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), + ) + } + + if len(peerIDs) != 0 { + // this will trigger peer disconnect from the management service + am.peersUpdateManager.CloseChannels(peerIDs) + am.updateAccountPeers(account) + } + return nil +} + +func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error { + if am.userDeleteFromIDPEnabled { + log.Debugf("user %s deleted from IdP", targetUserID) + err := am.idpManager.DeleteUser(targetUserID) + if err != nil { + return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err) + } + } else { + err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{}) + if err != nil { + return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err) + } + + _, err = am.refreshCache(accountID) + if err != nil { + log.Errorf("refresh account (%q) cache: %v", accountID, err) + } + } + return nil +} + +func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(accountId, initiatorId, targetId string) (string, string, error) { + userInfos, err := am.GetUsersFromAccount(accountId, initiatorId) + if err != nil { + return "", "", err + } + for _, ui := range userInfos { + if ui.ID == targetId { + return ui.Email, ui.Name, nil + } + } + + return "", "", fmt.Errorf("user info not found for user: %s", targetId) +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index b071546639b..1565814b81b 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -424,7 +424,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) } -func TestUser_DeleteUser_regularUser(t *testing.T) { +func TestUser_DeleteUser_SelfDelete(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") @@ -439,8 +439,35 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) + if err == nil { + t.Fatalf("failed to prevent self deletion") + } +} - assert.Errorf(t, err, "Regular users can not be deleted (yet)") +func TestUser_DeleteUser_regularUser(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + targetId := "user2" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: true, + ServiceUserName: "user2username", + } + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + err = am.DeleteUser(mockAccountID, mockUserID, targetId) + if err != nil { + t.Errorf("unexpected error: %s", err) + } } func TestDefaultAccountManager_GetUser(t *testing.T) { diff --git a/release_files/install.sh b/release_files/install.sh index 3df085016ff..c553cc28a45 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -1,4 +1,3 @@ -#!/bin/sh # This code is based on the netbird-installer contribution by physk on GitHub. # Source: https://github.com/physk/netbird-installer set -e @@ -17,6 +16,12 @@ OS_TYPE="" ARCH="$(uname -m)" PACKAGE_MANAGER="bin" INSTALL_DIR="" +SUDO="" + + +if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then + SUDO="sudo" +fi get_latest_release() { if [ -n "$GITHUB_TOKEN" ]; then @@ -65,27 +70,35 @@ download_release_binary() { unzip -q -o "$BINARY_NAME" mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR" else - sudo mkdir -p "$INSTALL_DIR" + ${SUDO} mkdir -p "$INSTALL_DIR" tar -xzvf "$BINARY_NAME" - sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/" + ${SUDO} mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/" fi } add_apt_repo() { - sudo apt-get update - sudo apt-get install ca-certificates gnupg -y + ${SUDO} apt-get update + ${SUDO} apt-get install ca-certificates curl gnupg -y - curl -sSL https://pkgs.wiretrustee.com/debian/public.key \ - | sudo gpg --dearmor --output /usr/share/keyrings/wiretrustee-archive-keyring.gpg + # Remove old keys and repo source files + ${SUDO} rm -f \ + /etc/apt/sources.list.d/netbird.list \ + /etc/apt/sources.list.d/wiretrustee.list \ + /etc/apt/trusted.gpg.d/wiretrustee.gpg \ + /usr/share/keyrings/netbird-archive-keyring.gpg \ + /usr/share/keyrings/wiretrustee-archive-keyring.gpg - APT_REPO="deb [signed-by=/usr/share/keyrings/wiretrustee-archive-keyring.gpg] https://pkgs.wiretrustee.com/debian stable main" - echo "$APT_REPO" | sudo tee /etc/apt/sources.list.d/wiretrustee.list + curl -sSL https://pkgs.netbird.io/debian/public.key \ + | ${SUDO} gpg --dearmor -o /usr/share/keyrings/netbird-archive-keyring.gpg - sudo apt-get update + echo 'deb [signed-by=/usr/share/keyrings/netbird-archive-keyring.gpg] https://pkgs.netbird.io/debian stable main' \ + | ${SUDO} tee /etc/apt/sources.list.d/netbird.list + + ${SUDO} apt-get update } add_rpm_repo() { -cat <<-EOF | sudo tee /etc/yum.repos.d/netbird.repo +cat <<-EOF | ${SUDO} tee /etc/yum.repos.d/netbird.repo [NetBird] name=NetBird baseurl=https://pkgs.netbird.io/yum/ @@ -104,7 +117,7 @@ add_aur_repo() { for PKG in $INSTALL_PKGS; do if ! pacman -Q "$PKG" > /dev/null 2>&1; then # Install missing package(s) - sudo pacman -S "$PKG" --noconfirm + ${SUDO} pacman -S "$PKG" --noconfirm # Add installed package for clean up later REMOVE_PKGS="$REMOVE_PKGS $PKG" @@ -121,7 +134,7 @@ add_aur_repo() { fi # Clean up the installed packages - sudo pacman -Rs "$REMOVE_PKGS" --noconfirm + ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm } install_native_binaries() { @@ -181,90 +194,36 @@ install_netbird() { fi fi - # Identify OS name and default package manager - if type uname >/dev/null 2>&1; then - case "$(uname)" in - Linux) - OS_NAME="$(. /etc/os-release && echo "$ID")" - OS_TYPE="linux" - INSTALL_DIR="/usr/bin" - - # Allow netbird UI installation for x64 arch only - if [ "$ARCH" != "amd64" ] && [ "$ARCH" != "arm64" ] \ - && [ "$ARCH" != "x86_64" ];then - SKIP_UI_APP=true - echo "NetBird UI installation will be omitted as $ARCH is not a compactible architecture" - fi - - # Allow netbird UI installation for linux running desktop enviroment - if [ -z "$XDG_CURRENT_DESKTOP" ];then - SKIP_UI_APP=true - echo "NetBird UI installation will be omitted as Linux does not run desktop environment" - fi - - # Check the availability of a compatible package manager - if check_use_bin_variable; then - PACKAGE_MANAGER="bin" - elif [ -x "$(command -v apt)" ]; then - PACKAGE_MANAGER="apt" - echo "The installation will be performed using apt package manager" - elif [ -x "$(command -v dnf)" ]; then - PACKAGE_MANAGER="dnf" - echo "The installation will be performed using dnf package manager" - elif [ -x "$(command -v yum)" ]; then - PACKAGE_MANAGER="yum" - echo "The installation will be performed using yum package manager" - elif [ -x "$(command -v pacman)" ]; then - PACKAGE_MANAGER="pacman" - echo "The installation will be performed using pacman package manager" - fi - ;; - Darwin) - OS_NAME="macos" - OS_TYPE="darwin" - INSTALL_DIR="/usr/local/bin" - - # Check the availability of a compatible package manager - if check_use_bin_variable; then - PACKAGE_MANAGER="bin" - elif [ -x "$(command -v brew)" ]; then - PACKAGE_MANAGER="brew" - echo "The installation will be performed using brew package manager" - fi - ;; - esac - fi - # Run the installation, if a desktop environment is not detected # only the CLI will be installed case "$PACKAGE_MANAGER" in apt) add_apt_repo - sudo apt-get install netbird -y + ${SUDO} apt-get install netbird -y if ! $SKIP_UI_APP; then - sudo apt-get install netbird-ui -y + ${SUDO} apt-get install netbird-ui -y fi ;; yum) add_rpm_repo - sudo yum -y install netbird + ${SUDO} yum -y install netbird if ! $SKIP_UI_APP; then - sudo yum -y install netbird-ui + ${SUDO} yum -y install netbird-ui fi ;; dnf) add_rpm_repo - sudo dnf -y install dnf-plugin-config-manager - sudo dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo - sudo dnf -y install netbird + ${SUDO} dnf -y install dnf-plugin-config-manager + ${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + ${SUDO} dnf -y install netbird if ! $SKIP_UI_APP; then - sudo dnf -y install netbird-ui + ${SUDO} dnf -y install netbird-ui fi ;; pacman) - sudo pacman -Syy + ${SUDO} pacman -Syy add_aur_repo ;; brew) @@ -297,7 +256,7 @@ install_netbird() { echo "Build and apply new configuration:" echo "" - echo "sudo nixos-rebuild switch" + echo "${SUDO} nixos-rebuild switch" exit 0 fi @@ -306,21 +265,21 @@ install_netbird() { esac # Add package manager to config - sudo mkdir -p "$CONFIG_FOLDER" - echo "package_manager=$PACKAGE_MANAGER" | sudo tee "$CONFIG_FILE" > /dev/null + ${SUDO} mkdir -p "$CONFIG_FOLDER" + echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null # Load and start netbird service - if ! sudo netbird service install 2>&1; then + if ! ${SUDO} netbird service install 2>&1; then echo "NetBird service has already been loaded" fi - if ! sudo netbird service start 2>&1; then + if ! ${SUDO} netbird service start 2>&1; then echo "NetBird service has already been started" fi echo "Installation has been finished. To connect, you need to run NetBird by executing the following command:" echo "" - echo "sudo netbird up" + echo "netbird up" } version_greater_equal() { @@ -328,7 +287,7 @@ version_greater_equal() { } is_bin_package_manager() { - if sudo test -f "$1" && sudo grep -q "package_manager=bin" "$1" ; then + if ${SUDO} test -f "$1" && ${SUDO} grep -q "package_manager=bin" "$1" ; then return 0 else return 1 @@ -351,18 +310,71 @@ update_netbird() { echo "" echo "Initiating NetBird update. This will stop the netbird service and restart it after the update" - sudo netbird service stop - sudo netbird service uninstall + ${SUDO} netbird service stop + ${SUDO} netbird service uninstall install_native_binaries - sudo netbird service install - sudo netbird service start + ${SUDO} netbird service install + ${SUDO} netbird service start fi else echo "NetBird installation was done using a package manager. Please use your system's package manager to update" fi } +# Identify OS name and default package manager +if type uname >/dev/null 2>&1; then + case "$(uname)" in + Linux) + OS_NAME="$(. /etc/os-release && echo "$ID")" + OS_TYPE="linux" + INSTALL_DIR="/usr/bin" + + # Allow netbird UI installation for x64 arch only + if [ "$ARCH" != "amd64" ] && [ "$ARCH" != "arm64" ] \ + && [ "$ARCH" != "x86_64" ];then + SKIP_UI_APP=true + echo "NetBird UI installation will be omitted as $ARCH is not a compactible architecture" + fi + + # Allow netbird UI installation for linux running desktop enviroment + if [ -z "$XDG_CURRENT_DESKTOP" ];then + SKIP_UI_APP=true + echo "NetBird UI installation will be omitted as Linux does not run desktop environment" + fi + + # Check the availability of a compatible package manager + if check_use_bin_variable; then + PACKAGE_MANAGER="bin" + elif [ -x "$(command -v apt)" ]; then + PACKAGE_MANAGER="apt" + echo "The installation will be performed using apt package manager" + elif [ -x "$(command -v dnf)" ]; then + PACKAGE_MANAGER="dnf" + echo "The installation will be performed using dnf package manager" + elif [ -x "$(command -v yum)" ]; then + PACKAGE_MANAGER="yum" + echo "The installation will be performed using yum package manager" + elif [ -x "$(command -v pacman)" ]; then + PACKAGE_MANAGER="pacman" + echo "The installation will be performed using pacman package manager" + fi + ;; + Darwin) + OS_NAME="macos" + OS_TYPE="darwin" + INSTALL_DIR="/usr/local/bin" + + # Check the availability of a compatible package manager + if check_use_bin_variable; then + PACKAGE_MANAGER="bin" + elif [ -x "$(command -v brew)" ]; then + PACKAGE_MANAGER="brew" + echo "The installation will be performed using brew package manager" + fi + ;; + esac +fi case "$1" in --update) @@ -370,4 +382,4 @@ case "$1" in ;; *) install_netbird -esac \ No newline at end of file +esac diff --git a/route/route.go b/route/route.go index 5c45e2cf58a..eb7bcba2f32 100644 --- a/route/route.go +++ b/route/route.go @@ -70,6 +70,7 @@ type Route struct { NetID string Description string Peer string + PeerGroups []string NetworkType NetworkType Masquerade bool Metric int @@ -79,7 +80,7 @@ type Route struct { // EventMeta returns activity event meta related to the route func (r *Route) EventMeta() map[string]any { - return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "peer_id": r.Peer} + return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "peer_id": r.Peer, "peer_groups": r.PeerGroups} } // Copy copies a route object @@ -91,12 +92,14 @@ func (r *Route) Copy() *Route { Network: r.Network, NetworkType: r.NetworkType, Peer: r.Peer, + PeerGroups: make([]string, len(r.PeerGroups)), Metric: r.Metric, Masquerade: r.Masquerade, Enabled: r.Enabled, Groups: make([]string, len(r.Groups)), } copy(route.Groups, r.Groups) + copy(route.PeerGroups, r.PeerGroups) return route } @@ -111,7 +114,8 @@ func (r *Route) IsEqual(other *Route) bool { other.Metric == r.Metric && other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && - compareGroupsList(r.Groups, other.Groups) + compareList(r.Groups, other.Groups) && + compareList(r.PeerGroups, other.PeerGroups) } // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 @@ -134,7 +138,7 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { return IPv4Network, masked, nil } -func compareGroupsList(list, other []string) bool { +func compareList(list, other []string) bool { if len(list) != len(other) { return false } diff --git a/util/file.go b/util/file.go index 022841947ee..0cbfa37ab33 100644 --- a/util/file.go +++ b/util/file.go @@ -5,6 +5,8 @@ import ( "io" "os" "path/filepath" + + log "github.com/sirupsen/logrus" ) // WriteJson writes JSON config object to a file creating parent directories if required @@ -54,6 +56,68 @@ func WriteJson(file string, obj interface{}) error { return nil } +// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file +func DirectWriteJson(file string, obj interface{}) error { + + _, _, err := prepareConfigFileDir(file) + if err != nil { + return err + } + + targetFile, err := openOrCreateFile(file) + if err != nil { + return err + } + + defer func() { + err = targetFile.Close() + if err != nil { + log.Errorf("failed to close file %s: %v", file, err) + } + }() + + // make it pretty + bs, err := json.MarshalIndent(obj, "", " ") + if err != nil { + return err + } + + err = targetFile.Truncate(0) + if err != nil { + return err + } + + _, err = targetFile.Write(bs) + if err != nil { + return err + } + + return nil +} + +func openOrCreateFile(file string) (*os.File, error) { + s, err := os.Stat(file) + if err == nil { + return os.OpenFile(file, os.O_WRONLY, s.Mode()) + } + + if !os.IsNotExist(err) { + return nil, err + } + + targetFile, err := os.Create(file) + if err != nil { + return nil, err + } + //no:lint + err = targetFile.Chmod(0640) + if err != nil { + _ = targetFile.Close() + return nil, err + } + return targetFile, nil +} + // ReadJson reads JSON config file and maps to a provided interface func ReadJson(file string, res interface{}) (interface{}, error) {