diff --git a/client/src/pages/config/users/configman.jsx b/client/src/pages/config/users/configman.jsx
index 70c3ba20..0654725c 100644
--- a/client/src/pages/config/users/configman.jsx
+++ b/client/src/pages/config/users/configman.jsx
@@ -108,6 +108,7 @@ const ConfigManagement = () => {
GenerateMissingAuthCert: config.HTTPConfig.GenerateMissingAuthCert,
HTTPPort: config.HTTPConfig.HTTPPort,
HTTPSPort: config.HTTPConfig.HTTPSPort,
+ TrustedProxies: config.HTTPConfig.TrustedProxies && config.HTTPConfig.TrustedProxies.join(', '),
SSLEmail: config.HTTPConfig.SSLEmail,
UseWildcardCertificate: config.HTTPConfig.UseWildcardCertificate,
HTTPSCertificateMode: config.HTTPConfig.HTTPSCertificateMode,
@@ -208,6 +209,8 @@ const ConfigManagement = () => {
AllowSearchEngine: values.AllowSearchEngine,
AllowHTTPLocalIPAccess: values.AllowHTTPLocalIPAccess,
PublishMDNS: values.PublishMDNS,
+ TrustedProxies: (values.TrustedProxies && values.TrustedProxies != "") ?
+ values.TrustedProxies.split(',').map((x) => x.trim()) : [],
},
EmailConfig: {
...config.EmailConfig,
@@ -640,6 +643,19 @@ const ConfigManagement = () => {
)}
+
+
+
+
+ {t('mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText')}
+
+
+
+
{t('mgmt.config.http.allowSearchIndexCheckbox')}
diff --git a/client/src/utils/locales/en/translation.json b/client/src/utils/locales/en/translation.json
index d7a56fe8..93f7af9c 100644
--- a/client/src/utils/locales/en/translation.json
+++ b/client/src/utils/locales/en/translation.json
@@ -176,6 +176,8 @@
"mgmt.config.http.hostnameInput.HostnameLabel": "Hostname: This will be used to restrict access to your Cosmos Server (Your IP, or your domain name)",
"mgmt.config.http.hostnameInput.HostnameValidation": "Hostname is required",
"mgmt.config.http.publishMDNSCheckbox": "This allows you to publish your server on your local network using mDNS. This means all your .local domains will be available on your local network with no additional config.",
+ "mgmt.config.http.TrustedProxiesInput.TrustesProxiesLabel": "Trusted proxies allow X-Forwarded-For from IP/IP range.",
+ "mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText": "Use this setting when you have an upstream proxy server to avoid it being blocked by Shield. IPs or IP ranges separated by commas.",
"mgmt.config.email.notifyLoginCheckbox.notifyLoginLabel": "Notify Users upon Successful Login",
"mgmt.config.proxy.noRoutesConfiguredText": "No routes configured.",
"mgmt.config.proxy.originTitle": "Origin",
diff --git a/client/src/utils/locales/fr/translation.json b/client/src/utils/locales/fr/translation.json
index f1f701ee..1372a4c3 100644
--- a/client/src/utils/locales/fr/translation.json
+++ b/client/src/utils/locales/fr/translation.json
@@ -174,6 +174,8 @@
"mgmt.config.http.hostnameInput.HostnameLabel": "Nom d'hôte : Cela sera utilisé pour restreindre l'accès à votre serveur Cosmos (Votre IP, ou votre nom de domaine)",
"mgmt.config.http.hostnameInput.HostnameValidation": "Le nom d'hôte est obligatoire",
"mgmt.config.http.publishMDNSCheckbox": "Cela vous permet de publier votre serveur sur votre réseau local en utilisant mDNS. Cela signifie que tous vos domaines .local seront disponibles sur votre réseau local sans configuration supplémentaire.",
+ "mgmt.config.http.TrustedProxiesInput.TrustesProxiesLabel": "IPs/Plages IP des proxys de confiance pour l'utilisation de X-Forwarded-For.",
+ "mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText": "Utilisez ce paramètre lorsque vous avez un serveur proxy en amont pour éviter le blocage de celui-ci par le Shield. IPs ou plages IP séparées par des virgules.",
"mgmt.config.email.notifyLoginCheckbox.notifyLoginLabel": "Notifier les utilisateurs en cas de connexion réussie",
"mgmt.config.proxy.noRoutesConfiguredText": "Aucune route configurée.",
"mgmt.config.proxy.originTitle": "Origine",
diff --git a/src/httpServer.go b/src/httpServer.go
index c0ad49b3..d805179d 100644
--- a/src/httpServer.go
+++ b/src/httpServer.go
@@ -359,6 +359,8 @@ func InitServer() *mux.Router {
router := mux.NewRouter().StrictSlash(true)
+ router.Use(utils.ClientRealIP)
+
router.Use(utils.BlockBannedIPs)
router.Use(utils.Logger)
diff --git a/src/proxy/shield.go b/src/proxy/shield.go
index 58683576..8a790aaf 100644
--- a/src/proxy/shield.go
+++ b/src/proxy/shield.go
@@ -7,6 +7,7 @@ import (
"fmt"
"math"
"strconv"
+ "strings"
"github.com/azukaar/cosmos-server/src/utils"
"github.com/azukaar/cosmos-server/src/metrics"
@@ -296,14 +297,19 @@ func calculateLowestExhaustedPercentage(policy utils.SmartShieldPolicy, userCons
func GetClientID(r *http.Request, route utils.ProxyRouteConfig) string {
// when using Docker we need to get the real IP
remoteAddr, _ := utils.SplitIP(r.RemoteAddr)
- UseForwardedFor := utils.GetMainConfig().HTTPConfig.UseForwardedFor
- isTunneledIp := constellation.GetDeviceIp(route.TunnelVia) == remoteAddr
- isConstIP := utils.IsConstellationIP(remoteAddr)
- isConstTokenValid := constellation.CheckConstellationToken(r) == nil
-
- if (UseForwardedFor && r.Header.Get("x-forwarded-for") != "") ||
- (isTunneledIp && isConstIP && isConstTokenValid) {
- ip, _ := utils.SplitIP(r.Header.Get("x-forwarded-for"))
+ useForwardedForHeader := false
+ if r.Header.Get("x-forwarded-for") != "" {
+ useForwardedForHeader = utils.IsTrustedProxy(remoteAddr)
+ if !useForwardedForHeader {
+ isTunneledIp := constellation.GetDeviceIp(route.TunnelVia) == remoteAddr
+ isConstIP := utils.IsConstellationIP(remoteAddr)
+ isConstTokenValid := constellation.CheckConstellationToken(r) == nil
+ useForwardedForHeader = isTunneledIp && isConstIP && isConstTokenValid
+ }
+ }
+
+ if useForwardedForHeader {
+ ip, _ := utils.SplitIP(strings.TrimSpace(strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0]))
utils.Debug("SmartShield: Getting forwarded client ID " + ip)
return ip
} else {
diff --git a/src/utils/middleware.go b/src/utils/middleware.go
index 4fd59aec..cab86042 100644
--- a/src/utils/middleware.go
+++ b/src/utils/middleware.go
@@ -48,34 +48,49 @@ func GetIPAbuseCounter(ip string) int64 {
return atomic.LoadInt64(&counter.val)
}
+func ClientRealIP(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ clientID := GetClientIP(r)
+ if(clientID == ""){
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ ctx := context.WithValue(r.Context(), "ClientID", clientID)
+ r = r.WithContext(ctx)
+
+ next.ServeHTTP(w, r)
+ })
+}
+
func BlockBannedIPs(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- if hj, ok := w.(http.Hijacker); ok {
- conn, _, err := hj.Hijack()
- if err == nil {
- conn.Close()
- }
- }
- return
+ ip, ok := r.Context().Value("ClientID").(string)
+ if !ok {
+ if hj, ok := w.(http.Hijacker); ok {
+ conn, _, err := hj.Hijack()
+ if err == nil {
+ conn.Close()
+ }
+ }
+ return
}
nbAbuse := GetIPAbuseCounter(ip)
if nbAbuse > 275 {
- Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.")
- }
+ Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.")
+ }
if nbAbuse > 300 {
- if hj, ok := w.(http.Hijacker); ok {
- conn, _, err := hj.Hijack()
- if err == nil {
- conn.Close()
- }
- }
- return
+ if hj, ok := w.(http.Hijacker); ok {
+ conn, _, err := hj.Hijack()
+ if err == nil {
+ conn.Close()
}
+ }
+ return
+ }
next.ServeHTTP(w, r)
})
@@ -204,8 +219,8 @@ func GetIPLocation(ip string) (string, error) {
func BlockByCountryMiddleware(blockedCountries []string, CountryBlacklistIsWhitelist bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
+ ip, ok := r.Context().Value("ClientID").(string)
+ if !ok {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -289,7 +304,7 @@ func BlockPostWithoutReferer(next http.Handler) http.Handler {
Error("Blocked POST request without Referer header", nil)
http.Error(w, "Bad Request: Invalid request.", http.StatusBadRequest)
- ip, _, _ := net.SplitHostPort(r.RemoteAddr)
+ ip, _ := r.Context().Value("ClientID").(string)
if ip != "" {
TriggerEvent(
"cosmos.proxy.shield.referer",
@@ -349,7 +364,7 @@ func EnsureHostname(next http.Handler) http.Handler {
w.WriteHeader(http.StatusBadRequest)
http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest)
- ip, _, _ := net.SplitHostPort(r.RemoteAddr)
+ ip, _ := r.Context().Value("ClientID").(string)
if ip != "" {
TriggerEvent(
"cosmos.proxy.shield.hostname",
@@ -415,7 +430,7 @@ func EnsureHostnameCosmosAPI(next http.Handler) http.Handler {
w.WriteHeader(http.StatusBadRequest)
http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest)
- ip, _, _ := net.SplitHostPort(r.RemoteAddr)
+ ip, _ := r.Context().Value("ClientID").(string)
if ip != "" {
TriggerEvent(
"cosmos.proxy.shield.hostname",
@@ -469,16 +484,19 @@ func IsValidHostname(hostname string) bool {
}
func IPInRange(ipStr, cidrStr string) (bool, error) {
- _, cidrNet, err := net.ParseCIDR(cidrStr)
- if err != nil {
- return false, fmt.Errorf("parse CIDR range: %w", err)
- }
-
ip := net.ParseIP(ipStr)
if ip == nil {
return false, fmt.Errorf("parse IP: invalid IP address")
}
+ _, cidrNet, err := net.ParseCIDR(cidrStr)
+ if err != nil {
+ if ipStr == cidrStr {
+ return true, nil
+ }
+ return false, fmt.Errorf("parse CIDR range: %w", err)
+ }
+
return cidrNet.Contains(ip), nil
}
@@ -486,8 +504,9 @@ func Restrictions(RestrictToConstellation bool, WhitelistInboundIPs []string) fu
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
+ remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr)
+ ip, ok := r.Context().Value("ClientID").(string)
+ if (err != nil) || !ok {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
@@ -495,7 +514,7 @@ func Restrictions(RestrictToConstellation bool, WhitelistInboundIPs []string) fu
isUsingWhiteList := len(WhitelistInboundIPs) > 0
isInWhitelist := false
- isInConstellation := strings.HasPrefix(ip, "192.168.201.") || strings.HasPrefix(ip, "192.168.202.")
+ isInConstellation := strings.HasPrefix(remoteAddr, "192.168.201.") || strings.HasPrefix(remoteAddr, "192.168.202.")
for _, ipRange := range WhitelistInboundIPs {
Debug("Checking if " + ip + " is in " + ipRange)
diff --git a/src/utils/types.go b/src/utils/types.go
index f36e37ec..5a8f7cbe 100644
--- a/src/utils/types.go
+++ b/src/utils/types.go
@@ -178,6 +178,7 @@ type HTTPConfig struct {
UseForwardedFor bool
AllowSearchEngine bool
PublishMDNS bool
+ TrustedProxies []string
}
const (
diff --git a/src/utils/utils.go b/src/utils/utils.go
index d46da55d..dc91b008 100644
--- a/src/utils/utils.go
+++ b/src/utils/utils.go
@@ -801,11 +801,13 @@ func DownloadFileToLocation(path, url string) error {
}
func GetClientIP(req *http.Request) string {
- /*ip := req.Header.Get("X-Forwarded-For")
- if ip == "" {
- ip = req.RemoteAddr
- }*/
- return req.RemoteAddr
+ // when using Docker we need to get the real IP
+ remoteAddr, _ := SplitIP(req.RemoteAddr)
+
+ if req.Header.Get("x-forwarded-for") != "" && IsTrustedProxy(remoteAddr) {
+ remoteAddr, _ = SplitIP(strings.TrimSpace(strings.Split(req.Header.Get("X-Forwarded-For"), ",")[0]))
+ }
+ return remoteAddr
}
func IsDomain(domain string) bool {
@@ -921,6 +923,15 @@ func IsConstellationIP(ip string) bool {
return false
}
+func IsTrustedProxy(ip string) bool {
+ for _, trustedProxy := range GetMainConfig().HTTPConfig.TrustedProxies {
+ if isInRange, _ := IPInRange(ip, trustedProxy); isInRange {
+ return true
+ }
+ }
+ return false
+}
+
func SplitIP(ipPort string) (string, string) {
host, port, err := osnet.SplitHostPort(ipPort)
if err != nil {