diff --git a/client/src/pages/config/users/configman.jsx b/client/src/pages/config/users/configman.jsx index 70c3ba20..0654725c 100644 --- a/client/src/pages/config/users/configman.jsx +++ b/client/src/pages/config/users/configman.jsx @@ -108,6 +108,7 @@ const ConfigManagement = () => { GenerateMissingAuthCert: config.HTTPConfig.GenerateMissingAuthCert, HTTPPort: config.HTTPConfig.HTTPPort, HTTPSPort: config.HTTPConfig.HTTPSPort, + TrustedProxies: config.HTTPConfig.TrustedProxies && config.HTTPConfig.TrustedProxies.join(', '), SSLEmail: config.HTTPConfig.SSLEmail, UseWildcardCertificate: config.HTTPConfig.UseWildcardCertificate, HTTPSCertificateMode: config.HTTPConfig.HTTPSCertificateMode, @@ -208,6 +209,8 @@ const ConfigManagement = () => { AllowSearchEngine: values.AllowSearchEngine, AllowHTTPLocalIPAccess: values.AllowHTTPLocalIPAccess, PublishMDNS: values.PublishMDNS, + TrustedProxies: (values.TrustedProxies && values.TrustedProxies != "") ? + values.TrustedProxies.split(',').map((x) => x.trim()) : [], }, EmailConfig: { ...config.EmailConfig, @@ -640,6 +643,19 @@ const ConfigManagement = () => { )} + + + + + {t('mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText')}
+
+ +
+
{t('mgmt.config.http.allowSearchIndexCheckbox')}
diff --git a/client/src/utils/locales/en/translation.json b/client/src/utils/locales/en/translation.json index d7a56fe8..93f7af9c 100644 --- a/client/src/utils/locales/en/translation.json +++ b/client/src/utils/locales/en/translation.json @@ -176,6 +176,8 @@ "mgmt.config.http.hostnameInput.HostnameLabel": "Hostname: This will be used to restrict access to your Cosmos Server (Your IP, or your domain name)", "mgmt.config.http.hostnameInput.HostnameValidation": "Hostname is required", "mgmt.config.http.publishMDNSCheckbox": "This allows you to publish your server on your local network using mDNS. This means all your .local domains will be available on your local network with no additional config.", + "mgmt.config.http.TrustedProxiesInput.TrustesProxiesLabel": "Trusted proxies allow X-Forwarded-For from IP/IP range.", + "mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText": "Use this setting when you have an upstream proxy server to avoid it being blocked by Shield. IPs or IP ranges separated by commas.", "mgmt.config.email.notifyLoginCheckbox.notifyLoginLabel": "Notify Users upon Successful Login", "mgmt.config.proxy.noRoutesConfiguredText": "No routes configured.", "mgmt.config.proxy.originTitle": "Origin", diff --git a/client/src/utils/locales/fr/translation.json b/client/src/utils/locales/fr/translation.json index f1f701ee..1372a4c3 100644 --- a/client/src/utils/locales/fr/translation.json +++ b/client/src/utils/locales/fr/translation.json @@ -174,6 +174,8 @@ "mgmt.config.http.hostnameInput.HostnameLabel": "Nom d'hôte : Cela sera utilisé pour restreindre l'accès à votre serveur Cosmos (Votre IP, ou votre nom de domaine)", "mgmt.config.http.hostnameInput.HostnameValidation": "Le nom d'hôte est obligatoire", "mgmt.config.http.publishMDNSCheckbox": "Cela vous permet de publier votre serveur sur votre réseau local en utilisant mDNS. Cela signifie que tous vos domaines .local seront disponibles sur votre réseau local sans configuration supplémentaire.", + "mgmt.config.http.TrustedProxiesInput.TrustesProxiesLabel": "IPs/Plages IP des proxys de confiance pour l'utilisation de X-Forwarded-For.", + "mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText": "Utilisez ce paramètre lorsque vous avez un serveur proxy en amont pour éviter le blocage de celui-ci par le Shield. IPs ou plages IP séparées par des virgules.", "mgmt.config.email.notifyLoginCheckbox.notifyLoginLabel": "Notifier les utilisateurs en cas de connexion réussie", "mgmt.config.proxy.noRoutesConfiguredText": "Aucune route configurée.", "mgmt.config.proxy.originTitle": "Origine", diff --git a/src/httpServer.go b/src/httpServer.go index c0ad49b3..d805179d 100644 --- a/src/httpServer.go +++ b/src/httpServer.go @@ -359,6 +359,8 @@ func InitServer() *mux.Router { router := mux.NewRouter().StrictSlash(true) + router.Use(utils.ClientRealIP) + router.Use(utils.BlockBannedIPs) router.Use(utils.Logger) diff --git a/src/proxy/shield.go b/src/proxy/shield.go index 58683576..8a790aaf 100644 --- a/src/proxy/shield.go +++ b/src/proxy/shield.go @@ -7,6 +7,7 @@ import ( "fmt" "math" "strconv" + "strings" "github.com/azukaar/cosmos-server/src/utils" "github.com/azukaar/cosmos-server/src/metrics" @@ -296,14 +297,19 @@ func calculateLowestExhaustedPercentage(policy utils.SmartShieldPolicy, userCons func GetClientID(r *http.Request, route utils.ProxyRouteConfig) string { // when using Docker we need to get the real IP remoteAddr, _ := utils.SplitIP(r.RemoteAddr) - UseForwardedFor := utils.GetMainConfig().HTTPConfig.UseForwardedFor - isTunneledIp := constellation.GetDeviceIp(route.TunnelVia) == remoteAddr - isConstIP := utils.IsConstellationIP(remoteAddr) - isConstTokenValid := constellation.CheckConstellationToken(r) == nil - - if (UseForwardedFor && r.Header.Get("x-forwarded-for") != "") || - (isTunneledIp && isConstIP && isConstTokenValid) { - ip, _ := utils.SplitIP(r.Header.Get("x-forwarded-for")) + useForwardedForHeader := false + if r.Header.Get("x-forwarded-for") != "" { + useForwardedForHeader = utils.IsTrustedProxy(remoteAddr) + if !useForwardedForHeader { + isTunneledIp := constellation.GetDeviceIp(route.TunnelVia) == remoteAddr + isConstIP := utils.IsConstellationIP(remoteAddr) + isConstTokenValid := constellation.CheckConstellationToken(r) == nil + useForwardedForHeader = isTunneledIp && isConstIP && isConstTokenValid + } + } + + if useForwardedForHeader { + ip, _ := utils.SplitIP(strings.TrimSpace(strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0])) utils.Debug("SmartShield: Getting forwarded client ID " + ip) return ip } else { diff --git a/src/utils/middleware.go b/src/utils/middleware.go index 4fd59aec..cab86042 100644 --- a/src/utils/middleware.go +++ b/src/utils/middleware.go @@ -48,34 +48,49 @@ func GetIPAbuseCounter(ip string) int64 { return atomic.LoadInt64(&counter.val) } +func ClientRealIP(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clientID := GetClientIP(r) + if(clientID == ""){ + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + ctx := context.WithValue(r.Context(), "ClientID", clientID) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) +} + func BlockBannedIPs(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - if hj, ok := w.(http.Hijacker); ok { - conn, _, err := hj.Hijack() - if err == nil { - conn.Close() - } - } - return + ip, ok := r.Context().Value("ClientID").(string) + if !ok { + if hj, ok := w.(http.Hijacker); ok { + conn, _, err := hj.Hijack() + if err == nil { + conn.Close() + } + } + return } nbAbuse := GetIPAbuseCounter(ip) if nbAbuse > 275 { - Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.") - } + Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.") + } if nbAbuse > 300 { - if hj, ok := w.(http.Hijacker); ok { - conn, _, err := hj.Hijack() - if err == nil { - conn.Close() - } - } - return + if hj, ok := w.(http.Hijacker); ok { + conn, _, err := hj.Hijack() + if err == nil { + conn.Close() } + } + return + } next.ServeHTTP(w, r) }) @@ -204,8 +219,8 @@ func GetIPLocation(ip string) (string, error) { func BlockByCountryMiddleware(blockedCountries []string, CountryBlacklistIsWhitelist bool) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { + ip, ok := r.Context().Value("ClientID").(string) + if !ok { http.Error(w, "Invalid request", http.StatusBadRequest) return } @@ -289,7 +304,7 @@ func BlockPostWithoutReferer(next http.Handler) http.Handler { Error("Blocked POST request without Referer header", nil) http.Error(w, "Bad Request: Invalid request.", http.StatusBadRequest) - ip, _, _ := net.SplitHostPort(r.RemoteAddr) + ip, _ := r.Context().Value("ClientID").(string) if ip != "" { TriggerEvent( "cosmos.proxy.shield.referer", @@ -349,7 +364,7 @@ func EnsureHostname(next http.Handler) http.Handler { w.WriteHeader(http.StatusBadRequest) http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest) - ip, _, _ := net.SplitHostPort(r.RemoteAddr) + ip, _ := r.Context().Value("ClientID").(string) if ip != "" { TriggerEvent( "cosmos.proxy.shield.hostname", @@ -415,7 +430,7 @@ func EnsureHostnameCosmosAPI(next http.Handler) http.Handler { w.WriteHeader(http.StatusBadRequest) http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest) - ip, _, _ := net.SplitHostPort(r.RemoteAddr) + ip, _ := r.Context().Value("ClientID").(string) if ip != "" { TriggerEvent( "cosmos.proxy.shield.hostname", @@ -469,16 +484,19 @@ func IsValidHostname(hostname string) bool { } func IPInRange(ipStr, cidrStr string) (bool, error) { - _, cidrNet, err := net.ParseCIDR(cidrStr) - if err != nil { - return false, fmt.Errorf("parse CIDR range: %w", err) - } - ip := net.ParseIP(ipStr) if ip == nil { return false, fmt.Errorf("parse IP: invalid IP address") } + _, cidrNet, err := net.ParseCIDR(cidrStr) + if err != nil { + if ipStr == cidrStr { + return true, nil + } + return false, fmt.Errorf("parse CIDR range: %w", err) + } + return cidrNet.Contains(ip), nil } @@ -486,8 +504,9 @@ func Restrictions(RestrictToConstellation bool, WhitelistInboundIPs []string) fu return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { + remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr) + ip, ok := r.Context().Value("ClientID").(string) + if (err != nil) || !ok { http.Error(w, "Invalid request", http.StatusBadRequest) return } @@ -495,7 +514,7 @@ func Restrictions(RestrictToConstellation bool, WhitelistInboundIPs []string) fu isUsingWhiteList := len(WhitelistInboundIPs) > 0 isInWhitelist := false - isInConstellation := strings.HasPrefix(ip, "192.168.201.") || strings.HasPrefix(ip, "192.168.202.") + isInConstellation := strings.HasPrefix(remoteAddr, "192.168.201.") || strings.HasPrefix(remoteAddr, "192.168.202.") for _, ipRange := range WhitelistInboundIPs { Debug("Checking if " + ip + " is in " + ipRange) diff --git a/src/utils/types.go b/src/utils/types.go index f36e37ec..5a8f7cbe 100644 --- a/src/utils/types.go +++ b/src/utils/types.go @@ -178,6 +178,7 @@ type HTTPConfig struct { UseForwardedFor bool AllowSearchEngine bool PublishMDNS bool + TrustedProxies []string } const ( diff --git a/src/utils/utils.go b/src/utils/utils.go index d46da55d..dc91b008 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -801,11 +801,13 @@ func DownloadFileToLocation(path, url string) error { } func GetClientIP(req *http.Request) string { - /*ip := req.Header.Get("X-Forwarded-For") - if ip == "" { - ip = req.RemoteAddr - }*/ - return req.RemoteAddr + // when using Docker we need to get the real IP + remoteAddr, _ := SplitIP(req.RemoteAddr) + + if req.Header.Get("x-forwarded-for") != "" && IsTrustedProxy(remoteAddr) { + remoteAddr, _ = SplitIP(strings.TrimSpace(strings.Split(req.Header.Get("X-Forwarded-For"), ",")[0])) + } + return remoteAddr } func IsDomain(domain string) bool { @@ -921,6 +923,15 @@ func IsConstellationIP(ip string) bool { return false } +func IsTrustedProxy(ip string) bool { + for _, trustedProxy := range GetMainConfig().HTTPConfig.TrustedProxies { + if isInRange, _ := IPInRange(ip, trustedProxy); isInRange { + return true + } + } + return false +} + func SplitIP(ipPort string) (string, string) { host, port, err := osnet.SplitHostPort(ipPort) if err != nil {