Skip to content

Commit

Permalink
Merge branch 'main' into feature/add-ios-support
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer committed Dec 14, 2023
2 parents bacc322 + f73a2e2 commit 0df30be
Show file tree
Hide file tree
Showing 21 changed files with 244 additions and 64 deletions.
1 change: 1 addition & 0 deletions .github/workflows/sync-tag.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
10 changes: 9 additions & 1 deletion client/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var loginCmd = &cobra.Command{
AdminURL: adminURL,
ConfigPath: configPath,
}
if preSharedKey != "" {
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}

Expand Down Expand Up @@ -151,13 +151,21 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
jwtToken = tokenInfo.GetTokenToUse()
}

var lastError error

err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})

if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}

if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion client/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

const (
externalIPMapFlag = "external-ip-map"
preSharedKeyFlag = "preshared-key"
dnsResolverAddress = "dns-resolver-address"
)

Expand Down Expand Up @@ -94,7 +95,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
Expand Down
50 changes: 42 additions & 8 deletions client/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ type statusOutputOverview struct {
}

var (
detailFlag bool
ipv4Flag bool
jsonFlag bool
yamlFlag bool
ipsFilter []string
statusFilter string
ipsFilterMap map[string]struct{}
detailFlag bool
ipv4Flag bool
jsonFlag bool
yamlFlag bool
ipsFilter []string
prefixNamesFilter []string
statusFilter string
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
)

var statusCmd = &cobra.Command{
Expand All @@ -83,12 +85,14 @@ var statusCmd = &cobra.Command{

func init() {
ipsFilterMap = make(map[string]struct{})
prefixNamesFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
}

Expand Down Expand Up @@ -172,8 +176,12 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
}

func parseFilters() error {

switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
}
Expand All @@ -185,11 +193,26 @@ func parseFilters() error {
return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err)
}
ipsFilterMap[addr] = struct{}{}
enableDetailFlagWhenFilterFlag()
}
}

if len(prefixNamesFilter) > 0 {
for _, name := range prefixNamesFilter {
prefixNamesFilterMap[strings.ToLower(name)] = struct{}{}
}
enableDetailFlagWhenFilterFlag()
}

return nil
}

func enableDetailFlagWhenFilterFlag() {
if !detailFlag && !jsonFlag && !yamlFlag {
detailFlag = true
}
}

func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
pbFullStatus := resp.GetFullStatus()

Expand Down Expand Up @@ -415,6 +438,7 @@ func parsePeers(peers peersStateOutput) string {
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := false

if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
Expand All @@ -431,5 +455,15 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
ipEval = true
}
}
return statusEval || ipEval

if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true
break
}
}
}

return statusEval || ipEval || nameEval
}
3 changes: 2 additions & 1 deletion client/cmd/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
}
if preSharedKey != "" {

if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}

Expand Down
9 changes: 3 additions & 6 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,9 @@ func update(input ConfigInput) (*Config, error) {
}

if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
if *input.PreSharedKey != "" {
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
*input.PreSharedKey, config.PreSharedKey)
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
refresh = true
}

if config.SSHKey == "" {
Expand Down
20 changes: 3 additions & 17 deletions client/internal/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import (
"path/filepath"
"testing"

"github.com/netbirdio/netbird/util"
"github.com/stretchr/testify/assert"

"github.com/netbirdio/netbird/util"
)

func TestGetConfig(t *testing.T) {
Expand Down Expand Up @@ -60,22 +61,7 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)

// case 4: new empty pre-shared key config -> fetch it
newPreSharedKey := ""
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: path,
PreSharedKey: &newPreSharedKey,
})
if err != nil {
return
}

assert.Equal(t, config.ManagementURL.String(), managementURL)
assert.Equal(t, config.PreSharedKey, preSharedKey)

// case 5: existing config, but new managementURL has been provided -> update config
// case 4: existing config, but new managementURL has been provided -> update config
newManagementURL := "https://test.newManagement.url:33071"
config, err = UpdateOrCreateConfig(ConfigInput{
ManagementURL: newManagementURL,
Expand Down
5 changes: 4 additions & 1 deletion management/cmd/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ var (
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
}
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled

if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
}

tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
Expand Down
5 changes: 3 additions & 2 deletions management/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (

const (
// ExitSetupFailed defines exit code
ExitSetupFailed = 1
ExitSetupFailed = 1
idpSignKeyRefreshEnabledFlagName = "idp-sign-key-refresh-enabled"
)

var (
Expand Down Expand Up @@ -62,7 +63,7 @@ func init() {
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max length is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, idpSignKeyRefreshEnabledFlagName, false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
rootCmd.MarkFlagRequired("config") //nolint

Expand Down
53 changes: 52 additions & 1 deletion management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (

"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/netbirdio/management-integrations/additions"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/management-integrations/additions"

"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
Expand Down Expand Up @@ -66,6 +67,7 @@ type AccountManager interface {
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(accountID, userID string) error
MarkPATUsed(tokenID string) error
Expand Down Expand Up @@ -164,6 +166,9 @@ type Settings struct {
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string

// JWTAllowGroups list of groups to which users are allowed access
JWTAllowGroups []string `gorm:"serializer:json"`

// Extra is a dictionary of Account settings
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
}
Expand All @@ -176,6 +181,7 @@ func (s *Settings) Copy() *Settings {
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()
Expand Down Expand Up @@ -1693,6 +1699,39 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain
}

// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
account, _, err := am.GetAccountFromToken(claims)
if err != nil {
return err
}

// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0)

if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}

if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
}
}
}

return nil
}

// addAllGroup to account object if it doesn't exists
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
Expand Down Expand Up @@ -1764,3 +1803,15 @@ func newAccountWithId(accountID, userID, domain string) *Account {
}
return acc
}

// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
for _, allowedGroup := range allowedGroups {
if userGroup == allowedGroup {
return true
}
}
}
return false
}
6 changes: 5 additions & 1 deletion management/server/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}

if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error())
}

return claims.UserId, nil
}

Expand Down Expand Up @@ -312,7 +316,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
userID, err = s.validateToken(loginReq.GetJwtToken())
if err != nil {
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
return nil, mapError(err)
return nil, err
}
}
var sshKey []byte
Expand Down
9 changes: 9 additions & 0 deletions management/server/http/accounts_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
if req.Settings.JwtGroupsClaimName != nil {
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
if req.Settings.JwtAllowGroups != nil {
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}

updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
if err != nil {
Expand Down Expand Up @@ -128,12 +131,18 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
}

func toAccountResponse(account *server.Account) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}

settings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
}

if account.Settings.Extra != nil {
Expand Down
Loading

0 comments on commit 0df30be

Please sign in to comment.