diff --git a/cmd/cloudwatch-ingestion/main.go b/cmd/cloudwatch-ingestion/main.go deleted file mode 100644 index 78798e5..0000000 --- a/cmd/cloudwatch-ingestion/main.go +++ /dev/null @@ -1,142 +0,0 @@ -package main - -import ( - "bytes" - "context" - "encoding/json" - "flag" - "fmt" - "log" - "net/http" - "os" - "strings" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" - "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types" -) - -func main() { - var ( - logGroupName string - ) - flag.StringVar(&logGroupName, "log-group", "", "log group to ingest") - flag.Parse() - - if logGroupName == "" { - flag.PrintDefaults() - os.Exit(1) - } - - cfg, err := config.LoadDefaultConfig(context.TODO()) - if err != nil { - log.Fatalf("unable to load SDK config, %v", err) - } - - svc := cloudwatchlogs.NewFromConfig(cfg) - - logStreams, err := fetchLogStreams(svc, logGroupName) - if err != nil { - log.Fatalf("failed to fetch log streams: %v", err) - } - - for _, stream := range logStreams { - err := fetchLogEvents(svc, logGroupName, *stream.LogStreamName) - if err != nil { - log.Printf("failed to fetch log events for stream %s: %v", *stream.LogStreamName, err) - return - } - } -} - -func fetchLogStreams(svc *cloudwatchlogs.Client, logGroupName string) ([]types.LogStream, error) { - var allStreams []types.LogStream - nextToken := "" - - for { - input := &cloudwatchlogs.DescribeLogStreamsInput{ - LogGroupName: aws.String(logGroupName), - } - - if nextToken != "" { - input.NextToken = aws.String(nextToken) - } - - result, err := svc.DescribeLogStreams(context.TODO(), input) - if err != nil { - return nil, err - } - - allStreams = append(allStreams, result.LogStreams...) - - if result.NextToken == nil { - break - } - - nextToken = *result.NextToken - } - - return allStreams, nil -} - -func fetchLogEvents(svc *cloudwatchlogs.Client, logGroupName, logStreamName string) error { - nextToken := "" - messages := []map[string]any{} - logStreamNameSplit := strings.Split(logStreamName, "/") - logStreamWithoutRandom := strings.Join(logStreamNameSplit[:len(logStreamNameSplit)-1], "/") - - for { - input := &cloudwatchlogs.GetLogEventsInput{ - LogGroupName: aws.String(logGroupName), - LogStreamName: aws.String(logStreamName), - StartFromHead: aws.Bool(true), - } - - if nextToken != "" { - input.NextToken = aws.String(nextToken) - } - - result, err := svc.GetLogEvents(context.TODO(), input) - if err != nil { - return err - } - - for _, event := range result.Events { - seconds := float64(*event.Timestamp / 1000) - microseconds := float64(*event.Timestamp%1000) * 1000 - messages = append(messages, map[string]any{ - "date": seconds + (microseconds / 1e6), - "log": *event.Message, - "log-group": logGroupName, - "log-stream": logStreamWithoutRandom, - }) - } - - if result.NextForwardToken == nil || nextToken == *result.NextForwardToken { - break - } - - nextToken = *result.NextForwardToken - } - - if len(messages) == 0 { - return nil - } - - out, err := json.Marshal(messages) - if err != nil { - return err - } - resp, err := http.Post("http://localhost/api/observability/ingestion/json", "image/jpeg", bytes.NewBuffer(out)) - if err != nil { - return err - } - if resp.StatusCode != 200 { - return fmt.Errorf("response code is not 200") - } - - fmt.Printf("Ingested log-group %s, stream %s: %d messages\n", logGroupName, logStreamWithoutRandom, len(messages)) - - return nil -} diff --git a/cmd/observability-rest-server/main.go b/cmd/observability-rest-server/main.go deleted file mode 100644 index 20dc3d5..0000000 --- a/cmd/observability-rest-server/main.go +++ /dev/null @@ -1,18 +0,0 @@ -package main - -import ( - "flag" - - "github.com/in4it/wireguard-server/pkg/rest" -) - -func main() { - var ( - httpPort int - httpsPort int - ) - flag.IntVar(&httpPort, "http-port", 80, "http port to run server on") - flag.IntVar(&httpsPort, "https-port", 443, "https port to run server on") - flag.Parse() - rest.StartServer(httpPort, httpsPort, rest.SERVER_TYPE_OBSERVABILITY) -} diff --git a/cmd/reset-admin-password/main.go b/cmd/reset-admin-password/main.go index 292ba27..cf9a877 100644 --- a/cmd/reset-admin-password/main.go +++ b/cmd/reset-admin-password/main.go @@ -10,8 +10,8 @@ import ( "strings" "syscall" + localstorage "github.com/in4it/go-devops-platform/storage/local" "github.com/in4it/wireguard-server/pkg/commands" - localstorage "github.com/in4it/wireguard-server/pkg/storage/local" "golang.org/x/term" ) diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index b7a7f22..123ebfc 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -1,9 +1,22 @@ package main import ( + "embed" "flag" + "log" - "github.com/in4it/wireguard-server/pkg/rest" + "github.com/in4it/go-devops-platform/auth/provisioning/scim" + "github.com/in4it/go-devops-platform/licensing" + "github.com/in4it/go-devops-platform/rest" + localstorage "github.com/in4it/go-devops-platform/storage/local" + "github.com/in4it/go-devops-platform/users" + "github.com/in4it/wireguard-server/pkg/vpn" + "github.com/in4it/wireguard-server/pkg/wireguard" +) + +var ( + //go:embed static + assets embed.FS ) func main() { @@ -14,5 +27,28 @@ func main() { flag.IntVar(&httpPort, "http-port", 80, "http port to run server on") flag.IntVar(&httpsPort, "https-port", 443, "https port to run server on") flag.Parse() - rest.StartServer(httpPort, httpsPort, rest.SERVER_TYPE_VPN) + + localStorage, err := localstorage.New() + if err != nil { + log.Fatalf("couldn't initialize storage: %s", err) + } + licenseUserCount, cloudType := licensing.GetMaxUsers(localStorage) + + userStore, err := users.NewUserStore(localStorage, licenseUserCount) + if err != nil { + log.Fatalf("startup failed: userstore initialization error: %s", err) + } + + scimInstance := scim.New(localStorage, userStore, "", wireguard.DisableAllClientConfigs, wireguard.ReactivateAllClientConfigs) + + apps := map[string]rest.AppClient{ + "vpn": vpn.New(localStorage, userStore), + } + + c, err := rest.NewContext(localStorage, rest.SERVER_TYPE_VPN, userStore, scimInstance, licenseUserCount, cloudType, apps) + if err != nil { + log.Fatalf("startup failed: %s", err) + } + + rest.StartServer(httpPort, httpsPort, rest.SERVER_TYPE_VPN, localStorage, c, assets) } diff --git a/pkg/rest/resources/.gitignore b/cmd/rest-server/static/.gitignore similarity index 100% rename from pkg/rest/resources/.gitignore rename to cmd/rest-server/static/.gitignore diff --git a/go.mod b/go.mod index 20188fc..3893290 100644 --- a/go.mod +++ b/go.mod @@ -12,31 +12,33 @@ require ( github.com/packetcap/go-pcap v0.0.0-20240528124601-8c87ecf5dbc5 github.com/russellhaering/gosaml2 v0.9.1 github.com/russellhaering/goxmldsig v1.4.0 - golang.org/x/crypto v0.27.0 - golang.org/x/sys v0.25.0 - golang.org/x/term v0.24.0 + golang.org/x/crypto v0.28.0 + golang.org/x/sys v0.26.0 + golang.org/x/term v0.25.0 ) +require github.com/in4it/go-devops-platform v0.0.0-20241015173332-a45080cabae5 // indirect + require ( - github.com/aws/aws-sdk-go-v2 v1.31.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 // indirect - github.com/aws/aws-sdk-go-v2/config v1.27.33 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.32 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.13 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 // indirect + github.com/aws/aws-sdk-go-v2 v1.32.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect + github.com/aws/aws-sdk-go-v2/config v1.27.43 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.41 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 // indirect github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.2 - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.19 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.19 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.17 // indirect - github.com/aws/aws-sdk-go-v2/service/s3 v1.61.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.22.7 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.7 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.30.7 // indirect - github.com/aws/smithy-go v1.21.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 // indirect + github.com/aws/smithy-go v1.22.0 // indirect github.com/beevik/etree v1.4.1 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect @@ -45,7 +47,7 @@ require ( github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - golang.org/x/net v0.29.0 // indirect + golang.org/x/net v0.30.0 // indirect golang.org/x/sync v0.8.0 // indirect - golang.org/x/text v0.18.0 // indirect + golang.org/x/text v0.19.0 // indirect ) diff --git a/go.sum b/go.sum index b4ceaad..675bfac 100644 --- a/go.sum +++ b/go.sum @@ -2,50 +2,67 @@ github.com/aws/aws-sdk-go-v2 v1.30.5 h1:mWSRTwQAb0aLE17dSzztCVJWI9+cRMgqebndjwDy github.com/aws/aws-sdk-go-v2 v1.30.5/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= github.com/aws/aws-sdk-go-v2 v1.31.0 h1:3V05LbxTSItI5kUqNwhJrrrY1BAXxXt0sN0l72QmG5U= github.com/aws/aws-sdk-go-v2 v1.31.0/go.mod h1:ztolYtaEUtdpf9Wftr31CJfLVjOnD/CVRkKOOYgF8hA= +github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 h1:70PVAiL15/aBMh5LThwgXdSQorVr91L127ttckI9QQU= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4/go.mod h1:/MQxMqci8tlqDH+pjmoLu1i0tbWCUP1hhyMRuFxpQCw= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 h1:xDAuZTn4IMm8o1LnBZvmrL8JA1io4o3YWNXgohbf20g= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5/go.mod h1:wYSv6iDS621sEFLfKvpPE2ugjTuGlAG7iROg0hLOkfc= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA= github.com/aws/aws-sdk-go-v2/config v1.27.33 h1:Nof9o/MsmH4oa0s2q9a0k7tMz5x/Yj5k06lDODWz3BU= github.com/aws/aws-sdk-go-v2/config v1.27.33/go.mod h1:kEqdYzRb8dd8Sy2pOdEbExTTF5v7ozEXX0McgPE7xks= +github.com/aws/aws-sdk-go-v2/config v1.27.43/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc= github.com/aws/aws-sdk-go-v2/credentials v1.17.32 h1:7Cxhp/BnT2RcGy4VisJ9miUPecY+lyE9I8JvcZofn9I= github.com/aws/aws-sdk-go-v2/credentials v1.17.32/go.mod h1:P5/QMF3/DCHbXGEGkdbilXHsyTBX5D3HSwcrSc9p20I= +github.com/aws/aws-sdk-go-v2/credentials v1.17.41/go.mod h1:u4Eb8d3394YLubphT4jLEwN1rLNq2wFOlT6OuxFwPzU= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.13 h1:pfQ2sqNpMVK6xz2RbqLEL0GH87JOwSxPV2rzm8Zsb74= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.13/go.mod h1:NG7RXPUlqfsCLLFfi0+IpKN4sCB9D9fw/qTaSB+xRoU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17/go.mod h1:1ZRXLdTpzdJb9fwTMXiLipENRxkGMTn1sfKexGllQCw= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 h1:pI7Bzt0BJtYA0N/JEC6B8fJ4RBrEMi1LBrkMdFYNSnQ= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17/go.mod h1:Dh5zzJYMtxfIjYW+/evjQ8uj2OyR/ve2KROHGHlSFqE= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 h1:kYQ3H1u0ANr9KEKlGs/jTLrBFPo8P8NaH/w7A01NeeM= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18/go.mod h1:r506HmK5JDUh9+Mw4CfGJGSSoqIiLCndAuqXuhbv67Y= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21/go.mod h1:JNr43NFf5L9YaG3eKTm7HQzls9J+A9YYcGI5Quh1r2Y= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 h1:Mqr/V5gvrhA2gvgnF42Zh5iMiQNcOYthFYwCyrnuWlc= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17/go.mod h1:aLJpZlCmjE+V+KtN1q1uyZkfnUWpQGpbsn89XPKyzfU= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 h1:Z7IdFUONvTcvS7YuhtVxN99v2cCoHRXOS4mTr0B/pUc= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18/go.mod h1:DkKMmksZVVyat+Y+r1dEOgJEfUeA7UngIHWeKsi0yNc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21/go.mod h1:1SR0GbLlnN3QUmYaflZNiH1ql+1qrSiB2vwcJ+4UM60= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.17 h1:Roo69qTpfu8OlJ2Tb7pAYVuF0CpuUMB0IYWwYP/4DZM= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.17/go.mod h1:NcWPxQzGM1USQggaTVwz6VpqMZPX1CvDJLDh6jnOCa4= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21/go.mod h1:Q9o5h4HoIWG8XfzxqiuK/CGUbepCJ8uTlaE3bAbxytQ= github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.2 h1:q5+hHt4JBA+8K6uAvfLWpUs7ErVR0GNW0Xf5KTOl84c= github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.2/go.mod h1:3p7NzlLlJesNGovq7Vqx8+0UibawzodrBRQAbaza6pI= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbLPPHEmf9DgMGw51jMj77VfGPAN2Kv4cfhlfgI= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.19 h1:FLMkfEiRjhgeDTCjjLoc3URo/TBkgeQbocA78lfkzSI= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.19/go.mod h1:Vx+GucNSsdhaxs3aZIKfSUjKVGsxN25nX2SRcdhuw08= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2/go.mod h1:LWoqeWlK9OZeJxsROW2RqrSPvQHKTpp69r/iDjwsSaw= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.19 h1:rfprUlsdzgl7ZL2KlXiUAoJnI/VxfHCvDFr2QDFj6u4= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.19/go.mod h1:SCWkEdRq8/7EK60NcvvQ6NXKuTcchAD4ROAsC37VEZE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2/go.mod h1:fnjjWyAW/Pj5HYOxl9LJqWtEwS7W2qgcRLWP+uWbss0= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.17 h1:u+EfGmksnJc/x5tq3A+OD7LrMbSSR/5TrKLvkdy/fhY= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.17/go.mod h1:VaMx6302JHax2vHJWgRo+5n9zvbacs3bLU/23DNQrTY= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2/go.mod h1:/niFCtmuQNxqx9v8WAPq5qh7EH25U4BF6tjoyq9bObM= github.com/aws/aws-sdk-go-v2/service/s3 v1.61.2 h1:Kp6PWAlXwP1UvIflkIP6MFZYBNDCa4mFCGtxrpICVOg= github.com/aws/aws-sdk-go-v2/service/s3 v1.61.2/go.mod h1:5FmD/Dqq57gP+XwaUnd5WFPipAuzrf0HmupX27Gvjvc= +github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3/go.mod h1:cB6oAuus7YXRZhWCc1wIwPywwZ1XwweNp2TVAEGYeB8= github.com/aws/aws-sdk-go-v2/service/sso v1.22.7 h1:pIaGg+08llrP7Q5aiz9ICWbY8cqhTkyy+0SHvfzQpTc= github.com/aws/aws-sdk-go-v2/service/sso v1.22.7/go.mod h1:eEygMHnTKH/3kNp9Jr1n3PdejuSNcgwLe1dWgQtO0VQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.2/go.mod h1:skMqY7JElusiOUjMJMOv1jJsP7YUg7DrhgqZZWuzu1U= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.7 h1:/Cfdu0XV3mONYKaOt1Gr0k1KvQzkzPyiKUdlWJqy+J4= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.7/go.mod h1:bCbAxKDqNvkHxRaIMnyVPXPo+OaPRwvmgzMxbz1VKSA= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2/go.mod h1:o8aQygT2+MVP0NaV6kbdE1YnnIM8RRVQzoeUH45GOdI= github.com/aws/aws-sdk-go-v2/service/sts v1.30.7 h1:NKTa1eqZYw8tiHSRGpP0VtTdub/8KNk8sDkNPFaOKDE= github.com/aws/aws-sdk-go-v2/service/sts v1.30.7/go.mod h1:NXi1dIAGteSaRLqYgarlhP/Ij0cFT+qmCwiJqWh/U5o= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.2/go.mod h1:HtaiBI8CjYoNVde8arShXb94UbQQi9L4EMr6D+xGBwo= github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= github.com/beevik/etree v1.4.0 h1:oz1UedHRepuY3p4N5OjE0nK1WLCqtzHf25bxplKOHLs= github.com/beevik/etree v1.4.0/go.mod h1:cyWiXwGoasx60gHvtnEh5x8+uIjUVnjWqBvEnhnqKDA= @@ -67,6 +84,26 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gopacket/gopacket v1.3.0 h1:MouZCc+ej0vnqzB0WeiaO/6+tGvb+KU7UczxoQ+X0Yc= github.com/gopacket/gopacket v1.3.0/go.mod h1:WnFrU1Xkf5lWKV38uKNR9+yYtppn+ZYzOyNqMeH4oNE= +github.com/in4it/go-devops-platform v0.0.0-20241015140612-1949bc7fd57a h1:annNJWLpGHM4C6bvJKYC/eU9G9mF4IHBjHftozKGl1g= +github.com/in4it/go-devops-platform v0.0.0-20241015140612-1949bc7fd57a/go.mod h1:NXjJaxWbmL3mCCKcfG/MgFRimD9Bw3GO6PyQrY7KpRk= +github.com/in4it/go-devops-platform v0.0.0-20241015143916-a95e92939b2d h1:CHqTnGJpP9D9K8IcwH/3Y6iDdwj2+hrNIld3xyEZeF8= +github.com/in4it/go-devops-platform v0.0.0-20241015143916-a95e92939b2d/go.mod h1:Rkybt1QK1H/UWQw4zQLUgmN7qkcus14Xwu7Po6uZyDY= +github.com/in4it/go-devops-platform v0.0.0-20241015144646-23f35360de7f h1:Ly3NBMZTWQJlOEVrEvyE49wev9cRz3Lk99ZJ1++i19I= +github.com/in4it/go-devops-platform v0.0.0-20241015144646-23f35360de7f/go.mod h1:+cwJAz5QYDgexz0GBjOqXPvlNjwxPZVsNO0VLAODDTs= +github.com/in4it/go-devops-platform v0.0.0-20241015145222-2bcb11eedbc9 h1:xtNPNaS8AQJAVZAfP6ip+X3WoyPJ1+MKonuoCMAmcnE= +github.com/in4it/go-devops-platform v0.0.0-20241015145222-2bcb11eedbc9/go.mod h1:+cwJAz5QYDgexz0GBjOqXPvlNjwxPZVsNO0VLAODDTs= +github.com/in4it/go-devops-platform v0.0.0-20241015151511-486df9498223 h1:7kVZ6gO1SmwHB1qmGUh592j1gRxV60rdXsGp1y46gYs= +github.com/in4it/go-devops-platform v0.0.0-20241015151511-486df9498223/go.mod h1:+cwJAz5QYDgexz0GBjOqXPvlNjwxPZVsNO0VLAODDTs= +github.com/in4it/go-devops-platform v0.0.0-20241015170045-c08f1c5bacbe h1:/K0LdRqsG0m5vNeL08bYNAcze3k1H8+WzObXR5qnvyY= +github.com/in4it/go-devops-platform v0.0.0-20241015170045-c08f1c5bacbe/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q= +github.com/in4it/go-devops-platform v0.0.0-20241015170530-725d20de58f6 h1:DC3GjGu5Sk1Z0SQTpynXDgwtd3ozHAf1EyRCCdRGiR8= +github.com/in4it/go-devops-platform v0.0.0-20241015170530-725d20de58f6/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q= +github.com/in4it/go-devops-platform v0.0.0-20241015171457-d0cf3d638954 h1:xLSjoVWn+ZnWsYU/4YTha/LmsmUWghKf3BJkPJw4TeE= +github.com/in4it/go-devops-platform v0.0.0-20241015171457-d0cf3d638954/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q= +github.com/in4it/go-devops-platform v0.0.0-20241015173130-0b49ea0db408 h1:iQl/CXUmPJPer9NfVYsRTOwWHcXUs1q4oJQLzecJeuI= +github.com/in4it/go-devops-platform v0.0.0-20241015173130-0b49ea0db408/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q= +github.com/in4it/go-devops-platform v0.0.0-20241015173332-a45080cabae5 h1:pcZ2PeYjk5/Bis6llji5uTqocjbIpxPlm4flLoweVt0= +github.com/in4it/go-devops-platform v0.0.0-20241015173332-a45080cabae5/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= @@ -113,10 +150,14 @@ golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -124,14 +165,20 @@ golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pkg/auth/oidc/rand.go b/pkg/auth/oidc/rand.go deleted file mode 100644 index 8881662..0000000 --- a/pkg/auth/oidc/rand.go +++ /dev/null @@ -1,19 +0,0 @@ -package oidc - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -func GetRandomString(n int) (string, error) { - buf := make([]byte, n) - - _, err := io.ReadFull(rand.Reader, buf) - if err != nil { - return "", fmt.Errorf("crypto/rand Reader error: %s", err) - } - - return base64.RawURLEncoding.EncodeToString(buf), nil -} diff --git a/pkg/auth/oidc/redirect.go b/pkg/auth/oidc/redirect.go deleted file mode 100644 index 4eb60f0..0000000 --- a/pkg/auth/oidc/redirect.go +++ /dev/null @@ -1,26 +0,0 @@ -package oidc - -import ( - "fmt" - "strings" -) - -func GetRedirectURI(discovery Discovery, clientID, scope, callback string, enableOIDCTokenRenewal bool) (string, string, error) { - var redirectURI string - - state, err := GetRandomString(64) - if err != nil { - return redirectURI, state, fmt.Errorf("GetRandomString error: %s", err) - } - - // add offline_access to scope if oidc token renewal is true - //if enableOIDCTokenRenewal { - // scope = strings.TrimSpace(scope) + " offline_access" - //} - - scope = strings.Replace(scope, " ", "%20", -1) - - redirectURI = fmt.Sprintf("%s?client_id=%s&state=%s&scope=%s&response_type=code&redirect_uri=%s", discovery.AuthorizationEndpoint, clientID, state, scope, callback) - - return redirectURI, state, nil -} diff --git a/pkg/auth/oidc/store/cleanup.go b/pkg/auth/oidc/store/cleanup.go deleted file mode 100644 index add79f5..0000000 --- a/pkg/auth/oidc/store/cleanup.go +++ /dev/null @@ -1,50 +0,0 @@ -package oidcstore - -import ( - "slices" - "time" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" -) - -func (store *Store) CleanupOAuth2DataForAllEntries() int { - deleted := 0 - for _, oauthData := range store.OAuth2Data { - deleted += store.CleanupOAuth2Data(oauthData) - } - return deleted -} - -func (store *Store) CleanupOAuth2Data(oauthData oidc.OAuthData) int { - keysToDelete1 := []string{} - keysToDelete2 := []string{} - keysToDelete3 := []string{} - store.Mu.Lock() - defer store.Mu.Unlock() - for k := range store.OAuth2Data { - if oauthData.CreatedAt.After(store.OAuth2Data[k].CreatedAt) { - // cleanup old oauth2 data that might be duplicates (same subject & oidc provider, but older tokens) - if store.OAuth2Data[k].ID != oauthData.ID && store.OAuth2Data[k].OIDCProviderID == oauthData.OIDCProviderID && store.OAuth2Data[k].Subject == oauthData.Subject { - keysToDelete1 = append(keysToDelete1, k) - } - // cleanup oauthdata with the same email address - if store.OAuth2Data[k].ID != oauthData.ID && store.OAuth2Data[k].UserInfo.Email == oauthData.UserInfo.Email { - keysToDelete3 = append(keysToDelete3, k) - } - } - // cleanup old oauth2 data that doesn't have a token and is stale - if store.OAuth2Data[k].Token.AccessToken == "" && store.OAuth2Data[k].CreatedAt.Add(10*time.Minute).After(time.Now()) { - keysToDelete2 = append(keysToDelete2, k) - } - } - keysToDelete := []string{} - keysToDelete = append(keysToDelete, keysToDelete1...) - keysToDelete = append(keysToDelete, keysToDelete2...) - keysToDelete = append(keysToDelete, keysToDelete3...) - slices.Sort(keysToDelete) - - for _, key := range slices.Compact(keysToDelete) { - delete(store.OAuth2Data, key) - } - return len(keysToDelete) -} diff --git a/pkg/auth/oidc/store/discovery.go b/pkg/auth/oidc/store/discovery.go deleted file mode 100644 index ee3e7f6..0000000 --- a/pkg/auth/oidc/store/discovery.go +++ /dev/null @@ -1,42 +0,0 @@ -package oidcstore - -import ( - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" -) - -func (store *Store) GetDiscoveryURI(discoveryURI string) (oidc.Discovery, error) { - var discovery oidc.Discovery - if cachedDiscovery, ok := store.getDiscoveryCache(discoveryURI); ok { - if cachedDiscovery.Expiration.After(time.Now()) { - return cachedDiscovery.Discovery, nil // cache hit, we can return - } - } - - client := http.Client{ - Timeout: 5 * time.Second, - } - resp, err := client.Get(discoveryURI) - if err != nil { - return discovery, fmt.Errorf("discoveryURL Get error: %s", err) - } - if resp.StatusCode != 200 { - return discovery, fmt.Errorf("DiscoveryURI Request unsuccesful (status code returned: %d)", resp.StatusCode) - } - decoder := json.NewDecoder(resp.Body) - err = decoder.Decode(&discovery) - if err != nil { - return discovery, fmt.Errorf("discoveryURL decode error: %s", err) - } - - store.setDiscoveryCache(discoveryURI, oidc.DiscoveryCache{ - Expiration: time.Now().Add(12 * time.Hour), // standard 12h cache - Discovery: discovery, - }) - - return discovery, nil -} diff --git a/pkg/auth/oidc/store/discovery_test.go b/pkg/auth/oidc/store/discovery_test.go deleted file mode 100644 index 31b6d9b..0000000 --- a/pkg/auth/oidc/store/discovery_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package oidcstore - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestGetDiscovery(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/discovery.json" { - discovery := oidc.Discovery{ - Issuer: "test-issuer", - } - out, err := json.Marshal(discovery) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - } - w.Write(out) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - store, err := NewStore(&memorystorage.MockMemoryStorage{}) - if err != nil { - t.Fatalf("new store error: %s", err) - } - uri := ts.URL + "/discovery.json" - discovery, err := store.GetDiscoveryURI(uri) - if err != nil { - t.Fatalf("get discovery error: %s", err) - } - if discovery.Issuer != "test-issuer" { - t.Fatalf("wrong issuer") - } - - // cached response - discovery, err = store.GetDiscoveryURI(uri) - if err != nil { - t.Fatalf("get discovery error: %s", err) - } - if discovery.Issuer != "test-issuer" { - t.Fatalf("wrong issuer") - } - if _, ok := store.DiscoveryCache[uri]; !ok { - t.Fatalf("discovery not in cache") - } -} diff --git a/pkg/auth/oidc/store/jwks.go b/pkg/auth/oidc/store/jwks.go deleted file mode 100644 index 264cb82..0000000 --- a/pkg/auth/oidc/store/jwks.go +++ /dev/null @@ -1,55 +0,0 @@ -package oidcstore - -import ( - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" -) - -func (store *Store) GetJwks(jwksURI string) (oidc.Jwks, error) { - var jwksKeys oidc.Jwks - - if cachedJwks, ok := store.JwksCache[jwksURI]; ok { - if cachedJwks.Expiration.Before(time.Now()) { - return cachedJwks.Jwks, nil // cache hit, we can return - } - } - - client := http.Client{ - Timeout: 5 * time.Second, - } - resp, err := client.Get(jwksURI) - if err != nil { - return jwksKeys, fmt.Errorf("discoveryURL Get error: %s", err) - } - if resp.StatusCode != 200 { - return jwksKeys, fmt.Errorf("DiscoveryURI Request unsuccesful (status code returned: %d)", resp.StatusCode) - } - decoder := json.NewDecoder(resp.Body) - err = decoder.Decode(&jwksKeys) - if err != nil { - return jwksKeys, fmt.Errorf("discoveryURL decode error: %s", err) - } - - store.setJwksCache(jwksURI, oidc.JwksCache{ - Expiration: time.Now().Add(20 * time.Minute), // 20 minute cache standard - Jwks: jwksKeys, - }) - - return jwksKeys, nil -} - -func (store *Store) GetAllJwks(discoveryProviders []oidc.Discovery) ([]oidc.Jwks, error) { - allJwks := make([]oidc.Jwks, len(discoveryProviders)) - for k, discovery := range discoveryProviders { - var err error - allJwks[k], err = store.GetJwks(discovery.JwksURI) - if err != nil { - return []oidc.Jwks{}, fmt.Errorf("get jwks error: %s", err) - } - } - return allJwks, nil -} diff --git a/pkg/auth/oidc/store/jwks_test.go b/pkg/auth/oidc/store/jwks_test.go deleted file mode 100644 index 94c808a..0000000 --- a/pkg/auth/oidc/store/jwks_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package oidcstore - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestGetJwks(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/jwks.json" { - jwksKeys := oidc.Jwks{ - Keys: []oidc.JwksKey{ - { - Kid: "1-2-3-4", - Kty: "kty", - }, - }, - } - out, err := json.Marshal(jwksKeys) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - } - w.Write(out) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - store, err := NewStore(&memorystorage.MockMemoryStorage{}) - if err != nil { - t.Fatalf("new store error: %s", err) - } - uri := ts.URL + "/jwks.json" - jwks, err := store.GetJwks(uri) - if err != nil { - t.Fatalf("get jwks error: %s", err) - } - if len(jwks.Keys) == 0 { - t.Fatalf("jwks is empty") - } - if jwks.Keys[0].Kid != "1-2-3-4" { - t.Fatalf("wrong kid: %s", jwks.Keys[0].Kid) - } - // cached response - - jwks, err = store.GetJwks(uri) - if err != nil { - t.Fatalf("get jwks error: %s", err) - } - if len(jwks.Keys) == 0 { - t.Fatalf("jwks is empty") - } - if jwks.Keys[0].Kid != "1-2-3-4" { - t.Fatalf("wrong kid: %s", jwks.Keys[0].Kid) - } - if _, ok := store.JwksCache[uri]; !ok { - t.Fatalf("jwks not in cache") - } - - // get all jwks - allJwks, err := store.GetAllJwks([]oidc.Discovery{{JwksURI: uri}}) - if err != nil { - t.Fatalf("get all jwks error: %s", err) - } - if len(allJwks) == 0 { - t.Fatalf("all jwks is zero") - } - if allJwks[0].Keys[0].Kid != "1-2-3-4" { - t.Fatalf("wrong kid for allJwks: %s", jwks.Keys[0].Kid) - } -} diff --git a/pkg/auth/oidc/store/renewal/new.go b/pkg/auth/oidc/store/renewal/new.go deleted file mode 100644 index 605c5c3..0000000 --- a/pkg/auth/oidc/store/renewal/new.go +++ /dev/null @@ -1,42 +0,0 @@ -package oidcrenewal - -import ( - "time" - - oidc "github.com/in4it/wireguard-server/pkg/auth/oidc" - oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store" - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" -) - -type Renewal struct { - oidcStore *oidcstore.Store - enabled bool - oidcProviders []oidc.OIDCProvider - userStore *users.UserStore - renewalTime time.Duration - storage storage.Iface -} - -func NewRenewal(storage storage.Iface, renewalTime int, contextLogLevel int, enabled bool, oidcstore *oidcstore.Store, oidcProviders []oidc.OIDCProvider, userStore *users.UserStore) (*Renewal, error) { - r := &Renewal{ - enabled: enabled, - oidcStore: oidcstore, - oidcProviders: oidcProviders, - userStore: userStore, - storage: storage, - } - logging.Loglevel = contextLogLevel - if renewalTime <= 5 { - r.renewalTime = DEFAULT_RENEWAL_TIME_MINUTES * time.Minute - } else { - r.renewalTime = time.Duration(renewalTime) * time.Minute - } - go r.Worker() - return r, nil -} - -func (r *Renewal) SetEnabled(enabled bool) { - r.enabled = enabled -} diff --git a/pkg/auth/oidc/store/renewal/refreshtoken.go b/pkg/auth/oidc/store/renewal/refreshtoken.go deleted file mode 100644 index 676cae4..0000000 --- a/pkg/auth/oidc/store/renewal/refreshtoken.go +++ /dev/null @@ -1,54 +0,0 @@ -package oidcrenewal - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "time" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" -) - -func refreshToken(discovery oidc.Discovery, refreshToken, clientID, clientSecret string) (oidc.Token, time.Time, error) { - var token oidc.Token - client := http.Client{ - Timeout: 5 * time.Second, - } - - if discovery.TokenEndpoint == "" { - return token, time.Time{}, fmt.Errorf("token endpoint is empty") - } - - payload := url.Values{ - "grant_type": {"refresh_token"}, - "refresh_token": {refreshToken}, - "client_id": {clientID}, - "client_secret": {clientSecret}, - } - - resp, err := client.PostForm(discovery.TokenEndpoint, payload) - if err != nil { - return token, time.Time{}, fmt.Errorf("tokenEndpoint PostForm error: %s", err) - } - renewalTime := time.Now() - if resp.StatusCode != 200 { - data, err := io.ReadAll(resp.Body) - if err != nil { - return token, renewalTime, fmt.Errorf("tokenEndpoint return error. statuscode: %d", resp.StatusCode) - } - return token, renewalTime, fmt.Errorf("tokenEndpoint return error (statuscode %d): %s", resp.StatusCode, data) - } - decoder := json.NewDecoder(resp.Body) - err = decoder.Decode(&token) - if err != nil { - return token, renewalTime, fmt.Errorf("tokenEndpoint decode error: %s", err) - } - if token.AccessToken == "" { - return token, renewalTime, fmt.Errorf("access token is empty") - } - - return token, renewalTime, nil - -} diff --git a/pkg/auth/oidc/store/renewal/renew.go b/pkg/auth/oidc/store/renewal/renew.go deleted file mode 100644 index b8a06c0..0000000 --- a/pkg/auth/oidc/store/renewal/renew.go +++ /dev/null @@ -1,156 +0,0 @@ -package oidcrenewal - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store" - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -func (r *Renewal) RenewAllOIDCConnections() { - // force renewal of all tokens, even if they're not expired (unless they're empty) - for key, oauth2Data := range r.oidcStore.OAuth2Data { - if oidcProvider, err := getOIDCProvider(oauth2Data.OIDCProviderID, r.oidcProviders); err == nil { - if discovery, err := r.oidcStore.GetDiscoveryURI(oidcProvider.DiscoveryURI); err == nil { - if oauth2Data.RenewalFailed || oauth2Data.Token.AccessToken == "" { - logging.DebugLog(fmt.Errorf("skipping %s (renewal already failed or access token is empty. RenewalFailed: %v, AccessToken is empty: %v)", oauth2Data.ID, oauth2Data.RenewalFailed, oauth2Data.Token.AccessToken == "")) - } else { - logging.DebugLog(fmt.Errorf("trying to renew %s", oauth2Data.ID)) - r.renew(discovery, key, oauth2Data, oidcProvider) - } - } else { - logging.DebugLog(fmt.Errorf("could not get discovery url for %s: %s", oauth2Data.ID, err)) - } - } else { - logging.DebugLog(fmt.Errorf("could not get oidcprovider for %s: %s", oauth2Data.ID, err)) - } - } -} -func (r *Renewal) renew(discovery oidc.Discovery, key string, oauth2Data oidc.OAuthData, oidcProvider oidc.OIDCProvider) { - newToken, newTokenTimestamp, err := refreshToken(discovery, oauth2Data.Token.RefreshToken, oidcProvider.ClientID, oidcProvider.ClientSecret) - if err != nil { - oauth2Data.RenewalRetries++ - logging.ErrorLog(fmt.Errorf("renewal Worker: could not refresh token for %s (attemp %d/%d): %s", oauth2Data.ID, oauth2Data.RenewalRetries, RENEWAL_RETRIES, err)) - if oauth2Data.RenewalRetries >= RENEWAL_RETRIES { - oauth2Data.RenewalFailed = true - } - err = r.oidcStore.StoreEntry(key, oauth2Data) - if err != nil { - logging.ErrorLog(fmt.Errorf("renewal Worker: [error] StoreEntry: %s", err)) - } - err = r.oidcStore.SaveOIDCStore() - if err != nil { - logging.ErrorLog(fmt.Errorf("renewal Worker: [error] SaveOIDCStore: %s", err)) - } - // suspend connections - if oauth2Data.RenewalFailed { - err = disableUser(r.storage, oauth2Data, r.userStore) - if err != nil { - logging.ErrorLog(fmt.Errorf("renewal Worker: [error] disableUser: %s", err)) - } - } - return - } - logging.DebugLog(fmt.Errorf("new token issued at %v: %+v", newToken, newTokenTimestamp)) - oauth2Data.LastTokenRenewal = newTokenTimestamp - oauth2Data.Token.AccessToken = newToken.AccessToken - oauth2Data.Token.ExpiresIn = newToken.ExpiresIn - oauth2Data.Token.RefreshToken = newToken.RefreshToken - if newToken.IDToken != "" { - oauth2Data.Token.IDToken = newToken.IDToken - } - err = r.oidcStore.StoreEntry(key, oauth2Data) - if err != nil { - logging.ErrorLog(fmt.Errorf("renewal Worker: [error] StoreEntry: %s", err)) - } - err = r.oidcStore.SaveOIDCStore() - if err != nil { - logging.ErrorLog(fmt.Errorf("renewal Worker: [error] SaveOIDCStore: %s", err)) - } -} - -func disableUser(storage storage.Iface, oauth2Data oidc.OAuthData, userStore *users.UserStore) error { - logging.DebugLog(fmt.Errorf("disable user with oidc id %s", oauth2Data.ID)) - user, err := userStore.GetUserByOIDCIDs([]string{oauth2Data.ID}) - if err != nil { - return fmt.Errorf("no user found with oidc id %s", oauth2Data.ID) - } - err = wireguard.DisableAllClientConfigs(storage, user.ID) - if err != nil { - return fmt.Errorf("DisableAllClientConfigs error for userID %s: %s", user.ID, err) - } - user.ConnectionsDisabledOnAuthFailure = true - err = userStore.UpdateUser(user) - if err != nil { - return fmt.Errorf("could not update connectionsDisabledOnAuthFailure user with userID %s: %s", user.ID, err) - } - return nil -} - -func getOIDCProvider(id string, oidcProviders []oidc.OIDCProvider) (oidc.OIDCProvider, error) { - for _, oidcProvider := range oidcProviders { - if oidcProvider.ID == id { - return oidcProvider, nil - } - } - return oidc.OIDCProvider{}, fmt.Errorf("oidc provider not found") - -} - -func getExpirationDate(token string) (time.Time, error) { - jwtSplit := strings.Split(token, ".") - if len(jwtSplit) < 2 { - return time.Time{}, fmt.Errorf("token split < 2") - } - data, err := base64.RawURLEncoding.DecodeString(jwtSplit[1]) - if err != nil { - return time.Time{}, fmt.Errorf("could not base64 decode data part of jwt") - } - var jwt jwtExp - err = json.Unmarshal(data, &jwt) - if err != nil { - return time.Time{}, fmt.Errorf("could not unmarshal jwt data") - } - if jwt.Expiration == 0 { - return time.Time{}, fmt.Errorf("exp not found in jwt data") - } - return time.Unix(jwt.Expiration, 0), nil -} - -func canRenew(renewalTime time.Duration, oauth2Data oidc.OAuthData, store *oidcstore.Store, oidcProviders []oidc.OIDCProvider) (bool, oidc.OIDCProvider, oidc.Discovery, error) { - if oauth2Data.RenewalFailed { - return false, oidc.OIDCProvider{}, oidc.Discovery{}, nil - } - if oauth2Data.Token.AccessToken == "" { - logging.DebugLog(fmt.Errorf("access token empty of oidc id %s", oauth2Data.ID)) - return false, oidc.OIDCProvider{}, oidc.Discovery{}, nil - } - expirationDate, err := getExpirationDate(oauth2Data.Token.AccessToken) - if err != nil { - return false, oidc.OIDCProvider{}, oidc.Discovery{}, fmt.Errorf("can't get expiration date of refresh_token (id:%s). error: %s", oauth2Data.ID, err) - } - - if time.Since(oauth2Data.LastTokenRenewal) > renewalTime || time.Now().After(expirationDate) { - logging.DebugLog(fmt.Errorf("going to renew token for %s", oauth2Data.ID)) - oidcProvider, err := getOIDCProvider(oauth2Data.OIDCProviderID, oidcProviders) - if err != nil { - return false, oidc.OIDCProvider{}, oidc.Discovery{}, fmt.Errorf("could not get oidcprovider for %s: %s", oauth2Data.ID, err) - } - discovery, err := store.GetDiscoveryURI(oidcProvider.DiscoveryURI) - if err != nil { - return false, oidc.OIDCProvider{}, oidc.Discovery{}, fmt.Errorf("could not get discovery url for %s: %s", oauth2Data.ID, err) - } - return true, oidcProvider, discovery, nil - } else { - logging.DebugLog(fmt.Errorf("not renewing oidc id %s. time since last token renewal: %d", oauth2Data.ID, time.Since(oauth2Data.LastTokenRenewal))) - } - return false, oidc.OIDCProvider{}, oidc.Discovery{}, nil -} diff --git a/pkg/auth/oidc/store/renewal/types.go b/pkg/auth/oidc/store/renewal/types.go deleted file mode 100644 index 202c716..0000000 --- a/pkg/auth/oidc/store/renewal/types.go +++ /dev/null @@ -1,5 +0,0 @@ -package oidcrenewal - -type jwtExp struct { - Expiration int64 `json:"exp"` -} diff --git a/pkg/auth/oidc/store/renewal/worker.go b/pkg/auth/oidc/store/renewal/worker.go deleted file mode 100644 index a4a8252..0000000 --- a/pkg/auth/oidc/store/renewal/worker.go +++ /dev/null @@ -1,48 +0,0 @@ -package oidcrenewal - -import ( - "fmt" - "log" - "time" - - "github.com/in4it/wireguard-server/pkg/logging" -) - -const WAKEUP_TIME_SECONDS = 300 // every 5 minutes we check -const DEFAULT_RENEWAL_TIME_MINUTES = 60 // every hour we want to refresh the token -const RENEWAL_RETRIES = 3 // 3 retries before we suspend a user -const RENEWAL_BACKOFF_SECONDS = 10 // time between requests to the oidc providers - -func (r *Renewal) Worker() { - // do renewal - if r.enabled { - fmt.Printf("Starting oidc renewal worker (loglevel: %d)\n", logging.Loglevel) - } - for { - if !r.enabled { - time.Sleep(WAKEUP_TIME_SECONDS * time.Second) - continue - } - deletedEntries := r.oidcStore.CleanupOAuth2DataForAllEntries() - if deletedEntries > 0 { - err := r.oidcStore.SaveOIDCStore() - if err != nil { - log.Printf("Renewal Worker: [warning] couldn't save oidc store after cleanup: %s", err) - } - } - for key, oauth2Data := range r.oidcStore.OAuth2Data { - logging.DebugLog(fmt.Errorf("running canRenew of %s", oauth2Data.ID)) - // can we renew? Do we have expiration date and it is expired? - canRenew, oidcProvider, discovery, err := canRenew(r.renewalTime, oauth2Data, r.oidcStore, r.oidcProviders) - if err != nil { - log.Printf("Renewal Worker: [warning] needsRenewal: %s", err) - } - if canRenew { - logging.DebugLog(fmt.Errorf("we can renew %s", oauth2Data.ID)) - r.renew(discovery, key, oauth2Data, oidcProvider) // error logging within function - } - time.Sleep(RENEWAL_BACKOFF_SECONDS * time.Second) - } - time.Sleep(WAKEUP_TIME_SECONDS * time.Second) - } -} diff --git a/pkg/auth/oidc/store/save.go b/pkg/auth/oidc/store/save.go deleted file mode 100644 index 67828a5..0000000 --- a/pkg/auth/oidc/store/save.go +++ /dev/null @@ -1,30 +0,0 @@ -package oidcstore - -import ( - "encoding/json" - "fmt" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" -) - -func (store *Store) SaveOIDCStore() error { - store.Mu.Lock() - defer store.Mu.Unlock() - out, err := json.Marshal(store) - if err != nil { - return fmt.Errorf("oidc store marshal error: %s", err) - } - filename := store.storage.ConfigPath("oidcstore.json") - err = store.storage.WriteFile(filename, out) - if err != nil { - return fmt.Errorf("oidcstore write error: %s", err) - } - return nil -} - -func (store *Store) SaveOAuth2Data(oauth2Data oidc.OAuthData, key string) error { - store.Mu.Lock() - store.OAuth2Data[key] = oauth2Data - store.Mu.Unlock() - return store.SaveOIDCStore() -} diff --git a/pkg/auth/oidc/store/save_test.go b/pkg/auth/oidc/store/save_test.go deleted file mode 100644 index 9e33d53..0000000 --- a/pkg/auth/oidc/store/save_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package oidcstore - -import ( - "testing" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestSave(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - store, err := NewStore(storage) - if err != nil { - t.Fatalf("error: %s", err) - } - err = store.setDiscoveryCache("test", oidc.DiscoveryCache{Discovery: oidc.Discovery{Issuer: "testissuer"}}) - if err != nil { - t.Fatalf("error: %s", err) - } - - err = store.SaveOIDCStore() - if err != nil { - t.Fatalf("error: %s", err) - } - - _, err = storage.ReadFile(storage.ConfigPath("oidcstore.json")) - if err != nil { - t.Fatalf("error: %s", err) - } - store2, err := NewStore(storage) - if err != nil { - t.Fatalf("error: %s", err) - } - - discovery, ok := store2.getDiscoveryCache("test") - if !ok { - t.Fatalf("can't find the cache") - } - if discovery.Discovery.Issuer != "testissuer" { - t.Fatalf("expected testissuer. Got: %s", discovery.Discovery.Issuer) - } -} diff --git a/pkg/auth/oidc/store/store.go b/pkg/auth/oidc/store/store.go deleted file mode 100644 index ce4480b..0000000 --- a/pkg/auth/oidc/store/store.go +++ /dev/null @@ -1,84 +0,0 @@ -package oidcstore - -import ( - "bytes" - "encoding/json" - "fmt" - "sync" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - "github.com/in4it/wireguard-server/pkg/storage" -) - -const DEFAULT_PATH = "oidcstore.json" - -var RetrieveTokenLock sync.Mutex - -func (store *Store) StoreEntry(state string, oauthData oidc.OAuthData) error { - store.Mu.Lock() - store.OAuth2Data[state] = oauthData - store.Mu.Unlock() - return nil -} - -func (store *Store) setDiscoveryCache(key string, value oidc.DiscoveryCache) error { - store.Mu.Lock() - store.DiscoveryCache[key] = value - store.Mu.Unlock() - return nil -} - -func (store *Store) setJwksCache(key string, value oidc.JwksCache) error { - store.Mu.Lock() - store.JwksCache[key] = value - store.Mu.Unlock() - return nil -} - -func (store *Store) getDiscoveryCache(key string) (oidc.DiscoveryCache, bool) { - discovery, ok := store.DiscoveryCache[key] - return discovery, ok -} - -func NewStore(storage storage.Iface) (*Store, error) { - var store *Store - - filename := storage.ConfigPath(DEFAULT_PATH) - - // check if oidc.Store exists - if !storage.FileExists(filename) { - return getEmptyOIDCStore(storage) - } - - data, err := storage.ReadFile(filename) - if err != nil { - return store, fmt.Errorf("config read error: %s", err) - } - decoder := json.NewDecoder(bytes.NewBuffer(data)) - err = decoder.Decode(&store) - if err != nil { - return store, fmt.Errorf("decode input error: %s", err) - } - if store.DiscoveryCache == nil { - store.DiscoveryCache = make(map[string]oidc.DiscoveryCache) - } - if store.JwksCache == nil { - store.JwksCache = make(map[string]oidc.JwksCache) - } - if store.OAuth2Data == nil { - store.OAuth2Data = make(map[string]oidc.OAuthData) - } - - store.storage = storage - - return store, nil -} - -func getEmptyOIDCStore(storage storage.Iface) (*Store, error) { - return &Store{ - OAuth2Data: make(map[string]oidc.OAuthData), - DiscoveryCache: make(map[string]oidc.DiscoveryCache), - JwksCache: make(map[string]oidc.JwksCache), - storage: storage, - }, nil -} diff --git a/pkg/auth/oidc/store/types.go b/pkg/auth/oidc/store/types.go deleted file mode 100644 index 0c194e6..0000000 --- a/pkg/auth/oidc/store/types.go +++ /dev/null @@ -1,16 +0,0 @@ -package oidcstore - -import ( - "sync" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - "github.com/in4it/wireguard-server/pkg/storage" -) - -type Store struct { - Mu sync.Mutex - OAuth2Data map[string]oidc.OAuthData `json:"oauth2Data"` - DiscoveryCache map[string]oidc.DiscoveryCache `json:"discoveryCache"` - JwksCache map[string]oidc.JwksCache `json:"jwksCache"` - storage storage.Iface -} diff --git a/pkg/auth/oidc/token.go b/pkg/auth/oidc/token.go deleted file mode 100644 index 1843ac9..0000000 --- a/pkg/auth/oidc/token.go +++ /dev/null @@ -1,135 +0,0 @@ -package oidc - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/mail" - "net/url" - "time" - - "github.com/go-jose/go-jose/v4" - "github.com/golang-jwt/jwt/v5" - "github.com/in4it/wireguard-server/pkg/logging" -) - -func RetrieveOAUth2DataUsingState(allOAuth2data map[string]OAuthData, state string) (OAuthData, error) { - if state == "" { - return OAuthData{}, fmt.Errorf("no state found") - } - oauthData, ok := allOAuth2data[state] - if !ok { - return OAuthData{}, fmt.Errorf("oauth data not found (is state missing?)") - } - return oauthData, nil -} - -func UpdateOAuth2DataWithToken(jwks Jwks, discovery Discovery, clientID, clientSecret, redirectURI, code, state string, oauth2Data OAuthData) (OAuthData, error) { - newOAuthData := oauth2Data - var token Token - client := http.Client{ - Timeout: 5 * time.Second, - } - - if discovery.TokenEndpoint == "" { - return newOAuthData, fmt.Errorf("token endpoint is empty") - } - - payload := url.Values{ - "grant_type": {"authorization_code"}, - "code": {code}, - "client_id": {clientID}, - "client_secret": {clientSecret}, - "redirect_uri": {redirectURI}, - } - - resp, err := client.PostForm(discovery.TokenEndpoint, payload) - if err != nil { - return newOAuthData, fmt.Errorf("tokenEndpoint PostForm error: %s", err) - } - renewalTime := time.Now() - if resp.StatusCode != 200 { - data, err := io.ReadAll(resp.Body) - if err != nil { - return newOAuthData, fmt.Errorf("tokenEndpoint return error. statuscode: %d", resp.StatusCode) - } - return newOAuthData, fmt.Errorf("tokenEndpoint return error (statuscode %d): %s", resp.StatusCode, data) - } - decoder := json.NewDecoder(resp.Body) - err = decoder.Decode(&token) - if err != nil { - return newOAuthData, fmt.Errorf("tokenEndpoint decode error: %s", err) - } - - // verify id token - parsedToken, err := jwt.Parse(token.IDToken, func(token *jwt.Token) (interface{}, error) { - publicKey, err := GetPublicKeyForToken([]Jwks{jwks}, []Discovery{discovery}, token) - if err != nil { - return nil, fmt.Errorf("GetPublicKeyForToken error: %s", err) - } - return publicKey, nil - }) - if err != nil { - logging.DebugLog(fmt.Errorf("couldn't verify id token: %s", err)) - return newOAuthData, fmt.Errorf("couldn't verify id token") - } - // remove old oauth2data matching oidcproivder and subject - claims := parsedToken.Claims.(jwt.MapClaims) - subject, ok := claims["sub"] - if !ok { - return newOAuthData, fmt.Errorf("subject missing from id token") - } - validEmail := "" - email, ok := claims["email"] - if !ok { - // check if email is in preferred_username - preferred_username, ok2 := claims["preferred_username"] - if !ok2 { - return newOAuthData, fmt.Errorf("email missing from id token (not in email / preferred_username claim)") - } else { - _, err := mail.ParseAddress(preferred_username.(string)) - if err != nil { - return newOAuthData, fmt.Errorf("email missing from id token and preferred_username is not an email address") - } - validEmail = preferred_username.(string) - } - } else { - validEmail = email.(string) - } - issuer, ok := claims["iss"] - if !ok { - return newOAuthData, fmt.Errorf("issuer missing from id token") - } - - newOAuthData.Token = token - newOAuthData.LastTokenRenewal = renewalTime - newOAuthData.Subject = subject.(string) - newOAuthData.Issuer = issuer.(string) - newOAuthData.UserInfo.Email = validEmail - return newOAuthData, nil -} - -func GetPublicKeyForToken(allJwks []Jwks, discoveryProviders []Discovery, token *jwt.Token) (any, error) { - kid, ok := token.Header["kid"] - if !ok { - return nil, fmt.Errorf("no kid found in token") - } - for _, jwks := range allJwks { - for _, key := range jwks.Keys { - if key.Kid == kid { - jsonWebKey := jose.JSONWebKey{} - singleKey, err := json.Marshal(key) - if err != nil { - return nil, fmt.Errorf("internal server error: cannot marshal key from kid endpoint: %s", err) - } - err = jsonWebKey.UnmarshalJSON(singleKey) - if err != nil { - return nil, fmt.Errorf("key from jwks import error: %s", err) - } - return jsonWebKey.Key, nil - } - } - } - return nil, fmt.Errorf("no matching kid found for token") -} diff --git a/pkg/auth/oidc/types.go b/pkg/auth/oidc/types.go deleted file mode 100644 index 3356d1a..0000000 --- a/pkg/auth/oidc/types.go +++ /dev/null @@ -1,78 +0,0 @@ -package oidc - -import ( - "time" -) - -type Discovery struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - UserinfoEndpoint string `json:"userinfo_endpoint"` - JwksURI string `json:"jwks_uri"` - ScopesSupported []string `json:"scopes_supported"` - ResponseTypesSupported []string `json:"response_types_supported"` - TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` - IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` - ClaimsSupported []string `json:"claims_supported"` - SubjectTypesSupported []string `json:"subject_types_supported"` -} - -type OIDCProvider struct { - ID string `json:"id"` - Name string `json:"name"` - ClientID string `json:"clientId"` - ClientSecret string `json:"clientSecret,omitempty"` - Scope string `json:"scope"` - DiscoveryURI string `json:"discoveryURI"` - RedirectURI string `json:"redirectURI"` - LoginURL string `json:"loginURL,omitempty"` -} - -type Token struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - IDToken string `json:"id_token"` -} - -// jwks -type Jwks struct { - Keys []JwksKey `json:"keys"` -} -type JwksKey struct { - N string `json:"n"` - E string `json:"e"` - Alg string `json:"alg"` - Use string `json:"use"` - Kid string `json:"kid"` - Kty string `json:"kty"` -} - -type DiscoveryCache struct { - Expiration time.Time `json:"expiration"` - Discovery Discovery `json:"discovery"` -} -type JwksCache struct { - Expiration time.Time `json:"expiration"` - Jwks Jwks `json:"jwks"` -} -type OAuthData struct { - ID string `json:"id"` - OIDCProviderID string `json:"oidcProviderID"` - CreatedAt time.Time `json:"createdAt"` - Subject string `json:"subject"` - Issuer string `json:"issuer"` - UserInfo UserInfo `json:"userInfo"` - Token Token `json:"token"` - AuthFailed bool `json:"authFailed"` - Suspended bool `json:"suspended"` - LastTokenRenewal time.Time `json:"lastTokenRenewal"` - RenewalFailed bool `json:"renewalFailed"` - RenewalRetries int `json:"renewalRetries"` -} - -type UserInfo struct { - Email string `json:"email"` -} diff --git a/pkg/auth/provisioning/scim/helpers.go b/pkg/auth/provisioning/scim/helpers.go deleted file mode 100644 index 8716cff..0000000 --- a/pkg/auth/provisioning/scim/helpers.go +++ /dev/null @@ -1,26 +0,0 @@ -package scim - -import ( - "fmt" - "net/http" -) - -func returnError(w http.ResponseWriter, err error, statusCode int) { - fmt.Println("========= ERROR =========") - fmt.Printf("Error: %s\n", err) - fmt.Println("=========================") - w.WriteHeader(statusCode) - w.Write([]byte(`{"error": "` + err.Error() + `"}`)) -} - -func writeWithStatus(w http.ResponseWriter, res []byte, status int) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - w.Write(res) -} - -func write(w http.ResponseWriter, res []byte) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write(res) -} diff --git a/pkg/auth/provisioning/scim/middleware.go b/pkg/auth/provisioning/scim/middleware.go deleted file mode 100644 index 10cfacb..0000000 --- a/pkg/auth/provisioning/scim/middleware.go +++ /dev/null @@ -1,31 +0,0 @@ -package scim - -import ( - "fmt" - "net/http" - "strings" -) - -func (s *scim) authMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") { - writeWithStatus(w, []byte(`{"error": "token not found"}`), http.StatusUnauthorized) - return - } - tokenString := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", -1) - if len(tokenString) == 0 { - returnError(w, fmt.Errorf("empty token"), http.StatusUnauthorized) - return - } - if s.Token == "" { - writeWithStatus(w, []byte(`{"error": "scim not active"}`), http.StatusUnauthorized) - return - } - if s.Token != tokenString { - writeWithStatus(w, []byte(`{"error": "authentication failed"}`), http.StatusUnauthorized) - return - } - - next.ServeHTTP(w, r) - }) -} diff --git a/pkg/auth/provisioning/scim/new.go b/pkg/auth/provisioning/scim/new.go deleted file mode 100644 index 9c1043c..0000000 --- a/pkg/auth/provisioning/scim/new.go +++ /dev/null @@ -1,15 +0,0 @@ -package scim - -import ( - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" -) - -func New(storage storage.Iface, userStore *users.UserStore, token string) *scim { - s := &scim{ - Token: token, - UserStore: userStore, - storage: storage, - } - return s -} diff --git a/pkg/auth/provisioning/scim/response.go b/pkg/auth/provisioning/scim/response.go deleted file mode 100644 index 7549e75..0000000 --- a/pkg/auth/provisioning/scim/response.go +++ /dev/null @@ -1,58 +0,0 @@ -package scim - -import ( - "encoding/json" - "fmt" - - "github.com/in4it/wireguard-server/pkg/users" -) - -func listUserResponse(users []users.User, attributes string, count, start int) ([]byte, error) { - if start != -1 && start > 1 && start <= len(users) { - users = users[start:] - } - totalResults := len(users) - if len(users) > count && count != -1 { - users = users[0:count] - } - response := UserResponse{ - TotalResults: totalResults, - ItemsPerPage: len(users), - StartIndex: start, - Schemas: getSchemas("ListResponse"), - Resources: make([]UserResource, len(users)), - } - for k := range users { - response.Resources[k] = UserResource{ - ID: users[k].ID, - UserName: users[k].Login, - } - } - out, err := json.Marshal(response) - if err != nil { - return out, fmt.Errorf("json marshal error: %s", err) - } - return out, nil -} - -func userResponse(user users.User) ([]byte, error) { - response := PostUserRequest{ - Schemas: getSchemas("User"), - Id: user.ID, - UserName: user.Login, - Active: !user.Suspended, - } - out, err := json.Marshal(response) - if err != nil { - return out, fmt.Errorf("json marshal error: %s", err) - } - return out, nil -} - -func getSchemas(responseType string) []string { - if responseType == "User" { - return []string{"urn:ietf:params:scim:schemas:core:2.0:User"} - } - return []string{"urn:ietf:params:scim:api:messages:2.0:" + responseType} - -} diff --git a/pkg/auth/provisioning/scim/router.go b/pkg/auth/provisioning/scim/router.go deleted file mode 100644 index f8f0248..0000000 --- a/pkg/auth/provisioning/scim/router.go +++ /dev/null @@ -1,20 +0,0 @@ -package scim - -import ( - "net/http" -) - -func (s *scim) GetRouter() *http.ServeMux { - mux := http.NewServeMux() - - mux.Handle("/api/scim/", s.authMiddleware(http.HandlerFunc(notFoundHandler))) - mux.Handle("/api/scim/v2/Users", s.authMiddleware(http.HandlerFunc(s.usersHandler))) - mux.Handle("/api/scim/v2/Users/{id}", s.authMiddleware(http.HandlerFunc(s.userHandler))) - - return mux -} - -func notFoundHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(`{"error": "page not found"}`)) -} diff --git a/pkg/auth/provisioning/scim/types.go b/pkg/auth/provisioning/scim/types.go deleted file mode 100644 index 0794075..0000000 --- a/pkg/auth/provisioning/scim/types.go +++ /dev/null @@ -1,54 +0,0 @@ -package scim - -import ( - "net/http" - - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" -) - -type scim struct { - Token string `json:"token"` - UserStore *users.UserStore `json:"userStore"` - storage storage.Iface -} - -type Iface interface { - GetRouter() *http.ServeMux - UpdateToken(token string) -} - -type UserResponse struct { - TotalResults int `json:"totalResults"` - ItemsPerPage int `json:"itemsPerPage"` - StartIndex int `json:"startIndex"` - Schemas []string `json:"schemas"` - Resources []UserResource `json:"Resources"` -} -type UserResource struct { - ID string `json:"id"` - UserName string `json:"userName,omitempty"` -} - -type PostUserRequest struct { - Schemas []string `json:"schemas"` - UserName string `json:"userName"` - Id string `json:"id,omitempty"` - Name Name `json:"name"` - Emails []Emails `json:"emails"` - DisplayName string `json:"displayName"` - Locale string `json:"locale"` - ExternalID string `json:"externalId"` - Groups []any `json:"groups"` - Password string `json:"password"` - Active bool `json:"active"` -} -type Name struct { - GivenName string `json:"givenName"` - FamilyName string `json:"familyName"` -} -type Emails struct { - Primary bool `json:"primary"` - Value string `json:"value"` - Type string `json:"type"` -} diff --git a/pkg/auth/provisioning/scim/update.go b/pkg/auth/provisioning/scim/update.go deleted file mode 100644 index e1abc5e..0000000 --- a/pkg/auth/provisioning/scim/update.go +++ /dev/null @@ -1,5 +0,0 @@ -package scim - -func (s *scim) UpdateToken(token string) { - s.Token = token -} diff --git a/pkg/auth/provisioning/scim/users.go b/pkg/auth/provisioning/scim/users.go deleted file mode 100644 index 34c7e27..0000000 --- a/pkg/auth/provisioning/scim/users.go +++ /dev/null @@ -1,247 +0,0 @@ -package scim - -import ( - "encoding/json" - "fmt" - "net/http" - "strconv" - "strings" - - "github.com/in4it/wireguard-server/pkg/users" - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -// handler for multiple users -func (s *scim) usersHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - s.getUsersHandler(w, r) - return - case http.MethodPost: - s.postUsersHandler(w, r) - return - default: - returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -// handler for a single user -func (s *scim) userHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - s.getUserHandler(w, r) - return - case http.MethodPut: - s.putUserHandler(w, r) - return - case http.MethodDelete: - s.deleteUserHandler(w, r) - return - default: - returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (s *scim) getUsersHandler(w http.ResponseWriter, r *http.Request) { - attributes := r.URL.Query().Get("attributes") - filter := r.URL.Query().Get("filter") - count, err := strconv.Atoi(r.URL.Query().Get("count")) - if err != nil { - count = -1 - } - start, err := strconv.Atoi(r.URL.Query().Get("startIndex")) - if err != nil { - start = 1 - } - - if filter != "" { - response, err := getUsersWithFilter(s.UserStore, attributes, filter) - if err != nil { - returnError(w, fmt.Errorf("get user with filter error: %s", err), http.StatusBadRequest) - return - } - write(w, response) - return - } - response, err := getUsersWithoutFilter(s.UserStore, attributes, count, start) - if err != nil { - returnError(w, fmt.Errorf("get user with filter error: %s", err), http.StatusBadRequest) - return - } - write(w, response) -} - -func (s *scim) getUserHandler(w http.ResponseWriter, r *http.Request) { - user, err := s.UserStore.GetUserByID(r.PathValue("id")) - if err != nil { - returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest) - return - } - - response, err := userResponse(user) - if err != nil { - returnError(w, fmt.Errorf("user response error: %s", err), http.StatusBadRequest) - return - } - - write(w, response) -} -func (s *scim) putUserHandler(w http.ResponseWriter, r *http.Request) { - user, err := s.UserStore.GetUserByID(r.PathValue("id")) - if err != nil { - returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest) - return - } - - var putUserRequest PostUserRequest - err = json.NewDecoder(r.Body).Decode(&putUserRequest) - if err != nil { - returnError(w, fmt.Errorf("unable to decode request payload"), http.StatusBadRequest) - return - } - - if !putUserRequest.Active && !user.Suspended { // user is suspended - err = wireguard.DisableAllClientConfigs(s.storage, user.ID) - if err != nil { - returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", user.ID, err), http.StatusBadRequest) - return - } - } - if putUserRequest.Active && user.Suspended { // user is unsuspended - err := wireguard.ReactivateAllClientConfigs(s.storage, user.ID) - if err != nil { - returnError(w, fmt.Errorf("could not reactivate all clients for user %s: %s", user.ID, err), http.StatusBadRequest) - return - } - } - - user.Suspended = !putUserRequest.Active - username := getUsername(putUserRequest) - if user.Login != username { - if !s.UserStore.LoginExists(username) { - user.Login = username - } - } - - err = s.UserStore.UpdateUser(user) - if err != nil { - returnError(w, fmt.Errorf("user update error: %s", err), http.StatusBadRequest) - return - } - - response, err := userResponse(user) - if err != nil { - returnError(w, fmt.Errorf("user response error: %s", err), http.StatusBadRequest) - return - } - - write(w, response) -} - -func (s *scim) deleteUserHandler(w http.ResponseWriter, r *http.Request) { - user, err := s.UserStore.GetUserByID(r.PathValue("id")) - if err != nil { - returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest) - return - } - - err = wireguard.DeleteAllClientConfigs(s.storage, user.ID) - if err != nil { - returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", user.ID, err), http.StatusBadRequest) - return - } - - err = s.UserStore.DeleteUserByID(user.ID) - if err != nil { - returnError(w, fmt.Errorf("user update error: %s", err), http.StatusBadRequest) - return - } - - write(w, []byte("")) -} - -func (s *scim) postUsersHandler(w http.ResponseWriter, r *http.Request) { - var postUserRequest PostUserRequest - err := json.NewDecoder(r.Body).Decode(&postUserRequest) - if err != nil { - returnError(w, fmt.Errorf("unable to decode request payload"), http.StatusBadRequest) - return - } - - username := getUsername(postUserRequest) - - if s.UserStore.LoginExists(username) { - writeWithStatus(w, []byte("user already exists"), http.StatusConflict) - return - } - - if s.UserStore.GetMaxUsers()-s.UserStore.UserCount() <= 0 { - writeWithStatus(w, []byte("no license available to add new user"), http.StatusBadRequest) - return - } - - user, err := s.UserStore.AddUser(users.User{ - Login: username, - Role: "user", - Provisioned: true, - ExternalID: postUserRequest.ExternalID, - }) - if err != nil { - returnError(w, fmt.Errorf("unable to add user: %s", err), http.StatusBadRequest) - return - } - response, err := userResponse(user) - if err != nil { - returnError(w, fmt.Errorf("unable to generate user response: %s", err), http.StatusBadRequest) - return - } - writeWithStatus(w, response, http.StatusCreated) -} - -func getUsername(postUserRequest PostUserRequest) string { - username := postUserRequest.UserName - if username == "" { - for _, email := range postUserRequest.Emails { - if email.Primary { - username = email.Value - } - } - } - return username -} - -func getUsersWithFilter(userStore *users.UserStore, attributes, filter string) ([]byte, error) { - filterSplit := strings.Split(filter, " ") - if len(filterSplit) != 3 { - return []byte{}, fmt.Errorf("invalid filter") - } - if strings.ToLower(filterSplit[0]) == "username" { - if strings.ToLower(filterSplit[1]) == "eq" { - if userStore.LoginExists(strings.Trim(filterSplit[2], `"`)) { - user, err := userStore.GetUserByLogin(strings.Trim(filterSplit[2], `"`)) - if err != nil { - return []byte{}, fmt.Errorf("get user by login error: %s", err) - } - response, err := listUserResponse([]users.User{user}, attributes, -1, -1) - if err != nil { - return []byte{}, fmt.Errorf("userResponse error: %s", err) - } - return response, nil - } - } - } - response, err := listUserResponse([]users.User{}, attributes, -1, -1) - if err != nil { - return response, fmt.Errorf("userResponse error: %s", err) - } - return response, nil -} - -func getUsersWithoutFilter(userStore *users.UserStore, attributes string, count, start int) ([]byte, error) { - users := userStore.ListUsers() - response, err := listUserResponse(users, attributes, count, start) - if err != nil { - return []byte{}, fmt.Errorf("userResponse error: %s", err) - } - return response, nil -} diff --git a/pkg/auth/saml/config.go b/pkg/auth/saml/config.go deleted file mode 100644 index 3db4e72..0000000 --- a/pkg/auth/saml/config.go +++ /dev/null @@ -1 +0,0 @@ -package saml diff --git a/pkg/auth/saml/handlers.go b/pkg/auth/saml/handlers.go deleted file mode 100644 index 946a836..0000000 --- a/pkg/auth/saml/handlers.go +++ /dev/null @@ -1,95 +0,0 @@ -package saml - -import ( - "fmt" - "net/http" - - "github.com/google/uuid" -) - -func (s *saml) samlHandler(w http.ResponseWriter, r *http.Request) { - providerID := r.PathValue("id") - - if providerID == "" { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte("saml error: no provider specified\n")) - return - } - - provider, err := s.getProviderByID(providerID) - if err != nil { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte(fmt.Sprintf("saml error: can't find provider with specified id: %s", err))) - return - } - - err = s.ensureSPLoaded(provider) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(fmt.Sprintf("saml error: invalid saml configuration: %s\n", err))) - return - } - err = r.ParseForm() - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - if r.Method != "POST" { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte("saml error: not a POST request\n")) - return - } - - if r.FormValue("SAMLResponse") == "" { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte("saml error: empty SAMLResponse\n")) - return - } - - if _, ok := s.serviceProvider[providerID]; !ok { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte("saml error: can't find provider with specified id\n")) - return - } - - assertionInfo, err := s.serviceProvider[providerID].RetrieveAssertionInfo(r.FormValue("SAMLResponse")) - if err != nil { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte(fmt.Sprintf("saml error: %s\n", err))) - return - } - - if assertionInfo.WarningInfo.InvalidTime { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte("saml error: invalid time\n")) - return - } - - if assertionInfo.WarningInfo.NotInAudience { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte("saml error: incorrect audience\n")) - return - } - - login := assertionInfo.NameID - notAfter := *assertionInfo.SessionNotOnOrAfter - - randomString, err := getRandomString(128) - if err != nil { - w.WriteHeader(http.StatusForbidden) - w.Write([]byte(fmt.Sprintf("saml error: could not create session: %s\n", err))) - return - } - sessionKey := SessionKey{ - ProviderID: providerID, - SessionID: randomString, - } - s.CreateSession(sessionKey, AuthenticatedUser{ - ID: uuid.New().String(), - Login: login, - ExpiresAt: notAfter, - }) - w.Header().Add("Location", fmt.Sprintf("/callback/saml/%s?code=%s", providerID, sessionKey.SessionID)) - w.WriteHeader(http.StatusFound) -} diff --git a/pkg/auth/saml/keypair.go b/pkg/auth/saml/keypair.go deleted file mode 100644 index c52a4af..0000000 --- a/pkg/auth/saml/keypair.go +++ /dev/null @@ -1,121 +0,0 @@ -package saml - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "math/big" - "time" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -type KeyPair struct { - storage storage.Iface - hostname string -} - -func NewKeyPair(storage storage.Iface, hostname string) *KeyPair { - return &KeyPair{ - storage: storage, - hostname: hostname, - } -} - -func (kp *KeyPair) GetKeyPair() (privateKey *rsa.PrivateKey, cert []byte, err error) { - if !kp.storage.FileExists(kp.storage.ConfigPath("saml/saml.key")) { - err := kp.generateKeyAndCert() - if err != nil { - return privateKey, cert, fmt.Errorf("can't generate saml key and cert: %s", err) - } - } else if !kp.storage.FileExists(kp.storage.ConfigPath("saml/saml.crt")) { - err = kp.generateKeyAndCert() - if err != nil { - return privateKey, cert, fmt.Errorf("can't generate saml key and cert: %s", err) - } - } - - certPEMBlock, err := kp.storage.ReadFile(kp.storage.ConfigPath("saml/saml.crt")) - if err != nil { - return privateKey, cert, fmt.Errorf("can't read saml certificate: %s", err) - } - keyPEMBlock, err := kp.storage.ReadFile(kp.storage.ConfigPath("saml/saml.key")) - if err != nil { - return privateKey, cert, fmt.Errorf("can't read saml key: %s", err) - } - keyPair, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) - if err != nil { - return privateKey, cert, fmt.Errorf("can't get saml keypair: %s", err) - } - - privateKey = keyPair.PrivateKey.(*rsa.PrivateKey) - cert = keyPair.Certificate[0] - - return privateKey, cert, nil -} - -func (kp *KeyPair) generateKeyAndCert() error { - var ( - certOut bytes.Buffer - keyOut bytes.Buffer - ) - - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return fmt.Errorf("private key generation failed: %s", err) - } - - if err = pem.Encode(&keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}); err != nil { - return err - } - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return fmt.Errorf("rand int error: %s", err) - } - template := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - CommonName: kp.hostname, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(5, 0, 0), - - IsCA: true, - - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) - if err != nil { - return fmt.Errorf("certificate creation error: %s", err) - } - if err = pem.Encode(&certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - return err - } - - // ensure storagepath exists - err = kp.storage.EnsurePath(kp.storage.ConfigPath("saml")) - if err != nil { - return fmt.Errorf("could not ensure saml path exists: %s", err) - } - - err = kp.storage.WriteFile(kp.storage.ConfigPath("saml/saml.key"), keyOut.Bytes()) - if err != nil { - return fmt.Errorf("saml key write error: %s", err) - } - err = kp.storage.WriteFile(kp.storage.ConfigPath("saml/saml.crt"), certOut.Bytes()) - if err != nil { - return fmt.Errorf("saml key write error: %s", err) - } - - return nil -} diff --git a/pkg/auth/saml/metadata.go b/pkg/auth/saml/metadata.go deleted file mode 100644 index 8be0b1f..0000000 --- a/pkg/auth/saml/metadata.go +++ /dev/null @@ -1,43 +0,0 @@ -package saml - -import ( - "encoding/xml" - "fmt" - "io" - "net/http" - "net/url" - - "github.com/russellhaering/gosaml2/types" -) - -func (s *saml) HasValidMetadataURL(metadataURL string) (bool, error) { - metadataURLParsed, err := url.Parse(metadataURL) - if err != nil { - return false, fmt.Errorf("url parse error: %s", err) - } - _, err = getMetadata(metadataURLParsed.String()) - if err != nil { - return false, fmt.Errorf("fetch metadata error: %s", err) - } - return true, nil -} - -func getMetadata(metadataURL string) (types.EntityDescriptor, error) { - metadata := types.EntityDescriptor{} - - res, err := http.Get(metadataURL) - if err != nil { - return metadata, fmt.Errorf("can't retrieve saml metadata: %s", err) - } - - rawMetadata, err := io.ReadAll(res.Body) - if err != nil { - return metadata, fmt.Errorf("can't read saml cert data: %s", err) - } - - err = xml.Unmarshal(rawMetadata, &metadata) - if err != nil { - return metadata, fmt.Errorf("can't decode saml cert data: %s", err) - } - return metadata, nil -} diff --git a/pkg/auth/saml/middleware.go b/pkg/auth/saml/middleware.go deleted file mode 100644 index 6c06349..0000000 --- a/pkg/auth/saml/middleware.go +++ /dev/null @@ -1,135 +0,0 @@ -package saml - -import ( - "crypto/x509" - "encoding/base64" - "encoding/xml" - "fmt" - "io" - "net/http" - "net/url" - - saml2 "github.com/russellhaering/gosaml2" - "github.com/russellhaering/gosaml2/types" - dsig "github.com/russellhaering/goxmldsig" -) - -const ISSUER_URL = "saml/iss" -const AUDIENCE_URL = "saml/aud" -const ACS_URL = "saml/acs" - -func (s *saml) ensureSPLoaded(provider Provider) error { - if _, ok := s.serviceProvider[provider.ID]; !ok { - err := s.loadSP(provider) - if err != nil { - return fmt.Errorf("could not load saml provider: %s", err) - } - } else { - // check if provider is up-to-date - if provider.AllowMissingAttributes != s.serviceProvider[provider.ID].AllowMissingAttributes { - s.serviceProvider[provider.ID] = nil - } - if s.serviceProvider[provider.ID] == nil { - err := s.loadSP(provider) - if err != nil { - return fmt.Errorf("could not reload saml provider: %s", err) - } - } - } - return nil -} - -func (s *saml) loadSP(provider Provider) error { - s.mu.Lock() - defer s.mu.Unlock() - idpMetadataURL, err := url.Parse(provider.MetadataURL) - if err != nil { - return fmt.Errorf("can't parse metadata url: %s", err) - } - // pull metadata - res, err := http.Get(idpMetadataURL.String()) - if err != nil { - return fmt.Errorf("can't retrieve saml metadata: %s", err) - } - - rawMetadata, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("can't read saml cert data: %s", err) - } - - metadata := &types.EntityDescriptor{} - err = xml.Unmarshal(rawMetadata, metadata) - if err != nil { - return fmt.Errorf("can't decode saml cert data: %s", err) - } - - // load certs - certStore := dsig.MemoryX509CertificateStore{} - - if metadata.IDPSSODescriptor == nil || len(metadata.IDPSSODescriptor.KeyDescriptors) == 0 { - return fmt.Errorf("keyDescriptors are empty") - } - if len(metadata.IDPSSODescriptor.SingleSignOnServices) == 0 { - return fmt.Errorf("SingleSignOnServices not found") - } - - certStore.Roots, err = getSAMLCertsFromMetadata(metadata.IDPSSODescriptor.KeyDescriptors) - if err != nil { - return fmt.Errorf("can't parse certs from metadata: %s", err) - } - - keyStore := NewKeyPair(s.storage, *s.hostname) - - sp := &saml2.SAMLServiceProvider{ - IdentityProviderSSOURL: metadata.IDPSSODescriptor.SingleSignOnServices[0].Location, - IdentityProviderIssuer: metadata.EntityID, - ServiceProviderIssuer: fmt.Sprintf("%s://%s/%s/%s", *s.protocol, *s.hostname, ISSUER_URL, provider.ID), - AssertionConsumerServiceURL: fmt.Sprintf("%s://%s/%s/%s", *s.protocol, *s.hostname, ACS_URL, provider.ID), - SignAuthnRequests: true, - AudienceURI: fmt.Sprintf("%s://%s/%s/%s", *s.protocol, *s.hostname, AUDIENCE_URL, provider.ID), - IDPCertificateStore: &certStore, - SPKeyStore: keyStore, - AllowMissingAttributes: provider.AllowMissingAttributes, - } - - s.serviceProvider[provider.ID] = sp - - return err -} - -func getSAMLCertsFromMetadata(keyDescriptors []types.KeyDescriptor) ([]*x509.Certificate, error) { - certs := []*x509.Certificate{} - - for _, kd := range keyDescriptors { - for idx, xcert := range kd.KeyInfo.X509Data.X509Certificates { - if xcert.Data == "" { - return nil, fmt.Errorf("metadata certificate(%d) must not be empty", idx) - } - certData, err := base64.StdEncoding.DecodeString(xcert.Data) - if err != nil { - return nil, fmt.Errorf("decode error:%s", err) - } - - idpCert, err := x509.ParseCertificate(certData) - if err != nil { - return nil, fmt.Errorf("cert parse error: %s", err) - } - - certs = append(certs, idpCert) - } - } - - return certs, nil - -} - -func (s *saml) GetAuthURL(provider Provider) (string, error) { - err := s.ensureSPLoaded(provider) - if err != nil { - return "", fmt.Errorf("saml error: invalid saml configuration: %s", err) - } - if _, ok := s.serviceProvider[provider.ID]; !ok { - return "", fmt.Errorf("provider not found") - } - return s.serviceProvider[provider.ID].BuildAuthURL("") -} diff --git a/pkg/auth/saml/new.go b/pkg/auth/saml/new.go deleted file mode 100644 index 4ca68e0..0000000 --- a/pkg/auth/saml/new.go +++ /dev/null @@ -1,21 +0,0 @@ -package saml - -import ( - "github.com/in4it/wireguard-server/pkg/storage" - saml2 "github.com/russellhaering/gosaml2" -) - -func New(providers *[]Provider, storage storage.Iface, protocol, hostname *string) Iface { - s := &saml{ - Providers: providers, - sessions: make(map[SessionKey]AuthenticatedUser), - serviceProvider: make(map[string]*saml2.SAMLServiceProvider), - protocol: protocol, - hostname: hostname, - storage: storage, - } - for _, provider := range *providers { - s.loadSP(provider) - } - return s -} diff --git a/pkg/auth/saml/provider.go b/pkg/auth/saml/provider.go deleted file mode 100644 index fc00389..0000000 --- a/pkg/auth/saml/provider.go +++ /dev/null @@ -1,12 +0,0 @@ -package saml - -import "fmt" - -func (s *saml) getProviderByID(id string) (Provider, error) { - for k := range *s.Providers { - if (*s.Providers)[k].ID == id { - return (*s.Providers)[k], nil - } - } - return Provider{}, fmt.Errorf("provider not found") -} diff --git a/pkg/auth/saml/rand.go b/pkg/auth/saml/rand.go deleted file mode 100644 index 294b95a..0000000 --- a/pkg/auth/saml/rand.go +++ /dev/null @@ -1,19 +0,0 @@ -package saml - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -func getRandomString(n int) (string, error) { - buf := make([]byte, n) - - _, err := io.ReadFull(rand.Reader, buf) - if err != nil { - return "", fmt.Errorf("crypto/rand Reader error: %s", err) - } - - return base64.RawURLEncoding.EncodeToString(buf), nil -} diff --git a/pkg/auth/saml/router.go b/pkg/auth/saml/router.go deleted file mode 100644 index 185e9e1..0000000 --- a/pkg/auth/saml/router.go +++ /dev/null @@ -1,12 +0,0 @@ -package saml - -import ( - "net/http" -) - -func (s *saml) GetRouter() *http.ServeMux { - mux := http.NewServeMux() - mux.Handle("/saml/acs/{id}", http.HandlerFunc(s.samlHandler)) - - return mux -} diff --git a/pkg/auth/saml/session.go b/pkg/auth/saml/session.go deleted file mode 100644 index 72dfa86..0000000 --- a/pkg/auth/saml/session.go +++ /dev/null @@ -1,26 +0,0 @@ -package saml - -import ( - "fmt" - "time" -) - -func (s *saml) GetAuthenticatedUser(provider Provider, sessionID string) (AuthenticatedUser, error) { - sessionKey := SessionKey{ - ProviderID: provider.ID, - SessionID: sessionID, - } - if authenticatedUser, ok := s.sessions[sessionKey]; ok { - if authenticatedUser.ExpiresAt.Before(time.Now()) { - return authenticatedUser, fmt.Errorf("session is expired") - } - return authenticatedUser, nil - } - return AuthenticatedUser{}, fmt.Errorf("session not found") -} - -func (s *saml) CreateSession(key SessionKey, value AuthenticatedUser) { - s.mu.Lock() - defer s.mu.Unlock() - s.sessions[key] = value -} diff --git a/pkg/auth/saml/types.go b/pkg/auth/saml/types.go deleted file mode 100644 index 1d16f48..0000000 --- a/pkg/auth/saml/types.go +++ /dev/null @@ -1,222 +0,0 @@ -package saml - -import ( - "encoding/xml" - "net/http" - "sync" - "time" - - "github.com/in4it/wireguard-server/pkg/storage" - saml2 "github.com/russellhaering/gosaml2" -) - -type Provider struct { - ID string `json:"id"` - Name string `json:"name"` - MetadataURL string `json:"metadataURL"` - Issuer string `json:"issuer,omitempty"` - Audience string `json:"audience,omitempty"` - Acs string `json:"acs,omitempty"` - AllowMissingAttributes bool `json:"allowMissingAttributes,omitempty"` -} - -type saml struct { - Providers *[]Provider - serviceProvider map[string]*saml2.SAMLServiceProvider - sessions map[SessionKey]AuthenticatedUser - hostname *string - protocol *string - mu sync.Mutex - storage storage.Iface -} -type AuthenticatedUser struct { - ID string - Login string - ExpiresAt time.Time -} -type SessionKey struct { - ProviderID string - SessionID string -} - -type Iface interface { - GetAuthURL(provider Provider) (string, error) - GetRouter() *http.ServeMux - GetAuthenticatedUser(provider Provider, sessionID string) (AuthenticatedUser, error) - HasValidMetadataURL(metadataURL string) (bool, error) - CreateSession(key SessionKey, value AuthenticatedUser) -} - -type AuthnRequest struct { - XMLName xml.Name `xml:"AuthnRequest"` - Text string `xml:",chardata"` - Samlp string `xml:"samlp,attr"` - Saml string `xml:"saml,attr"` - ID string `xml:"ID,attr"` - Version string `xml:"Version,attr"` - ProtocolBinding string `xml:"ProtocolBinding,attr"` - AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` - IssueInstant string `xml:"IssueInstant,attr"` - Destination string `xml:"Destination,attr"` - Issuer string `xml:"Issuer"` - Signature struct { - Text string `xml:",chardata"` - Ds string `xml:"ds,attr"` - SignedInfo struct { - Text string `xml:",chardata"` - CanonicalizationMethod struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` - } `xml:"CanonicalizationMethod"` - SignatureMethod struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` - } `xml:"SignatureMethod"` - Reference struct { - Text string `xml:",chardata"` - URI string `xml:"URI,attr"` - Transforms struct { - Text string `xml:",chardata"` - Transform []struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` - } `xml:"Transform"` - } `xml:"Transforms"` - DigestMethod struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` - } `xml:"DigestMethod"` - DigestValue string `xml:"DigestValue"` - } `xml:"Reference"` - } `xml:"SignedInfo"` - SignatureValue string `xml:"SignatureValue"` - KeyInfo struct { - Text string `xml:",chardata"` - X509Data struct { - Text string `xml:",chardata"` - X509Certificate string `xml:"X509Certificate"` - } `xml:"X509Data"` - } `xml:"KeyInfo"` - } `xml:"Signature"` - NameIDPolicy struct { - Text string `xml:",chardata"` - AllowCreate string `xml:"AllowCreate,attr"` - } `xml:"NameIDPolicy"` -} - -type Response struct { - XMLName xml.Name `xml:"Response"` - Text string `xml:",chardata"` - Saml string `xml:"xmlns:saml,attr"` - Samlp string `xml:"xmlns:samlp,attr"` - ID string `xml:"ID,attr"` - Version string `xml:"Version,attr"` - IssueInstant string `xml:"IssueInstant,attr"` - Destination string `xml:"Destination,attr"` - Issuer string `xml:"saml:Issuer"` - Signature ResponseSignature `xml:"ds:Signature"` - Status struct { - Text string `xml:",chardata"` - StatusCode struct { - Text string `xml:",chardata"` - Value string `xml:"Value,attr"` - } `xml:"StatusCode"` - } `xml:"Status"` - Assertion ResponseAssertion `xml:"Assertion"` -} - -type ResponseSignature struct { - XMLName xml.Name `xml:"ds:Signature"` - Text string `xml:",chardata"` - Ds string `xml:"xmlns:ds,attr"` - SignedInfo ResponseSignatureSignedInfo `xml:"ds:SignedInfo"` - SignatureValue string `xml:"ds:SignatureValue"` - KeyInfo ResponseSignatureKeyInfo `xml:"ds:KeyInfo"` -} - -type ResponseSignatureSignedInfo struct { - Text string `xml:",chardata"` - CanonicalizationMethod struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` - } `xml:"ds:CanonicalizationMethod"` - SignatureMethod ResponseSignatureSignedInfoSignatureMethod `xml:"ds:SignatureMethod"` - Reference ResponseSignatureSignedInfoReference `xml:"ds:Reference"` -} -type ResponseSignatureSignedInfoSignatureMethod struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` -} -type ResponseSignatureSignedInfoReference struct { - Text string `xml:",chardata"` - URI string `xml:"URI,attr"` - Transforms struct { - Text string `xml:",chardata"` - Transform []struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` - } `xml:"ds:Transform"` - } `xml:"ds:Transforms"` - DigestMethod struct { - Text string `xml:",chardata"` - Algorithm string `xml:"Algorithm,attr"` - } `xml:"ds:DigestMethod"` - DigestValue string `xml:"ds:DigestValue"` -} -type ResponseSignatureKeyInfo struct { - Text string `xml:",chardata"` - X509Data struct { - Text string `xml:",chardata"` - X509Certificate string `xml:"ds:X509Certificate"` - } `xml:"ds:X509Data"` -} - -type ResponseConditions struct { - Text string `xml:",chardata"` - NotBefore string `xml:"NotBefore,attr"` - NotOnOrAfter string `xml:"NotOnOrAfter,attr"` - AudienceRestriction ResponseConditionsAdienceRestriction `xml:"AudienceRestriction"` -} -type ResponseConditionsAdienceRestriction struct { - Text string `xml:",chardata"` - Audience string `xml:"Audience"` -} -type ResponseSubject struct { - Text string `xml:",chardata"` - NameID struct { - Text string `xml:",chardata"` - Format string `xml:"Format,attr"` - } `xml:"NameID"` - SubjectConfirmation struct { - Text string `xml:",chardata"` - Method string `xml:"Method,attr"` - SubjectConfirmationData struct { - Text string `xml:",chardata"` - NotOnOrAfter string `xml:"NotOnOrAfter,attr"` - Recipient string `xml:"Recipient,attr"` - } `xml:"SubjectConfirmationData"` - } `xml:"SubjectConfirmation"` -} - -type ResponseAssertion struct { - Text string `xml:",chardata"` - Saml string `xml:"saml,attr"` - Xs string `xml:"xs,attr"` - Xsi string `xml:"xsi,attr"` - Version string `xml:"Version,attr"` - ID string `xml:"ID,attr"` - IssueInstant string `xml:"IssueInstant,attr"` - Issuer string `xml:"Issuer"` - Subject ResponseSubject `xml:"Subject"` - Conditions ResponseConditions `xml:"Conditions"` - AuthnStatement struct { - Text string `xml:",chardata"` - AuthnInstant string `xml:"AuthnInstant,attr"` - SessionNotOnOrAfter string `xml:"SessionNotOnOrAfter,attr"` - SessionIndex string `xml:"SessionIndex,attr"` - AuthnContext struct { - Text string `xml:",chardata"` - AuthnContextClassRef string `xml:"AuthnContextClassRef"` - } `xml:"AuthnContext"` - } `xml:"AuthnStatement"` -} diff --git a/pkg/commands/resetmfa.go b/pkg/commands/resetmfa.go index 5de910a..2ab0427 100644 --- a/pkg/commands/resetmfa.go +++ b/pkg/commands/resetmfa.go @@ -3,9 +3,9 @@ package commands import ( "fmt" - "github.com/in4it/wireguard-server/pkg/rest" - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" + "github.com/in4it/go-devops-platform/rest" + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" ) func ResetAdminMFA(storage storage.Iface) error { diff --git a/pkg/commands/resetpassword.go b/pkg/commands/resetpassword.go index 2e5efa2..4182328 100644 --- a/pkg/commands/resetpassword.go +++ b/pkg/commands/resetpassword.go @@ -3,9 +3,9 @@ package commands import ( "fmt" - "github.com/in4it/wireguard-server/pkg/rest" - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" + "github.com/in4it/go-devops-platform/rest" + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" ) func ResetPassword(storage storage.Iface, password string) (bool, error) { diff --git a/pkg/commands/resetpassword_test.go b/pkg/commands/resetpassword_test.go index 45972fa..2d21690 100644 --- a/pkg/commands/resetpassword_test.go +++ b/pkg/commands/resetpassword_test.go @@ -3,8 +3,8 @@ package commands import ( "testing" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" - "github.com/in4it/wireguard-server/pkg/users" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + "github.com/in4it/go-devops-platform/users" ) func TestResetPassword(t *testing.T) { diff --git a/pkg/configmanager/refresh_darwin.go b/pkg/configmanager/refresh_darwin.go index 6ffb396..a6ea07f 100644 --- a/pkg/configmanager/refresh_darwin.go +++ b/pkg/configmanager/refresh_darwin.go @@ -8,7 +8,7 @@ import ( "fmt" "os" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" ) diff --git a/pkg/configmanager/refresh_linux.go b/pkg/configmanager/refresh_linux.go index b8311c8..af6af8b 100644 --- a/pkg/configmanager/refresh_linux.go +++ b/pkg/configmanager/refresh_linux.go @@ -8,7 +8,7 @@ import ( "fmt" "os" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" syncclients "github.com/in4it/wireguard-server/pkg/wireguard/linux/syncclients" ) diff --git a/pkg/configmanager/server.go b/pkg/configmanager/server.go index 3ef2580..f661821 100644 --- a/pkg/configmanager/server.go +++ b/pkg/configmanager/server.go @@ -5,8 +5,8 @@ import ( "log" "net/http" - "github.com/in4it/wireguard-server/pkg/storage" - localstorage "github.com/in4it/wireguard-server/pkg/storage/local" + "github.com/in4it/go-devops-platform/storage" + localstorage "github.com/in4it/go-devops-platform/storage/local" "github.com/in4it/wireguard-server/pkg/wireguard" ) diff --git a/pkg/configmanager/setupcode.go b/pkg/configmanager/setupcode.go index 4d5bd23..755482c 100644 --- a/pkg/configmanager/setupcode.go +++ b/pkg/configmanager/setupcode.go @@ -7,7 +7,7 @@ import ( "io" "os/user" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" ) func writeSetupCode(storage storage.Iface) error { diff --git a/pkg/configmanager/start_darwin.go b/pkg/configmanager/start_darwin.go index 07ea124..f4af3cd 100644 --- a/pkg/configmanager/start_darwin.go +++ b/pkg/configmanager/start_darwin.go @@ -6,7 +6,7 @@ package configmanager import ( "fmt" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" ) diff --git a/pkg/configmanager/start_linux.go b/pkg/configmanager/start_linux.go index 033ee02..50c7bd5 100644 --- a/pkg/configmanager/start_linux.go +++ b/pkg/configmanager/start_linux.go @@ -6,7 +6,7 @@ package configmanager import ( "log" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" ) diff --git a/pkg/configmanager/types.go b/pkg/configmanager/types.go index 0546805..7a73730 100644 --- a/pkg/configmanager/types.go +++ b/pkg/configmanager/types.go @@ -1,7 +1,7 @@ package configmanager import ( - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" ) diff --git a/pkg/license/aws.go b/pkg/license/aws.go deleted file mode 100644 index 4b507e7..0000000 --- a/pkg/license/aws.go +++ /dev/null @@ -1,272 +0,0 @@ -package license - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -const AWS_PRODUCT_CODE = "7h7h3bnutjn0ziamv7npi8a69" - -func getMetadataToken(client http.Client) string { - metadataEndpoint := "http://" + MetadataIP + "/latest/api/token" - - req, err := http.NewRequest("PUT", metadataEndpoint, nil) - if err != nil { - return "" - } - - req.Header.Add("X-aws-ec2-metadata-token-ttl-seconds", "21600") - - resp, err := client.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - if resp.StatusCode == 200 { - bodyBytes, _ := io.ReadAll(resp.Body) - return string(bodyBytes) - } - return "" -} - -func isOnAWSMarketPlace(client http.Client) bool { - token := getMetadataToken(client) - - instanceIdentityDocument, err := getInstanceIdentityDocument(client, token) - if err != nil { - return false - } - for _, productCode := range instanceIdentityDocument.MarketplaceProductCodes { - if productCode == AWS_PRODUCT_CODE { - return true - } - } - return false -} - -func isOnAWS(client http.Client) bool { - token := getMetadataToken(client) - - instanceIdentityDocument, err := getInstanceIdentityDocument(client, token) - if err != nil { - return false - } - return instanceIdentityDocument.AccountID != "" || instanceIdentityDocument.Version != "" -} - -func getInstanceIdentityDocument(client http.Client, token string) (InstanceIdentityDocument, error) { - var instanceIdentityDocument InstanceIdentityDocument - - endpoint := "http://" + MetadataIP + "/2022-09-24/dynamic/instance-identity/document" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return instanceIdentityDocument, err - } - if token != "" { - req.Header.Add("X-aws-ec2-metadata-token", token) - } - - resp, err := client.Do(req) - if err != nil { - return instanceIdentityDocument, err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - return instanceIdentityDocument, err - } - err = json.NewDecoder(resp.Body).Decode(&instanceIdentityDocument) - if err != nil { - return instanceIdentityDocument, err - } - - return instanceIdentityDocument, nil -} - -func GetMaxUsersAWSBYOL(client http.Client, storage storage.ReadWriter) int { - userLicense := 3 - licenseKey, err := getAWSLicenseKey(storage, client) - if err != nil { - return userLicense - } - license, err := getLicense(client, licenseKey) - if err != nil { - return userLicense - } - return license.Users -} - -func getAWSLicenseKey(storage storage.ReadWriter, client http.Client) (string, error) { - token := getMetadataToken(client) - licenseKey, err := getLicenseFromMetaData(token, client) - if err != nil || licenseKey == "" { - licenseKey, err = getLicenseKeyFromFile(storage) - if err != nil { - return "", err - } - } - - instanceIdentityDocument, err := getInstanceIdentityDocument(client, token) - if err != nil { - return "", err - } - - return generateLicenseKey(licenseKey, instanceIdentityDocument.AccountID), nil -} - -func getLicense(client http.Client, key string) (License, error) { - var license License - endpoint := licenseURL + "/" + key - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return license, err - } - - resp, err := client.Do(req) - if err != nil { - return license, err - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - return license, fmt.Errorf("statuscode %d", resp.StatusCode) - } - err = json.NewDecoder(resp.Body).Decode(&license) - if err != nil { - return license, err - } - - return license, nil - -} - -func getLicenseFromMetaData(token string, client http.Client) (string, error) { - endpoint := "http://" + MetadataIP + "/2022-09-24/meta-data/tags/instance/license" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return "", err - } - if token != "" { - req.Header.Add("X-aws-ec2-metadata-token", token) - } - - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - return "", err - } - - bodyBytes, err := io.ReadAll(resp.Body) - if resp.StatusCode != 200 { - return "", err - } - return string(bodyBytes), nil -} - -func getAWSInstanceType(client http.Client) string { - token := getMetadataToken(client) - - endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-type" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return "" - } - if token != "" { - req.Header.Add("X-aws-ec2-metadata-token", token) - } - - resp, err := client.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - if resp.StatusCode == 200 { - bodyBytes, _ := io.ReadAll(resp.Body) - return string(bodyBytes) - } - return "" -} - -func GetAWSInstanceID(client http.Client) (string, error) { - token := getMetadataToken(client) - - endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-id" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return "", err - } - if token != "" { - req.Header.Add("X-aws-ec2-metadata-token", token) - } - - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - if resp.StatusCode == 200 { - bodyBytes, _ := io.ReadAll(resp.Body) - return string(bodyBytes), err - } - return "", fmt.Errorf("received statuscode %d from aws metadata api", resp.StatusCode) -} - -func GetMaxUsersAWS(instanceType string) int { - if instanceType == "" { - return 3 - } - if strings.HasSuffix(instanceType, ".nano") { - return 3 - } - if strings.HasSuffix(instanceType, ".micro") { - return 10 - } - if strings.HasSuffix(instanceType, ".small") { - return 25 - } - if strings.HasSuffix(instanceType, ".medium") { - return 50 - } - if strings.HasSuffix(instanceType, ".large") { - return 100 - } - if strings.HasSuffix(instanceType, ".xlarge") { - return 250 - } - if strings.HasSuffix(instanceType, ".2xlarge") { - return 500 - } - if strings.HasSuffix(instanceType, ".4xlarge") { - return 1000 - } - if strings.HasSuffix(instanceType, ".8xlarge") { - return 2500 - } - if strings.HasSuffix(instanceType, ".12xlarge") { - return 5000 - } - if strings.HasSuffix(instanceType, ".16xlarge") { - return 10000 - } - if strings.HasSuffix(instanceType, ".24xlarge") { - return 10000 - } - if strings.HasSuffix(instanceType, ".32xlarge") { - return 10000 - } - if strings.HasSuffix(instanceType, ".48xlarge") { - return 10000 - } - if strings.HasSuffix(instanceType, ".metal") { - return 10000 - } - - return 3 -} diff --git a/pkg/license/azure.go b/pkg/license/azure.go deleted file mode 100644 index 18ab9de..0000000 --- a/pkg/license/azure.go +++ /dev/null @@ -1,81 +0,0 @@ -package license - -import ( - "encoding/json" - "io" - "net/http" - "regexp" - "strconv" -) - -func isOnAzure(client http.Client) bool { - req, err := http.NewRequest("GET", "http://"+MetadataIP+"/metadata/versions", nil) - if err != nil { - return false - } - - req.Header.Add("Metadata", "true") - - resp, err := client.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == 200 -} - -func GetMaxUsersAzure(instanceType string) int { - if instanceType == "" { - return 3 - } - // patterns - versionPattern := regexp.MustCompile(`^.*v[0-9]+#`) - cpuPattern := regexp.MustCompile("[0-9]+") - - // extract amount of CPUs - instanceTypeNoVersion := versionPattern.ReplaceAllString(instanceType, "") - - instanceTypeCPUs := cpuPattern.FindAllString(instanceTypeNoVersion, -1) - - if len(instanceTypeCPUs) > 0 { - instanceTypeCPUCount, err := strconv.Atoi(instanceTypeCPUs[0]) - if err != nil { - return 3 - } - if instanceTypeCPUCount == 0 { - return 15 - } - return instanceTypeCPUCount * 25 - } - - return 3 -} -func getAzureInstanceType(client http.Client) string { - metadataEndpoint := "http://" + MetadataIP + "/metadata/instance?api-version=2021-02-01" - req, err := http.NewRequest("GET", metadataEndpoint, nil) - if err != nil { - return "" - } - - req.Header.Add("Metadata", "true") - - resp, err := client.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - return "" - } - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return "" - } - var instanceMetadata AzureInstanceMetadata - err = json.Unmarshal(bodyBytes, &instanceMetadata) - if err != nil { - return "" - } - return instanceMetadata.Compute.VMSize -} diff --git a/pkg/license/digitalocean.go b/pkg/license/digitalocean.go deleted file mode 100644 index 5be4c07..0000000 --- a/pkg/license/digitalocean.go +++ /dev/null @@ -1,121 +0,0 @@ -package license - -import ( - "bufio" - "fmt" - "io" - "net/http" - "strings" - - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" -) - -func isOnDigitalOcean(client http.Client) bool { - endpoint := "http://" + MetadataIP + "/metadata/v1/interfaces/private/0/type" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return false - } - - resp, err := client.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == 200 -} - -func GetMaxUsersDigitalOceanBYOL(client http.Client, storage storage.ReadWriter) int { - userLicense := 3 - - licenseKey, err := getDigitalOceanLicenseKey(storage, client) - if err != nil { - logging.DebugLog(fmt.Errorf("get digitalocean license error: %s", err)) - return userLicense - } - - license, err := getLicense(client, licenseKey) - if err != nil { - logging.DebugLog(fmt.Errorf("getLicense error: %s", err)) - return userLicense - } - - return license.Users -} - -func getDigitalOceanLicenseKey(storage storage.ReadWriter, client http.Client) (string, error) { - identifier, err := getDigitalOceanIdentifier(client) - if err != nil { - logging.DebugLog(fmt.Errorf("License generation error (identifier error): %s", err)) - return "", err - } - - licenseKey, err := getLicenseKeyFromFile(storage) - if err != nil { - return "", err - } - - return generateLicenseKey(licenseKey, identifier), nil -} - -func getDigitalOceanIdentifier(client http.Client) (string, error) { - id := "" - endpoint := "http://" + MetadataIP + "/metadata/v1/id" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return id, err - } - - resp, err := client.Do(req) - if err != nil { - return id, err - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return id, err - } - if resp.StatusCode != 200 { - return id, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body) - } - - return strings.TrimSpace(string(body)), nil - -} - -func HasDigitalOceanTagSet(client http.Client, tag string) (bool, error) { - endpoint := "http://" + MetadataIP + "/metadata/v1/tags" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return false, err - } - - resp, err := client.Do(req) - if err != nil { - return false, err - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) - if err != nil { - return false, err - } - return false, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body) - } - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - if tag == strings.TrimSpace(scanner.Text()) { - return true, nil - } - } - - if err := scanner.Err(); err != nil { - return false, err - } - - return false, nil - -} diff --git a/pkg/license/gcp.go b/pkg/license/gcp.go deleted file mode 100644 index 5e69c54..0000000 --- a/pkg/license/gcp.go +++ /dev/null @@ -1,88 +0,0 @@ -package license - -import ( - "fmt" - "io" - "net/http" - "strings" - - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" -) - -func isOnGCP(client http.Client) bool { - endpoint := "http://" + MetadataIP + "/computeMetadata/v1/" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return false - } - - req.Header.Add("Metadata-Flavor", "Google") - - resp, err := client.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == 200 -} - -func GetMaxUsersGCPBYOL(client http.Client, storage storage.ReadWriter) int { - userLicense := 3 - - licenseKey, err := getGCPLicenseKey(storage, client) - if err != nil { - logging.DebugLog(fmt.Errorf("get gcp license error: %s", err)) - return userLicense - } - - license, err := getLicense(client, licenseKey) - if err != nil { - logging.DebugLog(fmt.Errorf("getLicense error: %s", err)) - return userLicense - } - - return license.Users -} - -func getGCPLicenseKey(storage storage.ReadWriter, client http.Client) (string, error) { - identifier, err := getGCPIdentifier(client) - if err != nil { - logging.DebugLog(fmt.Errorf("License generation error (identifier error): %s", err)) - return "", err - } - - licenseKey, err := getLicenseKeyFromFile(storage) - if err != nil { - return "", err - } - - return generateLicenseKey(licenseKey, identifier), nil -} - -func getGCPIdentifier(client http.Client) (string, error) { - id := "" - endpoint := "http://" + MetadataIP + "/computeMetadata/v1/project/project-id" - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return id, err - } - - req.Header.Add("Metadata-Flavor", "Google") - - resp, err := client.Do(req) - if err != nil { - return id, err - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return id, err - } - if resp.StatusCode != 200 { - return id, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body) - } - - return strings.TrimSpace(string(body)), nil - -} diff --git a/pkg/license/gcp_test.go b/pkg/license/gcp_test.go deleted file mode 100644 index f52b047..0000000 --- a/pkg/license/gcp_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package license - -import ( - "crypto/sha256" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestGuessInfrastructureGCP(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/computeMetadata/v1/" { - w.WriteHeader(http.StatusOK) - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - infra := guessInfrastructure() - - if infra != "gcp" { - t.Fatalf("wrong infra returned: %s", infra) - } -} - -func TestGetMaxUsersGCPBYOL(t *testing.T) { - projectID := "gcpproject-1234567890" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/computeMetadata/v1/project/project-id" { - w.Write([]byte(projectID)) - - return - } - h := sha256.New() - h.Write([]byte(projectID)) - if r.RequestURI == fmt.Sprintf("/license-1234556-license-%x", h.Sum(nil)) { - w.Write([]byte(`{"users": 50}`)) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - licenseURL = ts.URL - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - mockStorage := &memorystorage.MockMemoryStorage{} - err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) - if err != nil { - t.Fatalf("writefile error: %s", err) - } - - for _, v := range []int{50} { - if v2 := GetMaxUsersGCPBYOL(http.Client{Timeout: 5 * time.Second}, mockStorage); v2 != v { - t.Fatalf("Wrong output: %d vs %d", v2, v) - } - } -} diff --git a/pkg/license/license.go b/pkg/license/license.go deleted file mode 100644 index b3cfb77..0000000 --- a/pkg/license/license.go +++ /dev/null @@ -1,167 +0,0 @@ -package license - -import ( - "crypto/sha256" - "fmt" - "net/http" - "time" - - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" - randomutils "github.com/in4it/wireguard-server/pkg/utils/random" -) - -var MetadataIP = "169.254.169.254" -var licenseURL = "https://in4it-vpn-server.s3.amazonaws.com/licenses" - -func guessInfrastructure() string { - // check whether we are on AWS, Azure, DigitalOcean or something undefined - client := http.Client{ - Timeout: 5 * time.Second, - } - - if isOnAWSMarketPlace(client) { - return "aws-marketplace" - } - - if isOnAWS(client) { - return "aws" - } - - if isOnAzure(client) { - return "azure" - } - - if isOnDigitalOcean(client) { - return "digitalocean" - } - - if isOnGCP(client) { - return "gcp" - } - - return "" // no metadata server found -} - -func GetInstanceType() (string, string) { - client := http.Client{ - Timeout: 5 * time.Second, - } - switch guessInfrastructure() { - case "azure": - return "azure", getAzureInstanceType(client) - case "aws-marketplace": - return "aws-marketplace", getAWSInstanceType(client) - case "aws": - return "aws", getAWSInstanceType(client) - case "digitalocean": - return "digitalocean", "droplet" - case "gcp": - return "gcp", "instance" - default: - return "", "" - } -} -func GetMaxUsers(storage storage.ReadWriter) (int, string) { - cloudType, instanceType := GetInstanceType() - return getMaxUsers(storage, cloudType, instanceType), cloudType -} -func getMaxUsers(storage storage.ReadWriter, cloudType, instanceType string) int { - switch cloudType { - case "azure": - return GetMaxUsersAzure(instanceType) - case "aws-marketplace": - return GetMaxUsersAWS(instanceType) - case "aws": - client := http.Client{ - Timeout: 5 * time.Second, - } - return GetMaxUsersAWSBYOL(client, storage) - case "digitalocean": - client := http.Client{ - Timeout: 5 * time.Second, - } - return GetMaxUsersDigitalOceanBYOL(client, storage) - case "": - client := http.Client{ - Timeout: 5 * time.Second, - } - return GetMaxUsersBYOLNoCloud(client, storage) - default: - return 3 - } -} - -func RefreshLicense(storage storage.ReadWriter, cloudType string, currentLicense int) int { - if cloudType == "azure" || cloudType == "aws-marketplace" { // instance types / license is not going to change without a restart - return currentLicense - } - cloudType, instanceType := GetInstanceType() - return getMaxUsers(storage, cloudType, instanceType) -} - -func GetLicenseKey(storage storage.ReadWriter, cloudType string) string { - client := http.Client{ - Timeout: 5 * time.Second, - } - switch cloudType { - case "aws": - licenseKey, err := getAWSLicenseKey(storage, client) - if err != nil { - logging.DebugLog(fmt.Errorf("getAWSLicense error: %s", err)) - return "" - } - return licenseKey - case "digitalocean": - licenseKey, err := getDigitalOceanLicenseKey(storage, client) - if err != nil { - logging.DebugLog(fmt.Errorf("getDigitalOceanLicense error: %s", err)) - return "" - } - return licenseKey - case "gcp": - licenseKey, err := getGCPLicenseKey(storage, client) - if err != nil { - logging.DebugLog(fmt.Errorf("getGCPLicenseKey error: %s", err)) - return "" - } - return licenseKey - default: - licenseKey, err := getLicenseKeyFromFile(storage) - if err != nil { - logging.DebugLog(fmt.Errorf("getLicenseKeyFromFile error: %s", err)) - return "" - } - return licenseKey - } - -} - -func generateLicenseKey(key string, identifier string) string { - h := sha256.New() - h.Write([]byte(identifier)) - bs := h.Sum(nil) - - return key + "-" + fmt.Sprintf("%x", bs) -} - -func getLicenseKeyFromFile(storage storage.ReadWriter) (string, error) { - filename := storage.ConfigPath("license.key") - - if storage.FileExists(filename) { - licenseKeyBytes, err := storage.ReadFile(filename) - if err != nil { - return "", fmt.Errorf("License read error: %s", err) - } - return string(licenseKeyBytes), nil - } - key, err := randomutils.GetRandomString(128) - if err != nil { - return "", fmt.Errorf("License generation error: %s", err) - } - err = storage.WriteFile(filename, []byte(key)) - if err != nil { - return "", fmt.Errorf("License read error: %s", err) - } - return key, nil -} diff --git a/pkg/license/license_test.go b/pkg/license/license_test.go deleted file mode 100644 index 87172fa..0000000 --- a/pkg/license/license_test.go +++ /dev/null @@ -1,413 +0,0 @@ -package license - -import ( - "crypto/sha256" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/in4it/wireguard-server/pkg/logging" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestGetMaxUsersAzure(t *testing.T) { - usersPerVCPU := 25 - testCases := map[string]int{ - "Standard_B1s": usersPerVCPU, - "Basic_A0": 15, - "Standard_D1_v2": usersPerVCPU, - "Standard_D5_v2": usersPerVCPU * 5, - "D96as_v6": usersPerVCPU * 96, - "Standard_D16pls_v5": usersPerVCPU * 16, - "Standard_DC1s_v3": usersPerVCPU * 1, - } - for k, v := range testCases { - if GetMaxUsersAzure(k) != v { - t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAzure(k), v) - } - } -} - -func TestGetMaxUsersAWSMarketplace(t *testing.T) { - testCases := map[string]int{ - "t3.medium": 50, - "t3.large": 100, - "t3.xlarge": 250, - } - for instanceType, v := range testCases { - if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws-marketplace", instanceType) != v { - t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v) - } - } -} - -func TestGetMaxUsersAWSBYOL(t *testing.T) { - accountID := "1234567890" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/versions" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.RequestURI == "/latest/api/token" { - w.Write([]byte("this is a test token")) - return - } - if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" { - w.Write([]byte(`{ - "accountId" : "` + accountID + `", - "architecture" : "x86_64", - "availabilityZone" : "us-east-1c", - "billingProducts" : null, - "devpayProductCodes" : null, - "marketplaceProductCodes" : [ "7h7h3bnutjn0ziamv7npi8a69" ], - "imageId" : "ami-12345678", - "instanceId" : "i-123456", - "instanceType" : "t3.micro", - "kernelId" : null, - "pendingTime" : "2024-06-15T08:34:50Z", - "privateIp" : "10.0.1.123", - "ramdiskId" : null, - "region" : "us-east-1", - "version" : "2017-09-30" -}`)) - - return - } - if r.RequestURI == "/2022-09-24/meta-data/tags/instance/license" { - w.Write([]byte(`license-1234556-license`)) - return - } - h := sha256.New() - h.Write([]byte(accountID)) - if r.RequestURI == fmt.Sprintf("/license-1234556-license-%x", h.Sum(nil)) { - w.Write([]byte(`{"users": 50}`)) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - testCases := map[string]int{ - "t3.medium": 50, - "t3.xlarge": 50, - } - licenseURL = ts.URL - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - for _, v := range testCases { - if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &memorystorage.MockMemoryStorage{}); v2 != v { - t.Fatalf("Wrong output: %d vs %d", v2, v) - } - } -} - -func TestGetMaxUsersAWS(t *testing.T) { - testCases := map[string]int{ - "t3.medium": 3, - "t3.large": 3, - "t3.xlarge": 3, - } - for instanceType, v := range testCases { - if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws", instanceType) != v { - t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v) - } - } -} - -func TestGuessInfrastructureAzure(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/versions" { - w.WriteHeader(http.StatusOK) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - infra := guessInfrastructure() - - if infra != "azure" { - t.Fatalf("wrong infra returned: %s", infra) - } -} - -func TestGuessInfrastructureAWSMarketplace(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/versions" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.RequestURI == "/latest/api/token" { - w.Write([]byte("this is a test token")) - return - } - if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" { - w.Write([]byte(`{ - "accountId" : "12345678", - "architecture" : "x86_64", - "availabilityZone" : "us-east-1c", - "billingProducts" : null, - "devpayProductCodes" : null, - "marketplaceProductCodes" : [ "7h7h3bnutjn0ziamv7npi8a69" ], - "imageId" : "ami-12345678", - "instanceId" : "i-123456", - "instanceType" : "t3.micro", - "kernelId" : null, - "pendingTime" : "2024-06-15T08:34:50Z", - "privateIp" : "10.0.1.123", - "ramdiskId" : null, - "region" : "us-east-1", - "version" : "2017-09-30" -}`)) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - infra := guessInfrastructure() - - if infra != "aws-marketplace" { - t.Fatalf("wrong infra returned: %s", infra) - } -} - -func TestGuessInfrastructureAWS(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/versions" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.RequestURI == "/latest/api/token" { - w.Write([]byte("this is a test token")) - return - } - if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" { - w.Write([]byte(`{ - "accountId" : "12345678", - "architecture" : "x86_64", - "availabilityZone" : "us-east-1c", - "billingProducts" : null, - "devpayProductCodes" : null, - "marketplaceProductCodes" : null -}`)) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - infra := guessInfrastructure() - - if infra != "aws" { - t.Fatalf("wrong infra returned: %s", infra) - } - - if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 { - t.Fatalf("wrong users returned") - } -} - -func TestGuessInfrastructureOther(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - infra := guessInfrastructure() - - if infra != "" { - t.Fatalf("wrong infra returned: %s", infra) - } -} - -func TestGetAzureInstanceType(t *testing.T) { - vmSize := "Standard_D2_v5" - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - out, err := json.Marshal(AzureInstanceMetadata{ - Compute: Compute{ - VMSize: vmSize, - }, - }) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - } - w.Write(out) - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - usersPerVCPU := 25 - - users := getMaxUsers(&memorystorage.MockMemoryStorage{}, "azure", getAzureInstanceType(http.Client{Timeout: 5 * time.Second})) - - if users != usersPerVCPU*2 { - t.Fatalf("Wrong user count returned") - } -} - -func TestGetAWSInstanceType(t *testing.T) { - instanceType := "t4.xlarge" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/latest/api/token" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.RequestURI == "/latest/meta-data/instance-type" { - w.Write([]byte(instanceType)) - return - } - w.WriteHeader(http.StatusInternalServerError) - - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - users := GetMaxUsersAWS(getAWSInstanceType(http.Client{Timeout: 5 * time.Second})) - - if users != 250 { - t.Fatalf("Wrong user count returned: %d", users) - } -} - -func TestGuessInfrastructureDigitalOcean(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/v1/interfaces/private/0/type" { - w.Write([]byte(`private`)) - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer ts.Close() - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - infra := guessInfrastructure() - - if infra != "digitalocean" { - t.Fatalf("wrong infra returned: %s", infra) - } - - if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 { - t.Fatalf("wrong users returned") - } -} - -func TestGetMaxUsersDigitalOceanBYOL(t *testing.T) { - dropletID := "1234567890" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/v1/interfaces/private/0/type" { - w.Write([]byte(`private`)) - return - } - if r.RequestURI == "/metadata/v1/id" { - w.Write([]byte(dropletID)) - - return - } - h := sha256.New() - h.Write([]byte(dropletID)) - if r.RequestURI == fmt.Sprintf("/license-1234556-license-%x", h.Sum(nil)) { - w.Write([]byte(`{"users": 50}`)) - return - } - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - licenseURL = ts.URL - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - mockStorage := &memorystorage.MockMemoryStorage{} - err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) - if err != nil { - t.Fatalf("writefile error: %s", err) - } - for _, v := range []int{50} { - if v2 := GetMaxUsersDigitalOceanBYOL(http.Client{Timeout: 5 * time.Second}, mockStorage); v2 != v { - t.Fatalf("Wrong output: %d vs %d", v2, v) - } - } -} - -func TestGetLicenseKey(t *testing.T) { - dropletID := "1234567890" - projectID := "googleproject-12356" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/v1/interfaces/private/0/type" { - w.Write([]byte(`private`)) - return - } - if r.RequestURI == "/metadata/v1/id" { - w.Write([]byte(dropletID)) - return - } - if r.RequestURI == "/computeMetadata/v1/project/project-id" { - w.Write([]byte(projectID)) - return - } - if r.RequestURI == "/metadata/versions" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.RequestURI == "/latest/api/token" { - w.Write([]byte("this is a test token")) - return - } - if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" { - w.Write([]byte(`{ - "accountId" : "12345678", - "architecture" : "x86_64", - "availabilityZone" : "us-east-1c", - "billingProducts" : null, - "devpayProductCodes" : null, - "marketplaceProductCodes" : null -}`)) - return - } - - w.WriteHeader(http.StatusNotFound) - })) - - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - - logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR - key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") - if key == "" { - t.Fatalf("key is empty") - } - key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "aws") - if key == "" { - t.Fatalf("aws key is empty") - } - key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "digitalocean") - if key == "" { - t.Fatalf("digitalocean key is empty") - } - key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "gcp") - if key == "" { - t.Fatalf("gcp key is empty") - } -} -func TestGetLicenseKeyNoCloudProvider(t *testing.T) { - - logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR - key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") - if key == "" { - t.Fatalf("key is empty") - } -} diff --git a/pkg/license/nocloud.go b/pkg/license/nocloud.go deleted file mode 100644 index 0b6d036..0000000 --- a/pkg/license/nocloud.go +++ /dev/null @@ -1,26 +0,0 @@ -package license - -import ( - "fmt" - "net/http" - - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" -) - -func GetMaxUsersBYOLNoCloud(client http.Client, storage storage.ReadWriter) int { - userLicense := 3 - - licenseKey, err := getLicenseKeyFromFile(storage) - if err != nil { - return 3 - } - - license, err := getLicense(client, licenseKey) - if err != nil { - logging.DebugLog(fmt.Errorf("getLicense error: %s", err)) - return userLicense - } - - return license.Users -} diff --git a/pkg/license/types.go b/pkg/license/types.go deleted file mode 100644 index 9b96c10..0000000 --- a/pkg/license/types.go +++ /dev/null @@ -1,150 +0,0 @@ -package license - -import "time" - -type AzureInstanceMetadata struct { - Compute Compute `json:"compute"` - Network Network `json:"network"` -} -type OsProfile struct { - AdminUsername string `json:"adminUsername"` - ComputerName string `json:"computerName"` - DisablePasswordAuthentication string `json:"disablePasswordAuthentication"` -} -type Plan struct { - Name string `json:"name"` - Product string `json:"product"` - Publisher string `json:"publisher"` -} -type PublicKeys struct { - KeyData string `json:"keyData"` - Path string `json:"path"` -} -type SecurityProfile struct { - SecureBootEnabled string `json:"secureBootEnabled"` - VirtualTpmEnabled string `json:"virtualTpmEnabled"` -} -type ImageReference struct { - ID string `json:"id"` - Offer string `json:"offer"` - Publisher string `json:"publisher"` - Sku string `json:"sku"` - Version string `json:"version"` -} -type DiffDiskSettings struct { - Option string `json:"option"` -} -type EncryptionSettings struct { - Enabled string `json:"enabled"` -} -type Image struct { - URI string `json:"uri"` -} -type ManagedDisk struct { - ID string `json:"id"` - StorageAccountType string `json:"storageAccountType"` -} -type Vhd struct { - URI string `json:"uri"` -} -type OsDisk struct { - Caching string `json:"caching"` - CreateOption string `json:"createOption"` - DiffDiskSettings DiffDiskSettings `json:"diffDiskSettings"` - DiskSizeGB string `json:"diskSizeGB"` - EncryptionSettings EncryptionSettings `json:"encryptionSettings"` - Image Image `json:"image"` - ManagedDisk ManagedDisk `json:"managedDisk"` - Name string `json:"name"` - OsType string `json:"osType"` - Vhd Vhd `json:"vhd"` - WriteAcceleratorEnabled string `json:"writeAcceleratorEnabled"` -} -type ResourceDisk struct { - Size string `json:"size"` -} -type StorageProfile struct { - DataDisks []any `json:"dataDisks"` - ImageReference ImageReference `json:"imageReference"` - OsDisk OsDisk `json:"osDisk"` - ResourceDisk ResourceDisk `json:"resourceDisk"` -} -type Compute struct { - AzEnvironment string `json:"azEnvironment"` - CustomData string `json:"customData"` - EvictionPolicy string `json:"evictionPolicy"` - IsHostCompatibilityLayerVM string `json:"isHostCompatibilityLayerVm"` - LicenseType string `json:"licenseType"` - Location string `json:"location"` - Name string `json:"name"` - Offer string `json:"offer"` - OsProfile OsProfile `json:"osProfile"` - OsType string `json:"osType"` - PlacementGroupID string `json:"placementGroupId"` - Plan Plan `json:"plan"` - PlatformFaultDomain string `json:"platformFaultDomain"` - PlatformUpdateDomain string `json:"platformUpdateDomain"` - Priority string `json:"priority"` - Provider string `json:"provider"` - PublicKeys []PublicKeys `json:"publicKeys"` - Publisher string `json:"publisher"` - ResourceGroupName string `json:"resourceGroupName"` - ResourceID string `json:"resourceId"` - SecurityProfile SecurityProfile `json:"securityProfile"` - Sku string `json:"sku"` - StorageProfile StorageProfile `json:"storageProfile"` - SubscriptionID string `json:"subscriptionId"` - Tags string `json:"tags"` - TagsList []any `json:"tagsList"` - UserData string `json:"userData"` - Version string `json:"version"` - VMID string `json:"vmId"` - VMScaleSetName string `json:"vmScaleSetName"` - VMSize string `json:"vmSize"` - Zone string `json:"zone"` -} -type IPAddress struct { - PrivateIPAddress string `json:"privateIpAddress"` - PublicIPAddress string `json:"publicIpAddress"` -} -type Subnet struct { - Address string `json:"address"` - Prefix string `json:"prefix"` -} -type Ipv4 struct { - IPAddress []IPAddress `json:"ipAddress"` - Subnet []Subnet `json:"subnet"` -} -type Ipv6 struct { - IPAddress []any `json:"ipAddress"` -} -type Interface struct { - Ipv4 Ipv4 `json:"ipv4"` - Ipv6 Ipv6 `json:"ipv6"` - MacAddress string `json:"macAddress"` -} -type Network struct { - Interface []Interface `json:"interface"` -} - -type InstanceIdentityDocument struct { - AccountID string `json:"accountId"` - Architecture string `json:"architecture"` - AvailabilityZone string `json:"availabilityZone"` - BillingProducts any `json:"billingProducts"` - DevpayProductCodes any `json:"devpayProductCodes"` - MarketplaceProductCodes []string `json:"marketplaceProductCodes"` - ImageID string `json:"imageId"` - InstanceID string `json:"instanceId"` - InstanceType string `json:"instanceType"` - KernelID any `json:"kernelId"` - PendingTime time.Time `json:"pendingTime"` - PrivateIP string `json:"privateIp"` - RamdiskID any `json:"ramdiskId"` - Region string `json:"region"` - Version string `json:"version"` -} - -type License struct { - Users int `json:"users"` -} diff --git a/pkg/logging/log.go b/pkg/logging/log.go deleted file mode 100644 index 2b0f555..0000000 --- a/pkg/logging/log.go +++ /dev/null @@ -1,27 +0,0 @@ -package logging - -import "fmt" - -var Loglevel = 3 - -const LOG_ERROR = 1 -const LOG_INFO = 2 -const LOG_DEBUG = 16 - -func DebugLog(err error) { - if Loglevel&LOG_DEBUG == LOG_DEBUG { - fmt.Println("debug: " + err.Error()) - } -} - -func ErrorLog(err error) { - if Loglevel&LOG_ERROR == LOG_ERROR { - fmt.Println("error: " + err.Error()) - } -} - -func InfoLog(info string) { - if Loglevel&LOG_INFO == LOG_INFO { - fmt.Println("info: " + info) - } -} diff --git a/pkg/logging/log_test.go b/pkg/logging/log_test.go deleted file mode 100644 index 09a0ca8..0000000 --- a/pkg/logging/log_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package logging - -import "testing" - -func TestLog(t *testing.T) { - Loglevel = LOG_DEBUG - if Loglevel&LOG_DEBUG != LOG_DEBUG { - t.Fatalf("log level is not debugging1") - } - Loglevel = LOG_DEBUG + LOG_ERROR - if Loglevel&LOG_DEBUG != LOG_DEBUG { - t.Fatalf("log level is not debugging: %d vs %d", Loglevel&LOG_DEBUG, Loglevel) - } - Loglevel = LOG_ERROR - if Loglevel&LOG_DEBUG == LOG_DEBUG { - t.Fatalf("log level is debugging: %d vs %d", Loglevel&LOG_DEBUG, Loglevel) - } - -} diff --git a/pkg/mfa/totp/verify.go b/pkg/mfa/totp/verify.go deleted file mode 100644 index 651afb3..0000000 --- a/pkg/mfa/totp/verify.go +++ /dev/null @@ -1,62 +0,0 @@ -package totp - -import ( - "bytes" - "crypto/hmac" - "crypto/sha1" - "encoding/base32" - "encoding/binary" - "fmt" - "strings" - "time" -) - -const INTERVAL = 30 - -func GetToken(secret string, interval int64) (string, error) { - key, err := base32.StdEncoding.DecodeString(strings.ToUpper(secret)) - if err != nil { - return "", fmt.Errorf("base32 decode error: %s", err) - } - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, uint64(interval)) - hmacHash := hmac.New(sha1.New, key) - hmacHash.Write(buf) - h := hmacHash.Sum(nil) - offset := (h[19] & 15) - - var header uint32 - r := bytes.NewReader(h[offset : offset+4]) - err = binary.Read(r, binary.BigEndian, &header) - - if err != nil { - return "", fmt.Errorf("binary read error: %s", err) - } - - return fmt.Sprintf("%06d", int((int(header)&0x7fffffff)%1000000)), nil -} - -func Verify(secret, code string) (bool, error) { - token, err := GetToken(secret, time.Now().Unix()/30) - if err != nil { - return false, fmt.Errorf("GetToken error: %s", err) - } - return token == code, nil -} - -func VerifyMultipleIntervals(secret, code string, count int) (bool, error) { - return verifyMultipleIntervals(secret, code, count, time.Now()) -} - -func verifyMultipleIntervals(secret, code string, count int, now time.Time) (bool, error) { - for i := 0; i < count; i++ { - token, err := GetToken(secret, now.Add(time.Duration(i)*time.Duration(-30)*time.Second).Unix()/30) - if err != nil { - return false, fmt.Errorf("GetToken error: %s", err) - } - if token == code { - return true, nil - } - } - return false, nil -} diff --git a/pkg/mfa/totp/verify_test.go b/pkg/mfa/totp/verify_test.go deleted file mode 100644 index 5c52922..0000000 --- a/pkg/mfa/totp/verify_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package totp - -import ( - "testing" - "time" -) - -func TestVerify(t *testing.T) { // validated with https://2fa.glitch.me/ - interval := int64(57275699) - secret := "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ" - token, err := GetToken(secret, interval) - if err != nil { - t.Fatalf("error: %s", err) - } - if token != "840823" { - t.Fatalf("wrong token. Got: %s", token) - } -} - -func TestVerifyWrongSecret(t *testing.T) { - interval := int64(57275699) - secret := "wrong secret" - _, err := GetToken(secret, interval) - if err == nil { - t.Fatalf("expected error") - } -} - -func TestVerifyMultipleIntervals(t *testing.T) { - secret := "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ" - ok, err := verifyMultipleIntervals(secret, "312137", 20, time.Unix(1718272397, 0)) - if err != nil { - t.Fatalf("error: %s", err) - } - if !ok { - t.Fatalf("no token matched") - } -} - -func TestVerifyMultipleIntervalsWrongToken(t *testing.T) { - secret := "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ" - ok, err := verifyMultipleIntervals(secret, "312137", 20, time.Unix(1718272000, 0)) - if err != nil { - t.Fatalf("error: %s", err) - } - if ok { - t.Fatalf("token matched, but shouldn't have") - } -} diff --git a/pkg/observability/buffer.go b/pkg/observability/buffer.go deleted file mode 100644 index 3cc1fa0..0000000 --- a/pkg/observability/buffer.go +++ /dev/null @@ -1,178 +0,0 @@ -package observability - -import ( - "bytes" - "fmt" - "io" - "path" - "strconv" - "strings" - "time" - - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" -) - -func (o *Observability) WriteBufferToStorage(n int64) error { - o.ActiveBufferWriters.Add(1) - defer o.ActiveBufferWriters.Done() - o.WriteLock.Lock() - defer o.WriteLock.Unlock() - logging.DebugLog(fmt.Errorf("writing buffer to file. Buffer has: %d bytes", n)) - // copy first to temporary buffer (storage might have latency) - tempBuf := bytes.NewBuffer(make([]byte, 0, n)) - _, err := io.CopyN(tempBuf, o.Buffer, n) - if err != nil && err != io.EOF { - return fmt.Errorf("write error from buffer to temporary buffer: %s", err) - } - prefix := o.Buffer.ReadPrefix(n) - o.LastFlushed = time.Now() - - for _, bufferPosAndPrefix := range mergeBufferPosAndPrefix(prefix) { - now := time.Now() - filename := bufferPosAndPrefix.prefix + "/data-" + strconv.FormatInt(now.Unix(), 10) + "-" + strconv.FormatUint(o.FlushOverflowSequence.Add(1), 10) - err = ensurePath(o.Storage, filename) - if err != nil { - return fmt.Errorf("ensure path error: %s", err) - } - file, err := o.Storage.OpenFileForWriting(filename) - if err != nil { - return fmt.Errorf("open file for writing error: %s", err) - } - _, err = io.CopyN(file, tempBuf, int64(bufferPosAndPrefix.offset)) - if err != nil && err != io.EOF { - return fmt.Errorf("file write error: %s", err) - } - logging.DebugLog(fmt.Errorf("wrote file: %s", filename)) - err = file.Close() - if err != nil { - return fmt.Errorf("file close error: %s", err) - } - } - return nil -} - -func (o *Observability) monitorBuffer() { - for { - time.Sleep(FLUSH_TIME_MAX_MINUTES * time.Minute) - if time.Since(o.LastFlushed) >= (FLUSH_TIME_MAX_MINUTES*time.Minute) && o.Buffer.Len() > 0 { - if o.FlushOverflow.CompareAndSwap(false, true) { - err := o.WriteBufferToStorage(int64(o.Buffer.Len())) - o.FlushOverflow.Swap(true) - if err != nil { - logging.ErrorLog(fmt.Errorf("write log buffer to storage error: %s", err)) - continue - } - } - o.LastFlushed = time.Now() - } - } -} - -func (o *Observability) Ingest(data io.ReadCloser) error { - defer data.Close() - msgs, err := Decode(data) - if err != nil { - return fmt.Errorf("decode error: %s", err) - } - logging.DebugLog(fmt.Errorf("messages ingested: %d", len(msgs))) - if len(msgs) == 0 { - return nil // no messages to ingest - } - _, err = o.Buffer.Write(encodeMessage(msgs), FloatToDate(msgs[0].Date).Format(DATE_PREFIX)) - if err != nil { - return fmt.Errorf("write error: %s", err) - } - if o.Buffer.Len() >= o.MaxBufferSize { - if o.FlushOverflow.CompareAndSwap(false, true) { - go func() { // write to storage - if n := o.Buffer.Len(); n >= o.MaxBufferSize { - err := o.WriteBufferToStorage(int64(n)) - if err != nil { - logging.ErrorLog(fmt.Errorf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err)) - } - } - o.FlushOverflow.Swap(false) - }() - } - } - return nil -} - -func (o *Observability) Flush() error { - // wait until all data is flushed - o.ActiveBufferWriters.Wait() - - // flush remaining data that hasn't been flushed - if n := o.Buffer.Len(); n >= 0 { - err := o.WriteBufferToStorage(int64(n)) - if err != nil { - return fmt.Errorf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) - } - } - return nil -} - -func (c *ConcurrentRWBuffer) Write(p []byte, prefix string) (n int, err error) { - c.mu.Lock() - defer c.mu.Unlock() - c.prefix = append(c.prefix, BufferPosAndPrefix{prefix: prefix, offset: len(p)}) - return c.buffer.Write(p) -} -func (c *ConcurrentRWBuffer) Read(p []byte) (n int, err error) { - c.mu.Lock() - defer c.mu.Unlock() - return c.buffer.Read(p) -} -func (c *ConcurrentRWBuffer) ReadPrefix(n int64) []BufferPosAndPrefix { - c.mu.Lock() - defer c.mu.Unlock() - totalOffset := 0 - for k, v := range c.prefix { - if int64(totalOffset+v.offset) == n { - part1 := c.prefix[:k+1] - part2 := make([]BufferPosAndPrefix, len(c.prefix[k+1:])) - copy(part2, c.prefix[k+1:]) - c.prefix = part2 - return part1 - } - totalOffset += v.offset - } - return nil -} -func (c *ConcurrentRWBuffer) Len() int { - return c.buffer.Len() -} -func (c *ConcurrentRWBuffer) Cap() int { - return c.buffer.Cap() -} - -func ensurePath(storage storage.Iface, filename string) error { - base := path.Dir(filename) - baseSplit := strings.Split(base, "/") - fullPath := "" - for _, v := range baseSplit { - fullPath = path.Join(fullPath, v) - err := storage.EnsurePath(fullPath) - if err != nil { - return err - } - } - return nil -} - -func mergeBufferPosAndPrefix(a []BufferPosAndPrefix) []BufferPosAndPrefix { - bufferPosAndPrefix := []BufferPosAndPrefix{} - for i := 0; i < len(a); i++ { - offset := a[i].offset - for y := i; y+1 < len(a) && a[y].prefix == a[y+1].prefix; y++ { - offset += a[y+1].offset - i++ - } - bufferPosAndPrefix = append(bufferPosAndPrefix, BufferPosAndPrefix{ - prefix: a[i].prefix, - offset: offset, - }) - } - return bufferPosAndPrefix -} diff --git a/pkg/observability/buffer_test.go b/pkg/observability/buffer_test.go deleted file mode 100644 index d9c89d3..0000000 --- a/pkg/observability/buffer_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package observability - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "slices" - "strconv" - "testing" - - "github.com/in4it/wireguard-server/pkg/logging" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestIngestion(t *testing.T) { - logging.Loglevel = logging.LOG_DEBUG - totalMessagesToGenerate := 20 - storage := &memorystorage.MockMemoryStorage{} - o := NewWithoutMonitor(storage, 20) - o.Storage = storage - payloads := IncomingData{} - for i := 0; i < totalMessagesToGenerate/10; i++ { - payloads = append(payloads, map[string]any{ - "date": 1720613813.197045, - "log": "this is string: " + strconv.Itoa(i), - }) - } - - for i := 0; i < totalMessagesToGenerate/len(payloads); i++ { - payloadBytes, err := json.Marshal(payloads) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - data := io.NopCloser(bytes.NewReader(payloadBytes)) - err = o.Ingest(data) - if err != nil { - t.Fatalf("ingest error: %s", err) - } - } - - err := o.Flush() - if err != nil { - t.Fatalf("flush error: %s", err) - } - - dirlist, err := storage.ReadDir("") - if err != nil { - t.Fatalf("read dir error: %s", err) - } - - totalMessages := 0 - for _, file := range dirlist { - messages, err := storage.ReadFile(file) - if err != nil { - t.Fatalf("read file error: %s", err) - } - decodedMessages := decodeMessages(messages) - totalMessages += len(decodedMessages) - } - if len(dirlist) == 0 { - t.Fatalf("expected multiple files in directory, got %d", len(dirlist)) - } - - if totalMessages != totalMessagesToGenerate { - t.Fatalf("Tried to generate total message count of: %d; got: %d", totalMessagesToGenerate, totalMessages) - } -} - -func TestIngestionMoreMessages(t *testing.T) { - t.Skip() // we can skip this for general unit testing - totalMessagesToGenerate := 10000000 // 10,000,000 - storage := &memorystorage.MockMemoryStorage{} - o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE) - payload := IncomingData{ - { - "date": 1720613813.197045, - "log": "this is string: ", - }, - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - - for i := 0; i < totalMessagesToGenerate; i++ { - data := io.NopCloser(bytes.NewReader(payloadBytes)) - err := o.Ingest(data) - if err != nil { - t.Fatalf("ingest error: %s", err) - } - } - - err = o.Flush() - if err != nil { - t.Fatalf("flush error: %s", err) - } - - dirlist, err := storage.ReadDir("") - if err != nil { - t.Fatalf("read dir error: %s", err) - } - - totalMessages := 0 - for _, file := range dirlist { - messages, err := storage.ReadFile(file) - if err != nil { - t.Fatalf("read file error: %s", err) - } - decodedMessages := decodeMessages(messages) - totalMessages += len(decodedMessages) - } - if len(dirlist) == 0 { - t.Fatalf("expected multiple files in directory, got %d", len(dirlist)) - } - - if totalMessages != totalMessagesToGenerate { - t.Fatalf("Tried to generate total message count of: %d; got: %d", totalMessagesToGenerate, totalMessages) - } - fmt.Printf("Buffer size (read+unread): %d in %d files\n", o.Buffer.Cap(), len(dirlist)) - -} - -func BenchmarkIngest10000000(b *testing.B) { - totalMessagesToGenerate := 10000000 // 10,000,000 - storage := &memorystorage.MockMemoryStorage{} - o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE) - payload := IncomingData{ - { - "date": 1720613813.197045, - "log": "this is string", - }, - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - b.Fatalf("marshal error: %s", err) - } - - for i := 0; i < totalMessagesToGenerate; i++ { - data := io.NopCloser(bytes.NewReader(payloadBytes)) - err := o.Ingest(data) - if err != nil { - b.Fatalf("ingest error: %s", err) - } - } - - // wait until all data is flushed - o.ActiveBufferWriters.Wait() - - // flush remaining data that hasn't been flushed - if n := o.Buffer.Len(); n >= 0 { - err := o.WriteBufferToStorage(int64(n)) - if err != nil { - b.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) - } - } -} - -func BenchmarkIngest100000000(b *testing.B) { - totalMessagesToGenerate := 10000000 // 10,000,000 - storage := &memorystorage.MockMemoryStorage{} - o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE) - o.Storage = storage - payload := IncomingData{ - { - "date": 1720613813.197045, - "log": "this is string", - }, - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - b.Fatalf("marshal error: %s", err) - } - - for i := 0; i < totalMessagesToGenerate; i++ { - data := io.NopCloser(bytes.NewReader(payloadBytes)) - err := o.Ingest(data) - if err != nil { - b.Fatalf("ingest error: %s", err) - } - } - - // wait until all data is flushed - o.ActiveBufferWriters.Wait() - - // flush remaining data that hasn't been flushed - if n := o.Buffer.Len(); n >= 0 { - err := o.WriteBufferToStorage(int64(n)) - if err != nil { - b.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) - } - } -} - -func TestEnsurePath(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - err := ensurePath(storage, "a/b/c/filename.txt") - if err != nil { - t.Fatalf("error: %s", err) - } -} - -func TestMergeBufferPosAndPrefix(t *testing.T) { - testCase1 := []BufferPosAndPrefix{ - { - prefix: "abc", - offset: 3, - }, - { - prefix: "abc", - offset: 9, - }, - { - prefix: "abc", - offset: 2, - }, - { - prefix: "abc2", - offset: 3, - }, - { - prefix: "abc2", - offset: 2, - }, - { - prefix: "abc3", - offset: 2, - }, - } - expected1 := []BufferPosAndPrefix{ - { - prefix: "abc", - offset: 14, - }, - { - prefix: "abc2", - offset: 5, - }, - { - prefix: "abc3", - offset: 2, - }, - } - res := mergeBufferPosAndPrefix(testCase1) - if !slices.Equal(res, expected1) { - t.Fatalf("test case is not equal to expected\nGot: %+v\nExpected:%+v\n", res, expected1) - } -} - -func TestReadPrefix(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE) - o.Buffer.prefix = []BufferPosAndPrefix{ - { - prefix: "abc", - offset: 3, - }, - { - prefix: "abc", - offset: 9, - }, - { - prefix: "abc", - offset: 2, - }, - { - prefix: "abc2", - offset: 3, - }, - { - prefix: "abc2", - offset: 2, - }, - { - prefix: "abc3", - offset: 2, - }, - } - expected1 := []BufferPosAndPrefix{ - { - prefix: "abc", - offset: 3, - }, - { - prefix: "abc", - offset: 9, - }, - { - prefix: "abc", - offset: 2, - }, - } - expected2 := []BufferPosAndPrefix{ - { - prefix: "abc2", - offset: 3, - }, - } - res := o.Buffer.ReadPrefix(int64(o.Buffer.prefix[0].offset + o.Buffer.prefix[1].offset + o.Buffer.prefix[2].offset)) - if !slices.Equal(res, expected1) { - t.Fatalf("test case is not equal to expected\nGot: %+v\nExpected:%+v\n", res, expected1) - } - res2 := o.Buffer.ReadPrefix(3) - if !slices.Equal(res2, expected2) { - t.Fatalf("test case is not equal to expected\nGot: %+v\nExpected:%+v\n", res, expected2) - } -} diff --git a/pkg/observability/constants.go b/pkg/observability/constants.go deleted file mode 100644 index cb5c72f..0000000 --- a/pkg/observability/constants.go +++ /dev/null @@ -1,8 +0,0 @@ -package observability - -const MAX_BUFFER_SIZE = 1024 * 1024 // 1 MB -const FLUSH_TIME_MAX_MINUTES = 1 // should have 5 as default at release - -const TIMESTAMP_FORMAT = "2006-01-02T15:04:05" - -const DATE_PREFIX = "2006/01/02" diff --git a/pkg/observability/decoding.go b/pkg/observability/decoding.go deleted file mode 100644 index cbd9062..0000000 --- a/pkg/observability/decoding.go +++ /dev/null @@ -1,119 +0,0 @@ -package observability - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "io" - "math" - "reflect" - "strconv" -) - -func Decode(r io.Reader) ([]FluentBitMessage, error) { - var result []FluentBitMessage - var msg interface{} - - err := json.NewDecoder(r).Decode(&msg) - if err != nil { - return result, err - } - switch m1 := msg.(type) { - case []interface{}: - if len(m1) == 0 { - return result, fmt.Errorf("empty array") - } - for _, m1Element := range m1 { - switch m2 := m1Element.(type) { - case map[string]interface{}: - var fluentBitMessage FluentBitMessage - fluentBitMessage.Data = make(map[string]string) - val, ok := m2["date"] - if ok { - fluentBitMessage.Date = val.(float64) - } - for key, value := range m2 { - if key != "date" { - switch valueTyped := value.(type) { - case string: - fluentBitMessage.Data[key] = valueTyped - case float64: - fluentBitMessage.Data[key] = strconv.FormatFloat(valueTyped, 'f', -1, 64) - case []byte: - fluentBitMessage.Data[key] = string(valueTyped) - default: - fmt.Printf("no hit on type: %s", reflect.TypeOf(valueTyped)) - } - } - } - result = append(result, fluentBitMessage) - default: - return result, fmt.Errorf("invalid type: no map found in array") - } - } - default: - return result, fmt.Errorf("invalid type: no array found") - } - return result, nil -} - -func decodeMessages(msgs []byte) []FluentBitMessage { - res := []FluentBitMessage{} - recordOffset := 0 - for k := 0; k < len(msgs); k++ { - if k > recordOffset+8 && msgs[k] == 0xff && msgs[k-1] == 0xff { - res = append(res, decodeMessage(msgs[recordOffset:k])) - recordOffset = k + 1 - } - } - return res -} -func decodeMessage(data []byte) FluentBitMessage { - bits := binary.LittleEndian.Uint64(data[0:8]) - msg := FluentBitMessage{ - Date: math.Float64frombits(bits), - Data: map[string]string{}, - } - isKey := false - key := "" - start := 8 - for kk := start; kk < len(data); kk++ { - if data[kk] == 0xff { - if isKey { - isKey = false - msg.Data[key] = string(data[start+1 : kk]) - start = kk + 1 - } else { - isKey = true - key = string(data[start:kk]) - start = kk - } - } - } - // if last record didn't end with the terminator - if data[len(data)-1] != 0xff { - msg.Data[key] = string(data[start+1:]) - } - return msg -} - -func scanMessage(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - for i := 0; i < len(data); i++ { - if data[i] == 0xff && data[i-1] == 0xff { - return i + 1, data[0 : i-1], nil - } - } - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - if len(data) > 1 && data[len(data)-1] == 0xff && data[len(data)-2] == 0xff { - return len(data[0 : len(data)-2]), data, nil - } else { - return len(data), data, nil - } - } - // Request more data. - return 0, nil, nil -} diff --git a/pkg/observability/decoding_test.go b/pkg/observability/decoding_test.go deleted file mode 100644 index 82fa736..0000000 --- a/pkg/observability/decoding_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package observability - -import ( - "bytes" - "encoding/json" - "testing" -) - -func TestDecoding(t *testing.T) { - payload := IncomingData{ - { - "date": 1720613813.197045, - "rand_value": "rand", - }, - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("json marshal error: %s", err) - } - messages, err := Decode(bytes.NewBuffer(payloadBytes)) - if err != nil { - t.Fatalf("error: %s", err) - } - if len(messages) == 0 { - t.Fatalf("no messages returned") - } - if messages[0].Date != 1720613813.197045 { - t.Fatalf("wrong date returned") - } - val, ok := messages[0].Data["rand_value"] - if !ok { - t.Fatalf("rand_value key not found") - } - if string(val) != "rand" { - t.Fatalf("wrong data returned: %s", val) - } -} - -func TestDecodingMultiMessage(t *testing.T) { - payload := IncomingData{ - { - "date": 1727119152.0, - "container_name": "/fluentbit-nginx-1", - "source": "stdout", - "log": "/docker-entrypoint.sh: /docker-entrypoint.d/ is not empty, will attempt to perform configuration", - "container_id": "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186", - }, - { - "date": 1727119152.0, - "source": "stdout", - "log": "/docker-entrypoint.sh: Looking for shell scripts in /docker-entrypoint.d/", - "container_id": "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186", - "container_name": "/fluentbit-nginx-1", - }, - { - "date": 1727119152.0, - "container_id": "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186", - "container_name": "/fluentbit-nginx-1", - "source": "stdout", - "log": "/docker-entrypoint.sh: Launching /docker-entrypoint.d/10-listen-on-ipv6-by-default.sh", - }, - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("json marshal error: %s", err) - } - messages, err := Decode(bytes.NewBuffer(payloadBytes)) - if err != nil { - t.Fatalf("error: %s", err) - } - if len(messages) != len(payload) { - t.Fatalf("incorrect messages returned. Got %d, expected %d", len(messages), len(payload)) - } - val, ok := messages[2].Data["container_id"] - if !ok { - t.Fatalf("container_id key not found") - } - if string(val) != "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186" { - t.Fatalf("wrong data returned: %s", val) - } -} - -func TestDecodeMessages(t *testing.T) { - msgs := []FluentBitMessage{ - { - Date: 1720613813.197045, - Data: map[string]string{ - "mykey": "this is myvalue", - "second key": "this is my second value", - "third key": "this is my third value", - }, - }, - { - Date: 1720613813.197099, - Data: map[string]string{ - "second data set": "my value", - }, - }, - } - encoded := encodeMessage(msgs) - decoded := decodeMessages(encoded) - - if len(msgs) != len(decoded) { - t.Fatalf("length doesn't match") - } - for k := range decoded { - if msgs[k].Date != decoded[k].Date { - t.Fatalf("date doesn't match") - } - if len(msgs[k].Data) != len(decoded[k].Data) { - t.Fatalf("length of data doesn't match") - } - for kk := range decoded[k].Data { - if msgs[k].Data[kk] != decoded[k].Data[kk] { - t.Fatalf("key/value pair in data doesn't match: key: %s. Data: %s vs %s", kk, msgs[k].Data[kk], decoded[k].Data[kk]) - } - } - } -} - -func TestDecodeMessage(t *testing.T) { - msgs := []FluentBitMessage{ - { - Date: 1720613813.197099, - Data: map[string]string{ - "second data set": "my value", - }, - }, - } - encoded := encodeMessage(msgs) - message := decodeMessage(encoded) - - if message.Date != message.Date { - t.Fatalf("date doesn't match") - } - if len(msgs[0].Data) != len(message.Data) { - t.Fatalf("length of data doesn't match") - } - for kk := range message.Data { - if msgs[0].Data[kk] != message.Data[kk] { - t.Fatalf("key/value pair in data doesn't match: key: %s. Data: %s vs %s", kk, message.Data[kk], message.Data[kk]) - } - } -} -func TestDecodeMessageWithoutTerminator(t *testing.T) { - msgs := []FluentBitMessage{ - { - Date: 1720613813.197099, - Data: map[string]string{ - "second data set": "my value", - }, - }, - } - encoded := encodeMessage(msgs) - message := decodeMessage(bytes.TrimSuffix(encoded, []byte{0xff, 0xff})) - - if message.Date != message.Date { - t.Fatalf("date doesn't match") - } - if len(msgs[0].Data) != len(message.Data) { - t.Fatalf("length of data doesn't match: got: '%s', expected '%s'", message.Data, msgs[0].Data) - } - for kk := range message.Data { - if msgs[0].Data[kk] != message.Data[kk] { - t.Fatalf("key/value pair in data doesn't match: key: %s. Data: %s vs %s", kk, message.Data[kk], msgs[0].Data[kk]) - } - } -} - -func TestScanMessage(t *testing.T) { - msgs := []FluentBitMessage{ - { - Date: 1720613813.197045, - Data: map[string]string{ - "mykey": "this is myvalue", - "second key": "this is my second value", - "third key": "this is my third value", - }, - }, - { - Date: 1720613813.197099, - Data: map[string]string{ - "second data set": "my value", - }, - }, - } - encoded := encodeMessage(msgs) - // first record - advance, record1, err := scanMessage(encoded, false) - if err != nil { - t.Fatalf("scan lines error: %s", err) - } - firstRecord := decodeMessages(append(record1, []byte{0xff, 0xff}...)) - if len(firstRecord) == 0 { - t.Fatalf("first record is empty") - } - if firstRecord[0].Data["mykey"] != "this is myvalue" { - t.Fatalf("wrong data returned") - } - // second record - advance2, record2, err := scanMessage(encoded[advance:], false) - if err != nil { - t.Fatalf("scan lines error: %s", err) - } - secondRecord := decodeMessages(append(record2, []byte{0xff, 0xff}...)) - if len(secondRecord) == 0 { - t.Fatalf("first record is empty") - } - if secondRecord[0].Data["second data set"] != "my value" { - t.Fatalf("wrong data returned in second record") - } - // third call - advance3, record3, err := scanMessage(encoded[advance+advance2:], false) - if err != nil { - t.Fatalf("scan lines error: %s", err) - } - if advance3 != 0 { - t.Fatalf("third record should be empty. Got: %+v", record3) - } -} diff --git a/pkg/observability/encoding.go b/pkg/observability/encoding.go deleted file mode 100644 index c805b8f..0000000 --- a/pkg/observability/encoding.go +++ /dev/null @@ -1,24 +0,0 @@ -package observability - -import ( - "bytes" - "encoding/binary" - "math" -) - -func encodeMessage(msgs []FluentBitMessage) []byte { - out := bytes.Buffer{} - for _, msg := range msgs { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], math.Float64bits(msg.Date)) - out.Write(buf[:]) - for key, msgData := range msg.Data { - out.Write([]byte(key)) - out.Write([]byte{0xff}) - out.Write([]byte(msgData)) - out.Write([]byte{0xff}) - } - out.Write([]byte{0xff}) - } - return out.Bytes() -} diff --git a/pkg/observability/handlers.go b/pkg/observability/handlers.go deleted file mode 100644 index 884272f..0000000 --- a/pkg/observability/handlers.go +++ /dev/null @@ -1,96 +0,0 @@ -package observability - -import ( - "encoding/json" - "fmt" - "net/http" - "strconv" - "strings" - "time" -) - -func (o *Observability) observabilityHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) -} - -func (o *Observability) ingestionHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - w.WriteHeader(http.StatusBadRequest) - return - } - - if err := o.Ingest(r.Body); err != nil { - w.WriteHeader(http.StatusBadRequest) - fmt.Printf("error: %s", err) - return - } - w.WriteHeader(http.StatusOK) -} - -func (o *Observability) logsHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - w.WriteHeader(http.StatusBadRequest) - return - } - if r.FormValue("fromDate") == "" { - o.returnError(w, fmt.Errorf("no from date supplied"), http.StatusBadRequest) - return - } - fromDate, err := time.Parse("2006-01-02", r.FormValue("fromDate")) - if err != nil { - o.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest) - return - } - if r.FormValue("endDate") == "" { - o.returnError(w, fmt.Errorf("no end date supplied"), http.StatusBadRequest) - return - } - endDate, err := time.Parse("2006-01-02", r.FormValue("endDate")) - if err != nil { - o.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest) - return - } - offset := 0 - if r.FormValue("offset") != "" { - i, err := strconv.Atoi(r.FormValue("offset")) - if err == nil { - offset = i - } - } - maxLines := 0 - if r.FormValue("maxLines") != "" { - i, err := strconv.Atoi(r.FormValue("maxLines")) - if err == nil { - maxLines = i - } - } - pos := int64(0) - if r.FormValue("pos") != "" { - i, err := strconv.ParseInt(r.FormValue("pos"), 10, 64) - if err == nil { - pos = i - } - } - displayTags := strings.Split(r.FormValue("display-tags"), ",") - filterTagsSplit := strings.Split(r.FormValue("filter-tags"), ",") - filterTags := []KeyValue{} - for _, tag := range filterTagsSplit { - kv := strings.Split(tag, "=") - if len(kv) == 2 { - filterTags = append(filterTags, KeyValue{Key: kv[0], Value: kv[1]}) - } - } - out, err := o.getLogs(fromDate, endDate, pos, maxLines, offset, r.FormValue("search"), displayTags, filterTags) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - fmt.Printf("get logs error: %s", err) - return - } - outBytes, err := json.Marshal(out) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - fmt.Printf("marshal error: %s", err) - return - } - w.Write(outBytes) -} diff --git a/pkg/observability/handlers_test.go b/pkg/observability/handlers_test.go deleted file mode 100644 index dc212a4..0000000 --- a/pkg/observability/handlers_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package observability - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestIngestionHandler(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - o := NewWithoutMonitor(storage, 20) - o.Storage = storage - payload := IncomingData{ - { - "date": 1720613813.197045, - "log": "this is a string", - }, - } - - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - req := httptest.NewRequest(http.MethodPost, "/api/observability/ingestion/json", bytes.NewReader(payloadBytes)) - w := httptest.NewRecorder() - o.ingestionHandler(w, req) - res := w.Result() - - if res.StatusCode != http.StatusOK { - t.Fatalf("expected status code OK. Got: %d", res.StatusCode) - } - - // wait until all data is flushed - o.ActiveBufferWriters.Wait() - - // flush remaining data that hasn't been flushed - if n := o.Buffer.Len(); n >= 0 { - err := o.WriteBufferToStorage(int64(n)) - if err != nil { - t.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) - } - } - - dirlist, err := storage.ReadDir("") - if err != nil { - t.Fatalf("read dir error: %s", err) - } - if len(dirlist) == 0 { - t.Fatalf("dir is empty") - } - messages, err := storage.ReadFile(dirlist[0]) - if err != nil { - t.Fatalf("read file error: %s", err) - } - decodedMessages := decodeMessages(messages) - if decodedMessages[0].Date != 1720613813.197045 { - t.Fatalf("unexpected date. Got %f, expected: %f", decodedMessages[0].Date, 1720613813.197045) - } - if decodedMessages[0].Data["log"] != "this is a string" { - t.Fatalf("unexpected log data") - } -} diff --git a/pkg/observability/helpers.go b/pkg/observability/helpers.go deleted file mode 100644 index 2c98ee3..0000000 --- a/pkg/observability/helpers.go +++ /dev/null @@ -1,31 +0,0 @@ -package observability - -import ( - "fmt" - "math" - "net/http" - "strings" - "time" -) - -func (o *Observability) returnError(w http.ResponseWriter, err error, statusCode int) { - fmt.Println("========= ERROR =========") - fmt.Printf("Error: %s\n", err) - fmt.Println("=========================") - w.WriteHeader(statusCode) - w.Write([]byte(`{"error": "` + strings.Replace(err.Error(), `"`, `\"`, -1) + `"}`)) -} - -func FloatToDate(datetime float64) time.Time { - datetimeInt := int64(datetime) - decimals := datetime - float64(datetimeInt) - nsecs := int64(math.Round(decimals * 1_000_000)) // precision to match golang's time.Time - return time.Unix(datetimeInt, nsecs*1000) -} - -func DateToFloat(datetime time.Time) float64 { - seconds := float64(datetime.Unix()) - nanoseconds := float64(datetime.Nanosecond()) / 1e9 - fmt.Printf("nanosec: %f", nanoseconds) - return seconds + nanoseconds -} diff --git a/pkg/observability/helpers_test.go b/pkg/observability/helpers_test.go deleted file mode 100644 index 195156d..0000000 --- a/pkg/observability/helpers_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package observability - -import ( - "testing" - "time" -) - -func TestFloatToDate2Way(t *testing.T) { - now := time.Now() - float := DateToFloat(now) - date := FloatToDate(float) - if date.Format(TIMESTAMP_FORMAT) != now.Format(TIMESTAMP_FORMAT) { - t.Fatalf("got: %s, expected: %s", date.Format(TIMESTAMP_FORMAT), now.Format(TIMESTAMP_FORMAT)) - } -} diff --git a/pkg/observability/logs.go b/pkg/observability/logs.go deleted file mode 100644 index 772770b..0000000 --- a/pkg/observability/logs.go +++ /dev/null @@ -1,116 +0,0 @@ -package observability - -import ( - "bufio" - "fmt" - "sort" - "strings" - "time" -) - -func (o *Observability) getLogs(fromDate, endDate time.Time, pos int64, maxLogLines, offset int, search string, displayTags []string, filterTags []KeyValue) (LogEntryResponse, error) { - logEntryResponse := LogEntryResponse{ - Enabled: true, - LogEntries: []LogEntry{}, - Tags: KeyValueInt{}, - } - - keys := make(map[KeyValue]int) - - logFiles := []string{} - - if maxLogLines == 0 { - maxLogLines = 100 - } - - for d := fromDate; d.Before(endDate) || d.Equal(endDate); d = d.AddDate(0, 0, 1) { - fileList, err := o.Storage.ReadDir(d.Format(DATE_PREFIX)) - if err != nil { - logEntryResponse.NextPos = -1 - return logEntryResponse, nil // can't read directory, return empty response - } - for _, filename := range fileList { - logFiles = append(logFiles, d.Format(DATE_PREFIX)+"/"+filename) - } - } - - fileReaders, err := o.Storage.OpenFilesFromPos(logFiles, pos) - if err != nil { - return logEntryResponse, fmt.Errorf("error while reading files: %s", err) - } - for _, fileReader := range fileReaders { - defer fileReader.Close() - } - - for _, logInputData := range fileReaders { // read multiple files - if len(logEntryResponse.LogEntries) >= maxLogLines { - break - } - scanner := bufio.NewScanner(logInputData) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - advance, token, err = scanMessage(data, atEOF) - pos += int64(advance) - return - }) - for scanner.Scan() && len(logEntryResponse.LogEntries) < maxLogLines { // read multiple lines - // decode, store as logentry - logMessage := decodeMessage(scanner.Bytes()) - logline, ok := logMessage.Data["log"] - if ok { - timestamp := FloatToDate(logMessage.Date).Add(time.Duration(offset) * time.Minute) - if search == "" || strings.Contains(logline, search) { - tags := []KeyValue{} - for _, tag := range displayTags { - if tagValue, ok := logMessage.Data[tag]; ok { - tags = append(tags, KeyValue{Key: tag, Value: tagValue}) - } - } - filterMessage := true - if len(filterTags) == 0 { - filterMessage = false - } else { - for _, filter := range filterTags { - if tagValue, ok := logMessage.Data[filter.Key]; ok { - if tagValue == filter.Value { - filterMessage = false - } - } - } - } - if !filterMessage { - logEntry := LogEntry{ - Timestamp: timestamp.Format(TIMESTAMP_FORMAT), - Data: logline, - Tags: tags, - } - logEntryResponse.LogEntries = append(logEntryResponse.LogEntries, logEntry) - for k, v := range logMessage.Data { - if k != "log" { - keys[KeyValue{Key: k, Value: v}] += 1 - } - } - } - } - } - } - if err := scanner.Err(); err != nil { - return logEntryResponse, fmt.Errorf("log file read (scanner) error: %s", err) - } - } - if len(logEntryResponse.LogEntries) < maxLogLines { - logEntryResponse.NextPos = -1 // no more records - } else { - logEntryResponse.NextPos = pos - } - - for k, v := range keys { - logEntryResponse.Tags = append(logEntryResponse.Tags, KeyValueTotal{ - Key: k.Key, - Value: k.Value, - Total: v, - }) - } - sort.Sort(logEntryResponse.Tags) - - return logEntryResponse, nil -} diff --git a/pkg/observability/logs_test.go b/pkg/observability/logs_test.go deleted file mode 100644 index 1f4e34f..0000000 --- a/pkg/observability/logs_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package observability - -import ( - "bytes" - "encoding/json" - "io" - "strconv" - "strings" - "testing" - "time" - - "github.com/in4it/wireguard-server/pkg/logging" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestGetLogs(t *testing.T) { - logging.Loglevel = logging.LOG_DEBUG - totalMessagesToGenerate := 100 - storage := &memorystorage.MockMemoryStorage{} - o := NewWithoutMonitor(storage, 20) - timestamp := DateToFloat(time.Now()) - payload := IncomingData{ - { - "date": timestamp, - "log": "this is string: ", - }, - } - - for i := 0; i < totalMessagesToGenerate; i++ { - payload[0]["log"] = "this is string: " + strconv.Itoa(i) - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - data := io.NopCloser(bytes.NewReader(payloadBytes)) - err = o.Ingest(data) - if err != nil { - t.Fatalf("ingest error: %s", err) - } - } - - // wait until all data is flushed - o.ActiveBufferWriters.Wait() - - // flush remaining data that hasn't been flushed - if n := o.Buffer.Len(); n >= 0 { - err := o.WriteBufferToStorage(int64(n)) - if err != nil { - t.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) - } - } - - now := time.Now() - maxLogLines := 100 - search := "" - - logEntryResponse, err := o.getLogs(now, now, 0, maxLogLines, 0, search, []string{}, []KeyValue{}) - if err != nil { - t.Fatalf("get logs error: %s", err) - } - if len(logEntryResponse.LogEntries) != totalMessagesToGenerate { - t.Fatalf("didn't get the same log entries as messaged we generated: got: %d, expected: %d", len(logEntryResponse.LogEntries), totalMessagesToGenerate) - } - if logEntryResponse.LogEntries[0].Timestamp != FloatToDate(timestamp).Format(TIMESTAMP_FORMAT) { - t.Fatalf("unexpected timestamp: %s vs %s", logEntryResponse.LogEntries[0].Timestamp, FloatToDate(timestamp).Format(TIMESTAMP_FORMAT)) - } -} - -func TestFloatToDate(t *testing.T) { - for i := 0; i < 10; i++ { - now := time.Now() - floatDate := float64(now.Unix()) + float64(now.Nanosecond())/1e9 - floatToDate := FloatToDate(floatDate) - if now.Unix() != floatToDate.Unix() { - t.Fatalf("times are not equal. Got: %v, expected: %v", floatToDate, now) - } - /*if now.UnixNano() != floatToDate.UnixNano() { - t.Fatalf("times are not equal. Got: %v, expected: %v", floatToDate, now) - }*/ - } -} - -func TestKeyValue(t *testing.T) { - logEntryResponse := LogEntryResponse{ - Tags: KeyValueInt{ - {Key: "k", Value: "v", Total: 4}, - }, - } - out, err := json.Marshal(logEntryResponse) - if err != nil { - t.Fatalf("error: %s", err) - } - if !strings.Contains(string(out), `"tags":[{"key":"k","value":"v","total":4}]`) { - t.Fatalf("wrong output: %s", out) - } -} diff --git a/pkg/observability/new.go b/pkg/observability/new.go deleted file mode 100644 index 7bf9f0b..0000000 --- a/pkg/observability/new.go +++ /dev/null @@ -1,25 +0,0 @@ -package observability - -import ( - "net/http" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -func New(defaultStorage storage.Iface) *Observability { - o := NewWithoutMonitor(defaultStorage, MAX_BUFFER_SIZE) - go o.monitorBuffer() - return o -} -func NewWithoutMonitor(storage storage.Iface, maxBufferSize int) *Observability { - o := &Observability{ - Buffer: &ConcurrentRWBuffer{}, - MaxBufferSize: maxBufferSize, - Storage: storage, - } - return o -} - -type Iface interface { - GetRouter() *http.ServeMux -} diff --git a/pkg/observability/router.go b/pkg/observability/router.go deleted file mode 100644 index 7085fca..0000000 --- a/pkg/observability/router.go +++ /dev/null @@ -1,12 +0,0 @@ -package observability - -import "net/http" - -func (o *Observability) GetRouter() *http.ServeMux { - mux := http.NewServeMux() - mux.Handle("/api/observability/", http.HandlerFunc(o.observabilityHandler)) - mux.Handle("/api/observability/ingestion/json", http.HandlerFunc(o.ingestionHandler)) - mux.Handle("/api/observability/logs", http.HandlerFunc(o.logsHandler)) - - return mux -} diff --git a/pkg/observability/types.go b/pkg/observability/types.go deleted file mode 100644 index b718bac..0000000 --- a/pkg/observability/types.go +++ /dev/null @@ -1,89 +0,0 @@ -package observability - -import ( - "bytes" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -type IncomingData []map[string]any - -type FluentBitMessage struct { - Date float64 `json:"date"` - Data map[string]string `json:"data"` -} - -type Observability struct { - Storage storage.Iface - Buffer *ConcurrentRWBuffer - LastFlushed time.Time - FlushOverflow atomic.Bool - FlushOverflowSequence atomic.Uint64 - ActiveBufferWriters sync.WaitGroup - WriteLock sync.Mutex - MaxBufferSize int -} - -type ConcurrentRWBuffer struct { - buffer bytes.Buffer - prefix []BufferPosAndPrefix - mu sync.Mutex -} - -type BufferPosAndPrefix struct { - prefix string - offset int -} - -type LogEntryResponse struct { - Enabled bool `json:"enabled"` - LogEntries []LogEntry `json:"logEntries"` - Tags KeyValueInt `json:"tags"` - NextPos int64 `json:"nextPos"` -} - -type LogEntry struct { - Timestamp string `json:"timestamp"` - Data string `json:"data"` - Tags []KeyValue `json:"tags"` -} - -type KeyValueInt []KeyValueTotal - -type KeyValueTotal struct { - Key string - Value string - Total int -} -type KeyValue struct { - Key string `json:"key"` - Value string `json:"value"` -} - -func (kv KeyValueInt) MarshalJSON() ([]byte, error) { - res := "[" - for _, v := range kv { - res += `{ "key" : "` + v.Key + `", "value": "` + v.Value + `", "total": ` + strconv.Itoa(v.Total) + ` },` - } - res = strings.TrimRight(res, ",") - res += "]" - return []byte(res), nil -} - -func (kv KeyValueInt) Len() int { - return len(kv) -} -func (kv KeyValueInt) Less(i, j int) bool { - if kv[i].Key == kv[j].Key { - return kv[i].Value < kv[j].Value - } - return kv[i].Key < kv[j].Key -} -func (kv KeyValueInt) Swap(i, j int) { - kv[i], kv[j] = kv[j], kv[i] -} diff --git a/pkg/rest/auditlog/logentry.go b/pkg/rest/auditlog/logentry.go deleted file mode 100644 index 6a09cb0..0000000 --- a/pkg/rest/auditlog/logentry.go +++ /dev/null @@ -1,39 +0,0 @@ -package auditlog - -import ( - "encoding/json" - "fmt" - "path" - "time" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -const TIMESTAMP_FORMAT = "2006-01-02T15:04:05" -const AUDITLOG_STATS_DIR = "stats" - -type LogEntry struct { - Timestamp LogTimestamp `json:"timestamp"` - UserID string `json:"userID"` - Action string `json:"action"` -} -type LogTimestamp time.Time - -func (t LogTimestamp) MarshalJSON() ([]byte, error) { - timestamp := fmt.Sprintf("\"%s\"", time.Time(t).Format(TIMESTAMP_FORMAT)) - return []byte(timestamp), nil -} - -func Write(storage storage.Iface, logEntry LogEntry) error { - statsPath := path.Join(AUDITLOG_STATS_DIR, "logins-"+time.Now().Format("2006-01-02")) + ".log" - logEntryBytes, err := json.Marshal(logEntry) - if err != nil { - return fmt.Errorf("could not parse log entry: %s", err) - } - err = storage.AppendFile(statsPath, logEntryBytes) - if err != nil { - return fmt.Errorf("could not append stats to file (%s): %s", statsPath, err) - } - - return nil -} diff --git a/pkg/rest/auth.go b/pkg/rest/auth.go deleted file mode 100644 index 5196ab5..0000000 --- a/pkg/rest/auth.go +++ /dev/null @@ -1,419 +0,0 @@ -package rest - -import ( - "encoding/json" - "fmt" - "net/http" - "strings" - "time" - - "github.com/google/uuid" - "github.com/in4it/wireguard-server/pkg/auth/oidc" - oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store" - "github.com/in4it/wireguard-server/pkg/auth/saml" - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/rest/login" -) - -func (c *Context) authHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - c.returnError(w, fmt.Errorf("not a post request"), http.StatusBadRequest) - return - } - - if c.LocalAuthDisabled { - c.returnError(w, fmt.Errorf("local auth is disabled in settings"), http.StatusForbidden) - return - } - - decoder := json.NewDecoder(r.Body) - var loginReq login.LoginRequest - err := decoder.Decode(&loginReq) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - - // check login attempts - tooManyLogins := login.CheckTooManyLogins(c.LoginAttempts, loginReq.Login) - if tooManyLogins { - c.returnError(w, fmt.Errorf("too many login failures, try again later"), http.StatusTooManyRequests) - return - } - - loginResponse, user, err := login.Authenticate(loginReq, c.UserStore, c.JWTKeys.PrivateKey, c.JWTKeysKID) - if err != nil { - c.returnError(w, fmt.Errorf("authentication error: %s", err), http.StatusBadRequest) - return - } - out, err := json.Marshal(loginResponse) - if err != nil { - c.returnError(w, fmt.Errorf("unable to marshal response: %s", err), http.StatusBadRequest) - return - } - if loginResponse.MFARequired { - c.write(w, out) // status ok, but unauthorized, because we need a second call with MFA code - return - } else if loginResponse.Authenticated { - login.ClearAttemptsForLogin(c.LoginAttempts, loginReq.Login) - user.LastLogin = time.Now() - err = c.UserStore.UpdateUser(user) - if err != nil { - logging.ErrorLog(fmt.Errorf("last login update error: %s", err)) - } - c.write(w, out) - } else { - // log login attempts - login.RecordAttempt(c.LoginAttempts, loginReq.Login) - // return Unauthorized - c.writeWithStatus(w, out, http.StatusUnauthorized) - } -} - -func (c *Context) oidcProviderHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - oidcProviders := make([]oidc.OIDCProvider, len(c.OIDCProviders)) - copy(oidcProviders, c.OIDCProviders) - for k := range oidcProviders { - oidcProviders[k].LoginURL = fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, strings.Replace(oidcProviders[k].RedirectURI, "/callback/", "/login/", -1)) - oidcProviders[k].RedirectURI = fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProviders[k].RedirectURI) - } - out, err := json.Marshal(oidcProviders) - if err != nil { - c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - var oidcProvider oidc.OIDCProvider - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&oidcProvider) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - oidcProvider.ID = uuid.New().String() - if oidcProvider.Name == "" { - c.returnError(w, fmt.Errorf("name not set"), http.StatusBadRequest) - return - } - if oidcProvider.ClientID == "" { - c.returnError(w, fmt.Errorf("clientID not set"), http.StatusBadRequest) - return - } - if oidcProvider.ClientSecret == "" { - c.returnError(w, fmt.Errorf("clientSecret not set"), http.StatusBadRequest) - return - } - if oidcProvider.Scope == "" { - c.returnError(w, fmt.Errorf("scope not set"), http.StatusBadRequest) - return - } - if oidcProvider.DiscoveryURI == "" { - c.returnError(w, fmt.Errorf("discovery URL not set"), http.StatusBadRequest) - return - } - oidcProvider.RedirectURI = "/callback/oidc/" + oidcProvider.ID - c.OIDCProviders = append(c.OIDCProviders, oidcProvider) - out, err := json.Marshal(oidcProvider) - if err != nil { - c.returnError(w, fmt.Errorf("oidcProvider marshal error: %s", err), http.StatusBadRequest) - return - } - err = SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - return - } -} - -func (c *Context) oidcProviderElementHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodDelete: - match := -1 - for k, oidcProvider := range c.OIDCProviders { - if oidcProvider.ID == r.PathValue("id") { - match = k - } - } - if match == -1 { - c.returnError(w, fmt.Errorf("oidc provider not found"), http.StatusBadRequest) - return - } - c.OIDCProviders = append(c.OIDCProviders[:match], c.OIDCProviders[match+1:]...) - // save config (changed providers) - err := SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) - return - } - c.write(w, []byte(`{ "deleted": "`+r.PathValue("id")+`" }`)) - } -} - -func (c *Context) authMethods(w http.ResponseWriter, r *http.Request) { - response := AuthMethodsResponse{ - LocalAuthDisabled: c.LocalAuthDisabled, - OIDCProviders: make([]AuthMethodsProvider, len(c.OIDCProviders)), - } - for k, oidcProvider := range c.OIDCProviders { - response.OIDCProviders[k] = AuthMethodsProvider{ - ID: oidcProvider.ID, - Name: oidcProvider.Name, - } - } - - out, err := json.Marshal(response) - if err != nil { - c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) -} -func (c *Context) authMethodsByID(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodPost: - switch r.PathValue("method") { - case "saml": - loginResponse := login.LoginResponse{} - var samlCallback SAMLCallback - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&samlCallback) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - if samlCallback.Code == "" { - c.returnError(w, fmt.Errorf("no code provided"), http.StatusBadRequest) - return - } - var samlProvider saml.Provider - for k := range *c.SAML.Providers { - if r.PathValue("id") == (*c.SAML.Providers)[k].ID { - samlProvider = (*c.SAML.Providers)[k] - } - } - if samlProvider.ID == "" { - c.returnError(w, fmt.Errorf("saml provider not found"), http.StatusBadRequest) - return - } - - samlSession, err := c.SAML.Client.GetAuthenticatedUser(samlProvider, samlCallback.Code) - if err != nil { - c.returnError(w, fmt.Errorf("saml session not found"), http.StatusBadRequest) - return - } - - // add user to the user database (or modify existing one) - user, err := addOrModifyExternalUser(c.Storage.Client, c.UserStore, samlSession.Login, "saml", samlSession.ID) - if err != nil { - c.returnError(w, fmt.Errorf("couldn't add/modify user in database: %s", err), http.StatusBadRequest) - return - } - - if user.Suspended { - loginResponse.Suspended = true - } - - token, err := login.GetJWTTokenWithExpiration(user.Login, user.Role, c.JWTKeys.PrivateKey, c.JWTKeysKID, samlSession.ExpiresAt) - if err != nil { - c.returnError(w, fmt.Errorf("token generation failed: %s", err), http.StatusBadRequest) - return - } - loginResponse.Authenticated = true - loginResponse.Token = token - - out, err := json.Marshal(loginResponse) - if err != nil { - c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - return - default: // oidc is default - loginResponse := login.LoginResponse{} - var oidcCallback OIDCCallback - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&oidcCallback) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - for _, oidcProvider := range c.OIDCProviders { - if r.PathValue("id") == oidcProvider.ID && oidcCallback.Code != "" { // we got the code back - oidcstore.RetrieveTokenLock.Lock() - defer oidcstore.RetrieveTokenLock.Unlock() - oauth2data, err := oidc.RetrieveOAUth2DataUsingState(c.OIDCStore.OAuth2Data, oidcCallback.State) // get the oauth2 struct based on the state (key) - if err != nil { - c.returnError(w, fmt.Errorf("cannot find oauth2 data using state provided: %s", err), http.StatusBadRequest) - return - } - if oauth2data.Token.AccessToken != "" { - if oauth2data.Suspended { - loginResponse.Suspended = true - } else if c.LicenseUserCount >= c.UserStore.UserCount() { - loginResponse.NoLicense = true - } else { - loginResponse.Authenticated = true - loginResponse.Token = oauth2data.Token.AccessToken - } - out, err := json.Marshal(loginResponse) - if err != nil { - c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - return - } - // no token, let's generate a new one - discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI) - if err != nil { - c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest) - return - } - jwks, err := c.OIDCStore.GetJwks(discovery.JwksURI) - if err != nil { - c.returnError(w, fmt.Errorf("get jwks error: %s", err), http.StatusBadRequest) - return - } - updatedOauth2data, err := oidc.UpdateOAuth2DataWithToken(jwks, discovery, oidcProvider.ClientID, oidcProvider.ClientSecret, c.Protocol+"://"+c.Hostname+oidcCallback.RedirectURI, oidcCallback.Code, oidcCallback.State, oauth2data) - if err != nil { - c.returnError(w, fmt.Errorf("GetTokenFromCode error: %s", err), http.StatusBadRequest) - return - } - // add user to the user database (or modify existing one) - user, err := addOrModifyExternalUser(c.Storage.Client, c.UserStore, updatedOauth2data.UserInfo.Email, "oidc", updatedOauth2data.ID) - if err != nil { - c.returnError(w, fmt.Errorf("couldn't add/modify user in database: %s", err), http.StatusBadRequest) - return - } - if user.Suspended { - loginResponse.Suspended = true - updatedOauth2data.Suspended = true - } else { - updatedOauth2data.Suspended = false - } - // save oauth data (only when we're sure it's not a suspended user) - err = c.OIDCStore.SaveOAuth2Data(updatedOauth2data, oidcCallback.State) - if err != nil { - c.returnError(w, fmt.Errorf("oidc store save failed: %s", err), http.StatusBadRequest) - return - } - // cleanup oauth2 data - c.OIDCStore.CleanupOAuth2Data(updatedOauth2data) - - // save config (changed user info) - err = SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) - return - } - - // set loginResponse - if !loginResponse.Suspended { - loginResponse.Authenticated = true - loginResponse.Token = updatedOauth2data.Token.AccessToken - } - out, err := json.Marshal(loginResponse) - if err != nil { - c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - return - } - } - } - c.returnError(w, fmt.Errorf("oidc provider not found"), http.StatusBadRequest) - case http.MethodGet: - switch r.PathValue("method") { - case "saml": - id := r.PathValue("id") - samlProviderId := -1 - for k := range *c.SAML.Providers { - if (*c.SAML.Providers)[k].ID == id { - samlProviderId = k - } - } - if samlProviderId == -1 { - c.returnError(w, fmt.Errorf("cannot find saml provider"), http.StatusBadRequest) - return - } - redirectURI, err := c.SAML.Client.GetAuthURL((*c.SAML.Providers)[samlProviderId]) - if err != nil { - c.returnError(w, fmt.Errorf("cannot get auth url"), http.StatusBadRequest) - return - } - response := AuthMethodsProvider{ - ID: (*c.SAML.Providers)[samlProviderId].ID, - Name: (*c.SAML.Providers)[samlProviderId].Name, - RedirectURI: redirectURI, - } - out, err := json.Marshal(response) - if err != nil { - c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - return - default: - id := r.PathValue("id") - for _, oidcProvider := range c.OIDCProviders { - if id == oidcProvider.ID { - callback := fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProvider.RedirectURI) - discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI) - if err != nil { - c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest) - return - } - redirectURI, state, err := oidc.GetRedirectURI(discovery, oidcProvider.ClientID, oidcProvider.Scope, callback, c.EnableOIDCTokenRenewal) - if err != nil { - c.returnError(w, fmt.Errorf("GetRedirectURI error: %s", err), http.StatusBadRequest) - return - } - response := AuthMethodsProvider{ - ID: oidcProvider.ID, - Name: oidcProvider.Name, - RedirectURI: redirectURI, - } - out, err := json.Marshal(response) - if err != nil { - c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest) - return - } - newOAuthEntry := oidc.OAuthData{ - ID: uuid.NewString(), - OIDCProviderID: response.ID, - CreatedAt: time.Now(), - } - err = c.OIDCStore.SaveOAuth2Data(newOAuthEntry, state) - if err != nil { - c.returnError(w, fmt.Errorf("unable to save state to oidc store: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - return - } - } - c.returnError(w, fmt.Errorf("element not found"), http.StatusBadRequest) - } - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) oidcRenewTokensHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodPost: - c.OIDCRenewal.RenewAllOIDCConnections() - c.write(w, []byte(`{"status": "done"}`)) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} diff --git a/pkg/rest/auth_test.go b/pkg/rest/auth_test.go deleted file mode 100644 index 848e8f3..0000000 --- a/pkg/rest/auth_test.go +++ /dev/null @@ -1,946 +0,0 @@ -package rest - -import ( - "bytes" - "compress/flate" - "crypto/rand" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "net" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/in4it/wireguard-server/pkg/auth/oidc" - "github.com/in4it/wireguard-server/pkg/auth/saml" - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/rest/login" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" - "github.com/in4it/wireguard-server/pkg/users" - "github.com/russellhaering/gosaml2/types" - dsigtypes "github.com/russellhaering/goxmldsig/types" -) - -func getSAMLCertWithCustomCert(singleSignOnURL string, cert string) *types.EntityDescriptor { - return &types.EntityDescriptor{ - EntityID: "https://www.idp.inv/metadata", - IDPSSODescriptor: &types.IDPSSODescriptor{ - SingleSignOnServices: []types.SingleSignOnService{ - { - Location: singleSignOnURL, - }, - }, - KeyDescriptors: []types.KeyDescriptor{ - { - KeyInfo: dsigtypes.KeyInfo{ - X509Data: dsigtypes.X509Data{ - X509Certificates: []dsigtypes.X509Certificate{ - { - Data: cert, - }, - }, - }, - }, - }, - }, - }, - } -} -func getSAMLCert(singleSignOnURL string) *types.EntityDescriptor { - cert := `MIID2jCCA0MCAg39MA0GCSqGSIb3DQEBBQUAMIGbMQswCQYDVQQGEwJKUDEOMAwG -A1UECBMFVG9reW8xEDAOBgNVBAcTB0NodW8ta3UxETAPBgNVBAoTCEZyYW5rNERE -MRgwFgYDVQQLEw9XZWJDZXJ0IFN1cHBvcnQxGDAWBgNVBAMTD0ZyYW5rNEREIFdl -YiBDQTEjMCEGCSqGSIb3DQEJARYUc3VwcG9ydEBmcmFuazRkZC5jb20wHhcNMTIw -ODIyMDUyODAwWhcNMTcwODIxMDUyODAwWjBKMQswCQYDVQQGEwJKUDEOMAwGA1UE -CAwFVG9reW8xETAPBgNVBAoMCEZyYW5rNEREMRgwFgYDVQQDDA93d3cuZXhhbXBs -ZS5jb20wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCwvWITOLeyTbS1 -Q/UacqeILIK16UHLvSymIlbbiT7mpD4SMwB343xpIlXN64fC0Y1ylT6LLeX4St7A -cJrGIV3AMmJcsDsNzgo577LqtNvnOkLH0GojisFEKQiREX6gOgq9tWSqwaENccTE -sAXuV6AQ1ST+G16s00iN92hjX9V/V66snRwTsJ/p4WRpLSdAj4272hiM19qIg9zr -h92e2rQy7E/UShW4gpOrhg2f6fcCBm+aXIga+qxaSLchcDUvPXrpIxTd/OWQ23Qh -vIEzkGbPlBA8J7Nw9KCyaxbYMBFb1i0lBjwKLjmcoihiI7PVthAOu/B71D2hKcFj -Kpfv4D1Uam/0VumKwhwuhZVNjLq1BR1FKRJ1CioLG4wCTr0LVgtvvUyhFrS+3PdU -R0T5HlAQWPMyQDHgCpbOHW0wc0hbuNeO/lS82LjieGNFxKmMBFF9lsN2zsA6Qw32 -Xkb2/EFltXCtpuOwVztdk4MDrnaDXy9zMZuqFHpv5lWTbDVwDdyEQNclYlbAEbDe -vEQo/rAOZFl94Mu63rAgLiPeZN4IdS/48or5KaQaCOe0DuAb4GWNIQ42cYQ5TsEH -Wt+FIOAMSpf9hNPjDeu1uff40DOtsiyGeX9NViqKtttaHpvd7rb2zsasbcAGUl+f -NQJj4qImPSB9ThqZqPTukEcM/NtbeQIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAIAi -gU3My8kYYniDuKEXSJmbVB+K1upHxWDA8R6KMZGXfbe5BRd8s40cY6JBYL52Tgqd -l8z5Ek8dC4NNpfpcZc/teT1WqiO2wnpGHjgMDuDL1mxCZNL422jHpiPWkWp3AuDI -c7tL1QjbfAUHAQYwmHkWgPP+T2wAv0pOt36GgMCM` - return getSAMLCertWithCustomCert(singleSignOnURL, cert) -} - -func TestAuthHandler(t *testing.T) { - c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context: %s", err) - } - c.UserStore.Empty() - _, err = c.UserStore.AddUser(users.User{ - Login: "john", - Password: "mypass", - }) - if err != nil { - t.Fatalf("Cannot create user") - } - - loginReq := login.LoginRequest{ - Login: "john", - Password: "mypass", - } - - payload, err := json.Marshal(loginReq) - if err != nil { - t.Fatal(err) - } - - req := httptest.NewRequest("POST", "http://example.com/api/auth", bytes.NewBuffer(payload)) - w := httptest.NewRecorder() - c.authHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - var loginResponse login.LoginResponse - - err = json.NewDecoder(resp.Body).Decode(&loginResponse) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - - if !loginResponse.Authenticated { - t.Fatalf("expected to be authenticated") - } - -} - -func TestNewSAMLConnection(t *testing.T) { - // generate new keypair - kp := saml.NewKeyPair(&memorystorage.MockMemoryStorage{}, "www.idp.inv") - _, cert, err := kp.GetKeyPair() - if err != nil { - t.Fatalf("Can't generate new keypair: %s", err) - } - certBase64 := base64.StdEncoding.EncodeToString(cert) - - testUrl := "127.0.0.1:12347" - l, err := net.Listen("tcp", testUrl) - if err != nil { - t.Fatal(err) - } - - singleSignOnURL := "http://" + testUrl + "/auth" - audienceURL := "http://" + testUrl + "/aud" - login := "john@example.inv" - - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - requestURIParsed, _ := url.Parse(r.RequestURI) - if requestURIParsed.Path == "/auth" { - compressedSAMLReq, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest")) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(fmt.Sprintf("saml base64 decode error: %s", err))) - return - } - samlRequest := new(bytes.Buffer) - decompressor := flate.NewReader(bytes.NewReader(compressedSAMLReq)) - io.Copy(samlRequest, decompressor) - decompressor.Close() - - var authnReq saml.AuthnRequest - err = xml.Unmarshal(samlRequest.Bytes(), &authnReq) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(fmt.Sprintf("saml authn request decode error: %s", err))) - return - } - w.Write([]byte("OK")) - return - } - if r.RequestURI == "/metadata" { - out, _ := xml.Marshal(getSAMLCertWithCustomCert(singleSignOnURL, certBase64)) - w.Write(out) - return - } - w.WriteHeader(http.StatusBadRequest) - default: - w.WriteHeader(http.StatusBadRequest) - } - })) - - ts.Listener.Close() - ts.Listener = l - ts.Start() - defer ts.Close() - defer l.Close() - - // first create a new user - c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context") - } - - // create a new SAML connection - samlProvider := saml.Provider{ - Name: "testProvider", - MetadataURL: fmt.Sprintf("%s/metadata", ts.URL), - } - - payload, err := json.Marshal(samlProvider) - if err != nil { - t.Fatal(err) - } - - req := httptest.NewRequest("POST", "http://example.inv/api/saml-setup", bytes.NewBuffer(payload)) - w := httptest.NewRecorder() - c.samlSetupHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(&samlProvider) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - - if samlProvider.ID == "" { - t.Fatalf("Was expecting saml provider to have an ID") - } - - authURL, err := c.SAML.Client.GetAuthURL(samlProvider) - if err != nil { - t.Fatalf("cannot get Auth URL from saml: %s", err) - } - - if authURL == "" { - t.Fatalf("authURL is empty") - } - - resp, err = http.Get(authURL) - if err != nil { - t.Fatalf("http get auth url error: %s", err) - } - if resp.StatusCode != 200 { - t.Errorf("auth url get not status 200: %d", resp.StatusCode) - } - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("body read error: %s", err) - } - - // check SAML POST flow - tsSAML := httptest.NewServer(c.SAML.Client.GetRouter()) - defer tsSAML.Close() - - // build the SAML response - // example - /* - - - https://app.onelogin.com/saml/metadata/onelogin-id - - - - - - - - - - - 5eB3C+2/vwdigestvalue - - - sigvalue - - - MIIGETCCA/mgAcertt - - - - - - - - https://app.onelogin.com/saml/metadata/onelogin-id - - ward@in4it.io - - - - - - - https://vpn-server.in4it.io/saml/aud/provider-id - - - - - urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport - - - - - */ - samlResponse := saml.Response{ - Saml: "urn:oasis:names:tc:SAML:2.0:assertion", - Samlp: "urn:oasis:names:tc:SAML:2.0:protocol", - ID: "pfx181c1b93-cce6-a6c4-f1f7-99e539374d15", - Version: "2.0", - IssueInstant: time.Now().Format(time.RFC3339), - Destination: "http://example.inv/saml/acs/" + samlProvider.ID, - Issuer: ts.URL + "/metadata", - Signature: saml.ResponseSignature{ - Ds: "http://www.w3.org/2000/09/xmldsig#", - SignedInfo: saml.ResponseSignatureSignedInfo{ - CanonicalizationMethod: struct { - Text string "xml:\",chardata\"" - Algorithm string "xml:\"Algorithm,attr\"" - }{ - Algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#", - }, - SignatureMethod: saml.ResponseSignatureSignedInfoSignatureMethod{ - Algorithm: "http://www.w3.org/2000/09/xmldsig#rsa-sha1", - }, - Reference: saml.ResponseSignatureSignedInfoReference{ - Transforms: struct { - Text string "xml:\",chardata\"" - Transform []struct { - Text string "xml:\",chardata\"" - Algorithm string "xml:\"Algorithm,attr\"" - } "xml:\"ds:Transform\"" - }{ - Transform: []struct { - Text string "xml:\",chardata\"" - Algorithm string "xml:\"Algorithm,attr\"" - }{ - { - Algorithm: "http://www.w3.org/2000/09/xmldsig#enveloped-signature", - }, - { - Algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#", - }, - }, - }, - DigestMethod: struct { - Text string "xml:\",chardata\"" - Algorithm string "xml:\"Algorithm,attr\"" - }{ - Algorithm: "http://www.w3.org/2000/09/xmldsig#sha1", - }, - DigestValue: "thisisthesignature", - }, - }, - KeyInfo: saml.ResponseSignatureKeyInfo{ - X509Data: struct { - Text string "xml:\",chardata\"" - X509Certificate string "xml:\"ds:X509Certificate\"" - }{ - X509Certificate: certBase64, - }, - }, - }, - Assertion: saml.ResponseAssertion{ - Subject: saml.ResponseSubject{ - NameID: struct { - Text string "xml:\",chardata\"" - Format string "xml:\"Format,attr\"" - }{ - Text: login, - Format: "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", - }, - }, - Conditions: saml.ResponseConditions{ - NotBefore: time.Now().Format(time.RFC3339), - NotOnOrAfter: time.Now().Add(10 * time.Minute).Format(time.RFC3339), - AudienceRestriction: saml.ResponseConditionsAdienceRestriction{ - Audience: audienceURL, - }, - }, - }, - } - samlResponseBytes, err := xml.Marshal(samlResponse) - if err != nil { - t.Fatalf("xml marshal error: %s", err) - } - //fmt.Printf("saml respons bytes: %s\n", samlResponseBytes) - samlResponseBytesDeflated := new(bytes.Buffer) - compressor, err := flate.NewWriter(samlResponseBytesDeflated, 1) - if err != nil { - t.Fatalf("deflate error: %s", err) - } - io.Copy(compressor, bytes.NewBuffer(samlResponseBytes)) - compressor.Close() - - samlResponseEncoded := base64.StdEncoding.EncodeToString(samlResponseBytesDeflated.Bytes()) - - form := url.Values{} - form.Add("SAMLResponse", samlResponseEncoded) - - resp, err = http.Post(tsSAML.URL+"/saml/acs/"+samlProvider.ID, "application/x-www-form-urlencoded", strings.NewReader(form.Encode())) - if err != nil { - t.Fatalf("http post acs url error: %s", err) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("body read error: %s", err) - } - - if strings.Contains(string(body), "provider not found") { - t.Fatalf("provider not found: body output: %s", body) - } - - // currently does't authenticate because of missing signatures, but we checked if the auth process kicked off - /*if resp.StatusCode != 200 { - t.Errorf("auth url get not status 200: %d", resp.StatusCode) - }*/ - -} -func TestAddModifyDeleteNewSAMLConnection(t *testing.T) { - c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context") - } - c.Hostname = "example.inv" - c.Protocol = "https" - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata" { - out, err := xml.Marshal(getSAMLCert("http://localhost.inv")) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - w.Write(out) - return - } - w.WriteHeader(http.StatusBadRequest) - })) - - samlProvider := saml.Provider{ - Name: "testProvider", - MetadataURL: fmt.Sprintf("%s/metadata", ts.URL), - } - - payload, err := json.Marshal(samlProvider) - if err != nil { - t.Fatal(err) - } - - req := httptest.NewRequest("POST", "http://example.com/api/saml-setup", bytes.NewBuffer(payload)) - w := httptest.NewRecorder() - c.samlSetupHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(&samlProvider) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - - if samlProvider.ID == "" { - t.Fatalf("samlprovider id is empty") - } - - // GET authmethods and see if provider exists - req = httptest.NewRequest("GET", "http://example.com/authmethods/saml/"+samlProvider.ID, nil) - req.SetPathValue("method", "saml") - req.SetPathValue("id", samlProvider.ID) - w = httptest.NewRecorder() - c.authMethodsByID(w, req) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - var authMethodsProvider AuthMethodsProvider - - err = json.NewDecoder(resp.Body).Decode(&authMethodsProvider) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - if authMethodsProvider.ID != samlProvider.ID { - t.Fatalf("authmethods provider id is different than saml provider id: %s vs %s. authMethodsProvider: %+v", authMethodsProvider.ID, samlProvider.ID, authMethodsProvider) - } - - // PUT req - samlProvider.AllowMissingAttributes = true - payload, err = json.Marshal(samlProvider) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - req = httptest.NewRequest("PUT", "http://example.com/saml-setup/"+samlProvider.ID, bytes.NewBuffer(payload)) - req.SetPathValue("id", samlProvider.ID) - w = httptest.NewRecorder() - c.samlSetupElementHandler(w, req) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(&samlProvider) - if err != nil { - t.Fatalf("marshal decode error: %s", err) - } - - if samlProvider.AllowMissingAttributes == false { - t.Fatalf("allow missing attributes is false") - } - - // GET on the saml endpoint to see if we can return it - req = httptest.NewRequest("GET", "http://example.com/saml-setup", nil) - w = httptest.NewRecorder() - c.samlSetupHandler(w, req) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - var samlProviders []saml.Provider - err = json.NewDecoder(resp.Body).Decode(&samlProviders) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - if len(samlProviders) == 0 { - t.Fatalf("samlProviders is zero length") - } - if samlProviders[len(samlProviders)-1].ID != samlProvider.ID { - t.Fatalf("ID doesn't match: %s vs %s ", samlProviders[len(samlProviders)-1].ID, samlProvider.ID) - } - if samlProviders[len(samlProviders)-1].Acs != fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ACS_URL, samlProvider.ID) { - t.Fatalf("ACS doesn't match") - } - if samlProviders[len(samlProviders)-1].AllowMissingAttributes == false { - t.Fatalf("allow missing attributes is false when getting all samlproviders") - } - - // delete req - req = httptest.NewRequest("DELETE", "http://example.com/saml-setup/"+samlProvider.ID, bytes.NewBuffer(payload)) - req.SetPathValue("id", samlProvider.ID) - w = httptest.NewRecorder() - c.samlSetupElementHandler(w, req) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - // list to see if really deleted - req = httptest.NewRequest("GET", "http://example.com/saml-setup", nil) - w = httptest.NewRecorder() - c.samlSetupHandler(w, req) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - var samlProviders2 []saml.Provider - err = json.NewDecoder(resp.Body).Decode(&samlProviders2) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - if len(samlProviders)-1 != len(samlProviders2) { - t.Fatalf("samlProviders has wrong length") - } - -} - -func TestSAMLCallback(t *testing.T) { - c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context") - } - c.Hostname = "example.inv" - c.Protocol = "https" - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata" { - out, err := xml.Marshal(getSAMLCert("http://localhost.inv")) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - w.Write(out) - return - } - w.WriteHeader(http.StatusBadRequest) - })) - - samlProvider := saml.Provider{ - Name: "testProvider", - MetadataURL: fmt.Sprintf("%s/metadata", ts.URL), - } - - payload, err := json.Marshal(samlProvider) - if err != nil { - t.Fatal(err) - } - - req := httptest.NewRequest("POST", "http://example.com/api/saml-setup", bytes.NewBuffer(payload)) - w := httptest.NewRecorder() - c.samlSetupHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(&samlProvider) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - - if samlProvider.ID == "" { - t.Fatalf("samlprovider id is empty") - } - samlCallback := SAMLCallback{ - Code: "abc", - RedirectURI: "https://localhost.inv/something", - } - payload, err = json.Marshal(samlCallback) - if err != nil { - t.Fatal(err) - } - - c.SAML.Client.CreateSession(saml.SessionKey{ProviderID: samlProvider.ID, SessionID: "abc"}, saml.AuthenticatedUser{ID: "123", Login: "john@example.com", ExpiresAt: time.Now().AddDate(0, 0, 1)}) - - req = httptest.NewRequest("POST", "http://example.com/api/authmethods/saml/"+samlProvider.ID, bytes.NewBuffer(payload)) - req.SetPathValue("method", "saml") - req.SetPathValue("id", samlProvider.ID) - w = httptest.NewRecorder() - c.authMethodsByID(w, req) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - var loginResponse login.LoginResponse - err = json.NewDecoder(resp.Body).Decode(&loginResponse) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - - if !loginResponse.Authenticated { - t.Fatalf("Expected to be authenticated") - } - -} - -func TestOIDCFlow(t *testing.T) { - testUrl := "127.0.0.1:12346" - l, err := net.Listen("tcp", testUrl) - if err != nil { - t.Fatal(err) - } - - authURL := "http://" + testUrl + "/auth" - - // create a new OIDC connection - oidcProvider := oidc.OIDCProvider{ - Name: "test-oidc", - ClientID: "1-2-3-4", - ClientSecret: "9-9-9-9", - Scope: "openid", - DiscoveryURI: "http://" + testUrl + "/discovery.json", - } - jwtPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - t.Fatalf("can't generate jwt key: %s", err) - } - - // first create a new user - c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context") - } - c.Hostname = "example.inv" - c.Protocol = "http" - logging.Loglevel = 17 - - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - code := "thisisthecode" - - switch r.Method { - case http.MethodGet: - parsedURI, _ := url.Parse(r.RequestURI) - switch parsedURI.Path { - case "/discovery.json": - discovery := oidc.Discovery{ - Issuer: "test-issuer", - AuthorizationEndpoint: authURL, - TokenEndpoint: "http://" + testUrl + "/token", - JwksURI: "http://" + testUrl + "/jwks.json", - } - out, err := json.Marshal(discovery) - if err != nil { - t.Fatalf("json marshal error: %s", err) - } - w.Write(out) - return - case "/auth": - if oidcProvider.ClientID != r.URL.Query().Get("client_id") { - w.Write([]byte("client id mismatch")) - w.WriteHeader(http.StatusBadRequest) - return - } - if oidcProvider.Scope != r.URL.Query().Get("scope") { - w.Write([]byte("scope mismatch")) - w.WriteHeader(http.StatusBadRequest) - return - } - w.Write([]byte(code)) - case "/jwks.json": - publicKey := jwtPrivateKey.PublicKey - - jwks := oidc.Jwks{ - Keys: []oidc.JwksKey{ - { - Kid: "kid-id-1234", - Alg: "RS256", - Kty: "RSA", - Use: "sig", - N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()), - E: "AQAB", - }, - }, - } - out, err := json.Marshal(jwks) - if err != nil { - w.Write([]byte("jwks marshal error")) - w.WriteHeader(http.StatusBadRequest) - return - } - w.Write(out) - default: - w.WriteHeader(http.StatusNotFound) - } - case http.MethodPost: - parsedURI, _ := url.Parse(r.RequestURI) - switch parsedURI.Path { - case "/token": - if r.FormValue("grant_type") != "authorization_code" { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("wrong grant type")) - return - } - if r.FormValue("code") != code { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("wrong code")) - return - } - if oidcProvider.ClientID != r.FormValue("client_id") { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("client id mismatch")) - return - } - if oidcProvider.ClientSecret != r.FormValue("client_secret") { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("client secret mismatch")) - return - } - if c.Protocol+"://"+c.Hostname+oidcProvider.RedirectURI != r.FormValue("redirect_uri") { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(fmt.Sprintf("redirect uri mismatch: %s vs %s", oidcProvider.RedirectURI, r.FormValue("redirect_uri")))) - return - } - token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), jwt.MapClaims{ - "iss": "test-server", - "sub": "john", - "email": "john@example.inv", - "role": "user", - "exp": time.Now().AddDate(0, 0, 1).Unix(), - "iat": time.Now().Unix(), - }) - token.Header["kid"] = "kid-id-1234" - - tokenString, err := token.SignedString(jwtPrivateKey) - if err != nil { - t.Fatalf("can't generate jwt token: %s", err) - w.WriteHeader(http.StatusBadRequest) - } - tokenRes := oidc.Token{ - AccessToken: tokenString, - IDToken: tokenString, - ExpiresIn: 180, - } - tokenBytes, _ := json.Marshal(tokenRes) - - w.Write([]byte(tokenBytes)) - default: - w.WriteHeader(http.StatusNotFound) - } - default: - w.WriteHeader(http.StatusBadRequest) - } - })) - - ts.Listener.Close() - ts.Listener = l - ts.Start() - defer ts.Close() - defer l.Close() - - payload, err := json.Marshal(oidcProvider) - if err != nil { - t.Fatal(err) - } - - // create new oidc provider - req := httptest.NewRequest("POST", "http://example.inv/api/oidc", bytes.NewBuffer(payload)) - w := httptest.NewRecorder() - c.oidcProviderHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(&oidcProvider) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - - if oidcProvider.ID == "" { - t.Fatalf("Was expecting oidc provider to have an ID") - } - - // get redirect URL - req = httptest.NewRequest("GET", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID, nil) - req.SetPathValue("id", oidcProvider.ID) - w = httptest.NewRecorder() - c.authMethodsByID(w, req) - - resp = w.Result() - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - var authmethodsResponse AuthMethodsProvider - - err = json.NewDecoder(resp.Body).Decode(&authmethodsResponse) - - if err != nil { - t.Fatalf("cannot decode authmethodsresponse: %s", err) - } - if !strings.HasPrefix(authmethodsResponse.RedirectURI, authURL) { - t.Fatalf("expected authURL as prefix of redirect url. Redirect URL: %s", authmethodsResponse.RedirectURI) - } - - redirectURIParsed, err := url.Parse(authmethodsResponse.RedirectURI) - if err != nil { - t.Fatalf("could not parse redirect URI: %s", err) - } - state := redirectURIParsed.Query().Get("state") - if state == "" { - t.Fatalf("could not obtain state") - } - res, err := http.Get(authmethodsResponse.RedirectURI) - if err != nil { - t.Fatalf("http get redirect uri error: %s", err) - } - if res.StatusCode != 200 { - t.Fatalf("redirect uri statuscode not 200: %d", res.StatusCode) - } - code, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("body read error: %s", err) - } - - callback := OIDCCallback{ - Code: string(code), - State: state, - RedirectURI: oidcProvider.RedirectURI, - } - callbackPayload, err := json.Marshal(callback) - if err != nil { - t.Fatalf("callback marshal error: %s", err) - } - // execute callback - req = httptest.NewRequest("POST", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID, bytes.NewBuffer(callbackPayload)) - req.SetPathValue("id", oidcProvider.ID) - req.SetPathValue("method", "oidc") - w = httptest.NewRecorder() - c.authMethodsByID(w, req) - - resp = w.Result() - defer resp.Body.Close() - - if resp.StatusCode != 200 { - errorMessage, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("body read error after statuscode not 200 (%d): %s", resp.StatusCode, err) - } - t.Fatalf("status code is not 200: %d, errormessage: %s", resp.StatusCode, errorMessage) - } - - var loginResponse login.LoginResponse - - err = json.NewDecoder(resp.Body).Decode(&loginResponse) - if err != nil { - t.Fatalf("cannot decode login response: %s", err) - } - - if !loginResponse.Authenticated { - t.Fatalf("not authenticated: %+v", loginResponse) - } - if loginResponse.Token == "" { - t.Fatalf("no token received: %+v", loginResponse) - } -} diff --git a/pkg/rest/config.go b/pkg/rest/config.go deleted file mode 100644 index 18b6d68..0000000 --- a/pkg/rest/config.go +++ /dev/null @@ -1,79 +0,0 @@ -package rest - -import ( - "bytes" - "encoding/json" - "fmt" - "os/user" - "sync" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -var mu sync.Mutex - -func SaveConfig(c *Context) error { - mu.Lock() - defer mu.Unlock() - cCopy := *c - cCopy.SCIM = &SCIM{ // we don't save the client, but we want the token and enabled - EnableSCIM: c.SCIM.EnableSCIM, - Token: c.SCIM.Token, - } - cCopy.SAML = &SAML{ // we don't save the client, but we want the config - Providers: c.SAML.Providers, - } - cCopy.JWTKeys = nil // we retrieve JWTKeys from pem files at startup - cCopy.OIDCStore = nil // we save this separately - cCopy.UserStore = nil // we save this separately - cCopy.OIDCRenewal = nil // we don't save this - cCopy.LoginAttempts = nil // no need to save this - cCopy.Observability = nil // no need to save the client - cCopy.Storage = nil // no need to save storage - out, err := json.Marshal(cCopy) - if err != nil { - return fmt.Errorf("context marshal error: %s", err) - } - err = c.Storage.Client.WriteFile(c.Storage.Client.ConfigPath("config.json"), out) - if err != nil { - return fmt.Errorf("config write error: %s", err) - } - // fix permissions - currentUser, err := user.Current() - if err != nil { - return fmt.Errorf("could not get current user: %s", err) - } - if currentUser.Username != "vpn" { - err = c.Storage.Client.EnsureOwnership(c.Storage.Client.ConfigPath("config.json"), "vpn") - if err != nil { - return fmt.Errorf("config write error: %s", err) - } - } - - return nil -} - -func GetConfig(storage storage.Iface) (*Context, error) { - var c *Context - - appDir := storage.GetPath() - - // check if config exists - if !storage.FileExists(storage.ConfigPath("config.json")) { - return getEmptyContext(appDir) - } - - data, err := storage.ReadFile(storage.ConfigPath("config.json")) - if err != nil { - return c, fmt.Errorf("config read error: %s", err) - } - decoder := json.NewDecoder(bytes.NewBuffer(data)) - err = decoder.Decode(&c) - if err != nil { - return c, fmt.Errorf("decode input error: %s", err) - } - - c.AppDir = appDir - - return c, nil -} diff --git a/pkg/rest/constants.go b/pkg/rest/constants.go deleted file mode 100644 index 48fe858..0000000 --- a/pkg/rest/constants.go +++ /dev/null @@ -1,4 +0,0 @@ -package rest - -const SERVER_TYPE_OBSERVABILITY = "observability" -const SERVER_TYPE_VPN = "vpn" diff --git a/pkg/rest/context.go b/pkg/rest/context.go deleted file mode 100644 index ba7a2a4..0000000 --- a/pkg/rest/context.go +++ /dev/null @@ -1,116 +0,0 @@ -package rest - -import ( - "fmt" - "sync" - "time" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store" - oidcrenewal "github.com/in4it/wireguard-server/pkg/auth/oidc/store/renewal" - "github.com/in4it/wireguard-server/pkg/auth/provisioning/scim" - "github.com/in4it/wireguard-server/pkg/auth/saml" - "github.com/in4it/wireguard-server/pkg/license" - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/observability" - "github.com/in4it/wireguard-server/pkg/rest/login" - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" -) - -var muClientDownload sync.Mutex - -func newContext(storage storage.Iface, serverType string) (*Context, error) { - c, err := GetConfig(storage) - if err != nil { - return c, fmt.Errorf("getConfig error: %s", err) - } - c.ServerType = serverType - - c.Storage = &Storage{ - Client: storage, - } - - c.JWTKeys, err = getJWTKeys(storage) - if err != nil { - return c, fmt.Errorf("getJWTKeys error: %s", err) - } - c.OIDCStore, err = oidcstore.NewStore(storage) - if err != nil { - return c, fmt.Errorf("getOIDCStore error: %s", err) - } - if c.OIDCProviders == nil { - c.OIDCProviders = []oidc.OIDCProvider{} - } - - c.LicenseUserCount, c.CloudType = license.GetMaxUsers(c.Storage.Client) - go func() { // run license refresh - logging.DebugLog(fmt.Errorf("starting license refresh in background (current licenses: %d, cloud type: %s)", c.LicenseUserCount, c.CloudType)) - for { - time.Sleep(time.Hour * 24) - newLicenseCount := license.RefreshLicense(storage, c.CloudType, c.LicenseUserCount) - if newLicenseCount != c.LicenseUserCount { - logging.InfoLog(fmt.Sprintf("License changed from %d users to %d users", c.LicenseUserCount, newLicenseCount)) - c.LicenseUserCount = newLicenseCount - } - } - }() - c.UserStore, err = users.NewUserStore(c.Storage.Client, c.LicenseUserCount) - if err != nil { - return c, fmt.Errorf("userstore initialization error: %s", err) - } - - c.OIDCRenewal, err = oidcrenewal.NewRenewal(storage, c.TokenRenewalTimeMinutes, c.LogLevel, c.EnableOIDCTokenRenewal, c.OIDCStore, c.OIDCProviders, c.UserStore) - if err != nil { - return c, fmt.Errorf("oidcrenewal init error: %s", err) - } - - if c.LoginAttempts == nil { - c.LoginAttempts = make(login.Attempts) - } - - if c.SCIM == nil { - c.SCIM = &SCIM{ - Client: scim.New(storage, c.UserStore, ""), - Token: "", - EnableSCIM: false, - } - } else { - c.SCIM.Client = scim.New(storage, c.UserStore, c.SCIM.Token) - } - if c.SAML == nil { - providers := []saml.Provider{} - c.SAML = &SAML{ - Client: saml.New(&providers, storage, &c.Protocol, &c.Hostname), - Providers: &providers, - } - } else { - c.SAML.Client = saml.New(c.SAML.Providers, storage, &c.Protocol, &c.Hostname) - } - - if c.Observability == nil { - c.Observability = &Observability{ - Client: observability.New(storage), - } - } else { - c.Observability.Client = observability.New(storage) - } - - return c, nil -} - -func getEmptyContext(appDir string) (*Context, error) { - randomString, err := oidc.GetRandomString(64) - if err != nil { - return nil, fmt.Errorf("couldn't generate random string for local kid") - } - c := &Context{ - AppDir: appDir, - JWTKeysKID: randomString, - TokenRenewalTimeMinutes: oidcrenewal.DEFAULT_RENEWAL_TIME_MINUTES, - LogLevel: logging.LOG_ERROR, - SCIM: &SCIM{EnableSCIM: false}, - SAML: &SAML{Providers: &[]saml.Provider{}}, - } - return c, nil -} diff --git a/pkg/rest/helpers.go b/pkg/rest/helpers.go deleted file mode 100644 index 8bfab34..0000000 --- a/pkg/rest/helpers.go +++ /dev/null @@ -1,79 +0,0 @@ -package rest - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "strings" -) - -func (c *Context) returnError(w http.ResponseWriter, err error, statusCode int) { - fmt.Println("========= ERROR =========") - fmt.Printf("Error: %s\n", err) - fmt.Println("=========================") - sendCorsHeaders(w, "", c.Hostname, c.Protocol) - w.WriteHeader(statusCode) - w.Write([]byte(`{"error": "` + strings.Replace(err.Error(), `"`, `\"`, -1) + `"}`)) -} - -func (c *Context) write(w http.ResponseWriter, res []byte) { - sendCorsHeaders(w, "", c.Hostname, c.Protocol) - w.WriteHeader(http.StatusOK) - w.Write(res) -} -func (c *Context) writeWithStatus(w http.ResponseWriter, res []byte, status int) { - sendCorsHeaders(w, "", c.Hostname, c.Protocol) - w.WriteHeader(status) - w.Write(res) -} - -func sendCorsHeaders(w http.ResponseWriter, headers string, hostname string, protocol string) { - if hostname == "" { - w.Header().Add("Access-Control-Allow-Origin", "*") - } else { - w.Header().Add("Access-Control-Allow-Origin", fmt.Sprintf("%s://%s", protocol, hostname)) - } - w.Header().Add("Access-Control-allow-methods", "GET,HEAD,POST,PUT,OPTIONS,DELETE,PATCH") - if headers != "" { - w.Header().Add("Access-Control-Allow-Headers", headers) - } -} - -func isAlphaNumeric(str string) bool { - for _, c := range str { - if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { - return false - } - } - return true -} - -func getKidFromToken(token string) (string, error) { - jwtSplit := strings.Split(token, ".") - if len(jwtSplit) < 1 { - return "", fmt.Errorf("token split < 1") - } - data, err := base64.RawURLEncoding.DecodeString(jwtSplit[0]) - if err != nil { - return "", fmt.Errorf("could not base64 decode data part of jwt") - } - var header JwtHeader - err = json.Unmarshal(data, &header) - if err != nil { - return "", fmt.Errorf("could not unmarshal jwt data") - } - return header.Kid, nil -} - -func returnIndexOrNotFound(contents []byte) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !strings.HasPrefix(r.URL.Path, "/api") { - w.WriteHeader(http.StatusOK) - w.Write(contents) - } else { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte("404 page not found\n")) - } - }) -} diff --git a/pkg/rest/helpers_test.go b/pkg/rest/helpers_test.go deleted file mode 100644 index cfe500d..0000000 --- a/pkg/rest/helpers_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package rest - -import "testing" - -func TestIsAlphaNumeric(t *testing.T) { - if !isAlphaNumeric("abc123") { - t.Errorf("expected alphanumeric") - } - if isAlphaNumeric("abc@") { - t.Errorf("expected alphanumeric") - } -} diff --git a/pkg/rest/license.go b/pkg/rest/license.go deleted file mode 100644 index b353bbf..0000000 --- a/pkg/rest/license.go +++ /dev/null @@ -1,46 +0,0 @@ -package rest - -import ( - "encoding/json" - "fmt" - "net/http" - - "github.com/in4it/wireguard-server/pkg/license" - "github.com/in4it/wireguard-server/pkg/users" - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -func (c *Context) licenseHandler(w http.ResponseWriter, r *http.Request) { - if r.PathValue("action") == "get-more" { - c.LicenseUserCount = license.RefreshLicense(c.Storage.Client, c.CloudType, c.LicenseUserCount) - } - - currentUserCount := c.UserStore.UserCount() - licenseResponse := LicenseResponse{LicenseUserCount: c.LicenseUserCount, CurrentUserCount: currentUserCount, CloudType: c.CloudType} - - if r.PathValue("action") == "get-more" { - licenseResponse.Key = license.GetLicenseKey(c.Storage.Client, c.CloudType) - } - - out, err := json.Marshal(licenseResponse) - if err != nil { - c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) - return - } - c.write(w, out) -} -func (c *Context) connectionLicenseHandler(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value(CustomValue("user")).(users.User) - totalConnections, err := wireguard.GetConfigNumbers(c.Storage.Client, user.ID) - if err != nil { - c.returnError(w, fmt.Errorf("can't determine total connections: %s", err), http.StatusBadRequest) - return - - } - out, err := json.Marshal(ConnectionLicenseResponse{LicenseUserCount: c.LicenseUserCount, ConnectionCount: len(totalConnections)}) - if err != nil { - c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) - return - } - c.write(w, out) -} diff --git a/pkg/rest/login/attempt.go b/pkg/rest/login/attempt.go deleted file mode 100644 index 2fe2f0c..0000000 --- a/pkg/rest/login/attempt.go +++ /dev/null @@ -1,51 +0,0 @@ -package login - -import ( - "sync" - "time" -) - -var mu sync.Mutex - -type Attempts map[string][]Attempt - -type Attempt struct { - Timestamp time.Time -} - -func ClearAttemptsForLogin(attempts Attempts, login string) { - mu.Lock() - defer mu.Unlock() - attempts[login] = []Attempt{} -} - -func RecordAttempt(attempts Attempts, login string) { - mu.Lock() - defer mu.Unlock() - _, ok := attempts[login] - if !ok { - attempts[login] = []Attempt{} - } - attempts[login] = append(attempts[login], Attempt{Timestamp: time.Now()}) -} - -func CheckTooManyLogins(attempts Attempts, login string) bool { - threeMinutes := 3 * time.Minute - _, ok := attempts[login] - if ok { - loginAttempts := 0 - for _, loginAttempt := range attempts[login] { - if time.Since(loginAttempt.Timestamp) <= threeMinutes { - loginAttempts++ - } - } - if loginAttempts >= 3 { - if len(attempts[login]) > 3 { - index := len(attempts[login]) - 3 - attempts[login] = attempts[login][index:] - } - return true - } - } - return false -} diff --git a/pkg/rest/login/auth.go b/pkg/rest/login/auth.go deleted file mode 100644 index 5d0288e..0000000 --- a/pkg/rest/login/auth.go +++ /dev/null @@ -1,50 +0,0 @@ -package login - -import ( - "crypto/rsa" - "fmt" - - "github.com/in4it/wireguard-server/pkg/mfa/totp" - "github.com/in4it/wireguard-server/pkg/users" -) - -func Authenticate(loginReq LoginRequest, authIface AuthIface, jwtPrivateKey *rsa.PrivateKey, jwtKeyID string) (LoginResponse, users.User, error) { - loginResponse := LoginResponse{} - user, auth := authIface.AuthUser(loginReq.Login, loginReq.Password) - if auth && !user.Suspended { - if len(user.Factors) == 0 { // authentication without MFA - token, err := GetJWTToken(user.Login, user.Role, jwtPrivateKey, jwtKeyID) - if err != nil { - return loginResponse, user, fmt.Errorf("token generation failed: %s", err) - } - loginResponse.Authenticated = true - loginResponse.Token = token - } else { - if loginReq.FactorResponse.Name == "" { - loginResponse.Authenticated = false - loginResponse.MFARequired = true - for _, factor := range user.Factors { - loginResponse.Factors = append(loginResponse.Factors, factor.Name) - } - } else { - for _, factor := range user.Factors { - if factor.Name == loginReq.FactorResponse.Name { - ok, err := totp.Verify(factor.Secret, loginReq.FactorResponse.Code) - if err != nil { - return loginResponse, user, fmt.Errorf("MFA (totp) verify failed: %s", err) - } - if ok { // authentication with MFA - token, err := GetJWTToken(user.Login, user.Role, jwtPrivateKey, jwtKeyID) - if err != nil { - return loginResponse, user, fmt.Errorf("token generation failed: %s", err) - } - loginResponse.Authenticated = true - loginResponse.Token = token - } - } - } - } - } - } - return loginResponse, user, nil -} diff --git a/pkg/rest/login/auth_test.go b/pkg/rest/login/auth_test.go deleted file mode 100644 index 34e61be..0000000 --- a/pkg/rest/login/auth_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package login - -import ( - "crypto/rand" - "crypto/rsa" - "encoding/base32" - "testing" - "time" - - "github.com/in4it/wireguard-server/pkg/mfa/totp" - "github.com/in4it/wireguard-server/pkg/users" -) - -type MockAuth struct { - AuthUserUser users.User - AuthUserResult bool -} - -func (m *MockAuth) AuthUser(login string, password string) (users.User, bool) { - return m.AuthUserUser, m.AuthUserResult -} - -func TestAuthenticate(t *testing.T) { - m := MockAuth{ - AuthUserUser: users.User{ - Login: "john", - }, - AuthUserResult: true, - } - loginReq := LoginRequest{ - Login: "john", - Password: "mypass", - } - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - t.Fatalf("private key error: %s", err) - } - - loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID") - if err != nil { - t.Fatalf("authentication error: %s", err) - } - if !loginResp.Authenticated { - t.Fatalf("expected to be authenticated") - } - if loginResp.Token == "" { - t.Fatalf("no token") - } -} -func TestAuthenticateMFANoToken(t *testing.T) { - m := MockAuth{ - AuthUserUser: users.User{ - Login: "john", - Factors: []users.Factor{ - { - Name: "test-factor", - Type: "test", - Secret: "secret", - }, - }, - }, - AuthUserResult: true, - } - loginReq := LoginRequest{ - Login: "john", - Password: "mypass", - } - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - t.Fatalf("private key error: %s", err) - } - - loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID") - if err != nil { - t.Fatalf("authentication error: %s", err) - } - if loginResp.Authenticated { - t.Fatalf("expected not to be authenticated") - } - if len(loginResp.Factors) == 0 { - t.Fatalf("expected to get factors") - } -} -func TestAuthenticateMFAWithToken(t *testing.T) { - secret := base32.StdEncoding.EncodeToString([]byte("secret")) - m := MockAuth{ - AuthUserUser: users.User{ - Login: "john", - Factors: []users.Factor{ - { - Name: "test-factor", - Type: "test", - Secret: secret, - }, - }, - }, - AuthUserResult: true, - } - token, err := totp.GetToken(secret, time.Now().Unix()/30) - if err != nil { - t.Fatalf("GetToken error: %s", err) - } - loginReq := LoginRequest{ - Login: "john", - Password: "mypass", - FactorResponse: FactorResponse{ - Name: "test-factor", - Code: token, - }, - } - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - t.Fatalf("private key error: %s", err) - } - - loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID") - if err != nil { - t.Fatalf("authentication error: %s", err) - } - if !loginResp.Authenticated { - t.Fatalf("expected to be authenticated") - } -} diff --git a/pkg/rest/login/jwt.go b/pkg/rest/login/jwt.go deleted file mode 100644 index bba4e09..0000000 --- a/pkg/rest/login/jwt.go +++ /dev/null @@ -1,27 +0,0 @@ -package login - -import ( - "crypto/rsa" - "time" - - "github.com/golang-jwt/jwt/v5" -) - -func GetJWTToken(login, role string, signKey *rsa.PrivateKey, kid string) (string, error) { - return GetJWTTokenWithExpiration(login, role, signKey, kid, time.Now().Add(time.Hour*72)) -} - -func GetJWTTokenWithExpiration(login, role string, signKey *rsa.PrivateKey, kid string, expiration time.Time) (string, error) { - token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), jwt.MapClaims{ - "iss": "wireguard-server", - "sub": login, - "role": role, - "exp": expiration.Unix(), - "iat": time.Now().Unix(), - }) - token.Header["kid"] = kid - - tokenString, err := token.SignedString(signKey) - - return tokenString, err -} diff --git a/pkg/rest/login/types.go b/pkg/rest/login/types.go deleted file mode 100644 index bf8b76d..0000000 --- a/pkg/rest/login/types.go +++ /dev/null @@ -1,27 +0,0 @@ -package login - -import "github.com/in4it/wireguard-server/pkg/users" - -type AuthIface interface { - AuthUser(login string, password string) (users.User, bool) -} - -type LoginRequest struct { - Login string `json:"login"` - Password string `json:"password"` - FactorResponse FactorResponse `json:"factorResponse"` -} - -type FactorResponse struct { - Name string `json:"name"` - Code string `json:"code"` -} - -type LoginResponse struct { - Authenticated bool `json:"authenticated"` - Suspended bool `json:"suspended"` - NoLicense bool `json:"noLicense"` - Token string `json:"token,omitempty"` - MFARequired bool `json:"mfaRequired"` - Factors []string `json:"factors"` -} diff --git a/pkg/rest/middleware.go b/pkg/rest/middleware.go deleted file mode 100644 index 72b5b94..0000000 --- a/pkg/rest/middleware.go +++ /dev/null @@ -1,182 +0,0 @@ -package rest - -import ( - "context" - "fmt" - "log" - "net/http" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/in4it/wireguard-server/pkg/auth/oidc" - "github.com/in4it/wireguard-server/pkg/users" -) - -type CustomValue string - -// auth middleware - -func (c *Context) authMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !c.SetupCompleted { - c.returnError(w, fmt.Errorf("setup not completed"), http.StatusUnauthorized) - return - } - if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") { - c.writeWithStatus(w, []byte(`{"error": "token not found"}`), http.StatusUnauthorized) - return - } - tokenString := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", -1) - if len(tokenString) == 0 { - c.returnError(w, fmt.Errorf("empty token"), http.StatusUnauthorized) - return - } - - // determine token to parse - var tokenToParse string - // is token an access token or a jwt from local auth? - kid, _ := getKidFromToken(tokenString) - if kid == c.JWTKeysKID { // local auth token - tokenToParse = tokenString - } else { - for _, oauth2Data := range c.OIDCStore.OAuth2Data { - if oauth2Data.Token.AccessToken == tokenString { - tokenToParse = oauth2Data.Token.IDToken - } - } - if tokenToParse == "" { - c.returnError(w, fmt.Errorf("token error: access token not found (wrong token or token expired)"), http.StatusUnauthorized) - return - } - } - token, err := jwt.Parse(tokenToParse, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Header["kid"]; !ok { - return nil, fmt.Errorf("no kid header found in token") - } - if token.Header["kid"] == c.JWTKeysKID { - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, fmt.Errorf("local kid: unexpected signing method: %v", token.Header["alg"]) - } - return c.JWTKeys.PublicKey, nil - } - discoveryProviders := make([]oidc.Discovery, len(c.OIDCProviders)) - for k, oidcProvider := range c.OIDCProviders { - discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI) - if err != nil { - return nil, fmt.Errorf("couldn't retrieve discoveryURI from OIDC Provider (check discovery URI in OIDC settings). Error: %s", err) - } - discoveryProviders[k] = discovery - } - allJwks, err := c.OIDCStore.GetAllJwks(discoveryProviders) - if err != nil { - return nil, fmt.Errorf("couldn't retrieve JWKS URL from OIDC Provider (check discovery URI in OIDC settings). Error: %s", err) - } - publicKey, err := oidc.GetPublicKeyForToken(allJwks, discoveryProviders, token) - if err != nil { - return nil, fmt.Errorf("GetPublicKeyForToken error: %s", err) - } - return publicKey, nil - }) - if err != nil { - c.returnError(w, fmt.Errorf("token error: %s", err), http.StatusUnauthorized) - return - } - token.Claims.(jwt.MapClaims)["kid"] = token.Header["kid"] - ctx := context.WithValue(r.Context(), CustomValue("claims"), token.Claims.(jwt.MapClaims)) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -// logging middleware - -// responseWriter is a minimal wrapper for http.ResponseWriter that allows the -// written HTTP status code to be captured for logging. -// MIT licensed -type responseWriter struct { - http.ResponseWriter - status int - wroteHeader bool -} - -func (rw *responseWriter) Status() int { - return rw.status -} - -func (rw *responseWriter) WriteHeader(code int) { - if rw.wroteHeader { - return - } - - rw.status = code - rw.ResponseWriter.WriteHeader(code) - rw.wroteHeader = true -} - -func (c *Context) loggingMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - wrappedResponse := &responseWriter{ResponseWriter: w} - next.ServeHTTP(wrappedResponse, r) - log.Printf("req=%s res=%d method=%s src=%s duration=%s", r.RequestURI, wrappedResponse.status, r.Method, r.RemoteAddr, time.Since(start)) - }) -} - -func (c *Context) corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions { - sendCorsHeaders(w, r.Header.Get("Access-Control-Request-Headers"), c.Hostname, c.Protocol) - w.WriteHeader(http.StatusNoContent) - } else { - next.ServeHTTP(w, r) - } - }) -} -func (c *Context) injectUserMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, err := c.GetUserFromRequest(r) - if err != nil { - c.returnError(w, fmt.Errorf("token error: %s", err), http.StatusUnauthorized) - return - } - - ctx := context.WithValue(r.Context(), CustomValue("user"), user) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -func (c *Context) isAdminMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value(CustomValue("user")).(users.User) - if user.Role != "admin" { - c.returnError(w, fmt.Errorf("endpoint forbidden"), http.StatusForbidden) - return - } - next.ServeHTTP(w, r) - }) -} - -func (c *Context) httpsRedirectMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if c.RedirectToHttps && r.TLS == nil { - if strings.HasPrefix(r.URL.Path, "/api") { - c.returnError(w, fmt.Errorf("non-tls requests disabled"), http.StatusForbidden) - return - } - http.Redirect(w, r, fmt.Sprintf("https://%s%s", r.Host, r.RequestURI), http.StatusMovedPermanently) - return - } - next.ServeHTTP(w, r) - }) -} - -func (c *Context) isSCIMEnabled(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !c.SCIM.EnableSCIM { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(`{ "error": "SCIM Not Enabled" }`)) - return - } - next.ServeHTTP(w, r) - }) -} diff --git a/pkg/rest/profile.go b/pkg/rest/profile.go deleted file mode 100644 index e035bce..0000000 --- a/pkg/rest/profile.go +++ /dev/null @@ -1,139 +0,0 @@ -package rest - -import ( - "encoding/json" - "fmt" - "net/http" - - "github.com/in4it/wireguard-server/pkg/mfa/totp" - "github.com/in4it/wireguard-server/pkg/users" -) - -func (c *Context) profilePasswordHandler(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value(CustomValue("user")).(users.User) - switch r.Method { - case http.MethodPost: - var userInput users.User - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&userInput) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - if userInput.Password == "" { - c.returnError(w, fmt.Errorf("no password supplied"), http.StatusBadRequest) - return - } - err = c.UserStore.UpdatePassword(user.ID, userInput.Password) - if err != nil { - c.returnError(w, fmt.Errorf("update password error: %s", err), http.StatusBadRequest) - return - } - - c.write(w, []byte(`{"result": "OK"}`)) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - - } -} - -func (c *Context) profileFactorsHandler(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value(CustomValue("user")).(users.User) - switch r.Method { - case http.MethodGet: - factors := make([]users.Factor, len(user.Factors)) - copy(factors, user.Factors) - for k := range factors { - factors[k].Secret = "" // remove secret when outputting - } - out, err := json.Marshal(factors) - if err != nil { - c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - var factor FactorRequest - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&factor) - if err != nil { - c.returnError(w, fmt.Errorf("decode factor error: %s", err), http.StatusBadRequest) - return - } - if factor.Secret == "" { - c.returnError(w, fmt.Errorf("no factor secret supplied"), http.StatusBadRequest) - return - } - if factor.Name == "" { - c.returnError(w, fmt.Errorf("no factor name supplied"), http.StatusBadRequest) - return - } - if len(factor.Name) > 16 { - c.returnError(w, fmt.Errorf("factor name too long"), http.StatusBadRequest) - return - } - if factor.Type == "" { - c.returnError(w, fmt.Errorf("no factor type supplied"), http.StatusBadRequest) - return - } - if factor.Code == "" { - c.returnError(w, fmt.Errorf("no factor code supplied"), http.StatusBadRequest) - return - } - - ok, err := totp.VerifyMultipleIntervals(factor.Secret, factor.Code, 20) - if err != nil { - c.returnError(w, fmt.Errorf("totp verify error: %s", err), http.StatusBadRequest) - return - } - - if !ok { - c.returnError(w, fmt.Errorf("code doesn't match. Try entering code again or try with a new QR code"), http.StatusBadRequest) - return - } - - user.Factors = append(user.Factors, users.Factor{Type: factor.Type, Secret: factor.Secret, Name: factor.Name}) - out, err := json.Marshal(user.Factors) - if err != nil { - c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest) - return - } - err = c.UserStore.UpdateUser(user) - if err != nil { - c.returnError(w, fmt.Errorf("coudn't update user: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodDelete: - factorName := r.PathValue("name") - if factorName == "" { - c.returnError(w, fmt.Errorf("no factor name supplied"), http.StatusBadRequest) - return - } - toDelete := -1 - for k := range user.Factors { - if user.Factors[k].Name == factorName { - toDelete = k - } - } - if toDelete == -1 { - c.returnError(w, fmt.Errorf("factor not found"), http.StatusBadRequest) - return - } - user.Factors = append(user.Factors[:toDelete], user.Factors[toDelete+1:]...) - err := c.UserStore.UpdateUser(user) - if err != nil { - c.returnError(w, fmt.Errorf("coudn't update user: %s", err), http.StatusBadRequest) - return - } - out, err := json.Marshal(user.Factors) - if err != nil { - c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - - } -} diff --git a/pkg/rest/router.go b/pkg/rest/router.go deleted file mode 100644 index 98a1e8d..0000000 --- a/pkg/rest/router.go +++ /dev/null @@ -1,69 +0,0 @@ -package rest - -import ( - "io/fs" - "net/http" -) - -func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux { - mux := http.NewServeMux() - - // static files - mux.Handle("/assets/{filename}", http.FileServer(http.FS(assets))) - mux.Handle("/index.html", returnIndexOrNotFound(indexHtml)) - mux.Handle("/favicon.ico", http.FileServer(http.FS(assets))) - - // saml authentication - mux.Handle("/saml/", c.SAML.Client.GetRouter()) - - // endpoints with no authentication - mux.Handle("/api/context", http.HandlerFunc(c.contextHandler)) - mux.Handle("/api/auth", http.HandlerFunc(c.authHandler)) - mux.Handle("/api/authmethods", http.HandlerFunc(c.authMethods)) - mux.Handle("/api/authmethods/{method}/{id}", http.HandlerFunc(c.authMethodsByID)) - mux.Handle("/api/authmethods/{id}", http.HandlerFunc(c.authMethodsByID)) - mux.Handle("/api/version", http.HandlerFunc(c.version)) - mux.Handle("/api/upgrade", http.HandlerFunc(c.upgrade)) - mux.Handle("/", returnIndexOrNotFound(indexHtml)) - - // endpoints with no authentication (observability) - if c.ServerType == SERVER_TYPE_OBSERVABILITY { - mux.Handle("/api/observability/", c.Observability.Client.GetRouter()) - } - - // scim - mux.Handle("/api/scim/", c.isSCIMEnabled(c.SCIM.Client.GetRouter())) - - // endpoints with authentication - mux.Handle("/api/userinfo", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.userinfoHandler)))) - mux.Handle("/api/profile/password", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profilePasswordHandler)))) - mux.Handle("/api/profile/factors", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profileFactorsHandler)))) - mux.Handle("/api/profile/factors/{name}", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profileFactorsHandler)))) - - // endpoint with authentication (VPN) - if c.ServerType == SERVER_TYPE_VPN { - mux.Handle("/api/connections", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.connectionsHandler)))) - mux.Handle("/api/connection/{id}", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.connectionsElementHandler)))) - mux.Handle("/api/connectionlicense", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.connectionLicenseHandler)))) - } - - // endpoints with authentication, with admin role - mux.Handle("/api/license", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.licenseHandler))))) - mux.Handle("/api/license/{action}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.licenseHandler))))) - mux.Handle("/api/oidc", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderHandler))))) - mux.Handle("/api/oidc-renew-tokens", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcRenewTokensHandler))))) - mux.Handle("/api/oidc/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderElementHandler))))) - mux.Handle("/api/setup/general", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.setupHandler))))) - mux.Handle("/api/setup/vpn", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.vpnSetupHandler))))) - mux.Handle("/api/setup/templates", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.templateSetupHandler))))) - mux.Handle("/api/setup/restart-vpn", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.restartVPNHandler))))) - mux.Handle("/api/scim-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.scimSetupHandler))))) - mux.Handle("/api/saml-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupHandler))))) - mux.Handle("/api/saml-setup/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupElementHandler))))) - mux.Handle("/api/users", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.usersHandler))))) - mux.Handle("/api/user/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.userHandler))))) - mux.Handle("/api/stats/user/{date}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.userStatsHandler))))) - mux.Handle("/api/stats/packetlogs/{user}/{date}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.packetLogsHandler))))) - - return mux -} diff --git a/pkg/rest/rsa.go b/pkg/rest/rsa.go deleted file mode 100644 index 3f8fd4b..0000000 --- a/pkg/rest/rsa.go +++ /dev/null @@ -1,126 +0,0 @@ -package rest - -/* - * Genarate rsa keys. (https://github.com/wardviaene/http-echo/blob/master/rsa.go) - */ - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "fmt" - "path" - - "github.com/golang-jwt/jwt/v5" - "github.com/in4it/wireguard-server/pkg/storage" -) - -type JWTKeys struct { - PrivateKey *rsa.PrivateKey `json:"privateKey,omitempty"` - PublicKey *rsa.PublicKey `json:"publicKey,omitempty"` -} - -func getJWTKeys(storage storage.Iface) (*JWTKeys, error) { - - filename := storage.ConfigPath("pki/private.pem") - filenamePublicKey := storage.ConfigPath("pki/public.pem") - - if !storage.FileExists(filename) { - err := storage.EnsurePath(path.Dir(filename)) - if err != nil { - return nil, fmt.Errorf("ensure path error: %s", err) - } - err = createJWTKeys(storage, storage.ConfigPath("pki")) - if err != nil { - return nil, fmt.Errorf("createJWTKeys error: %s", err) - } - } - - signBytes, err := storage.ReadFile(filename) - if err != nil { - return nil, fmt.Errorf("private key read error: %s", err) - } - publicBytes, err := storage.ReadFile(filenamePublicKey) - if err != nil { - return nil, fmt.Errorf("private key read error: %s", err) - } - - signKey, err := jwt.ParseRSAPrivateKeyFromPEM(signBytes) - if err != nil { - return nil, fmt.Errorf("can't parse private key: %s", err) - } - publicKey, err := jwt.ParseRSAPublicKeyFromPEM(publicBytes) - if err != nil { - return nil, fmt.Errorf("can't parse public key: %s", err) - } - return &JWTKeys{PrivateKey: signKey, PublicKey: publicKey}, nil -} - -func createJWTKeys(storage storage.Iface, path string) error { - reader := rand.Reader - bitSize := 4096 - - key, err := rsa.GenerateKey(reader, bitSize) - if err != nil { - return err - } - - publicKey := key.PublicKey - - err = savePEMKey(storage, path+"/private.pem", key) - if err != nil { - return err - } - err = savePublicPEMKey(storage, path+"/public.pem", publicKey) - if err != nil { - return err - } - - return nil -} - -func savePEMKey(storage storage.Iface, fileName string, key *rsa.PrivateKey) error { - var buf bytes.Buffer - - var privateKey = &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - } - - err := pem.Encode(&buf, privateKey) - if err != nil { - return err - } - - err = storage.WriteFile(fileName, buf.Bytes()) - if err != nil { - return fmt.Errorf("WriteFile error: %s", err) - } - return nil -} - -func savePublicPEMKey(storage storage.Iface, fileName string, pubkey rsa.PublicKey) error { - var buf bytes.Buffer - - asn1Bytes, err := x509.MarshalPKIXPublicKey(&pubkey) - if err != nil { - return err - } - - var pemkey = &pem.Block{ - Type: "PUBLIC KEY", - Bytes: asn1Bytes, - } - - err = pem.Encode(&buf, pemkey) - if err != nil { - return err - } - err = storage.WriteFile(fileName, buf.Bytes()) - if err != nil { - return fmt.Errorf("WriteFile error: %s", err) - } - return nil -} diff --git a/pkg/rest/rsa_test.go b/pkg/rest/rsa_test.go deleted file mode 100644 index ce23943..0000000 --- a/pkg/rest/rsa_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package rest - -import ( - "bytes" - "crypto/x509" - "encoding/pem" - "testing" - - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" -) - -func TestGetJWTKeys(t *testing.T) { - mockStorage := memorystorage.MockMemoryStorage{} - keys, err := getJWTKeys(&mockStorage) - if err != nil { - t.Fatalf("error: %s", err) - } - privateKeyFromFile, err := mockStorage.ReadFile(mockStorage.ConfigPath("pki/private.pem")) - if err != nil { - t.Fatalf("read error: %s", err) - } - _, err = mockStorage.ReadFile(mockStorage.ConfigPath("pki/public.pem")) - if err != nil { - t.Fatalf("read error: %s", err) - } - - var buf bytes.Buffer - var privateKey = &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(keys.PrivateKey), - } - err = pem.Encode(&buf, privateKey) - if err != nil { - t.Fatalf("pem encode error: %s", err) - } - - if !bytes.Equal(privateKeyFromFile, buf.Bytes()) { - t.Fatalf("private keys don't match") - } -} diff --git a/pkg/rest/server.go b/pkg/rest/server.go deleted file mode 100644 index a223511..0000000 --- a/pkg/rest/server.go +++ /dev/null @@ -1,78 +0,0 @@ -package rest - -import ( - "crypto/tls" - "embed" - "fmt" - "io" - "io/fs" - "log" - "net/http" - - "github.com/in4it/wireguard-server/pkg/logging" - localstorage "github.com/in4it/wireguard-server/pkg/storage/local" - "golang.org/x/crypto/acme/autocert" -) - -var ( - //go:embed static - assets embed.FS - enableTLSWaiter chan (bool) = make(chan bool) - TLSWaiterCompleted bool -) - -func StartServer(httpPort, httpsPort int, serverType string) { - localStorage, err := localstorage.New() - if err != nil { - log.Fatalf("couldn't initialize storage: %s", err) - } - c, err := newContext(localStorage, serverType) - if err != nil { - log.Fatalf("startup failed: %s", err) - } - - go handleSignals(c) - - assetsFS, err := fs.Sub(assets, "static") - if err != nil { - log.Fatalf("could not load static web assets") - } - - indexHtml, err := assetsFS.Open("index.html") - if err != nil { - log.Fatalf("could not load static web assets (index.html)") - } - indexHtmlBody, err := io.ReadAll(indexHtml) - if err != nil { - log.Fatalf("could not read static web assets (index.html)") - } - - certManager := autocert.Manager{} - - // HTTP Configuration - go func() { // start http server - log.Printf("Start http server on port %d", httpPort) - log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", httpPort), certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))))))) - }() - - // TLS Configuration - if !c.EnableTLS || !canEnableTLS(c.Hostname) { - <-enableTLSWaiter - } - // only enable when TLS is enabled - - logging.DebugLog(fmt.Errorf("enabling TLS endpoint with let's encrypt for hostname '%s'", c.Hostname)) - certManager.Prompt = autocert.AcceptTOS - certManager.HostPolicy = autocert.HostWhitelist(c.Hostname) - certManager.Cache = autocert.DirCache("tls-certs") - tlsServer := &http.Server{ - Addr: fmt.Sprintf(":%d", httpsPort), - TLSConfig: &tls.Config{ - GetCertificate: certManager.GetCertificate, - }, - Handler: c.loggingMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))), - } - c.Protocol = "https" - TLSWaiterCompleted = true - log.Fatal(tlsServer.ListenAndServeTLS("", "")) -} diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go deleted file mode 100644 index 719ae8f..0000000 --- a/pkg/rest/setup.go +++ /dev/null @@ -1,669 +0,0 @@ -package rest - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "net/netip" - "reflect" - "slices" - "sort" - "strconv" - "strings" - "time" - - "github.com/google/uuid" - "github.com/in4it/wireguard-server/pkg/auth/oidc" - "github.com/in4it/wireguard-server/pkg/auth/saml" - "github.com/in4it/wireguard-server/pkg/license" - "github.com/in4it/wireguard-server/pkg/users" - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -func (c *Context) contextHandler(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - decoder := json.NewDecoder(r.Body) - var contextReq ContextRequest - err := decoder.Decode(&contextReq) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - if !c.Storage.Client.FileExists(SETUP_CODE_FILE) { - c.SetupCompleted = true - } - if !c.SetupCompleted { - // check if tag hash is chosen - accessGranted := false - switch c.CloudType { - case "digitalocean": // check if the hashtag is set - if contextReq.TagHash != "" { - if !strings.HasPrefix(contextReq.TagHash, "vpnsecret-") { - c.returnError(w, fmt.Errorf("tag doesn't have the correct prefix. The tag needs to start with 'vpnsecret-'"), http.StatusUnauthorized) - return - } - accessGranted, err = license.HasDigitalOceanTagSet(http.Client{Timeout: 5 * time.Second}, contextReq.TagHash) - if err != nil { - c.returnError(w, fmt.Errorf("could not retrieve tags at this time: %s", err), http.StatusUnauthorized) - return - } - if !accessGranted { - c.returnError(w, fmt.Errorf("tag not found. Make sure the correct tag is attached to the droplet"), http.StatusUnauthorized) - return - } - } - case "aws": // check if the instance id is set - if contextReq.InstanceID != "" { - instanceID, err := license.GetAWSInstanceID(http.Client{Timeout: 5 * time.Second}) - if err != nil { - c.returnError(w, fmt.Errorf("could not retrieve instance id at this time: %s", err), http.StatusUnauthorized) - return - } - if strings.TrimPrefix(instanceID, "i-") == strings.TrimPrefix(contextReq.InstanceID, "i-") { - accessGranted = true - } else { - c.returnError(w, fmt.Errorf("instance id doesn't match"), http.StatusUnauthorized) - return - } - } - } - // check secret - if !accessGranted { - localSecret, err := c.Storage.Client.ReadFile(SETUP_CODE_FILE) - if err != nil { - c.returnError(w, fmt.Errorf("secret file read error: %s", err), http.StatusBadRequest) - return - } - if strings.TrimSpace(string(localSecret)) != contextReq.Secret { - c.returnError(w, fmt.Errorf("wrong secret provided"), http.StatusUnauthorized) - return - } - } - if contextReq.AdminPassword != "" { - adminUser := users.User{ - Login: "admin", - Password: contextReq.AdminPassword, - Role: "admin", - } - if c.UserStore.LoginExists("admin") { - err = c.UserStore.UpdateUser(adminUser) - if err != nil { - c.returnError(w, fmt.Errorf("could not update user: %s", err), http.StatusBadRequest) - return - } - } else { - _, err = c.UserStore.AddUser(adminUser) - if err != nil { - c.returnError(w, fmt.Errorf("could not add user: %s", err), http.StatusBadRequest) - return - } - } - - c.SetupCompleted = true - c.Hostname = contextReq.Hostname - protocol := contextReq.Protocol - protocol = strings.Replace(protocol, "http:", "http", -1) - protocol = strings.Replace(protocol, "https:", "https", -1) - c.Protocol = protocol - - err = SaveConfig(c) - if err != nil { - c.SetupCompleted = false - c.returnError(w, fmt.Errorf("unable to save file: %s", err), http.StatusBadRequest) - return - } - - // update hostname in vpn config - vpnconfig, err := wireguard.GetVPNConfig(c.Storage.Client) - if err != nil { - c.SetupCompleted = false - c.returnError(w, fmt.Errorf("unable to get vpn-config: %s", err), http.StatusBadRequest) - return - } - vpnconfig.Endpoint = c.Hostname - err = wireguard.WriteVPNConfig(c.Storage.Client, vpnconfig) - if err != nil { - c.SetupCompleted = false - c.returnError(w, fmt.Errorf("unable to write vpn-config: %s", err), http.StatusBadRequest) - return - } - } - } - } - - out, err := json.Marshal(ContextSetupResponse{SetupCompleted: c.SetupCompleted, CloudType: c.CloudType, ServerType: c.ServerType}) - if err != nil { - c.returnError(w, err, http.StatusBadRequest) - return - } - c.write(w, out) -} - -func (c *Context) setupHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - setupRequest := GeneralSetupRequest{ - Hostname: c.Hostname, - EnableTLS: c.EnableTLS, - RedirectToHttps: c.RedirectToHttps, - DisableLocalAuth: c.LocalAuthDisabled, - EnableOIDCTokenRenewal: c.EnableOIDCTokenRenewal, - } - out, err := json.Marshal(setupRequest) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - var setupRequest GeneralSetupRequest - decoder := json.NewDecoder(r.Body) - decoder.Decode(&setupRequest) - if c.Hostname != setupRequest.Hostname { - c.Hostname = setupRequest.Hostname - } - if c.RedirectToHttps != setupRequest.RedirectToHttps { - c.RedirectToHttps = setupRequest.RedirectToHttps - } - if c.EnableTLS != setupRequest.EnableTLS { - if !c.EnableTLS && setupRequest.EnableTLS && !TLSWaiterCompleted && canEnableTLS(c.Hostname) { - enableTLSWaiter <- true - } - c.EnableTLS = setupRequest.EnableTLS - } - if c.LocalAuthDisabled != setupRequest.DisableLocalAuth { - c.LocalAuthDisabled = setupRequest.DisableLocalAuth - } - if c.EnableOIDCTokenRenewal != setupRequest.EnableOIDCTokenRenewal { - c.EnableOIDCTokenRenewal = setupRequest.EnableOIDCTokenRenewal - c.OIDCRenewal.SetEnabled(c.EnableOIDCTokenRenewal) - } - err := SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("could not save config to disk: %s", err), http.StatusBadRequest) - return - } - out, err := json.Marshal(setupRequest) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { - vpnConfig, err := wireguard.GetVPNConfig(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not get vpn config: %s", err), http.StatusBadRequest) - return - } - switch r.Method { - case http.MethodGet: - packetLogTypes := []string{} - for k, enabled := range vpnConfig.PacketLogsTypes { - if enabled { - packetLogTypes = append(packetLogTypes, k) - } - } - if vpnConfig.PacketLogsRetention == 0 { - vpnConfig.PacketLogsRetention = 7 - } - setupRequest := VPNSetupRequest{ - Routes: strings.Join(vpnConfig.ClientRoutes, ", "), - VPNEndpoint: vpnConfig.Endpoint, - AddressRange: vpnConfig.AddressRange.String(), - ClientAddressPrefix: vpnConfig.ClientAddressPrefix, - Port: strconv.Itoa(vpnConfig.Port), - ExternalInterface: vpnConfig.ExternalInterface, - Nameservers: strings.Join(vpnConfig.Nameservers, ","), - DisableNAT: vpnConfig.DisableNAT, - EnablePacketLogs: vpnConfig.EnablePacketLogs, - PacketLogsTypes: packetLogTypes, - PacketLogsRetention: strconv.Itoa(vpnConfig.PacketLogsRetention), - } - out, err := json.Marshal(setupRequest) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - var ( - writeVPNConfig bool - rewriteClientConfigs bool - setupRequest VPNSetupRequest - ) - decoder := json.NewDecoder(r.Body) - decoder.Decode(&setupRequest) - if strings.Join(vpnConfig.ClientRoutes, ", ") != setupRequest.Routes { - networks := strings.Split(setupRequest.Routes, ",") - validatedNetworks := []string{} - for _, network := range networks { - if strings.TrimSpace(network) == "::/0" { - validatedNetworks = append(validatedNetworks, "::/0") - } else { - _, ipnet, err := net.ParseCIDR(strings.TrimSpace(network)) - if err != nil { - c.returnError(w, fmt.Errorf("client route %s in wrong format: %s", strings.TrimSpace(network), err), http.StatusBadRequest) - return - } - validatedNetworks = append(validatedNetworks, ipnet.String()) - } - } - vpnConfig.ClientRoutes = validatedNetworks - writeVPNConfig = true - rewriteClientConfigs = true - } - if vpnConfig.Endpoint != setupRequest.VPNEndpoint { - vpnConfig.Endpoint = setupRequest.VPNEndpoint - writeVPNConfig = true - rewriteClientConfigs = true - } - addressRangeParsed, err := netip.ParsePrefix(setupRequest.AddressRange) - if err != nil { - c.returnError(w, fmt.Errorf("AddressRange in wrong format: %s", err), http.StatusBadRequest) - return - } - if addressRangeParsed.String() != vpnConfig.AddressRange.String() { - vpnConfig.AddressRange = addressRangeParsed - writeVPNConfig = true - rewriteClientConfigs = true - } - if setupRequest.ClientAddressPrefix != vpnConfig.ClientAddressPrefix { - vpnConfig.ClientAddressPrefix = setupRequest.ClientAddressPrefix - writeVPNConfig = true - rewriteClientConfigs = true - } - port, err := strconv.Atoi(setupRequest.Port) - if err != nil { - c.returnError(w, fmt.Errorf("port in wrong format: %s", err), http.StatusBadRequest) - return - } - if port != vpnConfig.Port { - vpnConfig.Port = port - writeVPNConfig = true - rewriteClientConfigs = true - } - - nameservers := strings.Split(setupRequest.Nameservers, ",") - for k := range nameservers { - nameservers[k] = strings.TrimSpace(nameservers[k]) - } - if !reflect.DeepEqual(nameservers, vpnConfig.Nameservers) { - vpnConfig.Nameservers = nameservers - writeVPNConfig = true - rewriteClientConfigs = true - } - if setupRequest.ExternalInterface != vpnConfig.ExternalInterface { // don't rewrite client config - vpnConfig.ExternalInterface = setupRequest.ExternalInterface - writeVPNConfig = true - } - if setupRequest.DisableNAT != vpnConfig.DisableNAT { // don't rewrite client config - vpnConfig.DisableNAT = setupRequest.DisableNAT - writeVPNConfig = true - } - if setupRequest.EnablePacketLogs != vpnConfig.EnablePacketLogs { - vpnConfig.EnablePacketLogs = setupRequest.EnablePacketLogs - writeVPNConfig = true - } - packetLogsRention, err := strconv.Atoi(setupRequest.PacketLogsRetention) - if err != nil || packetLogsRention < 1 { - c.returnError(w, fmt.Errorf("incorrect packet log retention. Enter a number of days the logs must be kept (minimum 1)"), http.StatusBadRequest) - return - } - if packetLogsRention != vpnConfig.PacketLogsRetention { - vpnConfig.PacketLogsRetention = packetLogsRention - writeVPNConfig = true - } - - // packetlogtypes - packetLogTypes := []string{} - for k, enabled := range vpnConfig.PacketLogsTypes { - if enabled { - packetLogTypes = append(packetLogTypes, k) - } - } - sort.Strings(setupRequest.PacketLogsTypes) - sort.Strings(packetLogTypes) - if !slices.Equal(setupRequest.PacketLogsTypes, packetLogTypes) { - vpnConfig.PacketLogsTypes = make(map[string]bool) - for _, v := range setupRequest.PacketLogsTypes { - if v == "http+https" || v == "dns" || v == "tcp" { - vpnConfig.PacketLogsTypes[v] = true - } - } - writeVPNConfig = true - } - - // write vpn config if config has changed - if writeVPNConfig { - err = wireguard.WriteVPNConfig(c.Storage.Client, vpnConfig) - if err != nil { - c.returnError(w, fmt.Errorf("could write vpn config: %s", err), http.StatusBadRequest) - return - } - err = wireguard.ReloadVPNServerConfig() - if err != nil { - c.returnError(w, fmt.Errorf("unable to reload server config: %s", err), http.StatusBadRequest) - return - } - } - if rewriteClientConfigs { - // rewrite client configs - err = wireguard.UpdateClientsConfig(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest) - return - } - } - out, err := json.Marshal(setupRequest) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) templateSetupHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest) - return - } - serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest) - return - } - setupRequest := TemplateSetupRequest{ - ClientTemplate: string(clientTemplate), - ServerTemplate: string(serverTemplate), - } - out, err := json.Marshal(setupRequest) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - var templateSetupRequest TemplateSetupRequest - decoder := json.NewDecoder(r.Body) - decoder.Decode(&templateSetupRequest) - clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest) - return - } - serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest) - return - } - if string(clientTemplate) != templateSetupRequest.ClientTemplate { - err = wireguard.WriteClientTemplate(c.Storage.Client, []byte(templateSetupRequest.ClientTemplate)) - if err != nil { - c.returnError(w, fmt.Errorf("WriteClientTemplate error: %s", err), http.StatusBadRequest) - return - } - // rewrite client configs - err = wireguard.UpdateClientsConfig(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest) - return - } - } - if string(serverTemplate) != templateSetupRequest.ServerTemplate { - err = wireguard.WriteServerTemplate(c.Storage.Client, []byte(templateSetupRequest.ServerTemplate)) - if err != nil { - c.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest) - return - } - } - out, err := json.Marshal(templateSetupRequest) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) restartVPNHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - c.returnError(w, fmt.Errorf("unsupported method"), http.StatusBadRequest) - return - } - client := http.Client{ - Timeout: 10 * time.Second, - } - req, err := http.NewRequest(r.Method, "http://"+wireguard.CONFIGMANAGER_URI+"/restart-vpn", nil) - if err != nil { - c.returnError(w, fmt.Errorf("restart request error: %s", err), http.StatusBadRequest) - return - } - resp, err := client.Do(req) - if err != nil { - c.returnError(w, fmt.Errorf("restart error: %s", err), http.StatusBadRequest) - return - } - if resp.StatusCode != http.StatusAccepted { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.returnError(w, fmt.Errorf("restart error: got status code: %d. Response: %s", resp.StatusCode, bodyBytes), http.StatusBadRequest) - return - } - c.returnError(w, fmt.Errorf("restart error: got status code: %d. Couldn't get response", resp.StatusCode), http.StatusBadRequest) - return - } - - defer resp.Body.Close() - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.returnError(w, fmt.Errorf("body read error: %s", err), http.StatusBadRequest) - return - } - - c.write(w, bodyBytes) -} - -func (c *Context) scimSetupHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - scimSetup := SCIMSetup{ - Enabled: c.SCIM.EnableSCIM, - } - if c.SCIM.EnableSCIM { - scimSetup.Token = c.SCIM.Token - scimSetup.BaseURL = fmt.Sprintf("%s://%s/%s", c.Protocol, c.Hostname, "api/scim/v2/") - } - out, err := json.Marshal(scimSetup) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal scim setup: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - saveConfig := false - var scimSetupRequest SCIMSetup - decoder := json.NewDecoder(r.Body) - decoder.Decode(&scimSetupRequest) - if scimSetupRequest.Enabled && !c.SCIM.EnableSCIM { - c.SCIM.EnableSCIM = true - saveConfig = true - } - if !scimSetupRequest.Enabled && c.SCIM.EnableSCIM { - c.SCIM.EnableSCIM = false - saveConfig = true - } - if scimSetupRequest.RegenerateToken || (scimSetupRequest.Enabled && c.SCIM.Token == "") { - // Generate new token - randomString, err := oidc.GetRandomString(64) - if err != nil { - c.returnError(w, fmt.Errorf("could not enable scim: %s", err), http.StatusBadRequest) - return - } - token := base64.StdEncoding.EncodeToString([]byte(randomString)) - scimSetupRequest.Token = token - c.SCIM.Token = token - c.SCIM.Client.UpdateToken(token) - saveConfig = true - } - if saveConfig { - // save config - err := SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("could not save config to disk: %s", err), http.StatusBadRequest) - return - } - } - out, err := json.Marshal(scimSetupRequest) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal scim setup: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) samlSetupHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - samlProviders := make([]saml.Provider, len(*c.SAML.Providers)) - copy(samlProviders, *c.SAML.Providers) - for k := range samlProviders { - samlProviders[k].Issuer = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ISSUER_URL, samlProviders[k].ID) - samlProviders[k].Audience = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.AUDIENCE_URL, samlProviders[k].ID) - samlProviders[k].Acs = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ACS_URL, samlProviders[k].ID) - } - out, err := json.Marshal(samlProviders) - if err != nil { - c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - var samlProvider saml.Provider - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&samlProvider) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - samlProvider.ID = uuid.New().String() - if samlProvider.Name == "" { - c.returnError(w, fmt.Errorf("name not set"), http.StatusBadRequest) - return - } - if samlProvider.MetadataURL == "" { - c.returnError(w, fmt.Errorf("metadata URL not set"), http.StatusBadRequest) - return - } - _, err = c.SAML.Client.HasValidMetadataURL(samlProvider.MetadataURL) - if err != nil { - c.returnError(w, fmt.Errorf("metadata error: %s", err), http.StatusBadRequest) - return - } - - *c.SAML.Providers = append(*c.SAML.Providers, samlProvider) - out, err := json.Marshal(samlProvider) - if err != nil { - c.returnError(w, fmt.Errorf("samlProvider marshal error: %s", err), http.StatusBadRequest) - return - } - err = SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) samlSetupElementHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodDelete: - match := -1 - for k, samlProvider := range *c.SAML.Providers { - if samlProvider.ID == r.PathValue("id") { - match = k - } - } - if match == -1 { - c.returnError(w, fmt.Errorf("saml provider not found"), http.StatusBadRequest) - return - } - *c.SAML.Providers = append((*c.SAML.Providers)[:match], (*c.SAML.Providers)[match+1:]...) - // save config (changed providers) - err := SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) - return - } - c.write(w, []byte(`{ "deleted": "`+r.PathValue("id")+`" }`)) - case http.MethodPut: - var samlProvider saml.Provider - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&samlProvider) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - samlProviderID := -1 - for k := range *c.SAML.Providers { - if (*c.SAML.Providers)[k].ID == r.PathValue("id") { - samlProviderID = k - } - } - if samlProviderID == -1 { - c.returnError(w, fmt.Errorf("cannot find saml provider: %s", err), http.StatusBadRequest) - return - } - saveConfig := false - if (*c.SAML.Providers)[samlProviderID].AllowMissingAttributes != samlProvider.AllowMissingAttributes { - (*c.SAML.Providers)[samlProviderID].AllowMissingAttributes = samlProvider.AllowMissingAttributes - saveConfig = true - } - if (*c.SAML.Providers)[samlProviderID].MetadataURL != samlProvider.MetadataURL { - _, err := c.SAML.Client.HasValidMetadataURL(samlProvider.MetadataURL) - if err != nil { - c.returnError(w, fmt.Errorf("metadata error: %s", err), http.StatusBadRequest) - return - } - (*c.SAML.Providers)[samlProviderID].MetadataURL = samlProvider.MetadataURL - saveConfig = true - } - out, err := json.Marshal(samlProvider) - if err != nil { - c.returnError(w, fmt.Errorf("samlProvider marshal error: %s", err), http.StatusBadRequest) - return - } - if saveConfig { - err = SaveConfig(c) - if err != nil { - c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) - return - } - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} diff --git a/pkg/rest/setup_test.go b/pkg/rest/setup_test.go deleted file mode 100644 index 2e14bca..0000000 --- a/pkg/rest/setup_test.go +++ /dev/null @@ -1,293 +0,0 @@ -package rest - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/in4it/wireguard-server/pkg/license" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" - "github.com/in4it/wireguard-server/pkg/users" -) - -func TestContextHandlerSetupSecret(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - - storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) - - userStore, err := users.NewUserStore(storage, -1) - if err != nil { - t.Fatalf("new user store error") - } - c, err := getEmptyContext("appdir") - if err != nil { - t.Fatalf("cannot create empty context") - } - c.Storage = &Storage{Client: storage} - c.UserStore = userStore - - payload := ContextRequest{ - Secret: "secret setup code", - AdminPassword: "adminPassword", - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) - w := httptest.NewRecorder() - c.contextHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("got read error: %s", err) - } - - var contextSetupResponse ContextSetupResponse - err = json.Unmarshal(body, &contextSetupResponse) - if err != nil { - t.Fatalf("unmarshal error: %s", err) - } - if !contextSetupResponse.SetupCompleted { - t.Fatalf("expected setup to be completed") - } -} - -func TestContextHandlerSetupWrongSecret(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - - storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) - - userStore, err := users.NewUserStore(storage, -1) - if err != nil { - t.Fatalf("new user store error") - } - c, err := getEmptyContext("appdir") - if err != nil { - t.Fatalf("cannot create empty context") - } - c.Storage = &Storage{Client: storage} - c.UserStore = userStore - - payload := ContextRequest{ - AdminPassword: "adminPassword", - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) - w := httptest.NewRecorder() - c.contextHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 401 { - t.Fatalf("status code is not 401: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("got read error: %s", err) - } - - var contextSetupResponse ContextSetupResponse - err = json.Unmarshal(body, &contextSetupResponse) - if err != nil { - t.Fatalf("unmarshal error: %s", err) - } - if contextSetupResponse.SetupCompleted { - t.Fatalf("expected setup to not be completed") - } -} -func TestContextHandlerSetupWrongSecretPartial(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - - storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) - - userStore, err := users.NewUserStore(storage, -1) - if err != nil { - t.Fatalf("new user store error") - } - c, err := getEmptyContext("appdir") - if err != nil { - t.Fatalf("cannot create empty context") - } - c.Storage = &Storage{Client: storage} - c.UserStore = userStore - - payload := ContextRequest{ - Secret: "secret setup cod", - AdminPassword: "adminPassword", - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) - w := httptest.NewRecorder() - c.contextHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 401 { - t.Fatalf("status code is not 401: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("got read error: %s", err) - } - - var contextSetupResponse ContextSetupResponse - err = json.Unmarshal(body, &contextSetupResponse) - if err != nil { - t.Fatalf("unmarshal error: %s", err) - } - if contextSetupResponse.SetupCompleted { - t.Fatalf("expected setup to not be completed") - } -} - -func TestContextHandlerSetupAWSInstanceID(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/latest/api/token" { - w.Write([]byte("this is a test token")) - return - } - if r.RequestURI == "/latest/meta-data/instance-id" { - w.Write([]byte("i-012aaaaaaaaaaaaa1")) - return - } - w.WriteHeader(http.StatusBadRequest) - })) - defer ts.Close() - license.MetadataIP = strings.TrimPrefix(ts.URL, "http://") - - storage := &memorystorage.MockMemoryStorage{} - - storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) - - userStore, err := users.NewUserStore(storage, -1) - if err != nil { - t.Fatalf("new user store error") - } - c, err := getEmptyContext("appdir") - if err != nil { - t.Fatalf("cannot create empty context") - } - c.Storage = &Storage{Client: storage} - c.UserStore = userStore - c.CloudType = "aws" - - payload := ContextRequest{ - InstanceID: "i-012aaaaaaaaaaaaa1", - AdminPassword: "adminPassword", - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) - w := httptest.NewRecorder() - c.contextHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("got read error: %s", err) - } - - var contextSetupResponse ContextSetupResponse - err = json.Unmarshal(body, &contextSetupResponse) - if err != nil { - t.Fatalf("unmarshal error: %s", err) - } - if !contextSetupResponse.SetupCompleted { - t.Fatalf("expected setup to be completed") - } -} -func TestContextHandlerSetupDigitalOceanTag(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/v1/tags" { - w.Write([]byte("vpnsecret-this-is-a-secret-tag")) - return - } - w.WriteHeader(http.StatusBadRequest) - })) - defer ts.Close() - license.MetadataIP = strings.TrimPrefix(ts.URL, "http://") - - storage := &memorystorage.MockMemoryStorage{} - - storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) - - userStore, err := users.NewUserStore(storage, -1) - if err != nil { - t.Fatalf("new user store error") - } - c, err := getEmptyContext("appdir") - if err != nil { - t.Fatalf("cannot create empty context") - } - c.Storage = &Storage{Client: storage} - c.UserStore = userStore - c.CloudType = "digitalocean" - - payload := ContextRequest{ - TagHash: "vpnsecret-this-is-a-secret-tag", - AdminPassword: "adminPassword", - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("marshal error: %s", err) - } - req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) - w := httptest.NewRecorder() - c.contextHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("got read error: %s", err) - } - - var contextSetupResponse ContextSetupResponse - err = json.Unmarshal(body, &contextSetupResponse) - if err != nil { - t.Fatalf("unmarshal error: %s", err) - } - if !contextSetupResponse.SetupCompleted { - t.Fatalf("expected setup to be completed") - } -} diff --git a/pkg/rest/signals.go b/pkg/rest/signals.go deleted file mode 100644 index 8e3dee5..0000000 --- a/pkg/rest/signals.go +++ /dev/null @@ -1,38 +0,0 @@ -package rest - -import ( - "fmt" - "log" - "os" - "os/signal" - "path" - "syscall" -) - -func handleSignals(c *Context) { - // write pid file so other process can find it - err := os.WriteFile(path.Join(c.AppDir, "rest-server.pid"), []byte(fmt.Sprintf("%d", os.Getpid())), 0664) - if err != nil { - log.Printf("Could not write pid file\n") - } - signalChannel := make(chan os.Signal, 1) - signal.Notify(signalChannel, syscall.SIGHUP) - for sig := range signalChannel { - switch sig { - case syscall.SIGHUP: - c.ReloadConfig() - } - } -} - -func (c *Context) ReloadConfig() { - newC, err := newContext(c.Storage.Client, SERVER_TYPE_VPN) - if err != nil { - log.Printf("ReloadConfig failed: %s\n", err) - } - c.AppDir = newC.AppDir - c.Hostname = newC.Hostname - c.SetupCompleted = newC.SetupCompleted - c.UserStore = newC.UserStore - log.Printf("Config Reloaded!\n") -} diff --git a/pkg/rest/static/.gitignore b/pkg/rest/static/.gitignore deleted file mode 100644 index 5e7d273..0000000 --- a/pkg/rest/static/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/pkg/rest/tls.go b/pkg/rest/tls.go deleted file mode 100644 index 5eb94b0..0000000 --- a/pkg/rest/tls.go +++ /dev/null @@ -1,17 +0,0 @@ -package rest - -import ( - "log" - "strings" -) - -func canEnableTLS(hostname string) bool { - hostnameSplit := strings.Split(hostname, ":") - if hostnameSplit[0] != "localhost" { - return true - } else { - log.Printf("Not enabling TLS with lets encrypt. Hostname is localhost") - } - - return false -} diff --git a/pkg/rest/types.go b/pkg/rest/types.go deleted file mode 100644 index d3c54c1..0000000 --- a/pkg/rest/types.go +++ /dev/null @@ -1,237 +0,0 @@ -package rest - -import ( - "time" - - "github.com/in4it/wireguard-server/pkg/auth/oidc" - oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store" - oidcrenewal "github.com/in4it/wireguard-server/pkg/auth/oidc/store/renewal" - "github.com/in4it/wireguard-server/pkg/auth/provisioning/scim" - "github.com/in4it/wireguard-server/pkg/auth/saml" - "github.com/in4it/wireguard-server/pkg/observability" - "github.com/in4it/wireguard-server/pkg/rest/login" - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" -) - -const SETUP_CODE_FILE = "setup-code.txt" -const ADMIN_USER = "admin" - -type Context struct { - AppDir string `json:"appDir,omitempty"` - ServerType string `json:"serverType,omitempty"` - SetupCompleted bool `json:"setupCompleted"` - Hostname string `json:"hostname,omitempty"` - Protocol string `json:"protocol,omitempty"` - JWTKeys *JWTKeys `json:"jwtKeys,omitempty"` - JWTKeysKID string `json:"jwtKeysKid,omitempty"` - OIDCProviders []oidc.OIDCProvider `json:"oidcProviders,omitempty"` - LocalAuthDisabled bool `json:"disableLocalAuth,omitempty"` - EnableTLS bool `json:"enableTLS,omitempty"` - RedirectToHttps bool `json:"redirectToHttps,omitempty"` - EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal,omitempty"` - OIDCStore *oidcstore.Store `json:"oidcStore,omitempty"` - UserStore *users.UserStore `json:"users,omitempty"` - OIDCRenewal *oidcrenewal.Renewal `json:"oidcRenewal,omitempty"` - LoginAttempts login.Attempts `json:"loginAttempts,omitempty"` - LicenseUserCount int `json:"licenseUserCount,omitempty"` - CloudType string `json:"cloudType,omitempty"` - TokenRenewalTimeMinutes int `json:"tokenRenewalTimeMinutes,omitempty"` - LogLevel int `json:"loglevel,omitempty"` - SCIM *SCIM `json:"scim,omitempty"` - SAML *SAML `json:"saml,omitempty"` - Observability *Observability `json:"observability,omitempty"` - Storage *Storage `json:"storage,omitempty"` -} -type SCIM struct { - EnableSCIM bool `json:"enableSCIM,omitempty"` - Token string `json:"token"` - Client scim.Iface `json:"client,omitempty"` -} -type SAML struct { - Providers *[]saml.Provider `json:"providers"` - Client saml.Iface `json:"client,omitempty"` -} -type Observability struct { - Client observability.Iface `json:"client,omitempty"` -} -type Storage struct { - Client storage.Iface `json:"client,omitempty"` -} - -type ContextRequest struct { - Secret string `json:"secret"` - TagHash string `json:"tagHash"` - InstanceID string `json:"instanceID"` - AdminPassword string `json:"adminPassword"` - Hostname string `json:"hostname"` - Protocol string `json:"protocol"` -} -type ContextSetupResponse struct { - SetupCompleted bool `json:"setupCompleted"` - CloudType string `json:"cloudType"` - ServerType string `json:"serverType"` -} - -type AuthMethodsResponse struct { - LocalAuthDisabled bool `json:"localAuthDisabled"` - OIDCProviders []AuthMethodsProvider `json:"oidcProviders"` -} - -type AuthMethodsProvider struct { - ID string `json:"id"` - Name string `json:"name"` - RedirectURI string `json:"redirectURI,omitempty"` -} - -type OIDCCallback struct { - Code string `json:"code"` - State string `json:"state"` - RedirectURI string `json:"redirectURI"` -} -type SAMLCallback struct { - Code string `json:"code"` - RedirectURI string `json:"redirectURI"` -} - -type UserInfoResponse struct { - Login string `json:"login"` - Role string `json:"role"` - UserType string `json:"userType"` -} - -type GeneralSetupRequest struct { - Hostname string `json:"hostname"` - EnableTLS bool `json:"enableTLS"` - RedirectToHttps bool `json:"redirectToHttps"` - DisableLocalAuth bool `json:"disableLocalAuth"` - EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal"` -} - -type VPNSetupRequest struct { - Routes string `json:"routes"` - VPNEndpoint string `json:"vpnEndpoint"` - AddressRange string `json:"addressRange"` - ClientAddressPrefix string `json:"clientAddressPrefix"` - Port string `json:"port"` - ExternalInterface string `json:"externalInterface"` - Nameservers string `json:"nameservers"` - DisableNAT bool `json:"disableNAT"` - EnablePacketLogs bool `json:"enablePacketLogs"` - PacketLogsTypes []string `json:"packetLogsTypes"` - PacketLogsRetention string `json:"packetLogsRetention"` -} - -type TemplateSetupRequest struct { - ClientTemplate string `json:"clientTemplate"` - ServerTemplate string `json:"serverTemplate"` -} - -type NewConnectionResponse struct { - Name string `json:"name"` -} -type Connection struct { - ID string `json:"id"` - Name string `json:"name"` -} - -type LicenseResponse struct { - LicenseUserCount int `json:"licenseUserCount"` - CurrentUserCount int `json:"currentUserCount,omitempty"` - CloudType string `json:"cloudType"` - Key string `json:"key,omitempty"` -} - -type ConnectionLicenseResponse struct { - LicenseUserCount int `json:"licenseUserCount"` - ConnectionCount int `json:"connectionCount"` -} - -type JwtHeader struct { - Alg string `json:"alg"` - Typ string `json:"typ"` - Kid string `json:"kid"` -} - -type UsersResponse struct { - ID string `json:"id"` - Login string `json:"login"` - Role string `json:"role"` - OIDCID string `json:"oidcID"` - SAMLID string `json:"samlID"` - Provisioned bool `json:"provisioned"` - Suspended bool `json:"suspended"` - ConnectionsDisabledOnAuthFailure bool `json:"connectionsDisabledOnAuthFailure"` - LastTokenRenewal time.Time `json:"lastTokenRenewal,omitempty"` - LastLogin string `json:"lastLogin"` -} - -type FactorRequest struct { - Name string `json:"name"` - Type string `json:"type"` - Secret string `json:"secret"` - Code string `json:"code"` -} - -type SCIMSetup struct { - Enabled bool `json:"enabled"` - Token string `json:"token,omitempty"` - RegenerateToken bool `json:"regenerateToken,omitempty"` - BaseURL string `json:"baseURL,omitempty"` -} - -type SAMLSetup struct { - Enabled bool `json:"enabled"` - MetadataURL string `json:"metadataURL,omitempty"` - RegenerateCert bool `json:"regenerateCert,omitempty"` -} - -type UserStatsResponse struct { - ReceiveBytes UserStatsData `json:"receivedBytes"` - TransmitBytes UserStatsData `json:"transmitBytes"` - Handshakes UserStatsData `json:"handshakes"` -} -type UserStatsData struct { - Datasets UserStatsDatasets `json:"datasets"` -} -type UserStatsDatasets []UserStatsDataset -type UserStatsDataset struct { - Label string `json:"label"` - Data []UserStatsDataPoint `json:"data"` - Fill bool `json:"fill"` - BorderColor string `json:"borderColor"` - BackgroundColor string `json:"backgroundColor"` - Tension float64 `json:"tension"` - ShowLine bool `json:"showLine"` -} - -type UserStatsDataPoint struct { - X string `json:"x"` - Y float64 `json:"y"` -} - -type NewUserRequest struct { - Login string `json:"login"` - Role string `json:"role"` - Password string `json:"password,omitempty"` -} - -type LogDataResponse struct { - LogData LogData `json:"logData"` - Enabled bool `json:"enabled"` - LogTypes []string `json:"logTypes"` - Users map[string]string `json:"users"` -} - -type LogData struct { - Schema LogSchema `json:"schema"` - Data []LogRow `json:"rows"` - NextPos int64 `json:"nextPos"` -} -type LogSchema struct { - Columns map[string]string `json:"columns"` -} -type LogRow struct { - Timestamp string `json:"t"` - Data []string `json:"d"` -} diff --git a/pkg/rest/users.go b/pkg/rest/users.go deleted file mode 100644 index 2dcdcc5..0000000 --- a/pkg/rest/users.go +++ /dev/null @@ -1,275 +0,0 @@ -package rest - -import ( - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/in4it/wireguard-server/pkg/storage" - "github.com/in4it/wireguard-server/pkg/users" - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -func (c *Context) GetUserFromRequest(r *http.Request) (users.User, error) { - claims := r.Context().Value(CustomValue("claims")).(jwt.MapClaims) - sub, ok := claims["sub"] - if !ok { - return users.User{}, fmt.Errorf("userinfoHandler: subject not found in token") - } - iss, ok := claims["iss"] - if !ok { - return users.User{}, fmt.Errorf("userinfoHandler: issuer not found in token") - } - - kid, ok := claims["kid"] - if !ok { - return users.User{}, fmt.Errorf("userinfoHandler: kid not found in token") - } - - if kid == c.JWTKeysKID { - user, err := c.UserStore.GetUserByLogin(sub.(string)) - if err != nil { - return users.User{}, fmt.Errorf("GetUserByLogin: user not found") - } - return user, nil - } else { // user comes from oidc - oauth2DataIDs := []string{} - for _, oauth2Data := range c.OIDCStore.OAuth2Data { - if oauth2Data.Issuer == iss && oauth2Data.Subject == sub { - oauth2DataIDs = append(oauth2DataIDs, oauth2Data.ID) - } - } - if len(oauth2DataIDs) == 0 { - return users.User{}, fmt.Errorf("userinfoHandler: couldn't find user in oidc database") - } - user, err := c.UserStore.GetUserByOIDCIDs(oauth2DataIDs) - if err != nil { - return user, fmt.Errorf("get user by oidc id failed: %s", err) - } - return user, nil - } -} - -func (c *Context) usersHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - users := c.UserStore.ListUsers() - userResponse := make([]UsersResponse, len(users)) - for k, user := range users { - userResponse[k].ID = user.ID - userResponse[k].Login = user.Login - userResponse[k].Role = user.Role - userResponse[k].OIDCID = user.OIDCID - userResponse[k].SAMLID = user.SAMLID - userResponse[k].Suspended = user.Suspended - userResponse[k].Provisioned = user.Provisioned - userResponse[k].ConnectionsDisabledOnAuthFailure = user.ConnectionsDisabledOnAuthFailure - if !user.LastLogin.IsZero() { - userResponse[k].LastLogin = user.LastLogin.UTC().Format(time.RFC3339) - } - for _, oauth2Data := range c.OIDCStore.OAuth2Data { - if oauth2Data.ID == user.OIDCID { - userResponse[k].LastTokenRenewal = oauth2Data.LastTokenRenewal - } - } - } - out, err := json.Marshal(userResponse) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - var user NewUserRequest - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&user) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - if !isAlphaNumeric(user.Login) { - c.returnError(w, fmt.Errorf("login not valid"), http.StatusBadRequest) - return - } - if user.Login == "" { - c.returnError(w, fmt.Errorf("login is empty"), http.StatusBadRequest) - return - } - if user.Password == "" { - c.returnError(w, fmt.Errorf("password is empty"), http.StatusBadRequest) - return - } - if user.Role != "user" && user.Role != "admin" { - c.returnError(w, fmt.Errorf("invalid role"), http.StatusBadRequest) - return - } - if c.UserStore.UserCount() >= c.LicenseUserCount { - c.returnError(w, fmt.Errorf("no more licenses available"), http.StatusBadRequest) - return - } - - newUser, err := c.UserStore.AddUser(users.User{Login: user.Login, Password: user.Password, Role: user.Role}) - if err != nil { - c.returnError(w, fmt.Errorf("add user error: %s", err), http.StatusBadRequest) - return - } - out, err := json.Marshal(newUser) - if err != nil { - c.returnError(w, fmt.Errorf("new user marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) userHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodDelete: - userID := r.PathValue("id") - err := c.UserStore.DeleteUserByID(userID) - if err != nil { - c.returnError(w, fmt.Errorf("delete user error: %s", err), http.StatusBadRequest) - return - } - err = wireguard.DeleteAllClientConfigs(c.Storage.Client, userID) - if err != nil { - c.returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", userID, err), http.StatusBadRequest) - return - } - c.write(w, []byte(`{"deleted": "`+userID+`"}`)) - case http.MethodPatch: - dbUser, err := c.UserStore.GetUserByID(r.PathValue("id")) - if err != nil { - c.returnError(w, fmt.Errorf("user not found: %s", err), http.StatusBadRequest) - return - } - var user users.User - decoder := json.NewDecoder(r.Body) - err = decoder.Decode(&user) - if err != nil { - c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) - return - } - updateUser := false - if user.Role != "" && dbUser.Role != user.Role { - dbUser.Role = user.Role - updateUser = true - } - if dbUser.Suspended != user.Suspended { - dbUser.Suspended = user.Suspended - updateUser = true - if user.Suspended { // user is now suspended - err := wireguard.DisableAllClientConfigs(c.Storage.Client, user.ID) - if err != nil { - c.returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", user.ID, err), http.StatusBadRequest) - return - } - } else { // user is now unsuspended - err := wireguard.ReactivateAllClientConfigs(c.Storage.Client, user.ID) - if err != nil { - c.returnError(w, fmt.Errorf("could not reactivate all clients for user %s: %s", user.ID, err), http.StatusBadRequest) - return - } - } - } - if updateUser { - err = c.UserStore.UpdateUser(dbUser) - if err != nil { - c.returnError(w, fmt.Errorf("update user error: %s", err), http.StatusBadRequest) - return - } - } - if user.Password != "" { - err = c.UserStore.UpdatePassword(user.ID, user.Password) - if err != nil { - c.returnError(w, fmt.Errorf("update password error: %s", err), http.StatusBadRequest) - return - } - } - out, err := json.Marshal(dbUser) - if err != nil { - c.returnError(w, fmt.Errorf("marshal dbuser error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func addOrModifyExternalUser(storage storage.Iface, userStore *users.UserStore, login, authType, externalAuthID string) (users.User, error) { - if userStore.LoginExists(login) { - existingUser, err := userStore.GetUserByLogin(login) - if err != nil { - return existingUser, fmt.Errorf("couldn't find existing user in database: %s", login) - } - - if authType == "oidc" { - existingUser.OIDCID = externalAuthID - } - if authType == "saml" { - existingUser.SAMLID = externalAuthID - } - - if existingUser.ConnectionsDisabledOnAuthFailure { // we can enable connections again after auth - err := wireguard.ReactivateAllClientConfigs(storage, existingUser.ID) - if err != nil { - return existingUser, fmt.Errorf("could not reactivate all clients for user %s: %s", existingUser.ID, err) - } - existingUser.ConnectionsDisabledOnAuthFailure = false - } - - existingUser.LastLogin = time.Now() - - err = userStore.UpdateUser(existingUser) - if err != nil { - return existingUser, fmt.Errorf("couldn't update user: %s", login) - } - return existingUser, nil - } else { - newUser := users.User{ - Login: login, - Role: "user", - } - if authType == "oidc" { - newUser.OIDCID = externalAuthID - } - if authType == "saml" { - newUser.SAMLID = externalAuthID - } - - newUser.LastLogin = time.Now() - - newUserAdded, err := userStore.AddUser(newUser) - if err != nil { - return newUserAdded, fmt.Errorf("could not add user: %s", err) - } - return newUserAdded, nil - } -} - -func (c *Context) userinfoHandler(w http.ResponseWriter, r *http.Request) { - var response UserInfoResponse - - user := r.Context().Value(CustomValue("user")).(users.User) - - response.Login = user.Login - response.Role = user.Role - if user.OIDCID == "" { - response.UserType = "local" - } else { - response.UserType = "oidc" - } - - out, err := json.Marshal(response) - if err != nil { - c.returnError(w, fmt.Errorf("cannot marshal userinfo response: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - -} diff --git a/pkg/rest/users_test.go b/pkg/rest/users_test.go deleted file mode 100644 index c516691..0000000 --- a/pkg/rest/users_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package rest - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "net/http/httptest" - "path" - "strings" - "testing" - - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" - "github.com/in4it/wireguard-server/pkg/users" - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { - l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI) - if err != nil { - t.Fatal(err) - } - - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodPost: - if r.RequestURI == "/refresh-clients" { - w.WriteHeader(http.StatusAccepted) - w.Write([]byte("OK")) - return - } - if r.RequestURI == "/refresh-server-config" { - w.WriteHeader(http.StatusAccepted) - w.Write([]byte("OK")) - return - } - w.WriteHeader(http.StatusBadRequest) - default: - w.WriteHeader(http.StatusBadRequest) - } - })) - - ts.Listener.Close() - ts.Listener = l - ts.Start() - defer ts.Close() - defer l.Close() - - // first create a new user - storage := &memorystorage.MockMemoryStorage{} - - c, err := newContext(storage, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context") - } - - err = c.UserStore.Empty() - if err != nil { - t.Fatalf("Cannot create context") - } - - // create a user - user := users.User{ - Login: "john", - Role: "user", - Password: "xyz", - } - payload, err := json.Marshal(user) - if err != nil { - t.Fatalf("Cannot create payload: %s", err) - } - req := httptest.NewRequest("POST", "http://example.com/users", bytes.NewBuffer(payload)) - w := httptest.NewRecorder() - c.usersHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - err = json.NewDecoder(resp.Body).Decode(&user) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - - // generate VPN config - _, err = wireguard.CreateNewVPNConfig(c.Storage.Client) - if err != nil { - t.Fatalf("Cannot create vpn config: %s", err) - } - - req = httptest.NewRequest("POST", "http://example.com/connections", nil) - w = httptest.NewRecorder() - c.connectionsHandler(w, req.WithContext(context.WithValue(context.Background(), CustomValue("user"), user))) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - connectionID := fmt.Sprintf("%s-1", user.ID) - - userConfigFilename := storage.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connectionID+".json")) - configBytes, err := storage.ReadFile(userConfigFilename) - if err != nil { - t.Fatalf("could not read user config file") - } - - var config wireguard.PeerConfig - err = json.Unmarshal(configBytes, &config) - if err != nil { - t.Fatalf("could not parse config: %s", err) - } - if config.Disabled { - t.Fatalf("VPN connection is disabled. Expected not disabled") - } - - req = httptest.NewRequest("GET", "http://example.com/connection/"+connectionID, nil) - req.SetPathValue("id", connectionID) - w = httptest.NewRecorder() - c.connectionsElementHandler(w, req.WithContext(context.WithValue(context.Background(), CustomValue("user"), user))) - - resp = w.Result() - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("readall error: %s", err) - } - if !strings.Contains(string(body), "[Interface]") { - t.Fatalf("output doesn't look like a wireguard client config: %s", body) - } - - req = httptest.NewRequest("DELETE", "http://example.com/user/"+user.ID, nil) - req.SetPathValue("id", user.ID) - w = httptest.NewRecorder() - c.userHandler(w, req) - - resp = w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - _, err = storage.ReadFile(userConfigFilename) - if err == nil { - t.Fatalf("could read user config file, expected not to") - } -} - -func TestCreateUser(t *testing.T) { - // first create a new user - storage := &memorystorage.MockMemoryStorage{} - - c, err := newContext(storage, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context") - } - - err = c.UserStore.Empty() - if err != nil { - t.Fatalf("Cannot create context") - } - - // create a user - payload := []byte(`{"id": "", "login": "testuser", "password": "tttt213", "role": "user", "oidcID": "", "samlID": "", "lastLogin": "", "provisioned": false, "role":"user","samlID":"","suspended":false}`) - req := httptest.NewRequest("POST", "http://example.com/users", bytes.NewBuffer(payload)) - w := httptest.NewRecorder() - c.usersHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 200 { - t.Fatalf("status code is not 200: %d", resp.StatusCode) - } - - defer resp.Body.Close() - - var user users.User - err = json.NewDecoder(resp.Body).Decode(&user) - if err != nil { - t.Fatalf("Cannot decode response from create user: %s", err) - } - -} diff --git a/pkg/rest/version.go b/pkg/rest/version.go deleted file mode 100644 index c4b9e70..0000000 --- a/pkg/rest/version.go +++ /dev/null @@ -1,67 +0,0 @@ -package rest - -import ( - _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -//go:generate cp -r ../../latest ./resources/version -//go:embed resources/version - -var version string - -func (c *Context) version(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - out, err := json.Marshal(map[string]string{"version": strings.TrimSpace(version)}) - if err != nil { - c.returnError(w, fmt.Errorf("version marshal error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} - -func (c *Context) upgrade(w http.ResponseWriter, r *http.Request) { - client := http.Client{ - Timeout: 10 * time.Second, - } - req, err := http.NewRequest(r.Method, "http://"+wireguard.CONFIGMANAGER_URI+"/upgrade", nil) - if err != nil { - c.returnError(w, fmt.Errorf("upgrade request error: %s", err), http.StatusBadRequest) - return - } - resp, err := client.Do(req) - if err != nil { - c.returnError(w, fmt.Errorf("upgrade error: %s", err), http.StatusBadRequest) - return - } - if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.returnError(w, fmt.Errorf("upgrade error: got status code: %d. Respons: %s", resp.StatusCode, bodyBytes), http.StatusBadRequest) - return - } - c.returnError(w, fmt.Errorf("upgrade error: got status code: %d. Couldn't get response", resp.StatusCode), http.StatusBadRequest) - return - } - - defer resp.Body.Close() - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.returnError(w, fmt.Errorf("body read error: %s", err), http.StatusBadRequest) - return - } - - c.write(w, bodyBytes) - -} diff --git a/pkg/rest/vpn.go b/pkg/rest/vpn.go deleted file mode 100644 index afdd92d..0000000 --- a/pkg/rest/vpn.go +++ /dev/null @@ -1,118 +0,0 @@ -package rest - -import ( - "encoding/json" - "fmt" - "net/http" - "path" - "strings" - - "github.com/in4it/wireguard-server/pkg/users" - "github.com/in4it/wireguard-server/pkg/wireguard" -) - -func (c *Context) connectionsHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - user := r.Context().Value(CustomValue("user")).(users.User) - - clients, err := c.Storage.Client.ReadDir(c.Storage.Client.ConfigPath(wireguard.VPN_CLIENTS_DIR)) - if err != nil { - c.returnError(w, fmt.Errorf("cannot list connections for user: %s", err), http.StatusBadRequest) - return - } - - connectionList := []string{} - for _, clientFilename := range clients { - if wireguard.HasClientUserID(clientFilename, user.ID) { - connectionList = append(connectionList, clientFilename) - } - } - peerConfigs := make([]wireguard.PeerConfig, len(connectionList)) - for k, connection := range connectionList { - var peerConfig wireguard.PeerConfig - filename := c.Storage.Client.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connection)) - toDeleteFileContents, err := c.Storage.Client.ReadFile(filename) - if err != nil { - c.returnError(w, fmt.Errorf("can't read file %s: %s", filename, err), http.StatusBadRequest) - return - } - err = json.Unmarshal(toDeleteFileContents, &peerConfig) - if err != nil { - c.returnError(w, fmt.Errorf("can't unmarshal file %s: %s", filename, err), http.StatusBadRequest) - return - } - peerConfigs[k] = peerConfig - } - connections := make([]Connection, len(peerConfigs)) - for k := range peerConfigs { - connections[k] = Connection{ - ID: peerConfigs[k].ID, - Name: peerConfigs[k].Name, - } - } - out, err := json.Marshal(connections) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal list connection response: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodPost: - muClientDownload.Lock() - defer muClientDownload.Unlock() - user := r.Context().Value(CustomValue("user")).(users.User) - peerConfig, err := wireguard.NewEmptyClientConfig(c.Storage.Client, user.ID) - if err != nil { - c.returnError(w, fmt.Errorf("could not generate client vpn config: %s", err), http.StatusBadRequest) - return - } - newConnectionResponse := NewConnectionResponse{Name: peerConfig.Name} - out, err := json.Marshal(newConnectionResponse) - if err != nil { - c.returnError(w, fmt.Errorf("could not marshal new connection response: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} -func (c *Context) connectionsElementHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - user := r.Context().Value(CustomValue("user")).(users.User) - if !strings.HasPrefix(r.PathValue("id"), user.ID) { - c.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest) - return - } - if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") { - c.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest) - return - } - out, err := wireguard.GenerateNewClientConfig(c.Storage.Client, r.PathValue("id"), user.ID) - if err != nil { - c.returnError(w, fmt.Errorf("GetClientConfig error: %s", err), http.StatusBadRequest) - return - } - c.write(w, out) - case http.MethodDelete: - user := r.Context().Value(CustomValue("user")).(users.User) - if !strings.HasPrefix(r.PathValue("id"), user.ID) { - c.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest) - return - } - if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") { - c.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest) - return - } - err := wireguard.DeleteClientConfig(c.Storage.Client, r.PathValue("id"), user.ID) - if err != nil { - c.returnError(w, fmt.Errorf("DeleteClientConfig error: %s", err), http.StatusBadRequest) - return - } - c.write(w, []byte(`{"deleted": "`+r.PathValue("id")+`"}`)) - - default: - c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) - } -} diff --git a/pkg/storage/iface.go b/pkg/storage/iface.go deleted file mode 100644 index f412fff..0000000 --- a/pkg/storage/iface.go +++ /dev/null @@ -1,34 +0,0 @@ -package storage - -import ( - "io" - "io/fs" -) - -type Iface interface { - GetPath() string - EnsurePath(path string) error - EnsureOwnership(filename, login string) error - ReadDir(name string) ([]string, error) - Remove(name string) error - Rename(oldName, newName string) error - AppendFile(name string, data []byte) error - EnsurePermissions(name string, mode fs.FileMode) error - FileInfo(name string) (fs.FileInfo, error) - ReadWriter - Seeker -} - -type ReadWriter interface { - ReadFile(name string) ([]byte, error) - WriteFile(name string, data []byte) error - FileExists(filename string) bool - ConfigPath(filename string) string - OpenFile(name string) (io.ReadCloser, error) - OpenFileForWriting(name string) (io.WriteCloser, error) - OpenFileForAppending(name string) (io.WriteCloser, error) -} - -type Seeker interface { - OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) -} diff --git a/pkg/storage/local/constants.go b/pkg/storage/local/constants.go deleted file mode 100644 index 938df52..0000000 --- a/pkg/storage/local/constants.go +++ /dev/null @@ -1,4 +0,0 @@ -package localstorage - -const CONFIG_PATH = "config" -const VPN_CLIENTS_DIR = "clients" diff --git a/pkg/storage/local/new.go b/pkg/storage/local/new.go deleted file mode 100644 index be9b35c..0000000 --- a/pkg/storage/local/new.go +++ /dev/null @@ -1,44 +0,0 @@ -package localstorage - -import ( - "fmt" - "os" - "path" -) - -func New() (*LocalStorage, error) { - pwd, err := os.Executable() - if err != nil { - return nil, fmt.Errorf("os Executable error: %s", err) - } - - pathname := path.Dir(pwd) - storage, err := NewWithPath(pathname) - if err != nil { - return storage, err - } - err = storage.EnsurePath(CONFIG_PATH) - if err != nil { - return nil, fmt.Errorf("cannot create storage directories: %s", err) - } - err = storage.EnsurePath(path.Join(CONFIG_PATH, VPN_CLIENTS_DIR)) - if err != nil { - return nil, fmt.Errorf("cannot create storage directories: %s", err) - } - err = storage.EnsureOwnership(CONFIG_PATH, "vpn") - if err != nil { - return nil, fmt.Errorf("cannot ensure vpn ownership of config directory: %s", err) - } - err = storage.EnsureOwnership(path.Join(CONFIG_PATH, VPN_CLIENTS_DIR), "vpn") - if err != nil { - return nil, fmt.Errorf("cannot ensure vpn ownership of config directory: %s", err) - } - - return NewWithPath(pathname) -} - -func NewWithPath(pathname string) (*LocalStorage, error) { - return &LocalStorage{ - path: pathname, - }, nil -} diff --git a/pkg/storage/local/path.go b/pkg/storage/local/path.go deleted file mode 100644 index 214b341..0000000 --- a/pkg/storage/local/path.go +++ /dev/null @@ -1,96 +0,0 @@ -package localstorage - -import ( - "errors" - "fmt" - "io/fs" - "os" - "os/user" - "path" - "strconv" - - "github.com/in4it/wireguard-server/pkg/logging" -) - -func (l *LocalStorage) FileExists(filename string) bool { - if _, err := os.Stat(path.Join(l.path, filename)); errors.Is(err, os.ErrNotExist) { - return false - } - return true -} - -func (l *LocalStorage) ConfigPath(filename string) string { - return path.Join(CONFIG_PATH, filename) -} - -func (l *LocalStorage) GetPath() string { - return l.path -} - -func (l *LocalStorage) EnsurePath(pathname string) error { - fullPathname := path.Join(l.path, pathname) - if _, err := os.Stat(fullPathname); errors.Is(err, os.ErrNotExist) { - err := os.Mkdir(fullPathname, 0700) - if err != nil { - return fmt.Errorf("create directory error: %s", err) - } - } - return nil -} - -func (l *LocalStorage) EnsureOwnership(filename, login string) error { - currentUser, err := user.Current() - if err != nil { - return fmt.Errorf("could not get current user: %s", err) - } - if currentUser.Username != "root" { - logging.DebugLog(fmt.Errorf("cannot ensure ownership of file %s when not user root (current user: %s)", filename, currentUser.Username)) - return nil - } - vpnUser, err := user.Lookup(login) - if err != nil { - return fmt.Errorf("user lookup error (vpn): %s", err) - } - vpnUserUid, err := strconv.Atoi(vpnUser.Uid) - if err != nil { - return fmt.Errorf("user lookup error (uid): %s", err) - } - vpnUserGid, err := strconv.Atoi(vpnUser.Gid) - if err != nil { - return fmt.Errorf("user lookup error (gid): %s", err) - } - - err = os.Chown(path.Join(l.path, filename), vpnUserUid, vpnUserGid) - if err != nil { - return fmt.Errorf("vpn chown error: %s", err) - } - return nil -} - -func (l *LocalStorage) ReadDir(pathname string) ([]string, error) { - res, err := os.ReadDir(path.Join(l.path, pathname)) - if err != nil { - return []string{}, err - } - resNames := make([]string, len(res)) - for k, v := range res { - resNames[k] = v.Name() - } - return resNames, nil -} - -func (l *LocalStorage) Remove(name string) error { - return os.Remove(path.Join(l.path, name)) -} - -func (l *LocalStorage) Rename(oldName, newName string) error { - return os.Rename(path.Join(l.path, oldName), path.Join(l.path, newName)) -} - -func (l *LocalStorage) EnsurePermissions(name string, mode fs.FileMode) error { - return os.Chmod(path.Join(l.path, name), mode) -} - -func (l *LocalStorage) FileInfo(name string) (fs.FileInfo, error) { - return os.Stat(name) -} diff --git a/pkg/storage/local/read.go b/pkg/storage/local/read.go deleted file mode 100644 index 3a2e6ce..0000000 --- a/pkg/storage/local/read.go +++ /dev/null @@ -1,48 +0,0 @@ -package localstorage - -import ( - "fmt" - "io" - "os" - "path" -) - -func (l *LocalStorage) ReadFile(name string) ([]byte, error) { - return os.ReadFile(path.Join(l.path, name)) -} - -func (l *LocalStorage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) { - readers := []io.ReadCloser{} - if pos < 0 { - return readers, nil - } - for _, name := range names { - file, err := os.Open(path.Join(l.path, name)) - if err != nil { - return nil, fmt.Errorf("cannot open file (%s): %s", name, err) - } - stat, err := file.Stat() - if err != nil { - return nil, fmt.Errorf("cannot get file stat (%s): %s", name, err) - } - if stat.Size() <= pos { - pos -= stat.Size() - } else { - _, err := file.Seek(pos, 0) - if err != nil { - return nil, fmt.Errorf("could not seek to pos (file: %s): %s", name, err) - } - pos = 0 - readers = append(readers, file) - } - } - return readers, nil -} - -func (l *LocalStorage) OpenFile(name string) (io.ReadCloser, error) { - file, err := os.Open(path.Join(l.path, name)) - if err != nil { - return nil, fmt.Errorf("cannot open file (%s): %s", name, err) - } - return file, nil -} diff --git a/pkg/storage/local/read_test.go b/pkg/storage/local/read_test.go deleted file mode 100644 index c40e6b4..0000000 --- a/pkg/storage/local/read_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package localstorage - -import ( - "bytes" - "io" - "os" - "path" - "testing" -) - -func TestOpenFilesFromPos(t *testing.T) { - pwd, err := os.Executable() - if err != nil { - t.Fatalf("os Executable error: %s", err) - } - l := LocalStorage{ - path: path.Dir(pwd), - } - contents1 := []byte(`this is the first file`) - contents2 := []byte(`this is the second file`) - err = l.WriteFile("1.txt", contents1) - if err != nil { - t.Fatalf("write file error: %s", err) - } - err = l.WriteFile("2.txt", contents2) - if err != nil { - t.Fatalf("write file error: %s", err) - } - t.Cleanup(func() { - err = os.Remove(path.Join(l.path, "1.txt")) - if err != nil { - t.Fatalf("file delete error: %s", err) - } - err = os.Remove(path.Join(l.path, "2.txt")) - if err != nil { - t.Fatalf("file delete error: %s", err) - } - }) - expected := []string{ - "this is the first filethis is the second file", - "is the first filethis is the second file", - "this is the second file", - "ethis is the second file", - "his is the second file", - "", - "", - "", - } - expextedOpenFiles := []int{ - 2, - 2, - 1, - 2, - 1, - 0, - 0, - 0, - } - tests := []int64{ - 0, - 5, - int64(len(contents1)), - int64(len(contents1) - 1), - int64(len(contents1) + 1), - int64(len(contents1) + len(contents2)), - int64(len(contents1) + len(contents2) + 1), - -5, - } - for k, pos := range tests { - files, err := l.OpenFilesFromPos([]string{"1.txt", "2.txt"}, pos) - if err != nil { - t.Fatalf("open file error: %s", err) - } - contents := bytes.NewBuffer([]byte{}) - for _, file := range files { - defer file.Close() - body, err := io.ReadAll(file) - if err != nil { - t.Fatalf("could not read file: %s", err) - } - contents.Write(body) - } - if expected[k] != contents.String() { - t.Fatalf("unexpected output: expected '%s' got '%s'", expected[k], contents.String()) - } - if expextedOpenFiles[k] != len(files) { - t.Fatalf("unexpected open files: expected %d got %d", expextedOpenFiles[k], len(files)) - } - } - -} diff --git a/pkg/storage/local/types.go b/pkg/storage/local/types.go deleted file mode 100644 index 103b254..0000000 --- a/pkg/storage/local/types.go +++ /dev/null @@ -1,5 +0,0 @@ -package localstorage - -type LocalStorage struct { - path string -} diff --git a/pkg/storage/local/write.go b/pkg/storage/local/write.go deleted file mode 100644 index 5eb3d19..0000000 --- a/pkg/storage/local/write.go +++ /dev/null @@ -1,41 +0,0 @@ -package localstorage - -import ( - "fmt" - "io" - "os" - "path" -) - -func (l *LocalStorage) WriteFile(name string, data []byte) error { - return os.WriteFile(path.Join(l.path, name), data, 0600) -} - -func (l *LocalStorage) AppendFile(name string, data []byte) error { - f, err := os.OpenFile(path.Join(l.path, name), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0660) - if err != nil { - return err - } - defer f.Close() - if _, err := f.Write(data); err != nil { - return err - } - - return nil -} - -func (l *LocalStorage) OpenFileForWriting(name string) (io.WriteCloser, error) { - file, err := os.Create(path.Join(l.path, name)) - if err != nil { - return nil, fmt.Errorf("cannot open file (%s): %s", name, err) - } - return file, nil -} - -func (l *LocalStorage) OpenFileForAppending(name string) (io.WriteCloser, error) { - file, err := os.OpenFile(path.Join(l.path, name), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0660) - if err != nil { - return nil, err - } - return file, nil -} diff --git a/pkg/storage/memory/fileinfo.go b/pkg/storage/memory/fileinfo.go deleted file mode 100644 index b9fc5dd..0000000 --- a/pkg/storage/memory/fileinfo.go +++ /dev/null @@ -1,34 +0,0 @@ -package memorystorage - -import ( - "io/fs" - "time" -) - -type FileInfo struct { - NameOut string // base name of the file - SizeOut int64 // length in bytes for regular files; system-dependent for others - ModeOut fs.FileMode // file mode bits - ModTimeOut time.Time // modification time - IsDirOut bool // abbreviation for Mode().IsDir() - SysOut any // underlying data source (can return nil) -} - -func (f FileInfo) Name() string { - return f.NameOut -} -func (f FileInfo) Size() int64 { - return f.SizeOut -} -func (f FileInfo) Mode() fs.FileMode { - return f.ModeOut -} -func (f FileInfo) ModTime() time.Time { - return f.ModTimeOut -} -func (f FileInfo) IsDir() bool { - return f.IsDirOut -} -func (f FileInfo) Sys() any { - return nil -} diff --git a/pkg/storage/memory/storage.go b/pkg/storage/memory/storage.go deleted file mode 100644 index eb7a6c1..0000000 --- a/pkg/storage/memory/storage.go +++ /dev/null @@ -1,201 +0,0 @@ -package memorystorage - -import ( - "bytes" - "fmt" - "io" - "io/fs" - "os" - "path" - "strings" - "sync" -) - -type MockReadWriterData []byte - -func (m *MockReadWriterData) Close() error { - return nil -} -func (m *MockReadWriterData) Write(p []byte) (nn int, err error) { - *m = append(*m, p...) - return len(p), nil -} - -type MockMemoryStorage struct { - FileInfoData map[string]*FileInfo - Data map[string]*MockReadWriterData - Mu sync.Mutex -} - -func (m *MockMemoryStorage) ConfigPath(filename string) string { - return path.Join("config", filename) -} -func (m *MockMemoryStorage) Rename(oldName, newName string) error { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - _, ok := m.Data[oldName] - if !ok { - return fmt.Errorf("file doesn't exist") - } - m.Data[newName] = m.Data[oldName] - delete(m.Data, oldName) - return nil -} -func (m *MockMemoryStorage) FileExists(name string) bool { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - _, ok := m.Data[name] - return ok -} - -func (m *MockMemoryStorage) ReadFile(name string) ([]byte, error) { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - val, ok := m.Data[name] - if !ok { - return nil, fmt.Errorf("file does not exist") - } - return *val, nil -} -func (m *MockMemoryStorage) WriteFile(name string, data []byte) error { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - m.Data[name] = (*MockReadWriterData)(&data) - return nil -} -func (m *MockMemoryStorage) AppendFile(name string, data []byte) error { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - if m.Data[name] == nil { - m.Data[name] = (*MockReadWriterData)(&data) - } else { - *m.Data[name] = append(*m.Data[name], data...) - } - - return nil -} - -func (m *MockMemoryStorage) GetPath() string { - pwd, _ := os.Executable() - return path.Dir(pwd) -} - -func (m *MockMemoryStorage) EnsurePath(pathname string) error { - return nil -} - -func (m *MockMemoryStorage) EnsureOwnership(filename, login string) error { - return nil -} - -func (m *MockMemoryStorage) ReadDir(path string) ([]string, error) { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - res := []string{} - for k := range m.Data { - if path == "" { - res = append(res, strings.TrimSuffix(k, "/")) - } else if strings.HasPrefix(k, path+"/") { - res = append(res, strings.ReplaceAll(k, path+"/", "")) - } - } - return res, nil -} - -func (m *MockMemoryStorage) Remove(name string) error { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - _, ok := m.Data[name] - if !ok { - return fmt.Errorf("file does not exist") - } - delete(m.Data, name) - return nil -} - -func (m *MockMemoryStorage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - if pos > 0 { - return nil, fmt.Errorf("pos > 0 not implemented") - } - readClosers := []io.ReadCloser{} - for _, name := range names { - val, ok := m.Data[name] - if !ok { - return nil, fmt.Errorf("file does not exist") - } - readClosers = append(readClosers, io.NopCloser(bytes.NewBuffer(*val))) - } - return readClosers, nil -} -func (m *MockMemoryStorage) OpenFile(name string) (io.ReadCloser, error) { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - val, ok := m.Data[name] - if !ok { - return nil, fmt.Errorf("file does not exist") - } - - return io.NopCloser(bytes.NewBuffer(*val)), nil -} -func (m *MockMemoryStorage) OpenFileForWriting(name string) (io.WriteCloser, error) { - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - m.Data[name] = (*MockReadWriterData)(&[]byte{}) - return m.Data[name], nil -} -func (m *MockMemoryStorage) OpenFileForAppending(name string) (io.WriteCloser, error) { - m.Mu.Lock() - defer m.Mu.Unlock() - if m.Data == nil { - m.Data = make(map[string]*MockReadWriterData) - } - val, ok := m.Data[name] - if !ok { - m.Data[name] = (*MockReadWriterData)(&[]byte{}) - return m.Data[name], nil - } - m.Data[name] = (*MockReadWriterData)(val) - return m.Data[name], nil -} -func (m *MockMemoryStorage) EnsurePermissions(name string, mode fs.FileMode) error { - return nil -} -func (m *MockMemoryStorage) FileInfo(name string) (fs.FileInfo, error) { - m.Mu.Lock() - defer m.Mu.Unlock() - val, ok := m.FileInfoData[name] - if !ok { - return FileInfo{}, fmt.Errorf("couldn't get file info for: %s", name) - } - return val, nil -} diff --git a/pkg/storage/s3/list.go b/pkg/storage/s3/list.go deleted file mode 100644 index 5cadfdc..0000000 --- a/pkg/storage/s3/list.go +++ /dev/null @@ -1,25 +0,0 @@ -package s3storage - -import ( - "context" - "fmt" - "strings" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" -) - -func (s *S3Storage) ReadDir(pathname string) ([]string, error) { - objectList, err := s.s3Client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{ - Bucket: aws.String(s.bucketname), - Prefix: aws.String(s.prefix + "/" + strings.TrimLeft(pathname, "/")), - }) - if err != nil { - return []string{}, fmt.Errorf("list object error: %s", err) - } - res := make([]string, len(objectList.Contents)) - for k, object := range objectList.Contents { - res[k] = *object.Key - } - return res, nil -} diff --git a/pkg/storage/s3/new.go b/pkg/storage/s3/new.go deleted file mode 100644 index 7e0c7f7..0000000 --- a/pkg/storage/s3/new.go +++ /dev/null @@ -1,23 +0,0 @@ -package s3storage - -import ( - "context" - "fmt" - - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/s3" -) - -func New(bucketname, prefix string) (*S3Storage, error) { - sdkConfig, err := config.LoadDefaultConfig(context.TODO()) - if err != nil { - return nil, fmt.Errorf("config load error: %s", err) - } - s3Client := s3.NewFromConfig(sdkConfig) - - return &S3Storage{ - bucketname: bucketname, - prefix: prefix, - s3Client: s3Client, - }, nil -} diff --git a/pkg/storage/s3/path.go b/pkg/storage/s3/path.go deleted file mode 100644 index b8db5ce..0000000 --- a/pkg/storage/s3/path.go +++ /dev/null @@ -1,42 +0,0 @@ -package s3storage - -import ( - "io/fs" - "strings" -) - -func (l *S3Storage) FileExists(filename string) bool { - return false -} - -func (l *S3Storage) ConfigPath(filename string) string { - return CONFIG_PATH + "/" + strings.TrimLeft(filename, "/") -} - -func (s *S3Storage) GetPath() string { - return s.prefix -} - -func (l *S3Storage) EnsurePath(pathname string) error { - return nil -} - -func (l *S3Storage) EnsureOwnership(filename, login string) error { - return nil -} - -func (l *S3Storage) Remove(name string) error { - return nil -} - -func (l *S3Storage) Rename(oldName, newName string) error { - return nil -} - -func (l *S3Storage) EnsurePermissions(name string, mode fs.FileMode) error { - return nil -} - -func (l *S3Storage) FileInfo(name string) (fs.FileInfo, error) { - return nil, nil -} diff --git a/pkg/storage/s3/read.go b/pkg/storage/s3/read.go deleted file mode 100644 index 91b0ffe..0000000 --- a/pkg/storage/s3/read.go +++ /dev/null @@ -1,17 +0,0 @@ -package s3storage - -import ( - "io" -) - -func (l *S3Storage) ReadFile(name string) ([]byte, error) { - return nil, nil -} - -func (l *S3Storage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) { - return nil, nil -} - -func (l *S3Storage) OpenFile(name string) (io.ReadCloser, error) { - return nil, nil -} diff --git a/pkg/storage/s3/types.go b/pkg/storage/s3/types.go deleted file mode 100644 index 3ff10eb..0000000 --- a/pkg/storage/s3/types.go +++ /dev/null @@ -1,11 +0,0 @@ -package s3storage - -import "github.com/aws/aws-sdk-go-v2/service/s3" - -const CONFIG_PATH = "config" - -type S3Storage struct { - bucketname string - prefix string - s3Client *s3.Client -} diff --git a/pkg/storage/s3/write.go b/pkg/storage/s3/write.go deleted file mode 100644 index 815b819..0000000 --- a/pkg/storage/s3/write.go +++ /dev/null @@ -1,35 +0,0 @@ -package s3storage - -import ( - "bytes" - "context" - "fmt" - "io" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" -) - -func (s *S3Storage) WriteFile(name string, data []byte) error { - _, err := s.s3Client.PutObject(context.TODO(), &s3.PutObjectInput{ - Bucket: aws.String(s.bucketname), - Key: aws.String(name), - Body: bytes.NewReader(data), - }) - if err != nil { - return fmt.Errorf("put object error: %s", err) - } - return nil -} - -func (s *S3Storage) AppendFile(name string, data []byte) error { - return nil -} - -func (s *S3Storage) OpenFileForWriting(name string) (io.WriteCloser, error) { - return nil, nil -} - -func (s *S3Storage) OpenFileForAppending(name string) (io.WriteCloser, error) { - return nil, nil -} diff --git a/pkg/users/localusers.go b/pkg/users/localusers.go deleted file mode 100644 index 55be548..0000000 --- a/pkg/users/localusers.go +++ /dev/null @@ -1,185 +0,0 @@ -package users - -import ( - "fmt" - - "github.com/google/uuid" - "golang.org/x/crypto/bcrypt" -) - -func (u *UserStore) AddUser(user User) (User, error) { - if user.Login == "" { - return user, fmt.Errorf("login cannot be empty") - } - existingUsers := u.ListUsers() - for _, existingUser := range existingUsers { - if existingUser.Login == user.Login { - return User{}, fmt.Errorf("user with login '%s' already exists", user.Login) - } - } - user.ID = uuid.NewString() - if user.Password != "" { - hashedPassword, err := HashPassword(user.Password) - if err != nil { - return user, fmt.Errorf("HashPassword error: %s", err) - } - user.Password = hashedPassword - } - u.Users = append(u.Users, user) - if u.autoSave { - return user, u.SaveUsers() - } - return user, nil -} - -func (u *UserStore) AddUsers(users []User) ([]User, error) { - createdUsers := []User{} - existingUsers := u.ListUsers() - for k := range users { - for _, existingUser := range existingUsers { - if existingUser.Login == users[k].Login { - return createdUsers, fmt.Errorf("user with login '%s' already exists", users[k].Login) - } - } - users[k].ID = uuid.NewString() - hashedPassword, err := HashPassword(users[k].Password) - if err != nil { - return createdUsers, fmt.Errorf("HashPassword error: %s", err) - } - users[k].Password = hashedPassword - u.Users = append(u.Users, users[k]) - existingUsers = append(existingUsers, users[k]) - createdUsers = append(createdUsers, users[k]) - } - if u.autoSave { - return createdUsers, u.SaveUsers() - } - return createdUsers, nil -} - -func (u *UserStore) GetUserByID(id string) (User, error) { - for _, user := range u.Users { - if user.ID == id { - user.Password = "" - return user, nil - } - } - return User{}, fmt.Errorf("User not found") -} - -func (u *UserStore) GetUserByLogin(login string) (User, error) { - for _, user := range u.Users { - if user.Login == login { - user.Password = "" - return user, nil - } - } - return User{}, fmt.Errorf("User not found") -} -func (u *UserStore) DeleteUserByLogin(login string) error { - for k, user := range u.Users { - if user.Login == login { - u.Users = append(u.Users[:k], u.Users[k+1:]...) - if u.autoSave { - return u.SaveUsers() - } - return nil - } - } - return fmt.Errorf("User not found") -} - -func (u *UserStore) DeleteUserByID(id string) error { - for k, user := range u.Users { - if user.ID == id { - u.Users = append(u.Users[:k], u.Users[k+1:]...) - if u.autoSave { - return u.SaveUsers() - } - return nil - } - } - return fmt.Errorf("User not found") -} - -func (u *UserStore) AuthUser(login, password string) (User, bool) { - for _, user := range u.Users { - if user.Login == login { - passwordMatch := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) - if passwordMatch == nil { - return user, true - } - } - } - return User{}, false -} - -func (u *UserStore) LoginExists(login string) bool { - for _, user := range u.Users { - if user.Login == login { - return true - } - } - return false -} -func (u *UserStore) UpdateUser(user User) error { - for k, existingUser := range u.Users { - if existingUser.Login == user.Login { - password := existingUser.Password // we keep the password - user.Password = password - u.Users[k] = user - if u.autoSave { - return u.SaveUsers() - } else { - return nil - } - } - } - return fmt.Errorf("user not found in database: %s", user.Login) -} -func (u *UserStore) UpdatePassword(userID string, password string) error { - for k, existingUser := range u.Users { - if existingUser.ID == userID { - hashedPassword, err := HashPassword(password) - if err != nil { - return fmt.Errorf("HashPassword error: %s", err) - } - u.Users[k].Password = hashedPassword - if u.autoSave { - return u.SaveUsers() - } else { - return nil - } - } - } - return fmt.Errorf("user not found in database: userID %s", userID) -} - -func HashPassword(password string) (string, error) { - adminPasswordHashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) - if err != nil { - return "", fmt.Errorf("unable to set password: %s", err) - } - return string(adminPasswordHashed), nil -} - -func (u *UserStore) ListUsers() []User { - users := make([]User, len(u.Users)) - for k, user := range u.Users { - user.Password = "" - users[k] = user - } - return users -} - -func (u *UserStore) UserCount() int { - return len(u.Users) -} - -func (u *UserStore) Empty() error { - u.Users = []User{} - if u.autoSave { - return u.SaveUsers() - } - return nil -} diff --git a/pkg/users/maxusers.go b/pkg/users/maxusers.go deleted file mode 100644 index ffc0221..0000000 --- a/pkg/users/maxusers.go +++ /dev/null @@ -1,5 +0,0 @@ -package users - -func (u *UserStore) GetMaxUsers() int { - return u.maxUsers -} diff --git a/pkg/users/oidcusers.go b/pkg/users/oidcusers.go deleted file mode 100644 index 66a6bfa..0000000 --- a/pkg/users/oidcusers.go +++ /dev/null @@ -1,15 +0,0 @@ -package users - -import "fmt" - -func (u *UserStore) GetUserByOIDCIDs(oidcIDs []string) (User, error) { - for _, user := range u.Users { - for _, oidcID := range oidcIDs { - if user.OIDCID == oidcID { - return user, nil - } - } - - } - return User{}, fmt.Errorf("User not found") -} diff --git a/pkg/users/save.go b/pkg/users/save.go deleted file mode 100644 index f35085f..0000000 --- a/pkg/users/save.go +++ /dev/null @@ -1,36 +0,0 @@ -package users - -import ( - "encoding/json" - "fmt" - "os/user" - "sync" -) - -var UserStoreMu sync.Mutex - -func (u *UserStore) SaveUsers() error { - UserStoreMu.Lock() - defer UserStoreMu.Unlock() - out, err := json.Marshal(u.Users) - if err != nil { - return fmt.Errorf("user store marshal error: %s", err) - } - err = u.storage.WriteFile(u.storage.ConfigPath(USERSTORE_FILENAME), out) - if err != nil { - return fmt.Errorf("user store write error: %s", err) - } - // fix permissions - currentUser, err := user.Current() - if err != nil { - return fmt.Errorf("could not get current user: %s", err) - } - if currentUser.Username != "vpn" { - err = u.storage.EnsureOwnership(u.storage.ConfigPath(USERSTORE_FILENAME), "vpn") - if err != nil { - return fmt.Errorf("ensure ownership error (userstore): %s", err) - } - } - - return nil -} diff --git a/pkg/users/store.go b/pkg/users/store.go deleted file mode 100644 index ca0ee62..0000000 --- a/pkg/users/store.go +++ /dev/null @@ -1,35 +0,0 @@ -package users - -import ( - "bytes" - "encoding/json" - "fmt" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -const USERSTORE_FILENAME = "users.json" - -func NewUserStore(storage storage.Iface, maxUsers int) (*UserStore, error) { - userStore := &UserStore{ - autoSave: true, - maxUsers: maxUsers, - storage: storage, - } - - if !userStore.storage.FileExists(userStore.storage.ConfigPath(USERSTORE_FILENAME)) { - userStore.Users = []User{} - return userStore, nil - } - - data, err := userStore.storage.ReadFile(userStore.storage.ConfigPath(USERSTORE_FILENAME)) - if err != nil { - return userStore, fmt.Errorf("config read error: %s", err) - } - decoder := json.NewDecoder(bytes.NewBuffer(data)) - err = decoder.Decode(&userStore.Users) - if err != nil { - return userStore, fmt.Errorf("decode input error: %s", err) - } - return userStore, nil -} diff --git a/pkg/users/types.go b/pkg/users/types.go deleted file mode 100644 index 0c2bc65..0000000 --- a/pkg/users/types.go +++ /dev/null @@ -1,35 +0,0 @@ -package users - -import ( - "time" - - "github.com/in4it/wireguard-server/pkg/storage" -) - -type UserStore struct { - Users []User `json:"users"` - autoSave bool - maxUsers int - storage storage.Iface -} - -type User struct { - ID string `json:"id"` - Login string `json:"login"` - Role string `json:"role"` - OIDCID string `json:"oidcID,omitempty"` - SAMLID string `json:"samlID,omitempty"` - Provisioned bool `json:"provisioned,omitempty"` - Password string `json:"password,omitempty"` - Suspended bool `json:"suspended"` - ConnectionsDisabledOnAuthFailure bool `json:"connectionsDisabledOnAuthFailure"` - Factors []Factor `json:"factors"` - ExternalID string `json:"externalID,omitempty"` - LastLogin time.Time `json:"lastLogin"` -} - -type Factor struct { - Name string `json:"name"` - Type string `json:"type"` - Secret string `json:"secret"` -} diff --git a/pkg/utils/date/compare.go b/pkg/utils/date/compare.go deleted file mode 100644 index b6a901e..0000000 --- a/pkg/utils/date/compare.go +++ /dev/null @@ -1,9 +0,0 @@ -package dateutils - -import "time" - -func DateEqual(date1, date2 time.Time) bool { - y1, m1, d1 := date1.Date() - y2, m2, d2 := date2.Date() - return y1 == y2 && m1 == m2 && d1 == d2 -} diff --git a/pkg/utils/random/rand.go b/pkg/utils/random/rand.go deleted file mode 100644 index 7645ba3..0000000 --- a/pkg/utils/random/rand.go +++ /dev/null @@ -1,19 +0,0 @@ -package randomutils - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -func GetRandomString(n int) (string, error) { - buf := make([]byte, n) - - _, err := io.ReadFull(rand.Reader, buf) - if err != nil { - return "", fmt.Errorf("crypto/rand Reader error: %s", err) - } - - return base64.RawURLEncoding.EncodeToString(buf), nil -} diff --git a/pkg/vpn/helpers.go b/pkg/vpn/helpers.go new file mode 100644 index 0000000..92c1340 --- /dev/null +++ b/pkg/vpn/helpers.go @@ -0,0 +1,38 @@ +package vpn + +import ( + "fmt" + "net/http" + "strings" +) + +func (v *VPN) returnError(w http.ResponseWriter, err error, statusCode int) { + fmt.Println("========= ERROR =========") + fmt.Printf("Error: %s\n", err) + fmt.Println("=========================") + w.WriteHeader(statusCode) + w.Write([]byte(`{"error": "` + strings.Replace(err.Error(), `"`, `\"`, -1) + `"}`)) +} + +func (v *VPN) write(w http.ResponseWriter, res []byte) { + sendCorsHeaders(w, "", v.Hostname, v.Protocol) + w.WriteHeader(http.StatusOK) + w.Write(res) +} +func (v *VPN) writeWithStatus(w http.ResponseWriter, res []byte, status int) { + sendCorsHeaders(w, "", v.Hostname, v.Protocol) + w.WriteHeader(status) + w.Write(res) +} + +func sendCorsHeaders(w http.ResponseWriter, headers string, hostname string, protocol string) { + if hostname == "" { + w.Header().Add("Access-Control-Allow-Origin", "*") + } else { + w.Header().Add("Access-Control-Allow-Origin", fmt.Sprintf("%s://%s", protocol, hostname)) + } + w.Header().Add("Access-Control-allow-methods", "GET,HEAD,POST,PUT,OPTIONS,DELETE,PATCH") + if headers != "" { + w.Header().Add("Access-Control-Allow-Headers", headers) + } +} diff --git a/pkg/vpn/new.go b/pkg/vpn/new.go new file mode 100644 index 0000000..d7bef01 --- /dev/null +++ b/pkg/vpn/new.go @@ -0,0 +1,13 @@ +package vpn + +import ( + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" +) + +func New(defaultStorage storage.Iface, userStore *users.UserStore) *VPN { + return &VPN{ + Storage: defaultStorage, + UserStore: userStore, + } +} diff --git a/pkg/vpn/router.go b/pkg/vpn/router.go new file mode 100644 index 0000000..c1134c0 --- /dev/null +++ b/pkg/vpn/router.go @@ -0,0 +1,26 @@ +package vpn + +import ( + "net/http" + + "github.com/in4it/go-devops-platform/rest" +) + +func (v *VPN) GetRouter() *http.ServeMux { + mux := http.NewServeMux() + + mux.Handle("/api/vpn/connections", http.HandlerFunc(v.connectionsHandler)) + mux.Handle("/api/vpn/connection/{id}", http.HandlerFunc(v.connectionsElementHandler)) + mux.Handle("/api/vpn/connectionlicense", http.HandlerFunc(v.connectionLicenseHandler)) + + mux.Handle("/api/vpn/stats/user/{date}", rest.IsAdminMiddleware(http.HandlerFunc(v.userStatsHandler))) + mux.Handle("/api/vpn/stats/packetlogs/{user}/{date}", rest.IsAdminMiddleware(http.HandlerFunc(v.packetLogsHandler))) + + mux.Handle("/api/vpn/setup/vpn", rest.IsAdminMiddleware(http.HandlerFunc(v.vpnSetupHandler))) + mux.Handle("/api/vpn/setup/templates", rest.IsAdminMiddleware(http.HandlerFunc(v.templateSetupHandler))) + mux.Handle("/api/vpn/setup/restart-vpn", rest.IsAdminMiddleware(http.HandlerFunc(v.restartVPNHandler))) + + mux.Handle("/api/vpn/version", http.HandlerFunc(v.version)) + + return mux +} diff --git a/pkg/vpn/setup.go b/pkg/vpn/setup.go new file mode 100644 index 0000000..470dcf1 --- /dev/null +++ b/pkg/vpn/setup.go @@ -0,0 +1,300 @@ +package vpn + +import ( + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "reflect" + "slices" + "sort" + "strconv" + "strings" + "time" + + "github.com/in4it/wireguard-server/pkg/wireguard" +) + +func (v *VPN) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { + vpnConfig, err := wireguard.GetVPNConfig(v.Storage) + if err != nil { + v.returnError(w, fmt.Errorf("could not get vpn config: %s", err), http.StatusBadRequest) + return + } + switch r.Method { + case http.MethodGet: + packetLogTypes := []string{} + for k, enabled := range vpnConfig.PacketLogsTypes { + if enabled { + packetLogTypes = append(packetLogTypes, k) + } + } + if vpnConfig.PacketLogsRetention == 0 { + vpnConfig.PacketLogsRetention = 7 + } + setupRequest := VPNSetupRequest{ + Routes: strings.Join(vpnConfig.ClientRoutes, ", "), + VPNEndpoint: vpnConfig.Endpoint, + AddressRange: vpnConfig.AddressRange.String(), + ClientAddressPrefix: vpnConfig.ClientAddressPrefix, + Port: strconv.Itoa(vpnConfig.Port), + ExternalInterface: vpnConfig.ExternalInterface, + Nameservers: strings.Join(vpnConfig.Nameservers, ","), + DisableNAT: vpnConfig.DisableNAT, + EnablePacketLogs: vpnConfig.EnablePacketLogs, + PacketLogsTypes: packetLogTypes, + PacketLogsRetention: strconv.Itoa(vpnConfig.PacketLogsRetention), + } + out, err := json.Marshal(setupRequest) + if err != nil { + v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + case http.MethodPost: + var ( + writeVPNConfig bool + rewriteClientConfigs bool + setupRequest VPNSetupRequest + ) + decoder := json.NewDecoder(r.Body) + decoder.Decode(&setupRequest) + if strings.Join(vpnConfig.ClientRoutes, ", ") != setupRequest.Routes { + networks := strings.Split(setupRequest.Routes, ",") + validatedNetworks := []string{} + for _, network := range networks { + if strings.TrimSpace(network) == "::/0" { + validatedNetworks = append(validatedNetworks, "::/0") + } else { + _, ipnet, err := net.ParseCIDR(strings.TrimSpace(network)) + if err != nil { + v.returnError(w, fmt.Errorf("client route %s in wrong format: %s", strings.TrimSpace(network), err), http.StatusBadRequest) + return + } + validatedNetworks = append(validatedNetworks, ipnet.String()) + } + } + vpnConfig.ClientRoutes = validatedNetworks + writeVPNConfig = true + rewriteClientConfigs = true + } + if vpnConfig.Endpoint != setupRequest.VPNEndpoint { + vpnConfig.Endpoint = setupRequest.VPNEndpoint + writeVPNConfig = true + rewriteClientConfigs = true + } + addressRangeParsed, err := netip.ParsePrefix(setupRequest.AddressRange) + if err != nil { + v.returnError(w, fmt.Errorf("AddressRange in wrong format: %s", err), http.StatusBadRequest) + return + } + if addressRangeParsed.String() != vpnConfig.AddressRange.String() { + vpnConfig.AddressRange = addressRangeParsed + writeVPNConfig = true + rewriteClientConfigs = true + } + if setupRequest.ClientAddressPrefix != vpnConfig.ClientAddressPrefix { + vpnConfig.ClientAddressPrefix = setupRequest.ClientAddressPrefix + writeVPNConfig = true + rewriteClientConfigs = true + } + port, err := strconv.Atoi(setupRequest.Port) + if err != nil { + v.returnError(w, fmt.Errorf("port in wrong format: %s", err), http.StatusBadRequest) + return + } + if port != vpnConfig.Port { + vpnConfig.Port = port + writeVPNConfig = true + rewriteClientConfigs = true + } + + nameservers := strings.Split(setupRequest.Nameservers, ",") + for k := range nameservers { + nameservers[k] = strings.TrimSpace(nameservers[k]) + } + if !reflect.DeepEqual(nameservers, vpnConfig.Nameservers) { + vpnConfig.Nameservers = nameservers + writeVPNConfig = true + rewriteClientConfigs = true + } + if setupRequest.ExternalInterface != vpnConfig.ExternalInterface { // don't rewrite client config + vpnConfig.ExternalInterface = setupRequest.ExternalInterface + writeVPNConfig = true + } + if setupRequest.DisableNAT != vpnConfig.DisableNAT { // don't rewrite client config + vpnConfig.DisableNAT = setupRequest.DisableNAT + writeVPNConfig = true + } + if setupRequest.EnablePacketLogs != vpnConfig.EnablePacketLogs { + vpnConfig.EnablePacketLogs = setupRequest.EnablePacketLogs + writeVPNConfig = true + } + packetLogsRention, err := strconv.Atoi(setupRequest.PacketLogsRetention) + if err != nil || packetLogsRention < 1 { + v.returnError(w, fmt.Errorf("incorrect packet log retention. Enter a number of days the logs must be kept (minimum 1)"), http.StatusBadRequest) + return + } + if packetLogsRention != vpnConfig.PacketLogsRetention { + vpnConfig.PacketLogsRetention = packetLogsRention + writeVPNConfig = true + } + + // packetlogtypes + packetLogTypes := []string{} + for k, enabled := range vpnConfig.PacketLogsTypes { + if enabled { + packetLogTypes = append(packetLogTypes, k) + } + } + sort.Strings(setupRequest.PacketLogsTypes) + sort.Strings(packetLogTypes) + if !slices.Equal(setupRequest.PacketLogsTypes, packetLogTypes) { + vpnConfig.PacketLogsTypes = make(map[string]bool) + for _, v := range setupRequest.PacketLogsTypes { + if v == "http+https" || v == "dns" || v == "tcp" { + vpnConfig.PacketLogsTypes[v] = true + } + } + writeVPNConfig = true + } + + // write vpn config if config has changed + if writeVPNConfig { + err = wireguard.WriteVPNConfig(v.Storage, vpnConfig) + if err != nil { + v.returnError(w, fmt.Errorf("could write vpn config: %s", err), http.StatusBadRequest) + return + } + err = wireguard.ReloadVPNServerConfig() + if err != nil { + v.returnError(w, fmt.Errorf("unable to reload server config: %s", err), http.StatusBadRequest) + return + } + } + if rewriteClientConfigs { + // rewrite client configs + err = wireguard.UpdateClientsConfig(v.Storage) + if err != nil { + v.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest) + return + } + } + out, err := json.Marshal(setupRequest) + if err != nil { + v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + default: + v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (v *VPN) templateSetupHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + clientTemplate, err := wireguard.GetClientTemplate(v.Storage) + if err != nil { + v.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest) + return + } + serverTemplate, err := wireguard.GetServerTemplate(v.Storage) + if err != nil { + v.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest) + return + } + setupRequest := TemplateSetupRequest{ + ClientTemplate: string(clientTemplate), + ServerTemplate: string(serverTemplate), + } + out, err := json.Marshal(setupRequest) + if err != nil { + v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + case http.MethodPost: + var templateSetupRequest TemplateSetupRequest + decoder := json.NewDecoder(r.Body) + decoder.Decode(&templateSetupRequest) + clientTemplate, err := wireguard.GetClientTemplate(v.Storage) + if err != nil { + v.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest) + return + } + serverTemplate, err := wireguard.GetServerTemplate(v.Storage) + if err != nil { + v.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest) + return + } + if string(clientTemplate) != templateSetupRequest.ClientTemplate { + err = wireguard.WriteClientTemplate(v.Storage, []byte(templateSetupRequest.ClientTemplate)) + if err != nil { + v.returnError(w, fmt.Errorf("WriteClientTemplate error: %s", err), http.StatusBadRequest) + return + } + // rewrite client configs + err = wireguard.UpdateClientsConfig(v.Storage) + if err != nil { + v.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest) + return + } + } + if string(serverTemplate) != templateSetupRequest.ServerTemplate { + err = wireguard.WriteServerTemplate(v.Storage, []byte(templateSetupRequest.ServerTemplate)) + if err != nil { + v.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest) + return + } + } + out, err := json.Marshal(templateSetupRequest) + if err != nil { + v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + default: + v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (v *VPN) restartVPNHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + v.returnError(w, fmt.Errorf("unsupported method"), http.StatusBadRequest) + return + } + client := http.Client{ + Timeout: 10 * time.Second, + } + req, err := http.NewRequest(r.Method, "http://"+wireguard.CONFIGMANAGER_URI+"/restart-vpn", nil) + if err != nil { + v.returnError(w, fmt.Errorf("restart request error: %s", err), http.StatusBadRequest) + return + } + resp, err := client.Do(req) + if err != nil { + v.returnError(w, fmt.Errorf("restart error: %s", err), http.StatusBadRequest) + return + } + if resp.StatusCode != http.StatusAccepted { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + v.returnError(w, fmt.Errorf("restart error: got status code: %d. Response: %s", resp.StatusCode, bodyBytes), http.StatusBadRequest) + return + } + v.returnError(w, fmt.Errorf("restart error: got status code: %d. Couldn't get response", resp.StatusCode), http.StatusBadRequest) + return + } + + defer resp.Body.Close() + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + v.returnError(w, fmt.Errorf("body read error: %s", err), http.StatusBadRequest) + return + } + + v.write(w, bodyBytes) +} diff --git a/pkg/rest/stats.go b/pkg/vpn/stats.go similarity index 88% rename from pkg/rest/stats.go rename to pkg/vpn/stats.go index e7d6abc..2e47a1d 100644 --- a/pkg/rest/stats.go +++ b/pkg/vpn/stats.go @@ -1,4 +1,4 @@ -package rest +package vpn import ( "bufio" @@ -15,21 +15,21 @@ import ( "strings" "time" - "github.com/in4it/wireguard-server/pkg/storage" - dateutils "github.com/in4it/wireguard-server/pkg/utils/date" + "github.com/in4it/go-devops-platform/storage" + dateutils "github.com/in4it/go-devops-platform/utils/date" "github.com/in4it/wireguard-server/pkg/wireguard" ) const MAX_LOG_OUTPUT_LINES = 100 -func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { +func (v *VPN) userStatsHandler(w http.ResponseWriter, r *http.Request) { if r.PathValue("date") == "" { - c.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest) return } date, err := time.Parse("2006-01-02", r.PathValue("date")) if err != nil { - c.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest) return } unitAdjustment := int64(1) @@ -49,7 +49,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { } } // get all users - users := c.UserStore.ListUsers() + users := v.UserStore.ListUsers() userMap := make(map[string]string) for _, user := range users { userMap[user.ID] = user.Login @@ -65,10 +65,10 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { } logData := bytes.NewBuffer([]byte{}) for _, statsFile := range statsFiles { - if c.Storage.Client.FileExists(statsFile) { - fileLogData, err := c.Storage.Client.ReadFile(statsFile) + if v.Storage.FileExists(statsFile) { + fileLogData, err := v.Storage.ReadFile(statsFile) if err != nil { - c.returnError(w, fmt.Errorf("readfile error: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("readfile error: %s", err), http.StatusBadRequest) return } logData.Write(fileLogData) @@ -149,7 +149,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { } if err := scanner.Err(); err != nil { - c.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest) return } userStatsResponse.ReceiveBytes = UserStatsData{ @@ -210,39 +210,39 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { out, err := json.Marshal(userStatsResponse) if err != nil { - c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) return } - c.write(w, out) + v.write(w, out) } -func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) { - vpnConfig, err := wireguard.GetVPNConfig(c.Storage.Client) +func (v *VPN) packetLogsHandler(w http.ResponseWriter, r *http.Request) { + vpnConfig, err := wireguard.GetVPNConfig(v.Storage) if err != nil { - c.returnError(w, fmt.Errorf("get vpn config error: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("get vpn config error: %s", err), http.StatusBadRequest) return } if !vpnConfig.EnablePacketLogs { // packet logs is disabled out, err := json.Marshal(LogDataResponse{Enabled: false}) if err != nil { - c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) return } - c.write(w, out) + v.write(w, out) return } userID := r.PathValue("user") if userID == "" { - c.returnError(w, fmt.Errorf("no user supplied"), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("no user supplied"), http.StatusBadRequest) return } if r.PathValue("date") == "" { - c.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest) return } date, err := time.Parse("2006-01-02", r.PathValue("date")) if err != nil { - c.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest) return } offset := 0 @@ -261,7 +261,7 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) { } search := r.FormValue("search") // get all users - users := c.UserStore.ListUsers() + users := v.UserStore.ListUsers() userMap := make(map[string]string) for _, user := range users { userMap[user.ID] = user.Login @@ -292,14 +292,14 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) { if !dateutils.DateEqual(time.Now(), date) { // date is in local timezone, and we are UTC, so also read next file statsFiles = append(statsFiles, path.Join(wireguard.VPN_STATS_DIR, wireguard.VPN_PACKETLOGGER_DIR, userID+"-"+date.AddDate(0, 0, 1).Format("2006-01-02")+".log")) } - statsFiles, err = getCompressedFilesAndRemoveNonExistent(c.Storage.Client, statsFiles) + statsFiles, err = getCompressedFilesAndRemoveNonExistent(v.Storage, statsFiles) if err != nil { - c.returnError(w, fmt.Errorf("unable to get files for reading: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("unable to get files for reading: %s", err), http.StatusBadRequest) return } - fileReaders, err := c.Storage.Client.OpenFilesFromPos(statsFiles, pos) + fileReaders, err := v.Storage.OpenFilesFromPos(statsFiles, pos) if err != nil { - c.returnError(w, fmt.Errorf("error while reading files: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("error while reading files: %s", err), http.StatusBadRequest) return } for _, fileReader := range fileReaders { @@ -334,7 +334,7 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) { } } if err := scanner.Err(); err != nil { - c.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest) return } } @@ -355,10 +355,10 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) { out, err := json.Marshal(LogDataResponse{Enabled: true, LogData: logData, LogTypes: packetLogTypes, Users: userMap}) if err != nil { - c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) + v.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) return } - c.write(w, out) + v.write(w, out) } func getCompressedFilesAndRemoveNonExistent(storage storage.Iface, files []string) ([]string, error) { diff --git a/pkg/rest/stats_test.go b/pkg/vpn/stats_test.go similarity index 92% rename from pkg/rest/stats_test.go rename to pkg/vpn/stats_test.go index b484606..b04c0d0 100644 --- a/pkg/rest/stats_test.go +++ b/pkg/vpn/stats_test.go @@ -1,4 +1,4 @@ -package rest +package vpn import ( "bytes" @@ -11,7 +11,8 @@ import ( "testing" "time" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + "github.com/in4it/go-devops-platform/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) @@ -19,10 +20,8 @@ func TestUserStatsHandler(t *testing.T) { storage := &memorystorage.MockMemoryStorage{} - c, err := newContext(storage, SERVER_TYPE_VPN) - if err != nil { - t.Fatalf("Cannot create context") - } + v := New(storage, &users.UserStore{}) + testData := `2024-08-23T19:29:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,12729136,24348520,2024-08-23T18:30:42 2024-08-23T19:34:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,13391716,25162108,2024-08-23T19:33:38 2024-08-23T19:39:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,14419152,27496068,2024-08-23T19:37:39 @@ -34,7 +33,7 @@ func TestUserStatsHandler(t *testing.T) { 2024-08-23T20:09:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,39928520,85171728,2024-08-23T20:08:54` statsFile := path.Join(wireguard.VPN_STATS_DIR, "user-"+time.Now().Format("2006-01-02")) + ".log" - err = c.Storage.Client.WriteFile(statsFile, []byte(strings.ReplaceAll(testData, "2024-08-23", time.Now().Format("2006-01-02")))) + err := v.Storage.WriteFile(statsFile, []byte(strings.ReplaceAll(testData, "2024-08-23", time.Now().Format("2006-01-02")))) if err != nil { t.Fatalf("Cannot write test file") } @@ -42,7 +41,7 @@ func TestUserStatsHandler(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/stats/user", nil) req.SetPathValue("date", time.Now().Format("2006-01-02")) w := httptest.NewRecorder() - c.userStatsHandler(w, req) + v.userStatsHandler(w, req) resp := w.Result() diff --git a/pkg/vpn/types.go b/pkg/vpn/types.go new file mode 100644 index 0000000..d70f370 --- /dev/null +++ b/pkg/vpn/types.go @@ -0,0 +1,89 @@ +package vpn + +import ( + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" +) + +type VPN struct { + Storage storage.Iface + UserStore *users.UserStore + Hostname string + Protocol string +} + +type NewConnectionResponse struct { + Name string `json:"name"` +} +type Connection struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type UserStatsResponse struct { + ReceiveBytes UserStatsData `json:"receivedBytes"` + TransmitBytes UserStatsData `json:"transmitBytes"` + Handshakes UserStatsData `json:"handshakes"` +} +type UserStatsData struct { + Datasets UserStatsDatasets `json:"datasets"` +} +type UserStatsDatasets []UserStatsDataset +type UserStatsDataset struct { + Label string `json:"label"` + Data []UserStatsDataPoint `json:"data"` + Fill bool `json:"fill"` + BorderColor string `json:"borderColor"` + BackgroundColor string `json:"backgroundColor"` + Tension float64 `json:"tension"` + ShowLine bool `json:"showLine"` +} + +type UserStatsDataPoint struct { + X string `json:"x"` + Y float64 `json:"y"` +} + +type LogDataResponse struct { + LogData LogData `json:"logData"` + Enabled bool `json:"enabled"` + LogTypes []string `json:"logTypes"` + Users map[string]string `json:"users"` +} + +type LogData struct { + Schema LogSchema `json:"schema"` + Data []LogRow `json:"rows"` + NextPos int64 `json:"nextPos"` +} +type LogSchema struct { + Columns map[string]string `json:"columns"` +} +type LogRow struct { + Timestamp string `json:"t"` + Data []string `json:"d"` +} + +type VPNSetupRequest struct { + Routes string `json:"routes"` + VPNEndpoint string `json:"vpnEndpoint"` + AddressRange string `json:"addressRange"` + ClientAddressPrefix string `json:"clientAddressPrefix"` + Port string `json:"port"` + ExternalInterface string `json:"externalInterface"` + Nameservers string `json:"nameservers"` + DisableNAT bool `json:"disableNAT"` + EnablePacketLogs bool `json:"enablePacketLogs"` + PacketLogsTypes []string `json:"packetLogsTypes"` + PacketLogsRetention string `json:"packetLogsRetention"` +} + +type TemplateSetupRequest struct { + ClientTemplate string `json:"clientTemplate"` + ServerTemplate string `json:"serverTemplate"` +} + +type ConnectionLicenseResponse struct { + LicenseUserCount int `json:"licenseUserCount"` + ConnectionCount int `json:"connectionCount"` +} diff --git a/pkg/rest/types_sort.go b/pkg/vpn/types_sort.go similarity index 94% rename from pkg/rest/types_sort.go rename to pkg/vpn/types_sort.go index ea7e9f7..34e5744 100644 --- a/pkg/rest/types_sort.go +++ b/pkg/vpn/types_sort.go @@ -1,4 +1,4 @@ -package rest +package vpn func (u UserStatsDatasets) Len() int { return len(u) diff --git a/pkg/vpn/version.go b/pkg/vpn/version.go new file mode 100644 index 0000000..3118ad5 --- /dev/null +++ b/pkg/vpn/version.go @@ -0,0 +1,27 @@ +package vpn + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" +) + +//go:generate cp -r ../../latest ./resources/version +//go:embed resources/version + +var version string + +func (v *VPN) version(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + out, err := json.Marshal(map[string]string{"version": strings.TrimSpace(version)}) + if err != nil { + v.returnError(w, fmt.Errorf("version marshal error: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + default: + v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} diff --git a/pkg/vpn/vpn.go b/pkg/vpn/vpn.go new file mode 100644 index 0000000..42f8bcb --- /dev/null +++ b/pkg/vpn/vpn.go @@ -0,0 +1,139 @@ +package vpn + +import ( + "encoding/json" + "fmt" + "net/http" + "path" + "strings" + "sync" + + "github.com/in4it/go-devops-platform/rest" + "github.com/in4it/go-devops-platform/users" + "github.com/in4it/wireguard-server/pkg/wireguard" +) + +var muClientDownload sync.Mutex + +func (v *VPN) connectionsHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + user := r.Context().Value(rest.CustomValue("user")).(users.User) + + clients, err := v.Storage.ReadDir(v.Storage.ConfigPath(wireguard.VPN_CLIENTS_DIR)) + if err != nil { + v.returnError(w, fmt.Errorf("cannot list connections for user: %s", err), http.StatusBadRequest) + return + } + + connectionList := []string{} + for _, clientFilename := range clients { + if wireguard.HasClientUserID(clientFilename, user.ID) { + connectionList = append(connectionList, clientFilename) + } + } + peerConfigs := make([]wireguard.PeerConfig, len(connectionList)) + for k, connection := range connectionList { + var peerConfig wireguard.PeerConfig + filename := v.Storage.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connection)) + toDeleteFileContents, err := v.Storage.ReadFile(filename) + if err != nil { + v.returnError(w, fmt.Errorf("can't read file %s: %s", filename, err), http.StatusBadRequest) + return + } + err = json.Unmarshal(toDeleteFileContents, &peerConfig) + if err != nil { + v.returnError(w, fmt.Errorf("can't unmarshal file %s: %s", filename, err), http.StatusBadRequest) + return + } + peerConfigs[k] = peerConfig + } + connections := make([]Connection, len(peerConfigs)) + for k := range peerConfigs { + connections[k] = Connection{ + ID: peerConfigs[k].ID, + Name: peerConfigs[k].Name, + } + } + out, err := json.Marshal(connections) + if err != nil { + v.returnError(w, fmt.Errorf("could not marshal list connection response: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + case http.MethodPost: + muClientDownload.Lock() + defer muClientDownload.Unlock() + user := r.Context().Value(rest.CustomValue("user")).(users.User) + peerConfig, err := wireguard.NewEmptyClientConfig(v.Storage, user.ID) + if err != nil { + v.returnError(w, fmt.Errorf("could not generate client vpn config: %s", err), http.StatusBadRequest) + return + } + newConnectionResponse := NewConnectionResponse{Name: peerConfig.Name} + out, err := json.Marshal(newConnectionResponse) + if err != nil { + v.returnError(w, fmt.Errorf("could not marshal new connection response: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + default: + v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} +func (v *VPN) connectionsElementHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + user := r.Context().Value(rest.CustomValue("user")).(users.User) + if !strings.HasPrefix(r.PathValue("id"), user.ID) { + v.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest) + return + } + if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") { + v.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest) + return + } + out, err := wireguard.GenerateNewClientConfig(v.Storage, r.PathValue("id"), user.ID) + if err != nil { + v.returnError(w, fmt.Errorf("GetClientConfig error: %s", err), http.StatusBadRequest) + return + } + v.write(w, out) + case http.MethodDelete: + user := r.Context().Value(rest.CustomValue("user")).(users.User) + if !strings.HasPrefix(r.PathValue("id"), user.ID) { + v.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest) + return + } + if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") { + v.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest) + return + } + err := wireguard.DeleteClientConfig(v.Storage, r.PathValue("id"), user.ID) + if err != nil { + v.returnError(w, fmt.Errorf("DeleteClientConfig error: %s", err), http.StatusBadRequest) + return + } + v.write(w, []byte(`{"deleted": "`+r.PathValue("id")+`"}`)) + + default: + v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (v *VPN) connectionLicenseHandler(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value(rest.CustomValue("user")).(users.User) + licenseUserCount := r.Context().Value(rest.CustomValue("licenseUserCount")).(int) + totalConnections, err := wireguard.GetConfigNumbers(v.Storage, user.ID) + if err != nil { + v.returnError(w, fmt.Errorf("can't determine total connections: %s", err), http.StatusBadRequest) + return + + } + out, err := json.Marshal(ConnectionLicenseResponse{LicenseUserCount: licenseUserCount, ConnectionCount: len(totalConnections)}) + if err != nil { + v.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) + return + } + v.write(w, out) +} diff --git a/pkg/auth/provisioning/scim/users_test.go b/pkg/vpn/vpn_test.go similarity index 53% rename from pkg/auth/provisioning/scim/users_test.go rename to pkg/vpn/vpn_test.go index 2c399ac..013e8e9 100644 --- a/pkg/auth/provisioning/scim/users_test.go +++ b/pkg/vpn/vpn_test.go @@ -1,7 +1,8 @@ -package scim +package vpn import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -9,158 +10,19 @@ import ( "net/http" "net/http/httptest" "path" + "strings" "testing" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" - "github.com/in4it/wireguard-server/pkg/users" + "github.com/in4it/go-devops-platform/auth/provisioning/scim" + "github.com/in4it/go-devops-platform/rest" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + "github.com/in4it/go-devops-platform/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) const USERSTORE_MAX_USERS = 1000 -func TestUsersGetCount100EmptyResult(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - - userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) - if err != nil { - t.Fatalf("cannot create new user store") - } - userStore.Empty() - if err != nil { - t.Fatalf("cannot empty user store") - } - - s := New(storage, userStore, "token") - req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?count=100&startIndex=1&", nil) - w := httptest.NewRecorder() - s.getUsersHandler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - response, err := listUserResponse([]users.User{}, "", 100, 1) - if err != nil { - t.Fatalf("userResponse error: %s", err) - } - if string(body) != string(response) { - t.Fatalf("expected empty input. Got %s\n", string(body)) - } -} - -func TestUsersGetCount10(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) - if err != nil { - t.Fatalf("cannot create new user store") - } - err = userStore.Empty() - if err != nil { - t.Fatalf("cannot empty user store") - } - totalUserCount := 150 - usersToCreate := make([]users.User, totalUserCount) - for i := 0; i < totalUserCount; i++ { - usersToCreate[i] = users.User{ - Login: fmt.Sprintf("user-%d@domain.inv", i), - } - } - users, err := userStore.AddUsers(usersToCreate) - if err != nil { - t.Fatalf("cannot create users: %s", err) - } - s := New(storage, userStore, "token") - req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?count=10&startIndex=1&", nil) - w := httptest.NewRecorder() - s.getUsersHandler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - response, err := listUserResponse(users, "", 10, 1) - if err != nil { - t.Fatalf("userResponse error: %s", err) - } - if string(body) != string(response) { - t.Fatalf("Unexpected output: Got: %s\nExpected: %s\n\n", string(body), string(response)) - } -} - -func TestUsersGetCount10Start5(t *testing.T) { - count := 10 - start := 5 - storage := &memorystorage.MockMemoryStorage{} - userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) - if err != nil { - t.Fatalf("cannot create new user store") - } - err = userStore.Empty() - if err != nil { - t.Fatalf("cannot empty user store") - } - totalUserCount := 150 - usersToCreate := make([]users.User, totalUserCount) - for i := 0; i < totalUserCount; i++ { - usersToCreate[i] = users.User{ - Login: fmt.Sprintf("user-%d@domain.inv", i), - } - } - users, err := userStore.AddUsers(usersToCreate) - if err != nil { - t.Fatalf("cannot create users: %s", err) - } - s := New(storage, userStore, "token") - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/api/scim/v2/Users?count=%d&startIndex=%d&", count, start), nil) - w := httptest.NewRecorder() - s.getUsersHandler(w, req) - - resp := w.Result() - - var userResponse UserResponse - err = json.NewDecoder(resp.Body).Decode(&userResponse) - if err != nil { - t.Fatalf("Could not decode output: %s", err) - } - if userResponse.TotalResults != totalUserCount-start { - t.Fatalf("Wrong user count: %d", userResponse.TotalResults) - } - if userResponse.ItemsPerPage != count { - t.Fatalf("Wrong page count: %d", userResponse.TotalResults) - } - if userResponse.StartIndex != start { - t.Fatalf("Wrong user start: %d", userResponse.StartIndex) - } - if len(userResponse.Resources) != count { - t.Fatalf("Wrong response count: %d", len(userResponse.Resources)) - } - if userResponse.Resources[0].UserName != users[5].Login { - t.Fatalf("Wrong first login: %s (actual) vs %s (expected)", userResponse.Resources[0].UserName, users[5].Login) - } -} - -func TestUsersGetNonExistentUser(t *testing.T) { - userStore, err := users.NewUserStore(&memorystorage.MockMemoryStorage{}, USERSTORE_MAX_USERS) - if err != nil { - t.Fatalf("cannot create new user stoer") - } - - s := New(&memorystorage.MockMemoryStorage{}, userStore, "token") - req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?filter=userName+eq+%22ward%40in4it.io%22&", nil) - w := httptest.NewRecorder() - s.getUsersHandler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - response, err := listUserResponse([]users.User{}, "", -1, -1) - if err != nil { - t.Fatalf("userResponse error: %s", err) - } - if string(body) != string(response) { - t.Fatalf("expected empty input. Got %s\n", string(body)) - } -} - -func TestAddUser(t *testing.T) { +func TestSCIMCreateUserConnectionDeleteUserFlow(t *testing.T) { storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { @@ -170,54 +32,7 @@ func TestAddUser(t *testing.T) { if err != nil { t.Fatalf("cannot empty user store") } - s := New(storage, userStore, "token") - payload := PostUserRequest{ - UserName: "john@domain.inv", - Name: Name{ - GivenName: "John", - FamilyName: "Doe", - }, - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("cannot marshal payload: %s", err) - } - req := httptest.NewRequest("POST", "http://example.com/api/scim/v2/Users?", bytes.NewBuffer(payloadBytes)) - w := httptest.NewRecorder() - s.postUsersHandler(w, req) - - resp := w.Result() - - if resp.StatusCode != 201 { - t.Fatalf("User not added. StatusCode: %d", resp.StatusCode) - } - - var postUserRequest PostUserRequest - err = json.NewDecoder(resp.Body).Decode(&postUserRequest) - if err != nil { - t.Fatalf("Could not decode output: %s", err) - } - - if postUserRequest.Id == "" { - t.Fatalf("id is empty: %s", err) - } - if postUserRequest.UserName != payload.UserName { - t.Fatalf("username mismatch: %s (actual) vs %s (expected)", postUserRequest.UserName, payload.UserName) - } - -} - -func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { - storage := &memorystorage.MockMemoryStorage{} - userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) - if err != nil { - t.Fatalf("cannot create new user store: %s", err) - } - userStore.Empty() - if err != nil { - t.Fatalf("cannot empty user store") - } - s := New(storage, userStore, "token") + s := scim.New(storage, userStore, "token", nil, nil) l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI) if err != nil { @@ -250,9 +65,9 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { defer l.Close() // create a user - payload := PostUserRequest{ + payload := scim.PostUserRequest{ UserName: "john@domain.inv", - Name: Name{ + Name: scim.Name{ GivenName: "John", FamilyName: "Doe", }, @@ -263,7 +78,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { } req := httptest.NewRequest("POST", "http://example.com/api/scim/v2/Users?", bytes.NewBuffer(payloadBytes)) w := httptest.NewRecorder() - s.postUsersHandler(w, req) + s.PostUsersHandler(w, req) resp := w.Result() @@ -273,7 +88,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { defer resp.Body.Close() - var postUserRequest PostUserRequest + var postUserRequest scim.PostUserRequest err = json.NewDecoder(resp.Body).Decode(&postUserRequest) if err != nil { t.Fatalf("Could not decode output: %s", err) @@ -314,7 +129,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { req = httptest.NewRequest("DELETE", "http://example.com/api/scim/v2/Users/"+user.ID, nil) req.SetPathValue("id", user.ID) w = httptest.NewRecorder() - s.deleteUserHandler(w, req) + s.DeleteUserHandler(w, req) resp = w.Result() @@ -338,7 +153,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { if err != nil { t.Fatalf("cannot empty user store") } - s := New(storage, userStore, "token") + s := scim.New(storage, userStore, "token", nil, nil) l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI) if err != nil { @@ -371,9 +186,9 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { defer l.Close() // create a user - payload := PostUserRequest{ + payload := scim.PostUserRequest{ UserName: "john@domain.inv", - Name: Name{ + Name: scim.Name{ GivenName: "John", FamilyName: "Doe", }, @@ -384,7 +199,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { } req := httptest.NewRequest("POST", "http://example.com/api/scim/v2/Users?", bytes.NewBuffer(payloadBytes)) w := httptest.NewRecorder() - s.postUsersHandler(w, req) + s.PostUsersHandler(w, req) resp := w.Result() @@ -394,7 +209,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { defer resp.Body.Close() - var postUserRequest PostUserRequest + var postUserRequest scim.PostUserRequest err = json.NewDecoder(resp.Body).Decode(&postUserRequest) if err != nil { t.Fatalf("Could not decode output: %s", err) @@ -444,7 +259,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { req = httptest.NewRequest("PUT", "http://example.com/api/scim/v2/Users/"+user.ID, bytes.NewBuffer(payloadBytes)) req.SetPathValue("id", user.ID) w = httptest.NewRecorder() - s.putUserHandler(w, req) + s.PutUserHandler(w, req) resp = w.Result() @@ -466,3 +281,119 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { t.Fatalf("VPN connection is enabled. Expected disabled") } } + +func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { + l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI) + if err != nil { + t.Fatal(err) + } + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + if r.RequestURI == "/refresh-clients" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } + w.WriteHeader(http.StatusBadRequest) + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + + ts.Listener.Close() + ts.Listener = l + ts.Start() + defer ts.Close() + defer l.Close() + + // first create a new user + storage := &memorystorage.MockMemoryStorage{} + + v := New(storage, &users.UserStore{}) + + err = v.UserStore.Empty() + if err != nil { + t.Fatalf("Cannot create context") + } + + // create a user + userToCreate := users.User{ + Login: "john", + Role: "user", + Password: "xyz", + } + user, err := v.UserStore.AddUser(userToCreate) + if err != nil { + t.Fatalf("user creation error: %s", err) + } + + // generate VPN config + _, err = wireguard.CreateNewVPNConfig(v.Storage) + if err != nil { + t.Fatalf("Cannot create vpn config: %s", err) + } + + req := httptest.NewRequest("POST", "http://example.com/connections", nil) + w := httptest.NewRecorder() + v.connectionsHandler(w, req.WithContext(context.WithValue(context.Background(), rest.CustomValue("user"), user))) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + connectionID := fmt.Sprintf("%s-1", user.ID) + + userConfigFilename := storage.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connectionID+".json")) + configBytes, err := storage.ReadFile(userConfigFilename) + if err != nil { + t.Fatalf("could not read user config file") + } + + var config wireguard.PeerConfig + err = json.Unmarshal(configBytes, &config) + if err != nil { + t.Fatalf("could not parse config: %s", err) + } + if config.Disabled { + t.Fatalf("VPN connection is disabled. Expected not disabled") + } + + req = httptest.NewRequest("GET", "http://example.com/connection/"+connectionID, nil) + req.SetPathValue("id", connectionID) + w = httptest.NewRecorder() + v.connectionsElementHandler(w, req.WithContext(context.WithValue(context.Background(), rest.CustomValue("user"), user))) + + resp = w.Result() + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("readall error: %s", err) + } + if !strings.Contains(string(body), "[Interface]") { + t.Fatalf("output doesn't look like a wireguard client config: %s", body) + } + + err = v.UserStore.DeleteUserByID(user.ID) + if err != nil { + t.Fatalf("user deletion error: %s", err) + } + + _, err = storage.ReadFile(userConfigFilename) + if err == nil { + t.Fatalf("could read user config file, expected not to") + } +} diff --git a/pkg/wireguard/ip.go b/pkg/wireguard/ip.go index 3def8ad..6d3101d 100644 --- a/pkg/wireguard/ip.go +++ b/pkg/wireguard/ip.go @@ -8,7 +8,7 @@ import ( "path" "strings" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" ) func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix, addressPrefix string) (net.IP, error) { diff --git a/pkg/wireguard/linux/syncclients/cleanup.go b/pkg/wireguard/linux/syncclients/cleanup.go index 85a14db..c877d1b 100644 --- a/pkg/wireguard/linux/syncclients/cleanup.go +++ b/pkg/wireguard/linux/syncclients/cleanup.go @@ -8,7 +8,7 @@ import ( "fmt" "path" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" wireguardlinux "github.com/in4it/wireguard-server/pkg/wireguard/linux" ) diff --git a/pkg/wireguard/linux/syncclients/process.go b/pkg/wireguard/linux/syncclients/process.go index 0da83f2..8e5327f 100644 --- a/pkg/wireguard/linux/syncclients/process.go +++ b/pkg/wireguard/linux/syncclients/process.go @@ -7,7 +7,7 @@ import ( "fmt" "log" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" wireguardlinux "github.com/in4it/wireguard-server/pkg/wireguard/linux" ) diff --git a/pkg/wireguard/linux/syncclients/server.go b/pkg/wireguard/linux/syncclients/server.go index 4675079..42082a8 100644 --- a/pkg/wireguard/linux/syncclients/server.go +++ b/pkg/wireguard/linux/syncclients/server.go @@ -8,7 +8,7 @@ import ( "path" "strings" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" wireguardlinux "github.com/in4it/wireguard-server/pkg/wireguard/linux" ) diff --git a/pkg/wireguard/linux/syncclients/wg.go b/pkg/wireguard/linux/syncclients/wg.go index bc36826..567fedf 100644 --- a/pkg/wireguard/linux/syncclients/wg.go +++ b/pkg/wireguard/linux/syncclients/wg.go @@ -9,7 +9,7 @@ import ( "path" "strings" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard" ) diff --git a/pkg/wireguard/packetlogger.go b/pkg/wireguard/packetlogger.go index 978acfd..b40069b 100644 --- a/pkg/wireguard/packetlogger.go +++ b/pkg/wireguard/packetlogger.go @@ -19,9 +19,9 @@ import ( "github.com/gopacket/gopacket" "github.com/gopacket/gopacket/layers" - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" - dateutils "github.com/in4it/wireguard-server/pkg/utils/date" + "github.com/in4it/go-devops-platform/logging" + "github.com/in4it/go-devops-platform/storage" + dateutils "github.com/in4it/go-devops-platform/utils/date" "github.com/packetcap/go-pcap" "golang.org/x/sys/unix" ) diff --git a/pkg/wireguard/packetlogger_test.go b/pkg/wireguard/packetlogger_test.go index a89c047..62267e3 100644 --- a/pkg/wireguard/packetlogger_test.go +++ b/pkg/wireguard/packetlogger_test.go @@ -13,9 +13,9 @@ import ( "testing" "time" - localstorage "github.com/in4it/wireguard-server/pkg/storage/local" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" - dateutils "github.com/in4it/wireguard-server/pkg/utils/date" + localstorage "github.com/in4it/go-devops-platform/storage/local" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + dateutils "github.com/in4it/go-devops-platform/utils/date" ) func TestParsePacket(t *testing.T) { diff --git a/pkg/wireguard/stats_linux.go b/pkg/wireguard/stats_linux.go index e4c0bfc..7963741 100644 --- a/pkg/wireguard/stats_linux.go +++ b/pkg/wireguard/stats_linux.go @@ -11,8 +11,8 @@ import ( "strings" "time" - "github.com/in4it/wireguard-server/pkg/logging" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/logging" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard/linux/stats" ) diff --git a/pkg/wireguard/vpnconfig.go b/pkg/wireguard/vpnconfig.go index 8b11516..21ce0f6 100644 --- a/pkg/wireguard/vpnconfig.go +++ b/pkg/wireguard/vpnconfig.go @@ -15,7 +15,7 @@ import ( "sync" "time" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" "github.com/in4it/wireguard-server/pkg/wireguard/network" ) diff --git a/pkg/wireguard/wireguardclientconfig.go b/pkg/wireguard/wireguardclientconfig.go index e65d77c..c6831a8 100644 --- a/pkg/wireguard/wireguardclientconfig.go +++ b/pkg/wireguard/wireguardclientconfig.go @@ -15,7 +15,8 @@ import ( "text/template" "time" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" ) var clientConfigMutex sync.Mutex @@ -377,7 +378,7 @@ func DeleteClientConfig(storage storage.Iface, connectionID, userID string) erro } return nil } -func DisableAllClientConfigs(storage storage.Iface, userID string) error { +func DisableAllClientConfigs(storage storage.Iface, user users.User) error { clientConfigMutex.Lock() defer clientConfigMutex.Unlock() clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR)) @@ -387,7 +388,7 @@ func DisableAllClientConfigs(storage storage.Iface, userID string) error { toDelete := []string{} for _, clientFilename := range clients { - if HasClientUserID(clientFilename, userID) { + if HasClientUserID(clientFilename, user.ID) { toDelete = append(toDelete, clientFilename) } } @@ -438,7 +439,7 @@ func DisableAllClientConfigs(storage storage.Iface, userID string) error { } return nil } -func ReactivateAllClientConfigs(storage storage.Iface, userID string) error { +func ReactivateAllClientConfigs(storage storage.Iface, user users.User) error { clientConfigMutex.Lock() defer clientConfigMutex.Unlock() clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR)) @@ -448,7 +449,7 @@ func ReactivateAllClientConfigs(storage storage.Iface, userID string) error { toAdd := []string{} for _, clientFilename := range clients { - if HasClientUserID(clientFilename, userID) { + if HasClientUserID(clientFilename, user.ID) { toAdd = append(toAdd, clientFilename) } } diff --git a/pkg/wireguard/wireguardclientconfig_test.go b/pkg/wireguard/wireguardclientconfig_test.go index 3a33628..2489eca 100644 --- a/pkg/wireguard/wireguardclientconfig_test.go +++ b/pkg/wireguard/wireguardclientconfig_test.go @@ -12,7 +12,8 @@ import ( "testing" "time" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + "github.com/in4it/go-devops-platform/users" ) func TestGetNextFreeIPFromList(t *testing.T) { @@ -482,7 +483,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Errorf("Peer config is disabled") } - err = DisableAllClientConfigs(storage, "2-2-2-2") + err = DisableAllClientConfigs(storage, users.User{ID: "2-2-2-2"}) if err != nil { t.Fatalf("DisableAllClientConfigs error: %s", err) } @@ -504,7 +505,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Errorf("peer config not disabled") } - err = ReactivateAllClientConfigs(storage, "2-2-2-2") + err = ReactivateAllClientConfigs(storage, users.User{ID: "2-2-2-2"}) if err != nil { t.Fatalf("DisableAllClientConfigs error: %s", err) } diff --git a/pkg/wireguard/wireguardserverconfig.go b/pkg/wireguard/wireguardserverconfig.go index be0155a..5b1f354 100644 --- a/pkg/wireguard/wireguardserverconfig.go +++ b/pkg/wireguard/wireguardserverconfig.go @@ -7,7 +7,7 @@ import ( "path" "text/template" - "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/go-devops-platform/storage" ) func WriteWireGuardServerConfig(storage storage.Iface) error { diff --git a/pkg/wireguard/wireguardserverconfig_test.go b/pkg/wireguard/wireguardserverconfig_test.go index 3552bf6..dcde1ab 100644 --- a/pkg/wireguard/wireguardserverconfig_test.go +++ b/pkg/wireguard/wireguardserverconfig_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" ) func TestWriteWireGuardServerConfig(t *testing.T) {