diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 6cadf56..50bf310 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -18,6 +18,6 @@ jobs:
- name: Go generate
run: |
go generate ./...
- touch pkg/rest/static/placeholder # we're not building frontend, so we put a placeholder
+ touch cmd/rest-server/static/placeholder # we're not building frontend, so we put a placeholder
- name: Test
run: make test
diff --git a/cmd/cloudwatch-ingestion/main.go b/cmd/cloudwatch-ingestion/main.go
deleted file mode 100644
index 78798e5..0000000
--- a/cmd/cloudwatch-ingestion/main.go
+++ /dev/null
@@ -1,142 +0,0 @@
-package main
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "flag"
- "fmt"
- "log"
- "net/http"
- "os"
- "strings"
-
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/config"
- "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
- "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
-)
-
-func main() {
- var (
- logGroupName string
- )
- flag.StringVar(&logGroupName, "log-group", "", "log group to ingest")
- flag.Parse()
-
- if logGroupName == "" {
- flag.PrintDefaults()
- os.Exit(1)
- }
-
- cfg, err := config.LoadDefaultConfig(context.TODO())
- if err != nil {
- log.Fatalf("unable to load SDK config, %v", err)
- }
-
- svc := cloudwatchlogs.NewFromConfig(cfg)
-
- logStreams, err := fetchLogStreams(svc, logGroupName)
- if err != nil {
- log.Fatalf("failed to fetch log streams: %v", err)
- }
-
- for _, stream := range logStreams {
- err := fetchLogEvents(svc, logGroupName, *stream.LogStreamName)
- if err != nil {
- log.Printf("failed to fetch log events for stream %s: %v", *stream.LogStreamName, err)
- return
- }
- }
-}
-
-func fetchLogStreams(svc *cloudwatchlogs.Client, logGroupName string) ([]types.LogStream, error) {
- var allStreams []types.LogStream
- nextToken := ""
-
- for {
- input := &cloudwatchlogs.DescribeLogStreamsInput{
- LogGroupName: aws.String(logGroupName),
- }
-
- if nextToken != "" {
- input.NextToken = aws.String(nextToken)
- }
-
- result, err := svc.DescribeLogStreams(context.TODO(), input)
- if err != nil {
- return nil, err
- }
-
- allStreams = append(allStreams, result.LogStreams...)
-
- if result.NextToken == nil {
- break
- }
-
- nextToken = *result.NextToken
- }
-
- return allStreams, nil
-}
-
-func fetchLogEvents(svc *cloudwatchlogs.Client, logGroupName, logStreamName string) error {
- nextToken := ""
- messages := []map[string]any{}
- logStreamNameSplit := strings.Split(logStreamName, "/")
- logStreamWithoutRandom := strings.Join(logStreamNameSplit[:len(logStreamNameSplit)-1], "/")
-
- for {
- input := &cloudwatchlogs.GetLogEventsInput{
- LogGroupName: aws.String(logGroupName),
- LogStreamName: aws.String(logStreamName),
- StartFromHead: aws.Bool(true),
- }
-
- if nextToken != "" {
- input.NextToken = aws.String(nextToken)
- }
-
- result, err := svc.GetLogEvents(context.TODO(), input)
- if err != nil {
- return err
- }
-
- for _, event := range result.Events {
- seconds := float64(*event.Timestamp / 1000)
- microseconds := float64(*event.Timestamp%1000) * 1000
- messages = append(messages, map[string]any{
- "date": seconds + (microseconds / 1e6),
- "log": *event.Message,
- "log-group": logGroupName,
- "log-stream": logStreamWithoutRandom,
- })
- }
-
- if result.NextForwardToken == nil || nextToken == *result.NextForwardToken {
- break
- }
-
- nextToken = *result.NextForwardToken
- }
-
- if len(messages) == 0 {
- return nil
- }
-
- out, err := json.Marshal(messages)
- if err != nil {
- return err
- }
- resp, err := http.Post("http://localhost/api/observability/ingestion/json", "image/jpeg", bytes.NewBuffer(out))
- if err != nil {
- return err
- }
- if resp.StatusCode != 200 {
- return fmt.Errorf("response code is not 200")
- }
-
- fmt.Printf("Ingested log-group %s, stream %s: %d messages\n", logGroupName, logStreamWithoutRandom, len(messages))
-
- return nil
-}
diff --git a/cmd/observability-rest-server/main.go b/cmd/observability-rest-server/main.go
deleted file mode 100644
index 20dc3d5..0000000
--- a/cmd/observability-rest-server/main.go
+++ /dev/null
@@ -1,18 +0,0 @@
-package main
-
-import (
- "flag"
-
- "github.com/in4it/wireguard-server/pkg/rest"
-)
-
-func main() {
- var (
- httpPort int
- httpsPort int
- )
- flag.IntVar(&httpPort, "http-port", 80, "http port to run server on")
- flag.IntVar(&httpsPort, "https-port", 443, "https port to run server on")
- flag.Parse()
- rest.StartServer(httpPort, httpsPort, rest.SERVER_TYPE_OBSERVABILITY)
-}
diff --git a/cmd/reset-admin-password/main.go b/cmd/reset-admin-password/main.go
index 292ba27..cf9a877 100644
--- a/cmd/reset-admin-password/main.go
+++ b/cmd/reset-admin-password/main.go
@@ -10,8 +10,8 @@ import (
"strings"
"syscall"
+ localstorage "github.com/in4it/go-devops-platform/storage/local"
"github.com/in4it/wireguard-server/pkg/commands"
- localstorage "github.com/in4it/wireguard-server/pkg/storage/local"
"golang.org/x/term"
)
diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go
index b7a7f22..bce3055 100644
--- a/cmd/rest-server/main.go
+++ b/cmd/rest-server/main.go
@@ -1,9 +1,22 @@
package main
import (
+ "embed"
"flag"
+ "log"
- "github.com/in4it/wireguard-server/pkg/rest"
+ "github.com/in4it/go-devops-platform/auth/provisioning/scim"
+ "github.com/in4it/go-devops-platform/licensing"
+ "github.com/in4it/go-devops-platform/rest"
+ localstorage "github.com/in4it/go-devops-platform/storage/local"
+ "github.com/in4it/go-devops-platform/users"
+ "github.com/in4it/wireguard-server/pkg/vpn"
+ "github.com/in4it/wireguard-server/pkg/wireguard"
+)
+
+var (
+ //go:embed static
+ assets embed.FS
)
func main() {
@@ -14,5 +27,32 @@ func main() {
flag.IntVar(&httpPort, "http-port", 80, "http port to run server on")
flag.IntVar(&httpsPort, "https-port", 443, "https port to run server on")
flag.Parse()
- rest.StartServer(httpPort, httpsPort, rest.SERVER_TYPE_VPN)
+
+ localStorage, err := localstorage.New()
+ if err != nil {
+ log.Fatalf("couldn't initialize storage: %s", err)
+ }
+ licenseUserCount, cloudType := licensing.GetMaxUsers(localStorage)
+
+ userStore, err := users.NewUserStoreWithHooks(localStorage, licenseUserCount, users.UserHooks{
+ DisableFunc: wireguard.DisableAllClientConfigs,
+ DeleteFunc: wireguard.DeleteAllClientConfigs,
+ ReactivateFunc: wireguard.ReactivateAllClientConfigs,
+ })
+ if err != nil {
+ log.Fatalf("startup failed: userstore initialization error: %s", err)
+ }
+
+ scimInstance := scim.New(localStorage, userStore, "")
+
+ apps := map[string]rest.AppClient{
+ "vpn": vpn.New(localStorage, userStore),
+ }
+
+ c, err := rest.NewContext(localStorage, rest.SERVER_TYPE_VPN, userStore, scimInstance, licenseUserCount, cloudType, apps)
+ if err != nil {
+ log.Fatalf("startup failed: %s", err)
+ }
+
+ rest.StartServer(httpPort, httpsPort, rest.SERVER_TYPE_VPN, localStorage, c, assets)
}
diff --git a/pkg/rest/resources/.gitignore b/cmd/rest-server/static/.gitignore
similarity index 100%
rename from pkg/rest/resources/.gitignore
rename to cmd/rest-server/static/.gitignore
diff --git a/go.mod b/go.mod
index 20188fc..64b81b5 100644
--- a/go.mod
+++ b/go.mod
@@ -12,31 +12,33 @@ require (
github.com/packetcap/go-pcap v0.0.0-20240528124601-8c87ecf5dbc5
github.com/russellhaering/gosaml2 v0.9.1
github.com/russellhaering/goxmldsig v1.4.0
- golang.org/x/crypto v0.27.0
- golang.org/x/sys v0.25.0
- golang.org/x/term v0.24.0
+ golang.org/x/crypto v0.28.0
+ golang.org/x/sys v0.26.0
+ golang.org/x/term v0.25.0
)
+require github.com/in4it/go-devops-platform v0.0.0-20241015191315-e2f711a32e69 // indirect
+
require (
- github.com/aws/aws-sdk-go-v2 v1.31.0 // indirect
- github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 // indirect
- github.com/aws/aws-sdk-go-v2/config v1.27.33 // indirect
- github.com/aws/aws-sdk-go-v2/credentials v1.17.32 // indirect
- github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.13 // indirect
- github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 // indirect
- github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 // indirect
+ github.com/aws/aws-sdk-go-v2 v1.32.2 // indirect
+ github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect
+ github.com/aws/aws-sdk-go-v2/config v1.27.43 // indirect
+ github.com/aws/aws-sdk-go-v2/credentials v1.17.41 // indirect
+ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
- github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.17 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 // indirect
github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.2
- github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect
- github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.19 // indirect
- github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.19 // indirect
- github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.17 // indirect
- github.com/aws/aws-sdk-go-v2/service/s3 v1.61.2 // indirect
- github.com/aws/aws-sdk-go-v2/service/sso v1.22.7 // indirect
- github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.7 // indirect
- github.com/aws/aws-sdk-go-v2/service/sts v1.30.7 // indirect
- github.com/aws/smithy-go v1.21.0 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 // indirect
+ github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 // indirect
+ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 // indirect
+ github.com/aws/smithy-go v1.22.0 // indirect
github.com/beevik/etree v1.4.1 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
@@ -45,7 +47,7 @@ require (
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
- golang.org/x/net v0.29.0 // indirect
+ golang.org/x/net v0.30.0 // indirect
golang.org/x/sync v0.8.0 // indirect
- golang.org/x/text v0.18.0 // indirect
+ golang.org/x/text v0.19.0 // indirect
)
diff --git a/go.sum b/go.sum
index b4ceaad..d3f98d7 100644
--- a/go.sum
+++ b/go.sum
@@ -2,50 +2,67 @@ github.com/aws/aws-sdk-go-v2 v1.30.5 h1:mWSRTwQAb0aLE17dSzztCVJWI9+cRMgqebndjwDy
github.com/aws/aws-sdk-go-v2 v1.30.5/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0=
github.com/aws/aws-sdk-go-v2 v1.31.0 h1:3V05LbxTSItI5kUqNwhJrrrY1BAXxXt0sN0l72QmG5U=
github.com/aws/aws-sdk-go-v2 v1.31.0/go.mod h1:ztolYtaEUtdpf9Wftr31CJfLVjOnD/CVRkKOOYgF8hA=
+github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 h1:70PVAiL15/aBMh5LThwgXdSQorVr91L127ttckI9QQU=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4/go.mod h1:/MQxMqci8tlqDH+pjmoLu1i0tbWCUP1hhyMRuFxpQCw=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 h1:xDAuZTn4IMm8o1LnBZvmrL8JA1io4o3YWNXgohbf20g=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5/go.mod h1:wYSv6iDS621sEFLfKvpPE2ugjTuGlAG7iROg0hLOkfc=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA=
github.com/aws/aws-sdk-go-v2/config v1.27.33 h1:Nof9o/MsmH4oa0s2q9a0k7tMz5x/Yj5k06lDODWz3BU=
github.com/aws/aws-sdk-go-v2/config v1.27.33/go.mod h1:kEqdYzRb8dd8Sy2pOdEbExTTF5v7ozEXX0McgPE7xks=
+github.com/aws/aws-sdk-go-v2/config v1.27.43/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc=
github.com/aws/aws-sdk-go-v2/credentials v1.17.32 h1:7Cxhp/BnT2RcGy4VisJ9miUPecY+lyE9I8JvcZofn9I=
github.com/aws/aws-sdk-go-v2/credentials v1.17.32/go.mod h1:P5/QMF3/DCHbXGEGkdbilXHsyTBX5D3HSwcrSc9p20I=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.41/go.mod h1:u4Eb8d3394YLubphT4jLEwN1rLNq2wFOlT6OuxFwPzU=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.13 h1:pfQ2sqNpMVK6xz2RbqLEL0GH87JOwSxPV2rzm8Zsb74=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.13/go.mod h1:NG7RXPUlqfsCLLFfi0+IpKN4sCB9D9fw/qTaSB+xRoU=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17/go.mod h1:1ZRXLdTpzdJb9fwTMXiLipENRxkGMTn1sfKexGllQCw=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 h1:pI7Bzt0BJtYA0N/JEC6B8fJ4RBrEMi1LBrkMdFYNSnQ=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17/go.mod h1:Dh5zzJYMtxfIjYW+/evjQ8uj2OyR/ve2KROHGHlSFqE=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 h1:kYQ3H1u0ANr9KEKlGs/jTLrBFPo8P8NaH/w7A01NeeM=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18/go.mod h1:r506HmK5JDUh9+Mw4CfGJGSSoqIiLCndAuqXuhbv67Y=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21/go.mod h1:JNr43NFf5L9YaG3eKTm7HQzls9J+A9YYcGI5Quh1r2Y=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 h1:Mqr/V5gvrhA2gvgnF42Zh5iMiQNcOYthFYwCyrnuWlc=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17/go.mod h1:aLJpZlCmjE+V+KtN1q1uyZkfnUWpQGpbsn89XPKyzfU=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 h1:Z7IdFUONvTcvS7YuhtVxN99v2cCoHRXOS4mTr0B/pUc=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18/go.mod h1:DkKMmksZVVyat+Y+r1dEOgJEfUeA7UngIHWeKsi0yNc=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21/go.mod h1:1SR0GbLlnN3QUmYaflZNiH1ql+1qrSiB2vwcJ+4UM60=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.17 h1:Roo69qTpfu8OlJ2Tb7pAYVuF0CpuUMB0IYWwYP/4DZM=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.17/go.mod h1:NcWPxQzGM1USQggaTVwz6VpqMZPX1CvDJLDh6jnOCa4=
+github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21/go.mod h1:Q9o5h4HoIWG8XfzxqiuK/CGUbepCJ8uTlaE3bAbxytQ=
github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.2 h1:q5+hHt4JBA+8K6uAvfLWpUs7ErVR0GNW0Xf5KTOl84c=
github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.2/go.mod h1:3p7NzlLlJesNGovq7Vqx8+0UibawzodrBRQAbaza6pI=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbLPPHEmf9DgMGw51jMj77VfGPAN2Kv4cfhlfgI=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.19 h1:FLMkfEiRjhgeDTCjjLoc3URo/TBkgeQbocA78lfkzSI=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.19/go.mod h1:Vx+GucNSsdhaxs3aZIKfSUjKVGsxN25nX2SRcdhuw08=
+github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2/go.mod h1:LWoqeWlK9OZeJxsROW2RqrSPvQHKTpp69r/iDjwsSaw=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.19 h1:rfprUlsdzgl7ZL2KlXiUAoJnI/VxfHCvDFr2QDFj6u4=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.19/go.mod h1:SCWkEdRq8/7EK60NcvvQ6NXKuTcchAD4ROAsC37VEZE=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2/go.mod h1:fnjjWyAW/Pj5HYOxl9LJqWtEwS7W2qgcRLWP+uWbss0=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.17 h1:u+EfGmksnJc/x5tq3A+OD7LrMbSSR/5TrKLvkdy/fhY=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.17/go.mod h1:VaMx6302JHax2vHJWgRo+5n9zvbacs3bLU/23DNQrTY=
+github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2/go.mod h1:/niFCtmuQNxqx9v8WAPq5qh7EH25U4BF6tjoyq9bObM=
github.com/aws/aws-sdk-go-v2/service/s3 v1.61.2 h1:Kp6PWAlXwP1UvIflkIP6MFZYBNDCa4mFCGtxrpICVOg=
github.com/aws/aws-sdk-go-v2/service/s3 v1.61.2/go.mod h1:5FmD/Dqq57gP+XwaUnd5WFPipAuzrf0HmupX27Gvjvc=
+github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3/go.mod h1:cB6oAuus7YXRZhWCc1wIwPywwZ1XwweNp2TVAEGYeB8=
github.com/aws/aws-sdk-go-v2/service/sso v1.22.7 h1:pIaGg+08llrP7Q5aiz9ICWbY8cqhTkyy+0SHvfzQpTc=
github.com/aws/aws-sdk-go-v2/service/sso v1.22.7/go.mod h1:eEygMHnTKH/3kNp9Jr1n3PdejuSNcgwLe1dWgQtO0VQ=
+github.com/aws/aws-sdk-go-v2/service/sso v1.24.2/go.mod h1:skMqY7JElusiOUjMJMOv1jJsP7YUg7DrhgqZZWuzu1U=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.7 h1:/Cfdu0XV3mONYKaOt1Gr0k1KvQzkzPyiKUdlWJqy+J4=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.7/go.mod h1:bCbAxKDqNvkHxRaIMnyVPXPo+OaPRwvmgzMxbz1VKSA=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2/go.mod h1:o8aQygT2+MVP0NaV6kbdE1YnnIM8RRVQzoeUH45GOdI=
github.com/aws/aws-sdk-go-v2/service/sts v1.30.7 h1:NKTa1eqZYw8tiHSRGpP0VtTdub/8KNk8sDkNPFaOKDE=
github.com/aws/aws-sdk-go-v2/service/sts v1.30.7/go.mod h1:NXi1dIAGteSaRLqYgarlhP/Ij0cFT+qmCwiJqWh/U5o=
+github.com/aws/aws-sdk-go-v2/service/sts v1.32.2/go.mod h1:HtaiBI8CjYoNVde8arShXb94UbQQi9L4EMr6D+xGBwo=
github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4=
github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA=
github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
+github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A=
github.com/beevik/etree v1.4.0 h1:oz1UedHRepuY3p4N5OjE0nK1WLCqtzHf25bxplKOHLs=
github.com/beevik/etree v1.4.0/go.mod h1:cyWiXwGoasx60gHvtnEh5x8+uIjUVnjWqBvEnhnqKDA=
@@ -67,6 +84,30 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopacket/gopacket v1.3.0 h1:MouZCc+ej0vnqzB0WeiaO/6+tGvb+KU7UczxoQ+X0Yc=
github.com/gopacket/gopacket v1.3.0/go.mod h1:WnFrU1Xkf5lWKV38uKNR9+yYtppn+ZYzOyNqMeH4oNE=
+github.com/in4it/go-devops-platform v0.0.0-20241015140612-1949bc7fd57a h1:annNJWLpGHM4C6bvJKYC/eU9G9mF4IHBjHftozKGl1g=
+github.com/in4it/go-devops-platform v0.0.0-20241015140612-1949bc7fd57a/go.mod h1:NXjJaxWbmL3mCCKcfG/MgFRimD9Bw3GO6PyQrY7KpRk=
+github.com/in4it/go-devops-platform v0.0.0-20241015143916-a95e92939b2d h1:CHqTnGJpP9D9K8IcwH/3Y6iDdwj2+hrNIld3xyEZeF8=
+github.com/in4it/go-devops-platform v0.0.0-20241015143916-a95e92939b2d/go.mod h1:Rkybt1QK1H/UWQw4zQLUgmN7qkcus14Xwu7Po6uZyDY=
+github.com/in4it/go-devops-platform v0.0.0-20241015144646-23f35360de7f h1:Ly3NBMZTWQJlOEVrEvyE49wev9cRz3Lk99ZJ1++i19I=
+github.com/in4it/go-devops-platform v0.0.0-20241015144646-23f35360de7f/go.mod h1:+cwJAz5QYDgexz0GBjOqXPvlNjwxPZVsNO0VLAODDTs=
+github.com/in4it/go-devops-platform v0.0.0-20241015145222-2bcb11eedbc9 h1:xtNPNaS8AQJAVZAfP6ip+X3WoyPJ1+MKonuoCMAmcnE=
+github.com/in4it/go-devops-platform v0.0.0-20241015145222-2bcb11eedbc9/go.mod h1:+cwJAz5QYDgexz0GBjOqXPvlNjwxPZVsNO0VLAODDTs=
+github.com/in4it/go-devops-platform v0.0.0-20241015151511-486df9498223 h1:7kVZ6gO1SmwHB1qmGUh592j1gRxV60rdXsGp1y46gYs=
+github.com/in4it/go-devops-platform v0.0.0-20241015151511-486df9498223/go.mod h1:+cwJAz5QYDgexz0GBjOqXPvlNjwxPZVsNO0VLAODDTs=
+github.com/in4it/go-devops-platform v0.0.0-20241015170045-c08f1c5bacbe h1:/K0LdRqsG0m5vNeL08bYNAcze3k1H8+WzObXR5qnvyY=
+github.com/in4it/go-devops-platform v0.0.0-20241015170045-c08f1c5bacbe/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q=
+github.com/in4it/go-devops-platform v0.0.0-20241015170530-725d20de58f6 h1:DC3GjGu5Sk1Z0SQTpynXDgwtd3ozHAf1EyRCCdRGiR8=
+github.com/in4it/go-devops-platform v0.0.0-20241015170530-725d20de58f6/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q=
+github.com/in4it/go-devops-platform v0.0.0-20241015171457-d0cf3d638954 h1:xLSjoVWn+ZnWsYU/4YTha/LmsmUWghKf3BJkPJw4TeE=
+github.com/in4it/go-devops-platform v0.0.0-20241015171457-d0cf3d638954/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q=
+github.com/in4it/go-devops-platform v0.0.0-20241015173130-0b49ea0db408 h1:iQl/CXUmPJPer9NfVYsRTOwWHcXUs1q4oJQLzecJeuI=
+github.com/in4it/go-devops-platform v0.0.0-20241015173130-0b49ea0db408/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q=
+github.com/in4it/go-devops-platform v0.0.0-20241015173332-a45080cabae5 h1:pcZ2PeYjk5/Bis6llji5uTqocjbIpxPlm4flLoweVt0=
+github.com/in4it/go-devops-platform v0.0.0-20241015173332-a45080cabae5/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q=
+github.com/in4it/go-devops-platform v0.0.0-20241015191019-e2183445416a h1:y0pHhwuhAZYB4v99PHYScjyPv45mfCM5iiUWL6rAhIg=
+github.com/in4it/go-devops-platform v0.0.0-20241015191019-e2183445416a/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q=
+github.com/in4it/go-devops-platform v0.0.0-20241015191315-e2f711a32e69 h1:/1BTKr5cqiQKAf8KMhR8z6Pup7eqzDbh32qQb+xh4Tw=
+github.com/in4it/go-devops-platform v0.0.0-20241015191315-e2f711a32e69/go.mod h1:xugWZer+8U7DcIWlE95SiPvtVPmJzhB9YCYiIScLK5Q=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
@@ -113,10 +154,14 @@ golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
+golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
+golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
+golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
+golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -124,14 +169,20 @@ golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
+golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
+golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
+golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
+golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
+golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
diff --git a/pkg/auth/oidc/rand.go b/pkg/auth/oidc/rand.go
deleted file mode 100644
index 8881662..0000000
--- a/pkg/auth/oidc/rand.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package oidc
-
-import (
- "crypto/rand"
- "encoding/base64"
- "fmt"
- "io"
-)
-
-func GetRandomString(n int) (string, error) {
- buf := make([]byte, n)
-
- _, err := io.ReadFull(rand.Reader, buf)
- if err != nil {
- return "", fmt.Errorf("crypto/rand Reader error: %s", err)
- }
-
- return base64.RawURLEncoding.EncodeToString(buf), nil
-}
diff --git a/pkg/auth/oidc/redirect.go b/pkg/auth/oidc/redirect.go
deleted file mode 100644
index 4eb60f0..0000000
--- a/pkg/auth/oidc/redirect.go
+++ /dev/null
@@ -1,26 +0,0 @@
-package oidc
-
-import (
- "fmt"
- "strings"
-)
-
-func GetRedirectURI(discovery Discovery, clientID, scope, callback string, enableOIDCTokenRenewal bool) (string, string, error) {
- var redirectURI string
-
- state, err := GetRandomString(64)
- if err != nil {
- return redirectURI, state, fmt.Errorf("GetRandomString error: %s", err)
- }
-
- // add offline_access to scope if oidc token renewal is true
- //if enableOIDCTokenRenewal {
- // scope = strings.TrimSpace(scope) + " offline_access"
- //}
-
- scope = strings.Replace(scope, " ", "%20", -1)
-
- redirectURI = fmt.Sprintf("%s?client_id=%s&state=%s&scope=%s&response_type=code&redirect_uri=%s", discovery.AuthorizationEndpoint, clientID, state, scope, callback)
-
- return redirectURI, state, nil
-}
diff --git a/pkg/auth/oidc/store/cleanup.go b/pkg/auth/oidc/store/cleanup.go
deleted file mode 100644
index add79f5..0000000
--- a/pkg/auth/oidc/store/cleanup.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package oidcstore
-
-import (
- "slices"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
-)
-
-func (store *Store) CleanupOAuth2DataForAllEntries() int {
- deleted := 0
- for _, oauthData := range store.OAuth2Data {
- deleted += store.CleanupOAuth2Data(oauthData)
- }
- return deleted
-}
-
-func (store *Store) CleanupOAuth2Data(oauthData oidc.OAuthData) int {
- keysToDelete1 := []string{}
- keysToDelete2 := []string{}
- keysToDelete3 := []string{}
- store.Mu.Lock()
- defer store.Mu.Unlock()
- for k := range store.OAuth2Data {
- if oauthData.CreatedAt.After(store.OAuth2Data[k].CreatedAt) {
- // cleanup old oauth2 data that might be duplicates (same subject & oidc provider, but older tokens)
- if store.OAuth2Data[k].ID != oauthData.ID && store.OAuth2Data[k].OIDCProviderID == oauthData.OIDCProviderID && store.OAuth2Data[k].Subject == oauthData.Subject {
- keysToDelete1 = append(keysToDelete1, k)
- }
- // cleanup oauthdata with the same email address
- if store.OAuth2Data[k].ID != oauthData.ID && store.OAuth2Data[k].UserInfo.Email == oauthData.UserInfo.Email {
- keysToDelete3 = append(keysToDelete3, k)
- }
- }
- // cleanup old oauth2 data that doesn't have a token and is stale
- if store.OAuth2Data[k].Token.AccessToken == "" && store.OAuth2Data[k].CreatedAt.Add(10*time.Minute).After(time.Now()) {
- keysToDelete2 = append(keysToDelete2, k)
- }
- }
- keysToDelete := []string{}
- keysToDelete = append(keysToDelete, keysToDelete1...)
- keysToDelete = append(keysToDelete, keysToDelete2...)
- keysToDelete = append(keysToDelete, keysToDelete3...)
- slices.Sort(keysToDelete)
-
- for _, key := range slices.Compact(keysToDelete) {
- delete(store.OAuth2Data, key)
- }
- return len(keysToDelete)
-}
diff --git a/pkg/auth/oidc/store/discovery.go b/pkg/auth/oidc/store/discovery.go
deleted file mode 100644
index ee3e7f6..0000000
--- a/pkg/auth/oidc/store/discovery.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package oidcstore
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
-)
-
-func (store *Store) GetDiscoveryURI(discoveryURI string) (oidc.Discovery, error) {
- var discovery oidc.Discovery
- if cachedDiscovery, ok := store.getDiscoveryCache(discoveryURI); ok {
- if cachedDiscovery.Expiration.After(time.Now()) {
- return cachedDiscovery.Discovery, nil // cache hit, we can return
- }
- }
-
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- resp, err := client.Get(discoveryURI)
- if err != nil {
- return discovery, fmt.Errorf("discoveryURL Get error: %s", err)
- }
- if resp.StatusCode != 200 {
- return discovery, fmt.Errorf("DiscoveryURI Request unsuccesful (status code returned: %d)", resp.StatusCode)
- }
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&discovery)
- if err != nil {
- return discovery, fmt.Errorf("discoveryURL decode error: %s", err)
- }
-
- store.setDiscoveryCache(discoveryURI, oidc.DiscoveryCache{
- Expiration: time.Now().Add(12 * time.Hour), // standard 12h cache
- Discovery: discovery,
- })
-
- return discovery, nil
-}
diff --git a/pkg/auth/oidc/store/discovery_test.go b/pkg/auth/oidc/store/discovery_test.go
deleted file mode 100644
index 31b6d9b..0000000
--- a/pkg/auth/oidc/store/discovery_test.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package oidcstore
-
-import (
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestGetDiscovery(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/discovery.json" {
- discovery := oidc.Discovery{
- Issuer: "test-issuer",
- }
- out, err := json.Marshal(discovery)
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- }
- w.Write(out)
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- store, err := NewStore(&memorystorage.MockMemoryStorage{})
- if err != nil {
- t.Fatalf("new store error: %s", err)
- }
- uri := ts.URL + "/discovery.json"
- discovery, err := store.GetDiscoveryURI(uri)
- if err != nil {
- t.Fatalf("get discovery error: %s", err)
- }
- if discovery.Issuer != "test-issuer" {
- t.Fatalf("wrong issuer")
- }
-
- // cached response
- discovery, err = store.GetDiscoveryURI(uri)
- if err != nil {
- t.Fatalf("get discovery error: %s", err)
- }
- if discovery.Issuer != "test-issuer" {
- t.Fatalf("wrong issuer")
- }
- if _, ok := store.DiscoveryCache[uri]; !ok {
- t.Fatalf("discovery not in cache")
- }
-}
diff --git a/pkg/auth/oidc/store/jwks.go b/pkg/auth/oidc/store/jwks.go
deleted file mode 100644
index 264cb82..0000000
--- a/pkg/auth/oidc/store/jwks.go
+++ /dev/null
@@ -1,55 +0,0 @@
-package oidcstore
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
-)
-
-func (store *Store) GetJwks(jwksURI string) (oidc.Jwks, error) {
- var jwksKeys oidc.Jwks
-
- if cachedJwks, ok := store.JwksCache[jwksURI]; ok {
- if cachedJwks.Expiration.Before(time.Now()) {
- return cachedJwks.Jwks, nil // cache hit, we can return
- }
- }
-
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- resp, err := client.Get(jwksURI)
- if err != nil {
- return jwksKeys, fmt.Errorf("discoveryURL Get error: %s", err)
- }
- if resp.StatusCode != 200 {
- return jwksKeys, fmt.Errorf("DiscoveryURI Request unsuccesful (status code returned: %d)", resp.StatusCode)
- }
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&jwksKeys)
- if err != nil {
- return jwksKeys, fmt.Errorf("discoveryURL decode error: %s", err)
- }
-
- store.setJwksCache(jwksURI, oidc.JwksCache{
- Expiration: time.Now().Add(20 * time.Minute), // 20 minute cache standard
- Jwks: jwksKeys,
- })
-
- return jwksKeys, nil
-}
-
-func (store *Store) GetAllJwks(discoveryProviders []oidc.Discovery) ([]oidc.Jwks, error) {
- allJwks := make([]oidc.Jwks, len(discoveryProviders))
- for k, discovery := range discoveryProviders {
- var err error
- allJwks[k], err = store.GetJwks(discovery.JwksURI)
- if err != nil {
- return []oidc.Jwks{}, fmt.Errorf("get jwks error: %s", err)
- }
- }
- return allJwks, nil
-}
diff --git a/pkg/auth/oidc/store/jwks_test.go b/pkg/auth/oidc/store/jwks_test.go
deleted file mode 100644
index 94c808a..0000000
--- a/pkg/auth/oidc/store/jwks_test.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package oidcstore
-
-import (
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestGetJwks(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/jwks.json" {
- jwksKeys := oidc.Jwks{
- Keys: []oidc.JwksKey{
- {
- Kid: "1-2-3-4",
- Kty: "kty",
- },
- },
- }
- out, err := json.Marshal(jwksKeys)
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- }
- w.Write(out)
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- store, err := NewStore(&memorystorage.MockMemoryStorage{})
- if err != nil {
- t.Fatalf("new store error: %s", err)
- }
- uri := ts.URL + "/jwks.json"
- jwks, err := store.GetJwks(uri)
- if err != nil {
- t.Fatalf("get jwks error: %s", err)
- }
- if len(jwks.Keys) == 0 {
- t.Fatalf("jwks is empty")
- }
- if jwks.Keys[0].Kid != "1-2-3-4" {
- t.Fatalf("wrong kid: %s", jwks.Keys[0].Kid)
- }
- // cached response
-
- jwks, err = store.GetJwks(uri)
- if err != nil {
- t.Fatalf("get jwks error: %s", err)
- }
- if len(jwks.Keys) == 0 {
- t.Fatalf("jwks is empty")
- }
- if jwks.Keys[0].Kid != "1-2-3-4" {
- t.Fatalf("wrong kid: %s", jwks.Keys[0].Kid)
- }
- if _, ok := store.JwksCache[uri]; !ok {
- t.Fatalf("jwks not in cache")
- }
-
- // get all jwks
- allJwks, err := store.GetAllJwks([]oidc.Discovery{{JwksURI: uri}})
- if err != nil {
- t.Fatalf("get all jwks error: %s", err)
- }
- if len(allJwks) == 0 {
- t.Fatalf("all jwks is zero")
- }
- if allJwks[0].Keys[0].Kid != "1-2-3-4" {
- t.Fatalf("wrong kid for allJwks: %s", jwks.Keys[0].Kid)
- }
-}
diff --git a/pkg/auth/oidc/store/renewal/new.go b/pkg/auth/oidc/store/renewal/new.go
deleted file mode 100644
index 605c5c3..0000000
--- a/pkg/auth/oidc/store/renewal/new.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package oidcrenewal
-
-import (
- "time"
-
- oidc "github.com/in4it/wireguard-server/pkg/auth/oidc"
- oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store"
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-type Renewal struct {
- oidcStore *oidcstore.Store
- enabled bool
- oidcProviders []oidc.OIDCProvider
- userStore *users.UserStore
- renewalTime time.Duration
- storage storage.Iface
-}
-
-func NewRenewal(storage storage.Iface, renewalTime int, contextLogLevel int, enabled bool, oidcstore *oidcstore.Store, oidcProviders []oidc.OIDCProvider, userStore *users.UserStore) (*Renewal, error) {
- r := &Renewal{
- enabled: enabled,
- oidcStore: oidcstore,
- oidcProviders: oidcProviders,
- userStore: userStore,
- storage: storage,
- }
- logging.Loglevel = contextLogLevel
- if renewalTime <= 5 {
- r.renewalTime = DEFAULT_RENEWAL_TIME_MINUTES * time.Minute
- } else {
- r.renewalTime = time.Duration(renewalTime) * time.Minute
- }
- go r.Worker()
- return r, nil
-}
-
-func (r *Renewal) SetEnabled(enabled bool) {
- r.enabled = enabled
-}
diff --git a/pkg/auth/oidc/store/renewal/refreshtoken.go b/pkg/auth/oidc/store/renewal/refreshtoken.go
deleted file mode 100644
index 676cae4..0000000
--- a/pkg/auth/oidc/store/renewal/refreshtoken.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package oidcrenewal
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
-)
-
-func refreshToken(discovery oidc.Discovery, refreshToken, clientID, clientSecret string) (oidc.Token, time.Time, error) {
- var token oidc.Token
- client := http.Client{
- Timeout: 5 * time.Second,
- }
-
- if discovery.TokenEndpoint == "" {
- return token, time.Time{}, fmt.Errorf("token endpoint is empty")
- }
-
- payload := url.Values{
- "grant_type": {"refresh_token"},
- "refresh_token": {refreshToken},
- "client_id": {clientID},
- "client_secret": {clientSecret},
- }
-
- resp, err := client.PostForm(discovery.TokenEndpoint, payload)
- if err != nil {
- return token, time.Time{}, fmt.Errorf("tokenEndpoint PostForm error: %s", err)
- }
- renewalTime := time.Now()
- if resp.StatusCode != 200 {
- data, err := io.ReadAll(resp.Body)
- if err != nil {
- return token, renewalTime, fmt.Errorf("tokenEndpoint return error. statuscode: %d", resp.StatusCode)
- }
- return token, renewalTime, fmt.Errorf("tokenEndpoint return error (statuscode %d): %s", resp.StatusCode, data)
- }
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&token)
- if err != nil {
- return token, renewalTime, fmt.Errorf("tokenEndpoint decode error: %s", err)
- }
- if token.AccessToken == "" {
- return token, renewalTime, fmt.Errorf("access token is empty")
- }
-
- return token, renewalTime, nil
-
-}
diff --git a/pkg/auth/oidc/store/renewal/renew.go b/pkg/auth/oidc/store/renewal/renew.go
deleted file mode 100644
index b8a06c0..0000000
--- a/pkg/auth/oidc/store/renewal/renew.go
+++ /dev/null
@@ -1,156 +0,0 @@
-package oidcrenewal
-
-import (
- "encoding/base64"
- "encoding/json"
- "fmt"
- "strings"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store"
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-func (r *Renewal) RenewAllOIDCConnections() {
- // force renewal of all tokens, even if they're not expired (unless they're empty)
- for key, oauth2Data := range r.oidcStore.OAuth2Data {
- if oidcProvider, err := getOIDCProvider(oauth2Data.OIDCProviderID, r.oidcProviders); err == nil {
- if discovery, err := r.oidcStore.GetDiscoveryURI(oidcProvider.DiscoveryURI); err == nil {
- if oauth2Data.RenewalFailed || oauth2Data.Token.AccessToken == "" {
- logging.DebugLog(fmt.Errorf("skipping %s (renewal already failed or access token is empty. RenewalFailed: %v, AccessToken is empty: %v)", oauth2Data.ID, oauth2Data.RenewalFailed, oauth2Data.Token.AccessToken == ""))
- } else {
- logging.DebugLog(fmt.Errorf("trying to renew %s", oauth2Data.ID))
- r.renew(discovery, key, oauth2Data, oidcProvider)
- }
- } else {
- logging.DebugLog(fmt.Errorf("could not get discovery url for %s: %s", oauth2Data.ID, err))
- }
- } else {
- logging.DebugLog(fmt.Errorf("could not get oidcprovider for %s: %s", oauth2Data.ID, err))
- }
- }
-}
-func (r *Renewal) renew(discovery oidc.Discovery, key string, oauth2Data oidc.OAuthData, oidcProvider oidc.OIDCProvider) {
- newToken, newTokenTimestamp, err := refreshToken(discovery, oauth2Data.Token.RefreshToken, oidcProvider.ClientID, oidcProvider.ClientSecret)
- if err != nil {
- oauth2Data.RenewalRetries++
- logging.ErrorLog(fmt.Errorf("renewal Worker: could not refresh token for %s (attemp %d/%d): %s", oauth2Data.ID, oauth2Data.RenewalRetries, RENEWAL_RETRIES, err))
- if oauth2Data.RenewalRetries >= RENEWAL_RETRIES {
- oauth2Data.RenewalFailed = true
- }
- err = r.oidcStore.StoreEntry(key, oauth2Data)
- if err != nil {
- logging.ErrorLog(fmt.Errorf("renewal Worker: [error] StoreEntry: %s", err))
- }
- err = r.oidcStore.SaveOIDCStore()
- if err != nil {
- logging.ErrorLog(fmt.Errorf("renewal Worker: [error] SaveOIDCStore: %s", err))
- }
- // suspend connections
- if oauth2Data.RenewalFailed {
- err = disableUser(r.storage, oauth2Data, r.userStore)
- if err != nil {
- logging.ErrorLog(fmt.Errorf("renewal Worker: [error] disableUser: %s", err))
- }
- }
- return
- }
- logging.DebugLog(fmt.Errorf("new token issued at %v: %+v", newToken, newTokenTimestamp))
- oauth2Data.LastTokenRenewal = newTokenTimestamp
- oauth2Data.Token.AccessToken = newToken.AccessToken
- oauth2Data.Token.ExpiresIn = newToken.ExpiresIn
- oauth2Data.Token.RefreshToken = newToken.RefreshToken
- if newToken.IDToken != "" {
- oauth2Data.Token.IDToken = newToken.IDToken
- }
- err = r.oidcStore.StoreEntry(key, oauth2Data)
- if err != nil {
- logging.ErrorLog(fmt.Errorf("renewal Worker: [error] StoreEntry: %s", err))
- }
- err = r.oidcStore.SaveOIDCStore()
- if err != nil {
- logging.ErrorLog(fmt.Errorf("renewal Worker: [error] SaveOIDCStore: %s", err))
- }
-}
-
-func disableUser(storage storage.Iface, oauth2Data oidc.OAuthData, userStore *users.UserStore) error {
- logging.DebugLog(fmt.Errorf("disable user with oidc id %s", oauth2Data.ID))
- user, err := userStore.GetUserByOIDCIDs([]string{oauth2Data.ID})
- if err != nil {
- return fmt.Errorf("no user found with oidc id %s", oauth2Data.ID)
- }
- err = wireguard.DisableAllClientConfigs(storage, user.ID)
- if err != nil {
- return fmt.Errorf("DisableAllClientConfigs error for userID %s: %s", user.ID, err)
- }
- user.ConnectionsDisabledOnAuthFailure = true
- err = userStore.UpdateUser(user)
- if err != nil {
- return fmt.Errorf("could not update connectionsDisabledOnAuthFailure user with userID %s: %s", user.ID, err)
- }
- return nil
-}
-
-func getOIDCProvider(id string, oidcProviders []oidc.OIDCProvider) (oidc.OIDCProvider, error) {
- for _, oidcProvider := range oidcProviders {
- if oidcProvider.ID == id {
- return oidcProvider, nil
- }
- }
- return oidc.OIDCProvider{}, fmt.Errorf("oidc provider not found")
-
-}
-
-func getExpirationDate(token string) (time.Time, error) {
- jwtSplit := strings.Split(token, ".")
- if len(jwtSplit) < 2 {
- return time.Time{}, fmt.Errorf("token split < 2")
- }
- data, err := base64.RawURLEncoding.DecodeString(jwtSplit[1])
- if err != nil {
- return time.Time{}, fmt.Errorf("could not base64 decode data part of jwt")
- }
- var jwt jwtExp
- err = json.Unmarshal(data, &jwt)
- if err != nil {
- return time.Time{}, fmt.Errorf("could not unmarshal jwt data")
- }
- if jwt.Expiration == 0 {
- return time.Time{}, fmt.Errorf("exp not found in jwt data")
- }
- return time.Unix(jwt.Expiration, 0), nil
-}
-
-func canRenew(renewalTime time.Duration, oauth2Data oidc.OAuthData, store *oidcstore.Store, oidcProviders []oidc.OIDCProvider) (bool, oidc.OIDCProvider, oidc.Discovery, error) {
- if oauth2Data.RenewalFailed {
- return false, oidc.OIDCProvider{}, oidc.Discovery{}, nil
- }
- if oauth2Data.Token.AccessToken == "" {
- logging.DebugLog(fmt.Errorf("access token empty of oidc id %s", oauth2Data.ID))
- return false, oidc.OIDCProvider{}, oidc.Discovery{}, nil
- }
- expirationDate, err := getExpirationDate(oauth2Data.Token.AccessToken)
- if err != nil {
- return false, oidc.OIDCProvider{}, oidc.Discovery{}, fmt.Errorf("can't get expiration date of refresh_token (id:%s). error: %s", oauth2Data.ID, err)
- }
-
- if time.Since(oauth2Data.LastTokenRenewal) > renewalTime || time.Now().After(expirationDate) {
- logging.DebugLog(fmt.Errorf("going to renew token for %s", oauth2Data.ID))
- oidcProvider, err := getOIDCProvider(oauth2Data.OIDCProviderID, oidcProviders)
- if err != nil {
- return false, oidc.OIDCProvider{}, oidc.Discovery{}, fmt.Errorf("could not get oidcprovider for %s: %s", oauth2Data.ID, err)
- }
- discovery, err := store.GetDiscoveryURI(oidcProvider.DiscoveryURI)
- if err != nil {
- return false, oidc.OIDCProvider{}, oidc.Discovery{}, fmt.Errorf("could not get discovery url for %s: %s", oauth2Data.ID, err)
- }
- return true, oidcProvider, discovery, nil
- } else {
- logging.DebugLog(fmt.Errorf("not renewing oidc id %s. time since last token renewal: %d", oauth2Data.ID, time.Since(oauth2Data.LastTokenRenewal)))
- }
- return false, oidc.OIDCProvider{}, oidc.Discovery{}, nil
-}
diff --git a/pkg/auth/oidc/store/renewal/types.go b/pkg/auth/oidc/store/renewal/types.go
deleted file mode 100644
index 202c716..0000000
--- a/pkg/auth/oidc/store/renewal/types.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package oidcrenewal
-
-type jwtExp struct {
- Expiration int64 `json:"exp"`
-}
diff --git a/pkg/auth/oidc/store/renewal/worker.go b/pkg/auth/oidc/store/renewal/worker.go
deleted file mode 100644
index a4a8252..0000000
--- a/pkg/auth/oidc/store/renewal/worker.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package oidcrenewal
-
-import (
- "fmt"
- "log"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/logging"
-)
-
-const WAKEUP_TIME_SECONDS = 300 // every 5 minutes we check
-const DEFAULT_RENEWAL_TIME_MINUTES = 60 // every hour we want to refresh the token
-const RENEWAL_RETRIES = 3 // 3 retries before we suspend a user
-const RENEWAL_BACKOFF_SECONDS = 10 // time between requests to the oidc providers
-
-func (r *Renewal) Worker() {
- // do renewal
- if r.enabled {
- fmt.Printf("Starting oidc renewal worker (loglevel: %d)\n", logging.Loglevel)
- }
- for {
- if !r.enabled {
- time.Sleep(WAKEUP_TIME_SECONDS * time.Second)
- continue
- }
- deletedEntries := r.oidcStore.CleanupOAuth2DataForAllEntries()
- if deletedEntries > 0 {
- err := r.oidcStore.SaveOIDCStore()
- if err != nil {
- log.Printf("Renewal Worker: [warning] couldn't save oidc store after cleanup: %s", err)
- }
- }
- for key, oauth2Data := range r.oidcStore.OAuth2Data {
- logging.DebugLog(fmt.Errorf("running canRenew of %s", oauth2Data.ID))
- // can we renew? Do we have expiration date and it is expired?
- canRenew, oidcProvider, discovery, err := canRenew(r.renewalTime, oauth2Data, r.oidcStore, r.oidcProviders)
- if err != nil {
- log.Printf("Renewal Worker: [warning] needsRenewal: %s", err)
- }
- if canRenew {
- logging.DebugLog(fmt.Errorf("we can renew %s", oauth2Data.ID))
- r.renew(discovery, key, oauth2Data, oidcProvider) // error logging within function
- }
- time.Sleep(RENEWAL_BACKOFF_SECONDS * time.Second)
- }
- time.Sleep(WAKEUP_TIME_SECONDS * time.Second)
- }
-}
diff --git a/pkg/auth/oidc/store/save.go b/pkg/auth/oidc/store/save.go
deleted file mode 100644
index 67828a5..0000000
--- a/pkg/auth/oidc/store/save.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package oidcstore
-
-import (
- "encoding/json"
- "fmt"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
-)
-
-func (store *Store) SaveOIDCStore() error {
- store.Mu.Lock()
- defer store.Mu.Unlock()
- out, err := json.Marshal(store)
- if err != nil {
- return fmt.Errorf("oidc store marshal error: %s", err)
- }
- filename := store.storage.ConfigPath("oidcstore.json")
- err = store.storage.WriteFile(filename, out)
- if err != nil {
- return fmt.Errorf("oidcstore write error: %s", err)
- }
- return nil
-}
-
-func (store *Store) SaveOAuth2Data(oauth2Data oidc.OAuthData, key string) error {
- store.Mu.Lock()
- store.OAuth2Data[key] = oauth2Data
- store.Mu.Unlock()
- return store.SaveOIDCStore()
-}
diff --git a/pkg/auth/oidc/store/save_test.go b/pkg/auth/oidc/store/save_test.go
deleted file mode 100644
index 9e33d53..0000000
--- a/pkg/auth/oidc/store/save_test.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package oidcstore
-
-import (
- "testing"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestSave(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
- store, err := NewStore(storage)
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- err = store.setDiscoveryCache("test", oidc.DiscoveryCache{Discovery: oidc.Discovery{Issuer: "testissuer"}})
- if err != nil {
- t.Fatalf("error: %s", err)
- }
-
- err = store.SaveOIDCStore()
- if err != nil {
- t.Fatalf("error: %s", err)
- }
-
- _, err = storage.ReadFile(storage.ConfigPath("oidcstore.json"))
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- store2, err := NewStore(storage)
- if err != nil {
- t.Fatalf("error: %s", err)
- }
-
- discovery, ok := store2.getDiscoveryCache("test")
- if !ok {
- t.Fatalf("can't find the cache")
- }
- if discovery.Discovery.Issuer != "testissuer" {
- t.Fatalf("expected testissuer. Got: %s", discovery.Discovery.Issuer)
- }
-}
diff --git a/pkg/auth/oidc/store/store.go b/pkg/auth/oidc/store/store.go
deleted file mode 100644
index ce4480b..0000000
--- a/pkg/auth/oidc/store/store.go
+++ /dev/null
@@ -1,84 +0,0 @@
-package oidcstore
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "sync"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-const DEFAULT_PATH = "oidcstore.json"
-
-var RetrieveTokenLock sync.Mutex
-
-func (store *Store) StoreEntry(state string, oauthData oidc.OAuthData) error {
- store.Mu.Lock()
- store.OAuth2Data[state] = oauthData
- store.Mu.Unlock()
- return nil
-}
-
-func (store *Store) setDiscoveryCache(key string, value oidc.DiscoveryCache) error {
- store.Mu.Lock()
- store.DiscoveryCache[key] = value
- store.Mu.Unlock()
- return nil
-}
-
-func (store *Store) setJwksCache(key string, value oidc.JwksCache) error {
- store.Mu.Lock()
- store.JwksCache[key] = value
- store.Mu.Unlock()
- return nil
-}
-
-func (store *Store) getDiscoveryCache(key string) (oidc.DiscoveryCache, bool) {
- discovery, ok := store.DiscoveryCache[key]
- return discovery, ok
-}
-
-func NewStore(storage storage.Iface) (*Store, error) {
- var store *Store
-
- filename := storage.ConfigPath(DEFAULT_PATH)
-
- // check if oidc.Store exists
- if !storage.FileExists(filename) {
- return getEmptyOIDCStore(storage)
- }
-
- data, err := storage.ReadFile(filename)
- if err != nil {
- return store, fmt.Errorf("config read error: %s", err)
- }
- decoder := json.NewDecoder(bytes.NewBuffer(data))
- err = decoder.Decode(&store)
- if err != nil {
- return store, fmt.Errorf("decode input error: %s", err)
- }
- if store.DiscoveryCache == nil {
- store.DiscoveryCache = make(map[string]oidc.DiscoveryCache)
- }
- if store.JwksCache == nil {
- store.JwksCache = make(map[string]oidc.JwksCache)
- }
- if store.OAuth2Data == nil {
- store.OAuth2Data = make(map[string]oidc.OAuthData)
- }
-
- store.storage = storage
-
- return store, nil
-}
-
-func getEmptyOIDCStore(storage storage.Iface) (*Store, error) {
- return &Store{
- OAuth2Data: make(map[string]oidc.OAuthData),
- DiscoveryCache: make(map[string]oidc.DiscoveryCache),
- JwksCache: make(map[string]oidc.JwksCache),
- storage: storage,
- }, nil
-}
diff --git a/pkg/auth/oidc/store/types.go b/pkg/auth/oidc/store/types.go
deleted file mode 100644
index 0c194e6..0000000
--- a/pkg/auth/oidc/store/types.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package oidcstore
-
-import (
- "sync"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-type Store struct {
- Mu sync.Mutex
- OAuth2Data map[string]oidc.OAuthData `json:"oauth2Data"`
- DiscoveryCache map[string]oidc.DiscoveryCache `json:"discoveryCache"`
- JwksCache map[string]oidc.JwksCache `json:"jwksCache"`
- storage storage.Iface
-}
diff --git a/pkg/auth/oidc/token.go b/pkg/auth/oidc/token.go
deleted file mode 100644
index 1843ac9..0000000
--- a/pkg/auth/oidc/token.go
+++ /dev/null
@@ -1,135 +0,0 @@
-package oidc
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/mail"
- "net/url"
- "time"
-
- "github.com/go-jose/go-jose/v4"
- "github.com/golang-jwt/jwt/v5"
- "github.com/in4it/wireguard-server/pkg/logging"
-)
-
-func RetrieveOAUth2DataUsingState(allOAuth2data map[string]OAuthData, state string) (OAuthData, error) {
- if state == "" {
- return OAuthData{}, fmt.Errorf("no state found")
- }
- oauthData, ok := allOAuth2data[state]
- if !ok {
- return OAuthData{}, fmt.Errorf("oauth data not found (is state missing?)")
- }
- return oauthData, nil
-}
-
-func UpdateOAuth2DataWithToken(jwks Jwks, discovery Discovery, clientID, clientSecret, redirectURI, code, state string, oauth2Data OAuthData) (OAuthData, error) {
- newOAuthData := oauth2Data
- var token Token
- client := http.Client{
- Timeout: 5 * time.Second,
- }
-
- if discovery.TokenEndpoint == "" {
- return newOAuthData, fmt.Errorf("token endpoint is empty")
- }
-
- payload := url.Values{
- "grant_type": {"authorization_code"},
- "code": {code},
- "client_id": {clientID},
- "client_secret": {clientSecret},
- "redirect_uri": {redirectURI},
- }
-
- resp, err := client.PostForm(discovery.TokenEndpoint, payload)
- if err != nil {
- return newOAuthData, fmt.Errorf("tokenEndpoint PostForm error: %s", err)
- }
- renewalTime := time.Now()
- if resp.StatusCode != 200 {
- data, err := io.ReadAll(resp.Body)
- if err != nil {
- return newOAuthData, fmt.Errorf("tokenEndpoint return error. statuscode: %d", resp.StatusCode)
- }
- return newOAuthData, fmt.Errorf("tokenEndpoint return error (statuscode %d): %s", resp.StatusCode, data)
- }
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&token)
- if err != nil {
- return newOAuthData, fmt.Errorf("tokenEndpoint decode error: %s", err)
- }
-
- // verify id token
- parsedToken, err := jwt.Parse(token.IDToken, func(token *jwt.Token) (interface{}, error) {
- publicKey, err := GetPublicKeyForToken([]Jwks{jwks}, []Discovery{discovery}, token)
- if err != nil {
- return nil, fmt.Errorf("GetPublicKeyForToken error: %s", err)
- }
- return publicKey, nil
- })
- if err != nil {
- logging.DebugLog(fmt.Errorf("couldn't verify id token: %s", err))
- return newOAuthData, fmt.Errorf("couldn't verify id token")
- }
- // remove old oauth2data matching oidcproivder and subject
- claims := parsedToken.Claims.(jwt.MapClaims)
- subject, ok := claims["sub"]
- if !ok {
- return newOAuthData, fmt.Errorf("subject missing from id token")
- }
- validEmail := ""
- email, ok := claims["email"]
- if !ok {
- // check if email is in preferred_username
- preferred_username, ok2 := claims["preferred_username"]
- if !ok2 {
- return newOAuthData, fmt.Errorf("email missing from id token (not in email / preferred_username claim)")
- } else {
- _, err := mail.ParseAddress(preferred_username.(string))
- if err != nil {
- return newOAuthData, fmt.Errorf("email missing from id token and preferred_username is not an email address")
- }
- validEmail = preferred_username.(string)
- }
- } else {
- validEmail = email.(string)
- }
- issuer, ok := claims["iss"]
- if !ok {
- return newOAuthData, fmt.Errorf("issuer missing from id token")
- }
-
- newOAuthData.Token = token
- newOAuthData.LastTokenRenewal = renewalTime
- newOAuthData.Subject = subject.(string)
- newOAuthData.Issuer = issuer.(string)
- newOAuthData.UserInfo.Email = validEmail
- return newOAuthData, nil
-}
-
-func GetPublicKeyForToken(allJwks []Jwks, discoveryProviders []Discovery, token *jwt.Token) (any, error) {
- kid, ok := token.Header["kid"]
- if !ok {
- return nil, fmt.Errorf("no kid found in token")
- }
- for _, jwks := range allJwks {
- for _, key := range jwks.Keys {
- if key.Kid == kid {
- jsonWebKey := jose.JSONWebKey{}
- singleKey, err := json.Marshal(key)
- if err != nil {
- return nil, fmt.Errorf("internal server error: cannot marshal key from kid endpoint: %s", err)
- }
- err = jsonWebKey.UnmarshalJSON(singleKey)
- if err != nil {
- return nil, fmt.Errorf("key from jwks import error: %s", err)
- }
- return jsonWebKey.Key, nil
- }
- }
- }
- return nil, fmt.Errorf("no matching kid found for token")
-}
diff --git a/pkg/auth/oidc/types.go b/pkg/auth/oidc/types.go
deleted file mode 100644
index 3356d1a..0000000
--- a/pkg/auth/oidc/types.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package oidc
-
-import (
- "time"
-)
-
-type Discovery struct {
- Issuer string `json:"issuer"`
- AuthorizationEndpoint string `json:"authorization_endpoint"`
- TokenEndpoint string `json:"token_endpoint"`
- UserinfoEndpoint string `json:"userinfo_endpoint"`
- JwksURI string `json:"jwks_uri"`
- ScopesSupported []string `json:"scopes_supported"`
- ResponseTypesSupported []string `json:"response_types_supported"`
- TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
- IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
- ClaimsSupported []string `json:"claims_supported"`
- SubjectTypesSupported []string `json:"subject_types_supported"`
-}
-
-type OIDCProvider struct {
- ID string `json:"id"`
- Name string `json:"name"`
- ClientID string `json:"clientId"`
- ClientSecret string `json:"clientSecret,omitempty"`
- Scope string `json:"scope"`
- DiscoveryURI string `json:"discoveryURI"`
- RedirectURI string `json:"redirectURI"`
- LoginURL string `json:"loginURL,omitempty"`
-}
-
-type Token struct {
- AccessToken string `json:"access_token"`
- TokenType string `json:"token_type"`
- RefreshToken string `json:"refresh_token"`
- ExpiresIn int `json:"expires_in"`
- IDToken string `json:"id_token"`
-}
-
-// jwks
-type Jwks struct {
- Keys []JwksKey `json:"keys"`
-}
-type JwksKey struct {
- N string `json:"n"`
- E string `json:"e"`
- Alg string `json:"alg"`
- Use string `json:"use"`
- Kid string `json:"kid"`
- Kty string `json:"kty"`
-}
-
-type DiscoveryCache struct {
- Expiration time.Time `json:"expiration"`
- Discovery Discovery `json:"discovery"`
-}
-type JwksCache struct {
- Expiration time.Time `json:"expiration"`
- Jwks Jwks `json:"jwks"`
-}
-type OAuthData struct {
- ID string `json:"id"`
- OIDCProviderID string `json:"oidcProviderID"`
- CreatedAt time.Time `json:"createdAt"`
- Subject string `json:"subject"`
- Issuer string `json:"issuer"`
- UserInfo UserInfo `json:"userInfo"`
- Token Token `json:"token"`
- AuthFailed bool `json:"authFailed"`
- Suspended bool `json:"suspended"`
- LastTokenRenewal time.Time `json:"lastTokenRenewal"`
- RenewalFailed bool `json:"renewalFailed"`
- RenewalRetries int `json:"renewalRetries"`
-}
-
-type UserInfo struct {
- Email string `json:"email"`
-}
diff --git a/pkg/auth/provisioning/scim/helpers.go b/pkg/auth/provisioning/scim/helpers.go
deleted file mode 100644
index 8716cff..0000000
--- a/pkg/auth/provisioning/scim/helpers.go
+++ /dev/null
@@ -1,26 +0,0 @@
-package scim
-
-import (
- "fmt"
- "net/http"
-)
-
-func returnError(w http.ResponseWriter, err error, statusCode int) {
- fmt.Println("========= ERROR =========")
- fmt.Printf("Error: %s\n", err)
- fmt.Println("=========================")
- w.WriteHeader(statusCode)
- w.Write([]byte(`{"error": "` + err.Error() + `"}`))
-}
-
-func writeWithStatus(w http.ResponseWriter, res []byte, status int) {
- w.Header().Add("Content-Type", "application/json")
- w.WriteHeader(status)
- w.Write(res)
-}
-
-func write(w http.ResponseWriter, res []byte) {
- w.Header().Add("Content-Type", "application/json")
- w.WriteHeader(http.StatusOK)
- w.Write(res)
-}
diff --git a/pkg/auth/provisioning/scim/middleware.go b/pkg/auth/provisioning/scim/middleware.go
deleted file mode 100644
index 10cfacb..0000000
--- a/pkg/auth/provisioning/scim/middleware.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package scim
-
-import (
- "fmt"
- "net/http"
- "strings"
-)
-
-func (s *scim) authMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
- writeWithStatus(w, []byte(`{"error": "token not found"}`), http.StatusUnauthorized)
- return
- }
- tokenString := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", -1)
- if len(tokenString) == 0 {
- returnError(w, fmt.Errorf("empty token"), http.StatusUnauthorized)
- return
- }
- if s.Token == "" {
- writeWithStatus(w, []byte(`{"error": "scim not active"}`), http.StatusUnauthorized)
- return
- }
- if s.Token != tokenString {
- writeWithStatus(w, []byte(`{"error": "authentication failed"}`), http.StatusUnauthorized)
- return
- }
-
- next.ServeHTTP(w, r)
- })
-}
diff --git a/pkg/auth/provisioning/scim/new.go b/pkg/auth/provisioning/scim/new.go
deleted file mode 100644
index 9c1043c..0000000
--- a/pkg/auth/provisioning/scim/new.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package scim
-
-import (
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-func New(storage storage.Iface, userStore *users.UserStore, token string) *scim {
- s := &scim{
- Token: token,
- UserStore: userStore,
- storage: storage,
- }
- return s
-}
diff --git a/pkg/auth/provisioning/scim/response.go b/pkg/auth/provisioning/scim/response.go
deleted file mode 100644
index 7549e75..0000000
--- a/pkg/auth/provisioning/scim/response.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package scim
-
-import (
- "encoding/json"
- "fmt"
-
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-func listUserResponse(users []users.User, attributes string, count, start int) ([]byte, error) {
- if start != -1 && start > 1 && start <= len(users) {
- users = users[start:]
- }
- totalResults := len(users)
- if len(users) > count && count != -1 {
- users = users[0:count]
- }
- response := UserResponse{
- TotalResults: totalResults,
- ItemsPerPage: len(users),
- StartIndex: start,
- Schemas: getSchemas("ListResponse"),
- Resources: make([]UserResource, len(users)),
- }
- for k := range users {
- response.Resources[k] = UserResource{
- ID: users[k].ID,
- UserName: users[k].Login,
- }
- }
- out, err := json.Marshal(response)
- if err != nil {
- return out, fmt.Errorf("json marshal error: %s", err)
- }
- return out, nil
-}
-
-func userResponse(user users.User) ([]byte, error) {
- response := PostUserRequest{
- Schemas: getSchemas("User"),
- Id: user.ID,
- UserName: user.Login,
- Active: !user.Suspended,
- }
- out, err := json.Marshal(response)
- if err != nil {
- return out, fmt.Errorf("json marshal error: %s", err)
- }
- return out, nil
-}
-
-func getSchemas(responseType string) []string {
- if responseType == "User" {
- return []string{"urn:ietf:params:scim:schemas:core:2.0:User"}
- }
- return []string{"urn:ietf:params:scim:api:messages:2.0:" + responseType}
-
-}
diff --git a/pkg/auth/provisioning/scim/router.go b/pkg/auth/provisioning/scim/router.go
deleted file mode 100644
index f8f0248..0000000
--- a/pkg/auth/provisioning/scim/router.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package scim
-
-import (
- "net/http"
-)
-
-func (s *scim) GetRouter() *http.ServeMux {
- mux := http.NewServeMux()
-
- mux.Handle("/api/scim/", s.authMiddleware(http.HandlerFunc(notFoundHandler)))
- mux.Handle("/api/scim/v2/Users", s.authMiddleware(http.HandlerFunc(s.usersHandler)))
- mux.Handle("/api/scim/v2/Users/{id}", s.authMiddleware(http.HandlerFunc(s.userHandler)))
-
- return mux
-}
-
-func notFoundHandler(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNotFound)
- w.Write([]byte(`{"error": "page not found"}`))
-}
diff --git a/pkg/auth/provisioning/scim/types.go b/pkg/auth/provisioning/scim/types.go
deleted file mode 100644
index 0794075..0000000
--- a/pkg/auth/provisioning/scim/types.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package scim
-
-import (
- "net/http"
-
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-type scim struct {
- Token string `json:"token"`
- UserStore *users.UserStore `json:"userStore"`
- storage storage.Iface
-}
-
-type Iface interface {
- GetRouter() *http.ServeMux
- UpdateToken(token string)
-}
-
-type UserResponse struct {
- TotalResults int `json:"totalResults"`
- ItemsPerPage int `json:"itemsPerPage"`
- StartIndex int `json:"startIndex"`
- Schemas []string `json:"schemas"`
- Resources []UserResource `json:"Resources"`
-}
-type UserResource struct {
- ID string `json:"id"`
- UserName string `json:"userName,omitempty"`
-}
-
-type PostUserRequest struct {
- Schemas []string `json:"schemas"`
- UserName string `json:"userName"`
- Id string `json:"id,omitempty"`
- Name Name `json:"name"`
- Emails []Emails `json:"emails"`
- DisplayName string `json:"displayName"`
- Locale string `json:"locale"`
- ExternalID string `json:"externalId"`
- Groups []any `json:"groups"`
- Password string `json:"password"`
- Active bool `json:"active"`
-}
-type Name struct {
- GivenName string `json:"givenName"`
- FamilyName string `json:"familyName"`
-}
-type Emails struct {
- Primary bool `json:"primary"`
- Value string `json:"value"`
- Type string `json:"type"`
-}
diff --git a/pkg/auth/provisioning/scim/update.go b/pkg/auth/provisioning/scim/update.go
deleted file mode 100644
index e1abc5e..0000000
--- a/pkg/auth/provisioning/scim/update.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package scim
-
-func (s *scim) UpdateToken(token string) {
- s.Token = token
-}
diff --git a/pkg/auth/provisioning/scim/users.go b/pkg/auth/provisioning/scim/users.go
deleted file mode 100644
index 34c7e27..0000000
--- a/pkg/auth/provisioning/scim/users.go
+++ /dev/null
@@ -1,247 +0,0 @@
-package scim
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "strconv"
- "strings"
-
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-// handler for multiple users
-func (s *scim) usersHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- s.getUsersHandler(w, r)
- return
- case http.MethodPost:
- s.postUsersHandler(w, r)
- return
- default:
- returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-// handler for a single user
-func (s *scim) userHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- s.getUserHandler(w, r)
- return
- case http.MethodPut:
- s.putUserHandler(w, r)
- return
- case http.MethodDelete:
- s.deleteUserHandler(w, r)
- return
- default:
- returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (s *scim) getUsersHandler(w http.ResponseWriter, r *http.Request) {
- attributes := r.URL.Query().Get("attributes")
- filter := r.URL.Query().Get("filter")
- count, err := strconv.Atoi(r.URL.Query().Get("count"))
- if err != nil {
- count = -1
- }
- start, err := strconv.Atoi(r.URL.Query().Get("startIndex"))
- if err != nil {
- start = 1
- }
-
- if filter != "" {
- response, err := getUsersWithFilter(s.UserStore, attributes, filter)
- if err != nil {
- returnError(w, fmt.Errorf("get user with filter error: %s", err), http.StatusBadRequest)
- return
- }
- write(w, response)
- return
- }
- response, err := getUsersWithoutFilter(s.UserStore, attributes, count, start)
- if err != nil {
- returnError(w, fmt.Errorf("get user with filter error: %s", err), http.StatusBadRequest)
- return
- }
- write(w, response)
-}
-
-func (s *scim) getUserHandler(w http.ResponseWriter, r *http.Request) {
- user, err := s.UserStore.GetUserByID(r.PathValue("id"))
- if err != nil {
- returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest)
- return
- }
-
- response, err := userResponse(user)
- if err != nil {
- returnError(w, fmt.Errorf("user response error: %s", err), http.StatusBadRequest)
- return
- }
-
- write(w, response)
-}
-func (s *scim) putUserHandler(w http.ResponseWriter, r *http.Request) {
- user, err := s.UserStore.GetUserByID(r.PathValue("id"))
- if err != nil {
- returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest)
- return
- }
-
- var putUserRequest PostUserRequest
- err = json.NewDecoder(r.Body).Decode(&putUserRequest)
- if err != nil {
- returnError(w, fmt.Errorf("unable to decode request payload"), http.StatusBadRequest)
- return
- }
-
- if !putUserRequest.Active && !user.Suspended { // user is suspended
- err = wireguard.DisableAllClientConfigs(s.storage, user.ID)
- if err != nil {
- returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", user.ID, err), http.StatusBadRequest)
- return
- }
- }
- if putUserRequest.Active && user.Suspended { // user is unsuspended
- err := wireguard.ReactivateAllClientConfigs(s.storage, user.ID)
- if err != nil {
- returnError(w, fmt.Errorf("could not reactivate all clients for user %s: %s", user.ID, err), http.StatusBadRequest)
- return
- }
- }
-
- user.Suspended = !putUserRequest.Active
- username := getUsername(putUserRequest)
- if user.Login != username {
- if !s.UserStore.LoginExists(username) {
- user.Login = username
- }
- }
-
- err = s.UserStore.UpdateUser(user)
- if err != nil {
- returnError(w, fmt.Errorf("user update error: %s", err), http.StatusBadRequest)
- return
- }
-
- response, err := userResponse(user)
- if err != nil {
- returnError(w, fmt.Errorf("user response error: %s", err), http.StatusBadRequest)
- return
- }
-
- write(w, response)
-}
-
-func (s *scim) deleteUserHandler(w http.ResponseWriter, r *http.Request) {
- user, err := s.UserStore.GetUserByID(r.PathValue("id"))
- if err != nil {
- returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest)
- return
- }
-
- err = wireguard.DeleteAllClientConfigs(s.storage, user.ID)
- if err != nil {
- returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", user.ID, err), http.StatusBadRequest)
- return
- }
-
- err = s.UserStore.DeleteUserByID(user.ID)
- if err != nil {
- returnError(w, fmt.Errorf("user update error: %s", err), http.StatusBadRequest)
- return
- }
-
- write(w, []byte(""))
-}
-
-func (s *scim) postUsersHandler(w http.ResponseWriter, r *http.Request) {
- var postUserRequest PostUserRequest
- err := json.NewDecoder(r.Body).Decode(&postUserRequest)
- if err != nil {
- returnError(w, fmt.Errorf("unable to decode request payload"), http.StatusBadRequest)
- return
- }
-
- username := getUsername(postUserRequest)
-
- if s.UserStore.LoginExists(username) {
- writeWithStatus(w, []byte("user already exists"), http.StatusConflict)
- return
- }
-
- if s.UserStore.GetMaxUsers()-s.UserStore.UserCount() <= 0 {
- writeWithStatus(w, []byte("no license available to add new user"), http.StatusBadRequest)
- return
- }
-
- user, err := s.UserStore.AddUser(users.User{
- Login: username,
- Role: "user",
- Provisioned: true,
- ExternalID: postUserRequest.ExternalID,
- })
- if err != nil {
- returnError(w, fmt.Errorf("unable to add user: %s", err), http.StatusBadRequest)
- return
- }
- response, err := userResponse(user)
- if err != nil {
- returnError(w, fmt.Errorf("unable to generate user response: %s", err), http.StatusBadRequest)
- return
- }
- writeWithStatus(w, response, http.StatusCreated)
-}
-
-func getUsername(postUserRequest PostUserRequest) string {
- username := postUserRequest.UserName
- if username == "" {
- for _, email := range postUserRequest.Emails {
- if email.Primary {
- username = email.Value
- }
- }
- }
- return username
-}
-
-func getUsersWithFilter(userStore *users.UserStore, attributes, filter string) ([]byte, error) {
- filterSplit := strings.Split(filter, " ")
- if len(filterSplit) != 3 {
- return []byte{}, fmt.Errorf("invalid filter")
- }
- if strings.ToLower(filterSplit[0]) == "username" {
- if strings.ToLower(filterSplit[1]) == "eq" {
- if userStore.LoginExists(strings.Trim(filterSplit[2], `"`)) {
- user, err := userStore.GetUserByLogin(strings.Trim(filterSplit[2], `"`))
- if err != nil {
- return []byte{}, fmt.Errorf("get user by login error: %s", err)
- }
- response, err := listUserResponse([]users.User{user}, attributes, -1, -1)
- if err != nil {
- return []byte{}, fmt.Errorf("userResponse error: %s", err)
- }
- return response, nil
- }
- }
- }
- response, err := listUserResponse([]users.User{}, attributes, -1, -1)
- if err != nil {
- return response, fmt.Errorf("userResponse error: %s", err)
- }
- return response, nil
-}
-
-func getUsersWithoutFilter(userStore *users.UserStore, attributes string, count, start int) ([]byte, error) {
- users := userStore.ListUsers()
- response, err := listUserResponse(users, attributes, count, start)
- if err != nil {
- return []byte{}, fmt.Errorf("userResponse error: %s", err)
- }
- return response, nil
-}
diff --git a/pkg/auth/saml/config.go b/pkg/auth/saml/config.go
deleted file mode 100644
index 3db4e72..0000000
--- a/pkg/auth/saml/config.go
+++ /dev/null
@@ -1 +0,0 @@
-package saml
diff --git a/pkg/auth/saml/handlers.go b/pkg/auth/saml/handlers.go
deleted file mode 100644
index 946a836..0000000
--- a/pkg/auth/saml/handlers.go
+++ /dev/null
@@ -1,95 +0,0 @@
-package saml
-
-import (
- "fmt"
- "net/http"
-
- "github.com/google/uuid"
-)
-
-func (s *saml) samlHandler(w http.ResponseWriter, r *http.Request) {
- providerID := r.PathValue("id")
-
- if providerID == "" {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte("saml error: no provider specified\n"))
- return
- }
-
- provider, err := s.getProviderByID(providerID)
- if err != nil {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte(fmt.Sprintf("saml error: can't find provider with specified id: %s", err)))
- return
- }
-
- err = s.ensureSPLoaded(provider)
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte(fmt.Sprintf("saml error: invalid saml configuration: %s\n", err)))
- return
- }
- err = r.ParseForm()
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- return
- }
-
- if r.Method != "POST" {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte("saml error: not a POST request\n"))
- return
- }
-
- if r.FormValue("SAMLResponse") == "" {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte("saml error: empty SAMLResponse\n"))
- return
- }
-
- if _, ok := s.serviceProvider[providerID]; !ok {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte("saml error: can't find provider with specified id\n"))
- return
- }
-
- assertionInfo, err := s.serviceProvider[providerID].RetrieveAssertionInfo(r.FormValue("SAMLResponse"))
- if err != nil {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte(fmt.Sprintf("saml error: %s\n", err)))
- return
- }
-
- if assertionInfo.WarningInfo.InvalidTime {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte("saml error: invalid time\n"))
- return
- }
-
- if assertionInfo.WarningInfo.NotInAudience {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte("saml error: incorrect audience\n"))
- return
- }
-
- login := assertionInfo.NameID
- notAfter := *assertionInfo.SessionNotOnOrAfter
-
- randomString, err := getRandomString(128)
- if err != nil {
- w.WriteHeader(http.StatusForbidden)
- w.Write([]byte(fmt.Sprintf("saml error: could not create session: %s\n", err)))
- return
- }
- sessionKey := SessionKey{
- ProviderID: providerID,
- SessionID: randomString,
- }
- s.CreateSession(sessionKey, AuthenticatedUser{
- ID: uuid.New().String(),
- Login: login,
- ExpiresAt: notAfter,
- })
- w.Header().Add("Location", fmt.Sprintf("/callback/saml/%s?code=%s", providerID, sessionKey.SessionID))
- w.WriteHeader(http.StatusFound)
-}
diff --git a/pkg/auth/saml/keypair.go b/pkg/auth/saml/keypair.go
deleted file mode 100644
index c52a4af..0000000
--- a/pkg/auth/saml/keypair.go
+++ /dev/null
@@ -1,121 +0,0 @@
-package saml
-
-import (
- "bytes"
- "crypto/rand"
- "crypto/rsa"
- "crypto/tls"
- "crypto/x509"
- "crypto/x509/pkix"
- "encoding/pem"
- "fmt"
- "math/big"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-type KeyPair struct {
- storage storage.Iface
- hostname string
-}
-
-func NewKeyPair(storage storage.Iface, hostname string) *KeyPair {
- return &KeyPair{
- storage: storage,
- hostname: hostname,
- }
-}
-
-func (kp *KeyPair) GetKeyPair() (privateKey *rsa.PrivateKey, cert []byte, err error) {
- if !kp.storage.FileExists(kp.storage.ConfigPath("saml/saml.key")) {
- err := kp.generateKeyAndCert()
- if err != nil {
- return privateKey, cert, fmt.Errorf("can't generate saml key and cert: %s", err)
- }
- } else if !kp.storage.FileExists(kp.storage.ConfigPath("saml/saml.crt")) {
- err = kp.generateKeyAndCert()
- if err != nil {
- return privateKey, cert, fmt.Errorf("can't generate saml key and cert: %s", err)
- }
- }
-
- certPEMBlock, err := kp.storage.ReadFile(kp.storage.ConfigPath("saml/saml.crt"))
- if err != nil {
- return privateKey, cert, fmt.Errorf("can't read saml certificate: %s", err)
- }
- keyPEMBlock, err := kp.storage.ReadFile(kp.storage.ConfigPath("saml/saml.key"))
- if err != nil {
- return privateKey, cert, fmt.Errorf("can't read saml key: %s", err)
- }
- keyPair, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
- if err != nil {
- return privateKey, cert, fmt.Errorf("can't get saml keypair: %s", err)
- }
-
- privateKey = keyPair.PrivateKey.(*rsa.PrivateKey)
- cert = keyPair.Certificate[0]
-
- return privateKey, cert, nil
-}
-
-func (kp *KeyPair) generateKeyAndCert() error {
- var (
- certOut bytes.Buffer
- keyOut bytes.Buffer
- )
-
- privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
- if err != nil {
- return fmt.Errorf("private key generation failed: %s", err)
- }
-
- if err = pem.Encode(&keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}); err != nil {
- return err
- }
-
- serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
- serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
- if err != nil {
- return fmt.Errorf("rand int error: %s", err)
- }
- template := &x509.Certificate{
- SerialNumber: serialNumber,
- Subject: pkix.Name{
- CommonName: kp.hostname,
- },
- NotBefore: time.Now(),
- NotAfter: time.Now().AddDate(5, 0, 0),
-
- IsCA: true,
-
- KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
- ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
- BasicConstraintsValid: true,
- }
-
- derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
- if err != nil {
- return fmt.Errorf("certificate creation error: %s", err)
- }
- if err = pem.Encode(&certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
- return err
- }
-
- // ensure storagepath exists
- err = kp.storage.EnsurePath(kp.storage.ConfigPath("saml"))
- if err != nil {
- return fmt.Errorf("could not ensure saml path exists: %s", err)
- }
-
- err = kp.storage.WriteFile(kp.storage.ConfigPath("saml/saml.key"), keyOut.Bytes())
- if err != nil {
- return fmt.Errorf("saml key write error: %s", err)
- }
- err = kp.storage.WriteFile(kp.storage.ConfigPath("saml/saml.crt"), certOut.Bytes())
- if err != nil {
- return fmt.Errorf("saml key write error: %s", err)
- }
-
- return nil
-}
diff --git a/pkg/auth/saml/metadata.go b/pkg/auth/saml/metadata.go
deleted file mode 100644
index 8be0b1f..0000000
--- a/pkg/auth/saml/metadata.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package saml
-
-import (
- "encoding/xml"
- "fmt"
- "io"
- "net/http"
- "net/url"
-
- "github.com/russellhaering/gosaml2/types"
-)
-
-func (s *saml) HasValidMetadataURL(metadataURL string) (bool, error) {
- metadataURLParsed, err := url.Parse(metadataURL)
- if err != nil {
- return false, fmt.Errorf("url parse error: %s", err)
- }
- _, err = getMetadata(metadataURLParsed.String())
- if err != nil {
- return false, fmt.Errorf("fetch metadata error: %s", err)
- }
- return true, nil
-}
-
-func getMetadata(metadataURL string) (types.EntityDescriptor, error) {
- metadata := types.EntityDescriptor{}
-
- res, err := http.Get(metadataURL)
- if err != nil {
- return metadata, fmt.Errorf("can't retrieve saml metadata: %s", err)
- }
-
- rawMetadata, err := io.ReadAll(res.Body)
- if err != nil {
- return metadata, fmt.Errorf("can't read saml cert data: %s", err)
- }
-
- err = xml.Unmarshal(rawMetadata, &metadata)
- if err != nil {
- return metadata, fmt.Errorf("can't decode saml cert data: %s", err)
- }
- return metadata, nil
-}
diff --git a/pkg/auth/saml/middleware.go b/pkg/auth/saml/middleware.go
deleted file mode 100644
index 6c06349..0000000
--- a/pkg/auth/saml/middleware.go
+++ /dev/null
@@ -1,135 +0,0 @@
-package saml
-
-import (
- "crypto/x509"
- "encoding/base64"
- "encoding/xml"
- "fmt"
- "io"
- "net/http"
- "net/url"
-
- saml2 "github.com/russellhaering/gosaml2"
- "github.com/russellhaering/gosaml2/types"
- dsig "github.com/russellhaering/goxmldsig"
-)
-
-const ISSUER_URL = "saml/iss"
-const AUDIENCE_URL = "saml/aud"
-const ACS_URL = "saml/acs"
-
-func (s *saml) ensureSPLoaded(provider Provider) error {
- if _, ok := s.serviceProvider[provider.ID]; !ok {
- err := s.loadSP(provider)
- if err != nil {
- return fmt.Errorf("could not load saml provider: %s", err)
- }
- } else {
- // check if provider is up-to-date
- if provider.AllowMissingAttributes != s.serviceProvider[provider.ID].AllowMissingAttributes {
- s.serviceProvider[provider.ID] = nil
- }
- if s.serviceProvider[provider.ID] == nil {
- err := s.loadSP(provider)
- if err != nil {
- return fmt.Errorf("could not reload saml provider: %s", err)
- }
- }
- }
- return nil
-}
-
-func (s *saml) loadSP(provider Provider) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- idpMetadataURL, err := url.Parse(provider.MetadataURL)
- if err != nil {
- return fmt.Errorf("can't parse metadata url: %s", err)
- }
- // pull metadata
- res, err := http.Get(idpMetadataURL.String())
- if err != nil {
- return fmt.Errorf("can't retrieve saml metadata: %s", err)
- }
-
- rawMetadata, err := io.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("can't read saml cert data: %s", err)
- }
-
- metadata := &types.EntityDescriptor{}
- err = xml.Unmarshal(rawMetadata, metadata)
- if err != nil {
- return fmt.Errorf("can't decode saml cert data: %s", err)
- }
-
- // load certs
- certStore := dsig.MemoryX509CertificateStore{}
-
- if metadata.IDPSSODescriptor == nil || len(metadata.IDPSSODescriptor.KeyDescriptors) == 0 {
- return fmt.Errorf("keyDescriptors are empty")
- }
- if len(metadata.IDPSSODescriptor.SingleSignOnServices) == 0 {
- return fmt.Errorf("SingleSignOnServices not found")
- }
-
- certStore.Roots, err = getSAMLCertsFromMetadata(metadata.IDPSSODescriptor.KeyDescriptors)
- if err != nil {
- return fmt.Errorf("can't parse certs from metadata: %s", err)
- }
-
- keyStore := NewKeyPair(s.storage, *s.hostname)
-
- sp := &saml2.SAMLServiceProvider{
- IdentityProviderSSOURL: metadata.IDPSSODescriptor.SingleSignOnServices[0].Location,
- IdentityProviderIssuer: metadata.EntityID,
- ServiceProviderIssuer: fmt.Sprintf("%s://%s/%s/%s", *s.protocol, *s.hostname, ISSUER_URL, provider.ID),
- AssertionConsumerServiceURL: fmt.Sprintf("%s://%s/%s/%s", *s.protocol, *s.hostname, ACS_URL, provider.ID),
- SignAuthnRequests: true,
- AudienceURI: fmt.Sprintf("%s://%s/%s/%s", *s.protocol, *s.hostname, AUDIENCE_URL, provider.ID),
- IDPCertificateStore: &certStore,
- SPKeyStore: keyStore,
- AllowMissingAttributes: provider.AllowMissingAttributes,
- }
-
- s.serviceProvider[provider.ID] = sp
-
- return err
-}
-
-func getSAMLCertsFromMetadata(keyDescriptors []types.KeyDescriptor) ([]*x509.Certificate, error) {
- certs := []*x509.Certificate{}
-
- for _, kd := range keyDescriptors {
- for idx, xcert := range kd.KeyInfo.X509Data.X509Certificates {
- if xcert.Data == "" {
- return nil, fmt.Errorf("metadata certificate(%d) must not be empty", idx)
- }
- certData, err := base64.StdEncoding.DecodeString(xcert.Data)
- if err != nil {
- return nil, fmt.Errorf("decode error:%s", err)
- }
-
- idpCert, err := x509.ParseCertificate(certData)
- if err != nil {
- return nil, fmt.Errorf("cert parse error: %s", err)
- }
-
- certs = append(certs, idpCert)
- }
- }
-
- return certs, nil
-
-}
-
-func (s *saml) GetAuthURL(provider Provider) (string, error) {
- err := s.ensureSPLoaded(provider)
- if err != nil {
- return "", fmt.Errorf("saml error: invalid saml configuration: %s", err)
- }
- if _, ok := s.serviceProvider[provider.ID]; !ok {
- return "", fmt.Errorf("provider not found")
- }
- return s.serviceProvider[provider.ID].BuildAuthURL("")
-}
diff --git a/pkg/auth/saml/new.go b/pkg/auth/saml/new.go
deleted file mode 100644
index 4ca68e0..0000000
--- a/pkg/auth/saml/new.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package saml
-
-import (
- "github.com/in4it/wireguard-server/pkg/storage"
- saml2 "github.com/russellhaering/gosaml2"
-)
-
-func New(providers *[]Provider, storage storage.Iface, protocol, hostname *string) Iface {
- s := &saml{
- Providers: providers,
- sessions: make(map[SessionKey]AuthenticatedUser),
- serviceProvider: make(map[string]*saml2.SAMLServiceProvider),
- protocol: protocol,
- hostname: hostname,
- storage: storage,
- }
- for _, provider := range *providers {
- s.loadSP(provider)
- }
- return s
-}
diff --git a/pkg/auth/saml/provider.go b/pkg/auth/saml/provider.go
deleted file mode 100644
index fc00389..0000000
--- a/pkg/auth/saml/provider.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package saml
-
-import "fmt"
-
-func (s *saml) getProviderByID(id string) (Provider, error) {
- for k := range *s.Providers {
- if (*s.Providers)[k].ID == id {
- return (*s.Providers)[k], nil
- }
- }
- return Provider{}, fmt.Errorf("provider not found")
-}
diff --git a/pkg/auth/saml/rand.go b/pkg/auth/saml/rand.go
deleted file mode 100644
index 294b95a..0000000
--- a/pkg/auth/saml/rand.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package saml
-
-import (
- "crypto/rand"
- "encoding/base64"
- "fmt"
- "io"
-)
-
-func getRandomString(n int) (string, error) {
- buf := make([]byte, n)
-
- _, err := io.ReadFull(rand.Reader, buf)
- if err != nil {
- return "", fmt.Errorf("crypto/rand Reader error: %s", err)
- }
-
- return base64.RawURLEncoding.EncodeToString(buf), nil
-}
diff --git a/pkg/auth/saml/router.go b/pkg/auth/saml/router.go
deleted file mode 100644
index 185e9e1..0000000
--- a/pkg/auth/saml/router.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package saml
-
-import (
- "net/http"
-)
-
-func (s *saml) GetRouter() *http.ServeMux {
- mux := http.NewServeMux()
- mux.Handle("/saml/acs/{id}", http.HandlerFunc(s.samlHandler))
-
- return mux
-}
diff --git a/pkg/auth/saml/session.go b/pkg/auth/saml/session.go
deleted file mode 100644
index 72dfa86..0000000
--- a/pkg/auth/saml/session.go
+++ /dev/null
@@ -1,26 +0,0 @@
-package saml
-
-import (
- "fmt"
- "time"
-)
-
-func (s *saml) GetAuthenticatedUser(provider Provider, sessionID string) (AuthenticatedUser, error) {
- sessionKey := SessionKey{
- ProviderID: provider.ID,
- SessionID: sessionID,
- }
- if authenticatedUser, ok := s.sessions[sessionKey]; ok {
- if authenticatedUser.ExpiresAt.Before(time.Now()) {
- return authenticatedUser, fmt.Errorf("session is expired")
- }
- return authenticatedUser, nil
- }
- return AuthenticatedUser{}, fmt.Errorf("session not found")
-}
-
-func (s *saml) CreateSession(key SessionKey, value AuthenticatedUser) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.sessions[key] = value
-}
diff --git a/pkg/auth/saml/types.go b/pkg/auth/saml/types.go
deleted file mode 100644
index 1d16f48..0000000
--- a/pkg/auth/saml/types.go
+++ /dev/null
@@ -1,222 +0,0 @@
-package saml
-
-import (
- "encoding/xml"
- "net/http"
- "sync"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/storage"
- saml2 "github.com/russellhaering/gosaml2"
-)
-
-type Provider struct {
- ID string `json:"id"`
- Name string `json:"name"`
- MetadataURL string `json:"metadataURL"`
- Issuer string `json:"issuer,omitempty"`
- Audience string `json:"audience,omitempty"`
- Acs string `json:"acs,omitempty"`
- AllowMissingAttributes bool `json:"allowMissingAttributes,omitempty"`
-}
-
-type saml struct {
- Providers *[]Provider
- serviceProvider map[string]*saml2.SAMLServiceProvider
- sessions map[SessionKey]AuthenticatedUser
- hostname *string
- protocol *string
- mu sync.Mutex
- storage storage.Iface
-}
-type AuthenticatedUser struct {
- ID string
- Login string
- ExpiresAt time.Time
-}
-type SessionKey struct {
- ProviderID string
- SessionID string
-}
-
-type Iface interface {
- GetAuthURL(provider Provider) (string, error)
- GetRouter() *http.ServeMux
- GetAuthenticatedUser(provider Provider, sessionID string) (AuthenticatedUser, error)
- HasValidMetadataURL(metadataURL string) (bool, error)
- CreateSession(key SessionKey, value AuthenticatedUser)
-}
-
-type AuthnRequest struct {
- XMLName xml.Name `xml:"AuthnRequest"`
- Text string `xml:",chardata"`
- Samlp string `xml:"samlp,attr"`
- Saml string `xml:"saml,attr"`
- ID string `xml:"ID,attr"`
- Version string `xml:"Version,attr"`
- ProtocolBinding string `xml:"ProtocolBinding,attr"`
- AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"`
- IssueInstant string `xml:"IssueInstant,attr"`
- Destination string `xml:"Destination,attr"`
- Issuer string `xml:"Issuer"`
- Signature struct {
- Text string `xml:",chardata"`
- Ds string `xml:"ds,attr"`
- SignedInfo struct {
- Text string `xml:",chardata"`
- CanonicalizationMethod struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
- } `xml:"CanonicalizationMethod"`
- SignatureMethod struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
- } `xml:"SignatureMethod"`
- Reference struct {
- Text string `xml:",chardata"`
- URI string `xml:"URI,attr"`
- Transforms struct {
- Text string `xml:",chardata"`
- Transform []struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
- } `xml:"Transform"`
- } `xml:"Transforms"`
- DigestMethod struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
- } `xml:"DigestMethod"`
- DigestValue string `xml:"DigestValue"`
- } `xml:"Reference"`
- } `xml:"SignedInfo"`
- SignatureValue string `xml:"SignatureValue"`
- KeyInfo struct {
- Text string `xml:",chardata"`
- X509Data struct {
- Text string `xml:",chardata"`
- X509Certificate string `xml:"X509Certificate"`
- } `xml:"X509Data"`
- } `xml:"KeyInfo"`
- } `xml:"Signature"`
- NameIDPolicy struct {
- Text string `xml:",chardata"`
- AllowCreate string `xml:"AllowCreate,attr"`
- } `xml:"NameIDPolicy"`
-}
-
-type Response struct {
- XMLName xml.Name `xml:"Response"`
- Text string `xml:",chardata"`
- Saml string `xml:"xmlns:saml,attr"`
- Samlp string `xml:"xmlns:samlp,attr"`
- ID string `xml:"ID,attr"`
- Version string `xml:"Version,attr"`
- IssueInstant string `xml:"IssueInstant,attr"`
- Destination string `xml:"Destination,attr"`
- Issuer string `xml:"saml:Issuer"`
- Signature ResponseSignature `xml:"ds:Signature"`
- Status struct {
- Text string `xml:",chardata"`
- StatusCode struct {
- Text string `xml:",chardata"`
- Value string `xml:"Value,attr"`
- } `xml:"StatusCode"`
- } `xml:"Status"`
- Assertion ResponseAssertion `xml:"Assertion"`
-}
-
-type ResponseSignature struct {
- XMLName xml.Name `xml:"ds:Signature"`
- Text string `xml:",chardata"`
- Ds string `xml:"xmlns:ds,attr"`
- SignedInfo ResponseSignatureSignedInfo `xml:"ds:SignedInfo"`
- SignatureValue string `xml:"ds:SignatureValue"`
- KeyInfo ResponseSignatureKeyInfo `xml:"ds:KeyInfo"`
-}
-
-type ResponseSignatureSignedInfo struct {
- Text string `xml:",chardata"`
- CanonicalizationMethod struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
- } `xml:"ds:CanonicalizationMethod"`
- SignatureMethod ResponseSignatureSignedInfoSignatureMethod `xml:"ds:SignatureMethod"`
- Reference ResponseSignatureSignedInfoReference `xml:"ds:Reference"`
-}
-type ResponseSignatureSignedInfoSignatureMethod struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
-}
-type ResponseSignatureSignedInfoReference struct {
- Text string `xml:",chardata"`
- URI string `xml:"URI,attr"`
- Transforms struct {
- Text string `xml:",chardata"`
- Transform []struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
- } `xml:"ds:Transform"`
- } `xml:"ds:Transforms"`
- DigestMethod struct {
- Text string `xml:",chardata"`
- Algorithm string `xml:"Algorithm,attr"`
- } `xml:"ds:DigestMethod"`
- DigestValue string `xml:"ds:DigestValue"`
-}
-type ResponseSignatureKeyInfo struct {
- Text string `xml:",chardata"`
- X509Data struct {
- Text string `xml:",chardata"`
- X509Certificate string `xml:"ds:X509Certificate"`
- } `xml:"ds:X509Data"`
-}
-
-type ResponseConditions struct {
- Text string `xml:",chardata"`
- NotBefore string `xml:"NotBefore,attr"`
- NotOnOrAfter string `xml:"NotOnOrAfter,attr"`
- AudienceRestriction ResponseConditionsAdienceRestriction `xml:"AudienceRestriction"`
-}
-type ResponseConditionsAdienceRestriction struct {
- Text string `xml:",chardata"`
- Audience string `xml:"Audience"`
-}
-type ResponseSubject struct {
- Text string `xml:",chardata"`
- NameID struct {
- Text string `xml:",chardata"`
- Format string `xml:"Format,attr"`
- } `xml:"NameID"`
- SubjectConfirmation struct {
- Text string `xml:",chardata"`
- Method string `xml:"Method,attr"`
- SubjectConfirmationData struct {
- Text string `xml:",chardata"`
- NotOnOrAfter string `xml:"NotOnOrAfter,attr"`
- Recipient string `xml:"Recipient,attr"`
- } `xml:"SubjectConfirmationData"`
- } `xml:"SubjectConfirmation"`
-}
-
-type ResponseAssertion struct {
- Text string `xml:",chardata"`
- Saml string `xml:"saml,attr"`
- Xs string `xml:"xs,attr"`
- Xsi string `xml:"xsi,attr"`
- Version string `xml:"Version,attr"`
- ID string `xml:"ID,attr"`
- IssueInstant string `xml:"IssueInstant,attr"`
- Issuer string `xml:"Issuer"`
- Subject ResponseSubject `xml:"Subject"`
- Conditions ResponseConditions `xml:"Conditions"`
- AuthnStatement struct {
- Text string `xml:",chardata"`
- AuthnInstant string `xml:"AuthnInstant,attr"`
- SessionNotOnOrAfter string `xml:"SessionNotOnOrAfter,attr"`
- SessionIndex string `xml:"SessionIndex,attr"`
- AuthnContext struct {
- Text string `xml:",chardata"`
- AuthnContextClassRef string `xml:"AuthnContextClassRef"`
- } `xml:"AuthnContext"`
- } `xml:"AuthnStatement"`
-}
diff --git a/pkg/commands/resetmfa.go b/pkg/commands/resetmfa.go
index 5de910a..2ab0427 100644
--- a/pkg/commands/resetmfa.go
+++ b/pkg/commands/resetmfa.go
@@ -3,9 +3,9 @@ package commands
import (
"fmt"
- "github.com/in4it/wireguard-server/pkg/rest"
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
+ "github.com/in4it/go-devops-platform/rest"
+ "github.com/in4it/go-devops-platform/storage"
+ "github.com/in4it/go-devops-platform/users"
)
func ResetAdminMFA(storage storage.Iface) error {
diff --git a/pkg/commands/resetpassword.go b/pkg/commands/resetpassword.go
index 2e5efa2..4182328 100644
--- a/pkg/commands/resetpassword.go
+++ b/pkg/commands/resetpassword.go
@@ -3,9 +3,9 @@ package commands
import (
"fmt"
- "github.com/in4it/wireguard-server/pkg/rest"
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
+ "github.com/in4it/go-devops-platform/rest"
+ "github.com/in4it/go-devops-platform/storage"
+ "github.com/in4it/go-devops-platform/users"
)
func ResetPassword(storage storage.Iface, password string) (bool, error) {
diff --git a/pkg/commands/resetpassword_test.go b/pkg/commands/resetpassword_test.go
index 45972fa..2d21690 100644
--- a/pkg/commands/resetpassword_test.go
+++ b/pkg/commands/resetpassword_test.go
@@ -3,8 +3,8 @@ package commands
import (
"testing"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
- "github.com/in4it/wireguard-server/pkg/users"
+ memorystorage "github.com/in4it/go-devops-platform/storage/memory"
+ "github.com/in4it/go-devops-platform/users"
)
func TestResetPassword(t *testing.T) {
diff --git a/pkg/configmanager/refresh_darwin.go b/pkg/configmanager/refresh_darwin.go
index 6ffb396..a6ea07f 100644
--- a/pkg/configmanager/refresh_darwin.go
+++ b/pkg/configmanager/refresh_darwin.go
@@ -8,7 +8,7 @@ import (
"fmt"
"os"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
diff --git a/pkg/configmanager/refresh_linux.go b/pkg/configmanager/refresh_linux.go
index b8311c8..af6af8b 100644
--- a/pkg/configmanager/refresh_linux.go
+++ b/pkg/configmanager/refresh_linux.go
@@ -8,7 +8,7 @@ import (
"fmt"
"os"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
syncclients "github.com/in4it/wireguard-server/pkg/wireguard/linux/syncclients"
)
diff --git a/pkg/configmanager/server.go b/pkg/configmanager/server.go
index 3ef2580..f661821 100644
--- a/pkg/configmanager/server.go
+++ b/pkg/configmanager/server.go
@@ -5,8 +5,8 @@ import (
"log"
"net/http"
- "github.com/in4it/wireguard-server/pkg/storage"
- localstorage "github.com/in4it/wireguard-server/pkg/storage/local"
+ "github.com/in4it/go-devops-platform/storage"
+ localstorage "github.com/in4it/go-devops-platform/storage/local"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
diff --git a/pkg/configmanager/setupcode.go b/pkg/configmanager/setupcode.go
index 4d5bd23..755482c 100644
--- a/pkg/configmanager/setupcode.go
+++ b/pkg/configmanager/setupcode.go
@@ -7,7 +7,7 @@ import (
"io"
"os/user"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
)
func writeSetupCode(storage storage.Iface) error {
diff --git a/pkg/configmanager/start_darwin.go b/pkg/configmanager/start_darwin.go
index 07ea124..f4af3cd 100644
--- a/pkg/configmanager/start_darwin.go
+++ b/pkg/configmanager/start_darwin.go
@@ -6,7 +6,7 @@ package configmanager
import (
"fmt"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
diff --git a/pkg/configmanager/start_linux.go b/pkg/configmanager/start_linux.go
index 033ee02..50c7bd5 100644
--- a/pkg/configmanager/start_linux.go
+++ b/pkg/configmanager/start_linux.go
@@ -6,7 +6,7 @@ package configmanager
import (
"log"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
diff --git a/pkg/configmanager/types.go b/pkg/configmanager/types.go
index 0546805..7a73730 100644
--- a/pkg/configmanager/types.go
+++ b/pkg/configmanager/types.go
@@ -1,7 +1,7 @@
package configmanager
import (
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
diff --git a/pkg/license/aws.go b/pkg/license/aws.go
deleted file mode 100644
index 4b507e7..0000000
--- a/pkg/license/aws.go
+++ /dev/null
@@ -1,272 +0,0 @@
-package license
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-const AWS_PRODUCT_CODE = "7h7h3bnutjn0ziamv7npi8a69"
-
-func getMetadataToken(client http.Client) string {
- metadataEndpoint := "http://" + MetadataIP + "/latest/api/token"
-
- req, err := http.NewRequest("PUT", metadataEndpoint, nil)
- if err != nil {
- return ""
- }
-
- req.Header.Add("X-aws-ec2-metadata-token-ttl-seconds", "21600")
-
- resp, err := client.Do(req)
- if err != nil {
- return ""
- }
- defer resp.Body.Close()
- if resp.StatusCode == 200 {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return string(bodyBytes)
- }
- return ""
-}
-
-func isOnAWSMarketPlace(client http.Client) bool {
- token := getMetadataToken(client)
-
- instanceIdentityDocument, err := getInstanceIdentityDocument(client, token)
- if err != nil {
- return false
- }
- for _, productCode := range instanceIdentityDocument.MarketplaceProductCodes {
- if productCode == AWS_PRODUCT_CODE {
- return true
- }
- }
- return false
-}
-
-func isOnAWS(client http.Client) bool {
- token := getMetadataToken(client)
-
- instanceIdentityDocument, err := getInstanceIdentityDocument(client, token)
- if err != nil {
- return false
- }
- return instanceIdentityDocument.AccountID != "" || instanceIdentityDocument.Version != ""
-}
-
-func getInstanceIdentityDocument(client http.Client, token string) (InstanceIdentityDocument, error) {
- var instanceIdentityDocument InstanceIdentityDocument
-
- endpoint := "http://" + MetadataIP + "/2022-09-24/dynamic/instance-identity/document"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return instanceIdentityDocument, err
- }
- if token != "" {
- req.Header.Add("X-aws-ec2-metadata-token", token)
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return instanceIdentityDocument, err
- }
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- return instanceIdentityDocument, err
- }
- err = json.NewDecoder(resp.Body).Decode(&instanceIdentityDocument)
- if err != nil {
- return instanceIdentityDocument, err
- }
-
- return instanceIdentityDocument, nil
-}
-
-func GetMaxUsersAWSBYOL(client http.Client, storage storage.ReadWriter) int {
- userLicense := 3
- licenseKey, err := getAWSLicenseKey(storage, client)
- if err != nil {
- return userLicense
- }
- license, err := getLicense(client, licenseKey)
- if err != nil {
- return userLicense
- }
- return license.Users
-}
-
-func getAWSLicenseKey(storage storage.ReadWriter, client http.Client) (string, error) {
- token := getMetadataToken(client)
- licenseKey, err := getLicenseFromMetaData(token, client)
- if err != nil || licenseKey == "" {
- licenseKey, err = getLicenseKeyFromFile(storage)
- if err != nil {
- return "", err
- }
- }
-
- instanceIdentityDocument, err := getInstanceIdentityDocument(client, token)
- if err != nil {
- return "", err
- }
-
- return generateLicenseKey(licenseKey, instanceIdentityDocument.AccountID), nil
-}
-
-func getLicense(client http.Client, key string) (License, error) {
- var license License
- endpoint := licenseURL + "/" + key
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return license, err
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return license, err
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- return license, fmt.Errorf("statuscode %d", resp.StatusCode)
- }
- err = json.NewDecoder(resp.Body).Decode(&license)
- if err != nil {
- return license, err
- }
-
- return license, nil
-
-}
-
-func getLicenseFromMetaData(token string, client http.Client) (string, error) {
- endpoint := "http://" + MetadataIP + "/2022-09-24/meta-data/tags/instance/license"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return "", err
- }
- if token != "" {
- req.Header.Add("X-aws-ec2-metadata-token", token)
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- return "", err
- }
-
- bodyBytes, err := io.ReadAll(resp.Body)
- if resp.StatusCode != 200 {
- return "", err
- }
- return string(bodyBytes), nil
-}
-
-func getAWSInstanceType(client http.Client) string {
- token := getMetadataToken(client)
-
- endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-type"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return ""
- }
- if token != "" {
- req.Header.Add("X-aws-ec2-metadata-token", token)
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return ""
- }
- defer resp.Body.Close()
- if resp.StatusCode == 200 {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return string(bodyBytes)
- }
- return ""
-}
-
-func GetAWSInstanceID(client http.Client) (string, error) {
- token := getMetadataToken(client)
-
- endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-id"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return "", err
- }
- if token != "" {
- req.Header.Add("X-aws-ec2-metadata-token", token)
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
- if resp.StatusCode == 200 {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return string(bodyBytes), err
- }
- return "", fmt.Errorf("received statuscode %d from aws metadata api", resp.StatusCode)
-}
-
-func GetMaxUsersAWS(instanceType string) int {
- if instanceType == "" {
- return 3
- }
- if strings.HasSuffix(instanceType, ".nano") {
- return 3
- }
- if strings.HasSuffix(instanceType, ".micro") {
- return 10
- }
- if strings.HasSuffix(instanceType, ".small") {
- return 25
- }
- if strings.HasSuffix(instanceType, ".medium") {
- return 50
- }
- if strings.HasSuffix(instanceType, ".large") {
- return 100
- }
- if strings.HasSuffix(instanceType, ".xlarge") {
- return 250
- }
- if strings.HasSuffix(instanceType, ".2xlarge") {
- return 500
- }
- if strings.HasSuffix(instanceType, ".4xlarge") {
- return 1000
- }
- if strings.HasSuffix(instanceType, ".8xlarge") {
- return 2500
- }
- if strings.HasSuffix(instanceType, ".12xlarge") {
- return 5000
- }
- if strings.HasSuffix(instanceType, ".16xlarge") {
- return 10000
- }
- if strings.HasSuffix(instanceType, ".24xlarge") {
- return 10000
- }
- if strings.HasSuffix(instanceType, ".32xlarge") {
- return 10000
- }
- if strings.HasSuffix(instanceType, ".48xlarge") {
- return 10000
- }
- if strings.HasSuffix(instanceType, ".metal") {
- return 10000
- }
-
- return 3
-}
diff --git a/pkg/license/azure.go b/pkg/license/azure.go
deleted file mode 100644
index 18ab9de..0000000
--- a/pkg/license/azure.go
+++ /dev/null
@@ -1,81 +0,0 @@
-package license
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "regexp"
- "strconv"
-)
-
-func isOnAzure(client http.Client) bool {
- req, err := http.NewRequest("GET", "http://"+MetadataIP+"/metadata/versions", nil)
- if err != nil {
- return false
- }
-
- req.Header.Add("Metadata", "true")
-
- resp, err := client.Do(req)
- if err != nil {
- return false
- }
- defer resp.Body.Close()
- return resp.StatusCode == 200
-}
-
-func GetMaxUsersAzure(instanceType string) int {
- if instanceType == "" {
- return 3
- }
- // patterns
- versionPattern := regexp.MustCompile(`^.*v[0-9]+#`)
- cpuPattern := regexp.MustCompile("[0-9]+")
-
- // extract amount of CPUs
- instanceTypeNoVersion := versionPattern.ReplaceAllString(instanceType, "")
-
- instanceTypeCPUs := cpuPattern.FindAllString(instanceTypeNoVersion, -1)
-
- if len(instanceTypeCPUs) > 0 {
- instanceTypeCPUCount, err := strconv.Atoi(instanceTypeCPUs[0])
- if err != nil {
- return 3
- }
- if instanceTypeCPUCount == 0 {
- return 15
- }
- return instanceTypeCPUCount * 25
- }
-
- return 3
-}
-func getAzureInstanceType(client http.Client) string {
- metadataEndpoint := "http://" + MetadataIP + "/metadata/instance?api-version=2021-02-01"
- req, err := http.NewRequest("GET", metadataEndpoint, nil)
- if err != nil {
- return ""
- }
-
- req.Header.Add("Metadata", "true")
-
- resp, err := client.Do(req)
- if err != nil {
- return ""
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- return ""
- }
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- return ""
- }
- var instanceMetadata AzureInstanceMetadata
- err = json.Unmarshal(bodyBytes, &instanceMetadata)
- if err != nil {
- return ""
- }
- return instanceMetadata.Compute.VMSize
-}
diff --git a/pkg/license/digitalocean.go b/pkg/license/digitalocean.go
deleted file mode 100644
index 5be4c07..0000000
--- a/pkg/license/digitalocean.go
+++ /dev/null
@@ -1,121 +0,0 @@
-package license
-
-import (
- "bufio"
- "fmt"
- "io"
- "net/http"
- "strings"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-func isOnDigitalOcean(client http.Client) bool {
- endpoint := "http://" + MetadataIP + "/metadata/v1/interfaces/private/0/type"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return false
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return false
- }
- defer resp.Body.Close()
- return resp.StatusCode == 200
-}
-
-func GetMaxUsersDigitalOceanBYOL(client http.Client, storage storage.ReadWriter) int {
- userLicense := 3
-
- licenseKey, err := getDigitalOceanLicenseKey(storage, client)
- if err != nil {
- logging.DebugLog(fmt.Errorf("get digitalocean license error: %s", err))
- return userLicense
- }
-
- license, err := getLicense(client, licenseKey)
- if err != nil {
- logging.DebugLog(fmt.Errorf("getLicense error: %s", err))
- return userLicense
- }
-
- return license.Users
-}
-
-func getDigitalOceanLicenseKey(storage storage.ReadWriter, client http.Client) (string, error) {
- identifier, err := getDigitalOceanIdentifier(client)
- if err != nil {
- logging.DebugLog(fmt.Errorf("License generation error (identifier error): %s", err))
- return "", err
- }
-
- licenseKey, err := getLicenseKeyFromFile(storage)
- if err != nil {
- return "", err
- }
-
- return generateLicenseKey(licenseKey, identifier), nil
-}
-
-func getDigitalOceanIdentifier(client http.Client) (string, error) {
- id := ""
- endpoint := "http://" + MetadataIP + "/metadata/v1/id"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return id, err
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return id, err
- }
- defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return id, err
- }
- if resp.StatusCode != 200 {
- return id, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body)
- }
-
- return strings.TrimSpace(string(body)), nil
-
-}
-
-func HasDigitalOceanTagSet(client http.Client, tag string) (bool, error) {
- endpoint := "http://" + MetadataIP + "/metadata/v1/tags"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return false, err
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return false, err
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return false, err
- }
- return false, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body)
- }
-
- scanner := bufio.NewScanner(resp.Body)
- for scanner.Scan() {
- if tag == strings.TrimSpace(scanner.Text()) {
- return true, nil
- }
- }
-
- if err := scanner.Err(); err != nil {
- return false, err
- }
-
- return false, nil
-
-}
diff --git a/pkg/license/gcp.go b/pkg/license/gcp.go
deleted file mode 100644
index 5e69c54..0000000
--- a/pkg/license/gcp.go
+++ /dev/null
@@ -1,88 +0,0 @@
-package license
-
-import (
- "fmt"
- "io"
- "net/http"
- "strings"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-func isOnGCP(client http.Client) bool {
- endpoint := "http://" + MetadataIP + "/computeMetadata/v1/"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return false
- }
-
- req.Header.Add("Metadata-Flavor", "Google")
-
- resp, err := client.Do(req)
- if err != nil {
- return false
- }
- defer resp.Body.Close()
- return resp.StatusCode == 200
-}
-
-func GetMaxUsersGCPBYOL(client http.Client, storage storage.ReadWriter) int {
- userLicense := 3
-
- licenseKey, err := getGCPLicenseKey(storage, client)
- if err != nil {
- logging.DebugLog(fmt.Errorf("get gcp license error: %s", err))
- return userLicense
- }
-
- license, err := getLicense(client, licenseKey)
- if err != nil {
- logging.DebugLog(fmt.Errorf("getLicense error: %s", err))
- return userLicense
- }
-
- return license.Users
-}
-
-func getGCPLicenseKey(storage storage.ReadWriter, client http.Client) (string, error) {
- identifier, err := getGCPIdentifier(client)
- if err != nil {
- logging.DebugLog(fmt.Errorf("License generation error (identifier error): %s", err))
- return "", err
- }
-
- licenseKey, err := getLicenseKeyFromFile(storage)
- if err != nil {
- return "", err
- }
-
- return generateLicenseKey(licenseKey, identifier), nil
-}
-
-func getGCPIdentifier(client http.Client) (string, error) {
- id := ""
- endpoint := "http://" + MetadataIP + "/computeMetadata/v1/project/project-id"
- req, err := http.NewRequest("GET", endpoint, nil)
- if err != nil {
- return id, err
- }
-
- req.Header.Add("Metadata-Flavor", "Google")
-
- resp, err := client.Do(req)
- if err != nil {
- return id, err
- }
- defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return id, err
- }
- if resp.StatusCode != 200 {
- return id, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body)
- }
-
- return strings.TrimSpace(string(body)), nil
-
-}
diff --git a/pkg/license/gcp_test.go b/pkg/license/gcp_test.go
deleted file mode 100644
index f52b047..0000000
--- a/pkg/license/gcp_test.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package license
-
-import (
- "crypto/sha256"
- "fmt"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestGuessInfrastructureGCP(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/computeMetadata/v1/" {
- w.WriteHeader(http.StatusOK)
- return
- }
- w.WriteHeader(http.StatusNotFound)
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- infra := guessInfrastructure()
-
- if infra != "gcp" {
- t.Fatalf("wrong infra returned: %s", infra)
- }
-}
-
-func TestGetMaxUsersGCPBYOL(t *testing.T) {
- projectID := "gcpproject-1234567890"
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/computeMetadata/v1/project/project-id" {
- w.Write([]byte(projectID))
-
- return
- }
- h := sha256.New()
- h.Write([]byte(projectID))
- if r.RequestURI == fmt.Sprintf("/license-1234556-license-%x", h.Sum(nil)) {
- w.Write([]byte(`{"users": 50}`))
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- licenseURL = ts.URL
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- mockStorage := &memorystorage.MockMemoryStorage{}
- err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license"))
- if err != nil {
- t.Fatalf("writefile error: %s", err)
- }
-
- for _, v := range []int{50} {
- if v2 := GetMaxUsersGCPBYOL(http.Client{Timeout: 5 * time.Second}, mockStorage); v2 != v {
- t.Fatalf("Wrong output: %d vs %d", v2, v)
- }
- }
-}
diff --git a/pkg/license/license.go b/pkg/license/license.go
deleted file mode 100644
index b3cfb77..0000000
--- a/pkg/license/license.go
+++ /dev/null
@@ -1,167 +0,0 @@
-package license
-
-import (
- "crypto/sha256"
- "fmt"
- "net/http"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
- randomutils "github.com/in4it/wireguard-server/pkg/utils/random"
-)
-
-var MetadataIP = "169.254.169.254"
-var licenseURL = "https://in4it-vpn-server.s3.amazonaws.com/licenses"
-
-func guessInfrastructure() string {
- // check whether we are on AWS, Azure, DigitalOcean or something undefined
- client := http.Client{
- Timeout: 5 * time.Second,
- }
-
- if isOnAWSMarketPlace(client) {
- return "aws-marketplace"
- }
-
- if isOnAWS(client) {
- return "aws"
- }
-
- if isOnAzure(client) {
- return "azure"
- }
-
- if isOnDigitalOcean(client) {
- return "digitalocean"
- }
-
- if isOnGCP(client) {
- return "gcp"
- }
-
- return "" // no metadata server found
-}
-
-func GetInstanceType() (string, string) {
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- switch guessInfrastructure() {
- case "azure":
- return "azure", getAzureInstanceType(client)
- case "aws-marketplace":
- return "aws-marketplace", getAWSInstanceType(client)
- case "aws":
- return "aws", getAWSInstanceType(client)
- case "digitalocean":
- return "digitalocean", "droplet"
- case "gcp":
- return "gcp", "instance"
- default:
- return "", ""
- }
-}
-func GetMaxUsers(storage storage.ReadWriter) (int, string) {
- cloudType, instanceType := GetInstanceType()
- return getMaxUsers(storage, cloudType, instanceType), cloudType
-}
-func getMaxUsers(storage storage.ReadWriter, cloudType, instanceType string) int {
- switch cloudType {
- case "azure":
- return GetMaxUsersAzure(instanceType)
- case "aws-marketplace":
- return GetMaxUsersAWS(instanceType)
- case "aws":
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- return GetMaxUsersAWSBYOL(client, storage)
- case "digitalocean":
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- return GetMaxUsersDigitalOceanBYOL(client, storage)
- case "":
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- return GetMaxUsersBYOLNoCloud(client, storage)
- default:
- return 3
- }
-}
-
-func RefreshLicense(storage storage.ReadWriter, cloudType string, currentLicense int) int {
- if cloudType == "azure" || cloudType == "aws-marketplace" { // instance types / license is not going to change without a restart
- return currentLicense
- }
- cloudType, instanceType := GetInstanceType()
- return getMaxUsers(storage, cloudType, instanceType)
-}
-
-func GetLicenseKey(storage storage.ReadWriter, cloudType string) string {
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- switch cloudType {
- case "aws":
- licenseKey, err := getAWSLicenseKey(storage, client)
- if err != nil {
- logging.DebugLog(fmt.Errorf("getAWSLicense error: %s", err))
- return ""
- }
- return licenseKey
- case "digitalocean":
- licenseKey, err := getDigitalOceanLicenseKey(storage, client)
- if err != nil {
- logging.DebugLog(fmt.Errorf("getDigitalOceanLicense error: %s", err))
- return ""
- }
- return licenseKey
- case "gcp":
- licenseKey, err := getGCPLicenseKey(storage, client)
- if err != nil {
- logging.DebugLog(fmt.Errorf("getGCPLicenseKey error: %s", err))
- return ""
- }
- return licenseKey
- default:
- licenseKey, err := getLicenseKeyFromFile(storage)
- if err != nil {
- logging.DebugLog(fmt.Errorf("getLicenseKeyFromFile error: %s", err))
- return ""
- }
- return licenseKey
- }
-
-}
-
-func generateLicenseKey(key string, identifier string) string {
- h := sha256.New()
- h.Write([]byte(identifier))
- bs := h.Sum(nil)
-
- return key + "-" + fmt.Sprintf("%x", bs)
-}
-
-func getLicenseKeyFromFile(storage storage.ReadWriter) (string, error) {
- filename := storage.ConfigPath("license.key")
-
- if storage.FileExists(filename) {
- licenseKeyBytes, err := storage.ReadFile(filename)
- if err != nil {
- return "", fmt.Errorf("License read error: %s", err)
- }
- return string(licenseKeyBytes), nil
- }
- key, err := randomutils.GetRandomString(128)
- if err != nil {
- return "", fmt.Errorf("License generation error: %s", err)
- }
- err = storage.WriteFile(filename, []byte(key))
- if err != nil {
- return "", fmt.Errorf("License read error: %s", err)
- }
- return key, nil
-}
diff --git a/pkg/license/license_test.go b/pkg/license/license_test.go
deleted file mode 100644
index 87172fa..0000000
--- a/pkg/license/license_test.go
+++ /dev/null
@@ -1,413 +0,0 @@
-package license
-
-import (
- "crypto/sha256"
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestGetMaxUsersAzure(t *testing.T) {
- usersPerVCPU := 25
- testCases := map[string]int{
- "Standard_B1s": usersPerVCPU,
- "Basic_A0": 15,
- "Standard_D1_v2": usersPerVCPU,
- "Standard_D5_v2": usersPerVCPU * 5,
- "D96as_v6": usersPerVCPU * 96,
- "Standard_D16pls_v5": usersPerVCPU * 16,
- "Standard_DC1s_v3": usersPerVCPU * 1,
- }
- for k, v := range testCases {
- if GetMaxUsersAzure(k) != v {
- t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAzure(k), v)
- }
- }
-}
-
-func TestGetMaxUsersAWSMarketplace(t *testing.T) {
- testCases := map[string]int{
- "t3.medium": 50,
- "t3.large": 100,
- "t3.xlarge": 250,
- }
- for instanceType, v := range testCases {
- if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws-marketplace", instanceType) != v {
- t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v)
- }
- }
-}
-
-func TestGetMaxUsersAWSBYOL(t *testing.T) {
- accountID := "1234567890"
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/versions" {
- w.WriteHeader(http.StatusForbidden)
- return
- }
- if r.RequestURI == "/latest/api/token" {
- w.Write([]byte("this is a test token"))
- return
- }
- if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" {
- w.Write([]byte(`{
- "accountId" : "` + accountID + `",
- "architecture" : "x86_64",
- "availabilityZone" : "us-east-1c",
- "billingProducts" : null,
- "devpayProductCodes" : null,
- "marketplaceProductCodes" : [ "7h7h3bnutjn0ziamv7npi8a69" ],
- "imageId" : "ami-12345678",
- "instanceId" : "i-123456",
- "instanceType" : "t3.micro",
- "kernelId" : null,
- "pendingTime" : "2024-06-15T08:34:50Z",
- "privateIp" : "10.0.1.123",
- "ramdiskId" : null,
- "region" : "us-east-1",
- "version" : "2017-09-30"
-}`))
-
- return
- }
- if r.RequestURI == "/2022-09-24/meta-data/tags/instance/license" {
- w.Write([]byte(`license-1234556-license`))
- return
- }
- h := sha256.New()
- h.Write([]byte(accountID))
- if r.RequestURI == fmt.Sprintf("/license-1234556-license-%x", h.Sum(nil)) {
- w.Write([]byte(`{"users": 50}`))
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- testCases := map[string]int{
- "t3.medium": 50,
- "t3.xlarge": 50,
- }
- licenseURL = ts.URL
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
- for _, v := range testCases {
- if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &memorystorage.MockMemoryStorage{}); v2 != v {
- t.Fatalf("Wrong output: %d vs %d", v2, v)
- }
- }
-}
-
-func TestGetMaxUsersAWS(t *testing.T) {
- testCases := map[string]int{
- "t3.medium": 3,
- "t3.large": 3,
- "t3.xlarge": 3,
- }
- for instanceType, v := range testCases {
- if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws", instanceType) != v {
- t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v)
- }
- }
-}
-
-func TestGuessInfrastructureAzure(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/versions" {
- w.WriteHeader(http.StatusOK)
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- infra := guessInfrastructure()
-
- if infra != "azure" {
- t.Fatalf("wrong infra returned: %s", infra)
- }
-}
-
-func TestGuessInfrastructureAWSMarketplace(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/versions" {
- w.WriteHeader(http.StatusForbidden)
- return
- }
- if r.RequestURI == "/latest/api/token" {
- w.Write([]byte("this is a test token"))
- return
- }
- if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" {
- w.Write([]byte(`{
- "accountId" : "12345678",
- "architecture" : "x86_64",
- "availabilityZone" : "us-east-1c",
- "billingProducts" : null,
- "devpayProductCodes" : null,
- "marketplaceProductCodes" : [ "7h7h3bnutjn0ziamv7npi8a69" ],
- "imageId" : "ami-12345678",
- "instanceId" : "i-123456",
- "instanceType" : "t3.micro",
- "kernelId" : null,
- "pendingTime" : "2024-06-15T08:34:50Z",
- "privateIp" : "10.0.1.123",
- "ramdiskId" : null,
- "region" : "us-east-1",
- "version" : "2017-09-30"
-}`))
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- infra := guessInfrastructure()
-
- if infra != "aws-marketplace" {
- t.Fatalf("wrong infra returned: %s", infra)
- }
-}
-
-func TestGuessInfrastructureAWS(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/versions" {
- w.WriteHeader(http.StatusForbidden)
- return
- }
- if r.RequestURI == "/latest/api/token" {
- w.Write([]byte("this is a test token"))
- return
- }
- if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" {
- w.Write([]byte(`{
- "accountId" : "12345678",
- "architecture" : "x86_64",
- "availabilityZone" : "us-east-1c",
- "billingProducts" : null,
- "devpayProductCodes" : null,
- "marketplaceProductCodes" : null
-}`))
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- infra := guessInfrastructure()
-
- if infra != "aws" {
- t.Fatalf("wrong infra returned: %s", infra)
- }
-
- if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 {
- t.Fatalf("wrong users returned")
- }
-}
-
-func TestGuessInfrastructureOther(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- infra := guessInfrastructure()
-
- if infra != "" {
- t.Fatalf("wrong infra returned: %s", infra)
- }
-}
-
-func TestGetAzureInstanceType(t *testing.T) {
- vmSize := "Standard_D2_v5"
-
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- out, err := json.Marshal(AzureInstanceMetadata{
- Compute: Compute{
- VMSize: vmSize,
- },
- })
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- }
- w.Write(out)
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- usersPerVCPU := 25
-
- users := getMaxUsers(&memorystorage.MockMemoryStorage{}, "azure", getAzureInstanceType(http.Client{Timeout: 5 * time.Second}))
-
- if users != usersPerVCPU*2 {
- t.Fatalf("Wrong user count returned")
- }
-}
-
-func TestGetAWSInstanceType(t *testing.T) {
- instanceType := "t4.xlarge"
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/latest/api/token" {
- w.WriteHeader(http.StatusForbidden)
- return
- }
- if r.RequestURI == "/latest/meta-data/instance-type" {
- w.Write([]byte(instanceType))
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
-
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- users := GetMaxUsersAWS(getAWSInstanceType(http.Client{Timeout: 5 * time.Second}))
-
- if users != 250 {
- t.Fatalf("Wrong user count returned: %d", users)
- }
-}
-
-func TestGuessInfrastructureDigitalOcean(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/v1/interfaces/private/0/type" {
- w.Write([]byte(`private`))
- return
- }
- w.WriteHeader(http.StatusNotFound)
- }))
- defer ts.Close()
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- infra := guessInfrastructure()
-
- if infra != "digitalocean" {
- t.Fatalf("wrong infra returned: %s", infra)
- }
-
- if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 {
- t.Fatalf("wrong users returned")
- }
-}
-
-func TestGetMaxUsersDigitalOceanBYOL(t *testing.T) {
- dropletID := "1234567890"
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/v1/interfaces/private/0/type" {
- w.Write([]byte(`private`))
- return
- }
- if r.RequestURI == "/metadata/v1/id" {
- w.Write([]byte(dropletID))
-
- return
- }
- h := sha256.New()
- h.Write([]byte(dropletID))
- if r.RequestURI == fmt.Sprintf("/license-1234556-license-%x", h.Sum(nil)) {
- w.Write([]byte(`{"users": 50}`))
- return
- }
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- licenseURL = ts.URL
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- mockStorage := &memorystorage.MockMemoryStorage{}
- err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license"))
- if err != nil {
- t.Fatalf("writefile error: %s", err)
- }
- for _, v := range []int{50} {
- if v2 := GetMaxUsersDigitalOceanBYOL(http.Client{Timeout: 5 * time.Second}, mockStorage); v2 != v {
- t.Fatalf("Wrong output: %d vs %d", v2, v)
- }
- }
-}
-
-func TestGetLicenseKey(t *testing.T) {
- dropletID := "1234567890"
- projectID := "googleproject-12356"
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/v1/interfaces/private/0/type" {
- w.Write([]byte(`private`))
- return
- }
- if r.RequestURI == "/metadata/v1/id" {
- w.Write([]byte(dropletID))
- return
- }
- if r.RequestURI == "/computeMetadata/v1/project/project-id" {
- w.Write([]byte(projectID))
- return
- }
- if r.RequestURI == "/metadata/versions" {
- w.WriteHeader(http.StatusForbidden)
- return
- }
- if r.RequestURI == "/latest/api/token" {
- w.Write([]byte("this is a test token"))
- return
- }
- if r.RequestURI == "/2022-09-24/dynamic/instance-identity/document" {
- w.Write([]byte(`{
- "accountId" : "12345678",
- "architecture" : "x86_64",
- "availabilityZone" : "us-east-1c",
- "billingProducts" : null,
- "devpayProductCodes" : null,
- "marketplaceProductCodes" : null
-}`))
- return
- }
-
- w.WriteHeader(http.StatusNotFound)
- }))
-
- MetadataIP = strings.Replace(ts.URL, "http://", "", -1)
-
- logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR
- key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "")
- if key == "" {
- t.Fatalf("key is empty")
- }
- key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "aws")
- if key == "" {
- t.Fatalf("aws key is empty")
- }
- key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "digitalocean")
- if key == "" {
- t.Fatalf("digitalocean key is empty")
- }
- key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "gcp")
- if key == "" {
- t.Fatalf("gcp key is empty")
- }
-}
-func TestGetLicenseKeyNoCloudProvider(t *testing.T) {
-
- logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR
- key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "")
- if key == "" {
- t.Fatalf("key is empty")
- }
-}
diff --git a/pkg/license/nocloud.go b/pkg/license/nocloud.go
deleted file mode 100644
index 0b6d036..0000000
--- a/pkg/license/nocloud.go
+++ /dev/null
@@ -1,26 +0,0 @@
-package license
-
-import (
- "fmt"
- "net/http"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-func GetMaxUsersBYOLNoCloud(client http.Client, storage storage.ReadWriter) int {
- userLicense := 3
-
- licenseKey, err := getLicenseKeyFromFile(storage)
- if err != nil {
- return 3
- }
-
- license, err := getLicense(client, licenseKey)
- if err != nil {
- logging.DebugLog(fmt.Errorf("getLicense error: %s", err))
- return userLicense
- }
-
- return license.Users
-}
diff --git a/pkg/license/types.go b/pkg/license/types.go
deleted file mode 100644
index 9b96c10..0000000
--- a/pkg/license/types.go
+++ /dev/null
@@ -1,150 +0,0 @@
-package license
-
-import "time"
-
-type AzureInstanceMetadata struct {
- Compute Compute `json:"compute"`
- Network Network `json:"network"`
-}
-type OsProfile struct {
- AdminUsername string `json:"adminUsername"`
- ComputerName string `json:"computerName"`
- DisablePasswordAuthentication string `json:"disablePasswordAuthentication"`
-}
-type Plan struct {
- Name string `json:"name"`
- Product string `json:"product"`
- Publisher string `json:"publisher"`
-}
-type PublicKeys struct {
- KeyData string `json:"keyData"`
- Path string `json:"path"`
-}
-type SecurityProfile struct {
- SecureBootEnabled string `json:"secureBootEnabled"`
- VirtualTpmEnabled string `json:"virtualTpmEnabled"`
-}
-type ImageReference struct {
- ID string `json:"id"`
- Offer string `json:"offer"`
- Publisher string `json:"publisher"`
- Sku string `json:"sku"`
- Version string `json:"version"`
-}
-type DiffDiskSettings struct {
- Option string `json:"option"`
-}
-type EncryptionSettings struct {
- Enabled string `json:"enabled"`
-}
-type Image struct {
- URI string `json:"uri"`
-}
-type ManagedDisk struct {
- ID string `json:"id"`
- StorageAccountType string `json:"storageAccountType"`
-}
-type Vhd struct {
- URI string `json:"uri"`
-}
-type OsDisk struct {
- Caching string `json:"caching"`
- CreateOption string `json:"createOption"`
- DiffDiskSettings DiffDiskSettings `json:"diffDiskSettings"`
- DiskSizeGB string `json:"diskSizeGB"`
- EncryptionSettings EncryptionSettings `json:"encryptionSettings"`
- Image Image `json:"image"`
- ManagedDisk ManagedDisk `json:"managedDisk"`
- Name string `json:"name"`
- OsType string `json:"osType"`
- Vhd Vhd `json:"vhd"`
- WriteAcceleratorEnabled string `json:"writeAcceleratorEnabled"`
-}
-type ResourceDisk struct {
- Size string `json:"size"`
-}
-type StorageProfile struct {
- DataDisks []any `json:"dataDisks"`
- ImageReference ImageReference `json:"imageReference"`
- OsDisk OsDisk `json:"osDisk"`
- ResourceDisk ResourceDisk `json:"resourceDisk"`
-}
-type Compute struct {
- AzEnvironment string `json:"azEnvironment"`
- CustomData string `json:"customData"`
- EvictionPolicy string `json:"evictionPolicy"`
- IsHostCompatibilityLayerVM string `json:"isHostCompatibilityLayerVm"`
- LicenseType string `json:"licenseType"`
- Location string `json:"location"`
- Name string `json:"name"`
- Offer string `json:"offer"`
- OsProfile OsProfile `json:"osProfile"`
- OsType string `json:"osType"`
- PlacementGroupID string `json:"placementGroupId"`
- Plan Plan `json:"plan"`
- PlatformFaultDomain string `json:"platformFaultDomain"`
- PlatformUpdateDomain string `json:"platformUpdateDomain"`
- Priority string `json:"priority"`
- Provider string `json:"provider"`
- PublicKeys []PublicKeys `json:"publicKeys"`
- Publisher string `json:"publisher"`
- ResourceGroupName string `json:"resourceGroupName"`
- ResourceID string `json:"resourceId"`
- SecurityProfile SecurityProfile `json:"securityProfile"`
- Sku string `json:"sku"`
- StorageProfile StorageProfile `json:"storageProfile"`
- SubscriptionID string `json:"subscriptionId"`
- Tags string `json:"tags"`
- TagsList []any `json:"tagsList"`
- UserData string `json:"userData"`
- Version string `json:"version"`
- VMID string `json:"vmId"`
- VMScaleSetName string `json:"vmScaleSetName"`
- VMSize string `json:"vmSize"`
- Zone string `json:"zone"`
-}
-type IPAddress struct {
- PrivateIPAddress string `json:"privateIpAddress"`
- PublicIPAddress string `json:"publicIpAddress"`
-}
-type Subnet struct {
- Address string `json:"address"`
- Prefix string `json:"prefix"`
-}
-type Ipv4 struct {
- IPAddress []IPAddress `json:"ipAddress"`
- Subnet []Subnet `json:"subnet"`
-}
-type Ipv6 struct {
- IPAddress []any `json:"ipAddress"`
-}
-type Interface struct {
- Ipv4 Ipv4 `json:"ipv4"`
- Ipv6 Ipv6 `json:"ipv6"`
- MacAddress string `json:"macAddress"`
-}
-type Network struct {
- Interface []Interface `json:"interface"`
-}
-
-type InstanceIdentityDocument struct {
- AccountID string `json:"accountId"`
- Architecture string `json:"architecture"`
- AvailabilityZone string `json:"availabilityZone"`
- BillingProducts any `json:"billingProducts"`
- DevpayProductCodes any `json:"devpayProductCodes"`
- MarketplaceProductCodes []string `json:"marketplaceProductCodes"`
- ImageID string `json:"imageId"`
- InstanceID string `json:"instanceId"`
- InstanceType string `json:"instanceType"`
- KernelID any `json:"kernelId"`
- PendingTime time.Time `json:"pendingTime"`
- PrivateIP string `json:"privateIp"`
- RamdiskID any `json:"ramdiskId"`
- Region string `json:"region"`
- Version string `json:"version"`
-}
-
-type License struct {
- Users int `json:"users"`
-}
diff --git a/pkg/logging/log.go b/pkg/logging/log.go
deleted file mode 100644
index 2b0f555..0000000
--- a/pkg/logging/log.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package logging
-
-import "fmt"
-
-var Loglevel = 3
-
-const LOG_ERROR = 1
-const LOG_INFO = 2
-const LOG_DEBUG = 16
-
-func DebugLog(err error) {
- if Loglevel&LOG_DEBUG == LOG_DEBUG {
- fmt.Println("debug: " + err.Error())
- }
-}
-
-func ErrorLog(err error) {
- if Loglevel&LOG_ERROR == LOG_ERROR {
- fmt.Println("error: " + err.Error())
- }
-}
-
-func InfoLog(info string) {
- if Loglevel&LOG_INFO == LOG_INFO {
- fmt.Println("info: " + info)
- }
-}
diff --git a/pkg/logging/log_test.go b/pkg/logging/log_test.go
deleted file mode 100644
index 09a0ca8..0000000
--- a/pkg/logging/log_test.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package logging
-
-import "testing"
-
-func TestLog(t *testing.T) {
- Loglevel = LOG_DEBUG
- if Loglevel&LOG_DEBUG != LOG_DEBUG {
- t.Fatalf("log level is not debugging1")
- }
- Loglevel = LOG_DEBUG + LOG_ERROR
- if Loglevel&LOG_DEBUG != LOG_DEBUG {
- t.Fatalf("log level is not debugging: %d vs %d", Loglevel&LOG_DEBUG, Loglevel)
- }
- Loglevel = LOG_ERROR
- if Loglevel&LOG_DEBUG == LOG_DEBUG {
- t.Fatalf("log level is debugging: %d vs %d", Loglevel&LOG_DEBUG, Loglevel)
- }
-
-}
diff --git a/pkg/mfa/totp/verify.go b/pkg/mfa/totp/verify.go
deleted file mode 100644
index 651afb3..0000000
--- a/pkg/mfa/totp/verify.go
+++ /dev/null
@@ -1,62 +0,0 @@
-package totp
-
-import (
- "bytes"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base32"
- "encoding/binary"
- "fmt"
- "strings"
- "time"
-)
-
-const INTERVAL = 30
-
-func GetToken(secret string, interval int64) (string, error) {
- key, err := base32.StdEncoding.DecodeString(strings.ToUpper(secret))
- if err != nil {
- return "", fmt.Errorf("base32 decode error: %s", err)
- }
- buf := make([]byte, 8)
- binary.BigEndian.PutUint64(buf, uint64(interval))
- hmacHash := hmac.New(sha1.New, key)
- hmacHash.Write(buf)
- h := hmacHash.Sum(nil)
- offset := (h[19] & 15)
-
- var header uint32
- r := bytes.NewReader(h[offset : offset+4])
- err = binary.Read(r, binary.BigEndian, &header)
-
- if err != nil {
- return "", fmt.Errorf("binary read error: %s", err)
- }
-
- return fmt.Sprintf("%06d", int((int(header)&0x7fffffff)%1000000)), nil
-}
-
-func Verify(secret, code string) (bool, error) {
- token, err := GetToken(secret, time.Now().Unix()/30)
- if err != nil {
- return false, fmt.Errorf("GetToken error: %s", err)
- }
- return token == code, nil
-}
-
-func VerifyMultipleIntervals(secret, code string, count int) (bool, error) {
- return verifyMultipleIntervals(secret, code, count, time.Now())
-}
-
-func verifyMultipleIntervals(secret, code string, count int, now time.Time) (bool, error) {
- for i := 0; i < count; i++ {
- token, err := GetToken(secret, now.Add(time.Duration(i)*time.Duration(-30)*time.Second).Unix()/30)
- if err != nil {
- return false, fmt.Errorf("GetToken error: %s", err)
- }
- if token == code {
- return true, nil
- }
- }
- return false, nil
-}
diff --git a/pkg/mfa/totp/verify_test.go b/pkg/mfa/totp/verify_test.go
deleted file mode 100644
index 5c52922..0000000
--- a/pkg/mfa/totp/verify_test.go
+++ /dev/null
@@ -1,49 +0,0 @@
-package totp
-
-import (
- "testing"
- "time"
-)
-
-func TestVerify(t *testing.T) { // validated with https://2fa.glitch.me/
- interval := int64(57275699)
- secret := "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"
- token, err := GetToken(secret, interval)
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- if token != "840823" {
- t.Fatalf("wrong token. Got: %s", token)
- }
-}
-
-func TestVerifyWrongSecret(t *testing.T) {
- interval := int64(57275699)
- secret := "wrong secret"
- _, err := GetToken(secret, interval)
- if err == nil {
- t.Fatalf("expected error")
- }
-}
-
-func TestVerifyMultipleIntervals(t *testing.T) {
- secret := "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"
- ok, err := verifyMultipleIntervals(secret, "312137", 20, time.Unix(1718272397, 0))
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- if !ok {
- t.Fatalf("no token matched")
- }
-}
-
-func TestVerifyMultipleIntervalsWrongToken(t *testing.T) {
- secret := "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"
- ok, err := verifyMultipleIntervals(secret, "312137", 20, time.Unix(1718272000, 0))
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- if ok {
- t.Fatalf("token matched, but shouldn't have")
- }
-}
diff --git a/pkg/observability/buffer.go b/pkg/observability/buffer.go
deleted file mode 100644
index 3cc1fa0..0000000
--- a/pkg/observability/buffer.go
+++ /dev/null
@@ -1,178 +0,0 @@
-package observability
-
-import (
- "bytes"
- "fmt"
- "io"
- "path"
- "strconv"
- "strings"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-func (o *Observability) WriteBufferToStorage(n int64) error {
- o.ActiveBufferWriters.Add(1)
- defer o.ActiveBufferWriters.Done()
- o.WriteLock.Lock()
- defer o.WriteLock.Unlock()
- logging.DebugLog(fmt.Errorf("writing buffer to file. Buffer has: %d bytes", n))
- // copy first to temporary buffer (storage might have latency)
- tempBuf := bytes.NewBuffer(make([]byte, 0, n))
- _, err := io.CopyN(tempBuf, o.Buffer, n)
- if err != nil && err != io.EOF {
- return fmt.Errorf("write error from buffer to temporary buffer: %s", err)
- }
- prefix := o.Buffer.ReadPrefix(n)
- o.LastFlushed = time.Now()
-
- for _, bufferPosAndPrefix := range mergeBufferPosAndPrefix(prefix) {
- now := time.Now()
- filename := bufferPosAndPrefix.prefix + "/data-" + strconv.FormatInt(now.Unix(), 10) + "-" + strconv.FormatUint(o.FlushOverflowSequence.Add(1), 10)
- err = ensurePath(o.Storage, filename)
- if err != nil {
- return fmt.Errorf("ensure path error: %s", err)
- }
- file, err := o.Storage.OpenFileForWriting(filename)
- if err != nil {
- return fmt.Errorf("open file for writing error: %s", err)
- }
- _, err = io.CopyN(file, tempBuf, int64(bufferPosAndPrefix.offset))
- if err != nil && err != io.EOF {
- return fmt.Errorf("file write error: %s", err)
- }
- logging.DebugLog(fmt.Errorf("wrote file: %s", filename))
- err = file.Close()
- if err != nil {
- return fmt.Errorf("file close error: %s", err)
- }
- }
- return nil
-}
-
-func (o *Observability) monitorBuffer() {
- for {
- time.Sleep(FLUSH_TIME_MAX_MINUTES * time.Minute)
- if time.Since(o.LastFlushed) >= (FLUSH_TIME_MAX_MINUTES*time.Minute) && o.Buffer.Len() > 0 {
- if o.FlushOverflow.CompareAndSwap(false, true) {
- err := o.WriteBufferToStorage(int64(o.Buffer.Len()))
- o.FlushOverflow.Swap(true)
- if err != nil {
- logging.ErrorLog(fmt.Errorf("write log buffer to storage error: %s", err))
- continue
- }
- }
- o.LastFlushed = time.Now()
- }
- }
-}
-
-func (o *Observability) Ingest(data io.ReadCloser) error {
- defer data.Close()
- msgs, err := Decode(data)
- if err != nil {
- return fmt.Errorf("decode error: %s", err)
- }
- logging.DebugLog(fmt.Errorf("messages ingested: %d", len(msgs)))
- if len(msgs) == 0 {
- return nil // no messages to ingest
- }
- _, err = o.Buffer.Write(encodeMessage(msgs), FloatToDate(msgs[0].Date).Format(DATE_PREFIX))
- if err != nil {
- return fmt.Errorf("write error: %s", err)
- }
- if o.Buffer.Len() >= o.MaxBufferSize {
- if o.FlushOverflow.CompareAndSwap(false, true) {
- go func() { // write to storage
- if n := o.Buffer.Len(); n >= o.MaxBufferSize {
- err := o.WriteBufferToStorage(int64(n))
- if err != nil {
- logging.ErrorLog(fmt.Errorf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err))
- }
- }
- o.FlushOverflow.Swap(false)
- }()
- }
- }
- return nil
-}
-
-func (o *Observability) Flush() error {
- // wait until all data is flushed
- o.ActiveBufferWriters.Wait()
-
- // flush remaining data that hasn't been flushed
- if n := o.Buffer.Len(); n >= 0 {
- err := o.WriteBufferToStorage(int64(n))
- if err != nil {
- return fmt.Errorf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err)
- }
- }
- return nil
-}
-
-func (c *ConcurrentRWBuffer) Write(p []byte, prefix string) (n int, err error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.prefix = append(c.prefix, BufferPosAndPrefix{prefix: prefix, offset: len(p)})
- return c.buffer.Write(p)
-}
-func (c *ConcurrentRWBuffer) Read(p []byte) (n int, err error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.buffer.Read(p)
-}
-func (c *ConcurrentRWBuffer) ReadPrefix(n int64) []BufferPosAndPrefix {
- c.mu.Lock()
- defer c.mu.Unlock()
- totalOffset := 0
- for k, v := range c.prefix {
- if int64(totalOffset+v.offset) == n {
- part1 := c.prefix[:k+1]
- part2 := make([]BufferPosAndPrefix, len(c.prefix[k+1:]))
- copy(part2, c.prefix[k+1:])
- c.prefix = part2
- return part1
- }
- totalOffset += v.offset
- }
- return nil
-}
-func (c *ConcurrentRWBuffer) Len() int {
- return c.buffer.Len()
-}
-func (c *ConcurrentRWBuffer) Cap() int {
- return c.buffer.Cap()
-}
-
-func ensurePath(storage storage.Iface, filename string) error {
- base := path.Dir(filename)
- baseSplit := strings.Split(base, "/")
- fullPath := ""
- for _, v := range baseSplit {
- fullPath = path.Join(fullPath, v)
- err := storage.EnsurePath(fullPath)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func mergeBufferPosAndPrefix(a []BufferPosAndPrefix) []BufferPosAndPrefix {
- bufferPosAndPrefix := []BufferPosAndPrefix{}
- for i := 0; i < len(a); i++ {
- offset := a[i].offset
- for y := i; y+1 < len(a) && a[y].prefix == a[y+1].prefix; y++ {
- offset += a[y+1].offset
- i++
- }
- bufferPosAndPrefix = append(bufferPosAndPrefix, BufferPosAndPrefix{
- prefix: a[i].prefix,
- offset: offset,
- })
- }
- return bufferPosAndPrefix
-}
diff --git a/pkg/observability/buffer_test.go b/pkg/observability/buffer_test.go
deleted file mode 100644
index d9c89d3..0000000
--- a/pkg/observability/buffer_test.go
+++ /dev/null
@@ -1,307 +0,0 @@
-package observability
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "slices"
- "strconv"
- "testing"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestIngestion(t *testing.T) {
- logging.Loglevel = logging.LOG_DEBUG
- totalMessagesToGenerate := 20
- storage := &memorystorage.MockMemoryStorage{}
- o := NewWithoutMonitor(storage, 20)
- o.Storage = storage
- payloads := IncomingData{}
- for i := 0; i < totalMessagesToGenerate/10; i++ {
- payloads = append(payloads, map[string]any{
- "date": 1720613813.197045,
- "log": "this is string: " + strconv.Itoa(i),
- })
- }
-
- for i := 0; i < totalMessagesToGenerate/len(payloads); i++ {
- payloadBytes, err := json.Marshal(payloads)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- data := io.NopCloser(bytes.NewReader(payloadBytes))
- err = o.Ingest(data)
- if err != nil {
- t.Fatalf("ingest error: %s", err)
- }
- }
-
- err := o.Flush()
- if err != nil {
- t.Fatalf("flush error: %s", err)
- }
-
- dirlist, err := storage.ReadDir("")
- if err != nil {
- t.Fatalf("read dir error: %s", err)
- }
-
- totalMessages := 0
- for _, file := range dirlist {
- messages, err := storage.ReadFile(file)
- if err != nil {
- t.Fatalf("read file error: %s", err)
- }
- decodedMessages := decodeMessages(messages)
- totalMessages += len(decodedMessages)
- }
- if len(dirlist) == 0 {
- t.Fatalf("expected multiple files in directory, got %d", len(dirlist))
- }
-
- if totalMessages != totalMessagesToGenerate {
- t.Fatalf("Tried to generate total message count of: %d; got: %d", totalMessagesToGenerate, totalMessages)
- }
-}
-
-func TestIngestionMoreMessages(t *testing.T) {
- t.Skip() // we can skip this for general unit testing
- totalMessagesToGenerate := 10000000 // 10,000,000
- storage := &memorystorage.MockMemoryStorage{}
- o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE)
- payload := IncomingData{
- {
- "date": 1720613813.197045,
- "log": "this is string: ",
- },
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
-
- for i := 0; i < totalMessagesToGenerate; i++ {
- data := io.NopCloser(bytes.NewReader(payloadBytes))
- err := o.Ingest(data)
- if err != nil {
- t.Fatalf("ingest error: %s", err)
- }
- }
-
- err = o.Flush()
- if err != nil {
- t.Fatalf("flush error: %s", err)
- }
-
- dirlist, err := storage.ReadDir("")
- if err != nil {
- t.Fatalf("read dir error: %s", err)
- }
-
- totalMessages := 0
- for _, file := range dirlist {
- messages, err := storage.ReadFile(file)
- if err != nil {
- t.Fatalf("read file error: %s", err)
- }
- decodedMessages := decodeMessages(messages)
- totalMessages += len(decodedMessages)
- }
- if len(dirlist) == 0 {
- t.Fatalf("expected multiple files in directory, got %d", len(dirlist))
- }
-
- if totalMessages != totalMessagesToGenerate {
- t.Fatalf("Tried to generate total message count of: %d; got: %d", totalMessagesToGenerate, totalMessages)
- }
- fmt.Printf("Buffer size (read+unread): %d in %d files\n", o.Buffer.Cap(), len(dirlist))
-
-}
-
-func BenchmarkIngest10000000(b *testing.B) {
- totalMessagesToGenerate := 10000000 // 10,000,000
- storage := &memorystorage.MockMemoryStorage{}
- o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE)
- payload := IncomingData{
- {
- "date": 1720613813.197045,
- "log": "this is string",
- },
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- b.Fatalf("marshal error: %s", err)
- }
-
- for i := 0; i < totalMessagesToGenerate; i++ {
- data := io.NopCloser(bytes.NewReader(payloadBytes))
- err := o.Ingest(data)
- if err != nil {
- b.Fatalf("ingest error: %s", err)
- }
- }
-
- // wait until all data is flushed
- o.ActiveBufferWriters.Wait()
-
- // flush remaining data that hasn't been flushed
- if n := o.Buffer.Len(); n >= 0 {
- err := o.WriteBufferToStorage(int64(n))
- if err != nil {
- b.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err)
- }
- }
-}
-
-func BenchmarkIngest100000000(b *testing.B) {
- totalMessagesToGenerate := 10000000 // 10,000,000
- storage := &memorystorage.MockMemoryStorage{}
- o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE)
- o.Storage = storage
- payload := IncomingData{
- {
- "date": 1720613813.197045,
- "log": "this is string",
- },
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- b.Fatalf("marshal error: %s", err)
- }
-
- for i := 0; i < totalMessagesToGenerate; i++ {
- data := io.NopCloser(bytes.NewReader(payloadBytes))
- err := o.Ingest(data)
- if err != nil {
- b.Fatalf("ingest error: %s", err)
- }
- }
-
- // wait until all data is flushed
- o.ActiveBufferWriters.Wait()
-
- // flush remaining data that hasn't been flushed
- if n := o.Buffer.Len(); n >= 0 {
- err := o.WriteBufferToStorage(int64(n))
- if err != nil {
- b.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err)
- }
- }
-}
-
-func TestEnsurePath(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
- err := ensurePath(storage, "a/b/c/filename.txt")
- if err != nil {
- t.Fatalf("error: %s", err)
- }
-}
-
-func TestMergeBufferPosAndPrefix(t *testing.T) {
- testCase1 := []BufferPosAndPrefix{
- {
- prefix: "abc",
- offset: 3,
- },
- {
- prefix: "abc",
- offset: 9,
- },
- {
- prefix: "abc",
- offset: 2,
- },
- {
- prefix: "abc2",
- offset: 3,
- },
- {
- prefix: "abc2",
- offset: 2,
- },
- {
- prefix: "abc3",
- offset: 2,
- },
- }
- expected1 := []BufferPosAndPrefix{
- {
- prefix: "abc",
- offset: 14,
- },
- {
- prefix: "abc2",
- offset: 5,
- },
- {
- prefix: "abc3",
- offset: 2,
- },
- }
- res := mergeBufferPosAndPrefix(testCase1)
- if !slices.Equal(res, expected1) {
- t.Fatalf("test case is not equal to expected\nGot: %+v\nExpected:%+v\n", res, expected1)
- }
-}
-
-func TestReadPrefix(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
- o := NewWithoutMonitor(storage, MAX_BUFFER_SIZE)
- o.Buffer.prefix = []BufferPosAndPrefix{
- {
- prefix: "abc",
- offset: 3,
- },
- {
- prefix: "abc",
- offset: 9,
- },
- {
- prefix: "abc",
- offset: 2,
- },
- {
- prefix: "abc2",
- offset: 3,
- },
- {
- prefix: "abc2",
- offset: 2,
- },
- {
- prefix: "abc3",
- offset: 2,
- },
- }
- expected1 := []BufferPosAndPrefix{
- {
- prefix: "abc",
- offset: 3,
- },
- {
- prefix: "abc",
- offset: 9,
- },
- {
- prefix: "abc",
- offset: 2,
- },
- }
- expected2 := []BufferPosAndPrefix{
- {
- prefix: "abc2",
- offset: 3,
- },
- }
- res := o.Buffer.ReadPrefix(int64(o.Buffer.prefix[0].offset + o.Buffer.prefix[1].offset + o.Buffer.prefix[2].offset))
- if !slices.Equal(res, expected1) {
- t.Fatalf("test case is not equal to expected\nGot: %+v\nExpected:%+v\n", res, expected1)
- }
- res2 := o.Buffer.ReadPrefix(3)
- if !slices.Equal(res2, expected2) {
- t.Fatalf("test case is not equal to expected\nGot: %+v\nExpected:%+v\n", res, expected2)
- }
-}
diff --git a/pkg/observability/constants.go b/pkg/observability/constants.go
deleted file mode 100644
index cb5c72f..0000000
--- a/pkg/observability/constants.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package observability
-
-const MAX_BUFFER_SIZE = 1024 * 1024 // 1 MB
-const FLUSH_TIME_MAX_MINUTES = 1 // should have 5 as default at release
-
-const TIMESTAMP_FORMAT = "2006-01-02T15:04:05"
-
-const DATE_PREFIX = "2006/01/02"
diff --git a/pkg/observability/decoding.go b/pkg/observability/decoding.go
deleted file mode 100644
index cbd9062..0000000
--- a/pkg/observability/decoding.go
+++ /dev/null
@@ -1,119 +0,0 @@
-package observability
-
-import (
- "encoding/binary"
- "encoding/json"
- "fmt"
- "io"
- "math"
- "reflect"
- "strconv"
-)
-
-func Decode(r io.Reader) ([]FluentBitMessage, error) {
- var result []FluentBitMessage
- var msg interface{}
-
- err := json.NewDecoder(r).Decode(&msg)
- if err != nil {
- return result, err
- }
- switch m1 := msg.(type) {
- case []interface{}:
- if len(m1) == 0 {
- return result, fmt.Errorf("empty array")
- }
- for _, m1Element := range m1 {
- switch m2 := m1Element.(type) {
- case map[string]interface{}:
- var fluentBitMessage FluentBitMessage
- fluentBitMessage.Data = make(map[string]string)
- val, ok := m2["date"]
- if ok {
- fluentBitMessage.Date = val.(float64)
- }
- for key, value := range m2 {
- if key != "date" {
- switch valueTyped := value.(type) {
- case string:
- fluentBitMessage.Data[key] = valueTyped
- case float64:
- fluentBitMessage.Data[key] = strconv.FormatFloat(valueTyped, 'f', -1, 64)
- case []byte:
- fluentBitMessage.Data[key] = string(valueTyped)
- default:
- fmt.Printf("no hit on type: %s", reflect.TypeOf(valueTyped))
- }
- }
- }
- result = append(result, fluentBitMessage)
- default:
- return result, fmt.Errorf("invalid type: no map found in array")
- }
- }
- default:
- return result, fmt.Errorf("invalid type: no array found")
- }
- return result, nil
-}
-
-func decodeMessages(msgs []byte) []FluentBitMessage {
- res := []FluentBitMessage{}
- recordOffset := 0
- for k := 0; k < len(msgs); k++ {
- if k > recordOffset+8 && msgs[k] == 0xff && msgs[k-1] == 0xff {
- res = append(res, decodeMessage(msgs[recordOffset:k]))
- recordOffset = k + 1
- }
- }
- return res
-}
-func decodeMessage(data []byte) FluentBitMessage {
- bits := binary.LittleEndian.Uint64(data[0:8])
- msg := FluentBitMessage{
- Date: math.Float64frombits(bits),
- Data: map[string]string{},
- }
- isKey := false
- key := ""
- start := 8
- for kk := start; kk < len(data); kk++ {
- if data[kk] == 0xff {
- if isKey {
- isKey = false
- msg.Data[key] = string(data[start+1 : kk])
- start = kk + 1
- } else {
- isKey = true
- key = string(data[start:kk])
- start = kk
- }
- }
- }
- // if last record didn't end with the terminator
- if data[len(data)-1] != 0xff {
- msg.Data[key] = string(data[start+1:])
- }
- return msg
-}
-
-func scanMessage(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- for i := 0; i < len(data); i++ {
- if data[i] == 0xff && data[i-1] == 0xff {
- return i + 1, data[0 : i-1], nil
- }
- }
- // If we're at EOF, we have a final, non-terminated line. Return it.
- if atEOF {
- if len(data) > 1 && data[len(data)-1] == 0xff && data[len(data)-2] == 0xff {
- return len(data[0 : len(data)-2]), data, nil
- } else {
- return len(data), data, nil
- }
- }
- // Request more data.
- return 0, nil, nil
-}
diff --git a/pkg/observability/decoding_test.go b/pkg/observability/decoding_test.go
deleted file mode 100644
index 82fa736..0000000
--- a/pkg/observability/decoding_test.go
+++ /dev/null
@@ -1,220 +0,0 @@
-package observability
-
-import (
- "bytes"
- "encoding/json"
- "testing"
-)
-
-func TestDecoding(t *testing.T) {
- payload := IncomingData{
- {
- "date": 1720613813.197045,
- "rand_value": "rand",
- },
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("json marshal error: %s", err)
- }
- messages, err := Decode(bytes.NewBuffer(payloadBytes))
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- if len(messages) == 0 {
- t.Fatalf("no messages returned")
- }
- if messages[0].Date != 1720613813.197045 {
- t.Fatalf("wrong date returned")
- }
- val, ok := messages[0].Data["rand_value"]
- if !ok {
- t.Fatalf("rand_value key not found")
- }
- if string(val) != "rand" {
- t.Fatalf("wrong data returned: %s", val)
- }
-}
-
-func TestDecodingMultiMessage(t *testing.T) {
- payload := IncomingData{
- {
- "date": 1727119152.0,
- "container_name": "/fluentbit-nginx-1",
- "source": "stdout",
- "log": "/docker-entrypoint.sh: /docker-entrypoint.d/ is not empty, will attempt to perform configuration",
- "container_id": "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186",
- },
- {
- "date": 1727119152.0,
- "source": "stdout",
- "log": "/docker-entrypoint.sh: Looking for shell scripts in /docker-entrypoint.d/",
- "container_id": "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186",
- "container_name": "/fluentbit-nginx-1",
- },
- {
- "date": 1727119152.0,
- "container_id": "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186",
- "container_name": "/fluentbit-nginx-1",
- "source": "stdout",
- "log": "/docker-entrypoint.sh: Launching /docker-entrypoint.d/10-listen-on-ipv6-by-default.sh",
- },
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("json marshal error: %s", err)
- }
- messages, err := Decode(bytes.NewBuffer(payloadBytes))
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- if len(messages) != len(payload) {
- t.Fatalf("incorrect messages returned. Got %d, expected %d", len(messages), len(payload))
- }
- val, ok := messages[2].Data["container_id"]
- if !ok {
- t.Fatalf("container_id key not found")
- }
- if string(val) != "7a9c8ae0ca6c5434b778fa0a2e74e038710b3f18dedb3478235291832f121186" {
- t.Fatalf("wrong data returned: %s", val)
- }
-}
-
-func TestDecodeMessages(t *testing.T) {
- msgs := []FluentBitMessage{
- {
- Date: 1720613813.197045,
- Data: map[string]string{
- "mykey": "this is myvalue",
- "second key": "this is my second value",
- "third key": "this is my third value",
- },
- },
- {
- Date: 1720613813.197099,
- Data: map[string]string{
- "second data set": "my value",
- },
- },
- }
- encoded := encodeMessage(msgs)
- decoded := decodeMessages(encoded)
-
- if len(msgs) != len(decoded) {
- t.Fatalf("length doesn't match")
- }
- for k := range decoded {
- if msgs[k].Date != decoded[k].Date {
- t.Fatalf("date doesn't match")
- }
- if len(msgs[k].Data) != len(decoded[k].Data) {
- t.Fatalf("length of data doesn't match")
- }
- for kk := range decoded[k].Data {
- if msgs[k].Data[kk] != decoded[k].Data[kk] {
- t.Fatalf("key/value pair in data doesn't match: key: %s. Data: %s vs %s", kk, msgs[k].Data[kk], decoded[k].Data[kk])
- }
- }
- }
-}
-
-func TestDecodeMessage(t *testing.T) {
- msgs := []FluentBitMessage{
- {
- Date: 1720613813.197099,
- Data: map[string]string{
- "second data set": "my value",
- },
- },
- }
- encoded := encodeMessage(msgs)
- message := decodeMessage(encoded)
-
- if message.Date != message.Date {
- t.Fatalf("date doesn't match")
- }
- if len(msgs[0].Data) != len(message.Data) {
- t.Fatalf("length of data doesn't match")
- }
- for kk := range message.Data {
- if msgs[0].Data[kk] != message.Data[kk] {
- t.Fatalf("key/value pair in data doesn't match: key: %s. Data: %s vs %s", kk, message.Data[kk], message.Data[kk])
- }
- }
-}
-func TestDecodeMessageWithoutTerminator(t *testing.T) {
- msgs := []FluentBitMessage{
- {
- Date: 1720613813.197099,
- Data: map[string]string{
- "second data set": "my value",
- },
- },
- }
- encoded := encodeMessage(msgs)
- message := decodeMessage(bytes.TrimSuffix(encoded, []byte{0xff, 0xff}))
-
- if message.Date != message.Date {
- t.Fatalf("date doesn't match")
- }
- if len(msgs[0].Data) != len(message.Data) {
- t.Fatalf("length of data doesn't match: got: '%s', expected '%s'", message.Data, msgs[0].Data)
- }
- for kk := range message.Data {
- if msgs[0].Data[kk] != message.Data[kk] {
- t.Fatalf("key/value pair in data doesn't match: key: %s. Data: %s vs %s", kk, message.Data[kk], msgs[0].Data[kk])
- }
- }
-}
-
-func TestScanMessage(t *testing.T) {
- msgs := []FluentBitMessage{
- {
- Date: 1720613813.197045,
- Data: map[string]string{
- "mykey": "this is myvalue",
- "second key": "this is my second value",
- "third key": "this is my third value",
- },
- },
- {
- Date: 1720613813.197099,
- Data: map[string]string{
- "second data set": "my value",
- },
- },
- }
- encoded := encodeMessage(msgs)
- // first record
- advance, record1, err := scanMessage(encoded, false)
- if err != nil {
- t.Fatalf("scan lines error: %s", err)
- }
- firstRecord := decodeMessages(append(record1, []byte{0xff, 0xff}...))
- if len(firstRecord) == 0 {
- t.Fatalf("first record is empty")
- }
- if firstRecord[0].Data["mykey"] != "this is myvalue" {
- t.Fatalf("wrong data returned")
- }
- // second record
- advance2, record2, err := scanMessage(encoded[advance:], false)
- if err != nil {
- t.Fatalf("scan lines error: %s", err)
- }
- secondRecord := decodeMessages(append(record2, []byte{0xff, 0xff}...))
- if len(secondRecord) == 0 {
- t.Fatalf("first record is empty")
- }
- if secondRecord[0].Data["second data set"] != "my value" {
- t.Fatalf("wrong data returned in second record")
- }
- // third call
- advance3, record3, err := scanMessage(encoded[advance+advance2:], false)
- if err != nil {
- t.Fatalf("scan lines error: %s", err)
- }
- if advance3 != 0 {
- t.Fatalf("third record should be empty. Got: %+v", record3)
- }
-}
diff --git a/pkg/observability/encoding.go b/pkg/observability/encoding.go
deleted file mode 100644
index c805b8f..0000000
--- a/pkg/observability/encoding.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package observability
-
-import (
- "bytes"
- "encoding/binary"
- "math"
-)
-
-func encodeMessage(msgs []FluentBitMessage) []byte {
- out := bytes.Buffer{}
- for _, msg := range msgs {
- var buf [8]byte
- binary.LittleEndian.PutUint64(buf[:], math.Float64bits(msg.Date))
- out.Write(buf[:])
- for key, msgData := range msg.Data {
- out.Write([]byte(key))
- out.Write([]byte{0xff})
- out.Write([]byte(msgData))
- out.Write([]byte{0xff})
- }
- out.Write([]byte{0xff})
- }
- return out.Bytes()
-}
diff --git a/pkg/observability/handlers.go b/pkg/observability/handlers.go
deleted file mode 100644
index 884272f..0000000
--- a/pkg/observability/handlers.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package observability
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "strconv"
- "strings"
- "time"
-)
-
-func (o *Observability) observabilityHandler(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNotFound)
-}
-
-func (o *Observability) ingestionHandler(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost {
- w.WriteHeader(http.StatusBadRequest)
- return
- }
-
- if err := o.Ingest(r.Body); err != nil {
- w.WriteHeader(http.StatusBadRequest)
- fmt.Printf("error: %s", err)
- return
- }
- w.WriteHeader(http.StatusOK)
-}
-
-func (o *Observability) logsHandler(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- if r.FormValue("fromDate") == "" {
- o.returnError(w, fmt.Errorf("no from date supplied"), http.StatusBadRequest)
- return
- }
- fromDate, err := time.Parse("2006-01-02", r.FormValue("fromDate"))
- if err != nil {
- o.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest)
- return
- }
- if r.FormValue("endDate") == "" {
- o.returnError(w, fmt.Errorf("no end date supplied"), http.StatusBadRequest)
- return
- }
- endDate, err := time.Parse("2006-01-02", r.FormValue("endDate"))
- if err != nil {
- o.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest)
- return
- }
- offset := 0
- if r.FormValue("offset") != "" {
- i, err := strconv.Atoi(r.FormValue("offset"))
- if err == nil {
- offset = i
- }
- }
- maxLines := 0
- if r.FormValue("maxLines") != "" {
- i, err := strconv.Atoi(r.FormValue("maxLines"))
- if err == nil {
- maxLines = i
- }
- }
- pos := int64(0)
- if r.FormValue("pos") != "" {
- i, err := strconv.ParseInt(r.FormValue("pos"), 10, 64)
- if err == nil {
- pos = i
- }
- }
- displayTags := strings.Split(r.FormValue("display-tags"), ",")
- filterTagsSplit := strings.Split(r.FormValue("filter-tags"), ",")
- filterTags := []KeyValue{}
- for _, tag := range filterTagsSplit {
- kv := strings.Split(tag, "=")
- if len(kv) == 2 {
- filterTags = append(filterTags, KeyValue{Key: kv[0], Value: kv[1]})
- }
- }
- out, err := o.getLogs(fromDate, endDate, pos, maxLines, offset, r.FormValue("search"), displayTags, filterTags)
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- fmt.Printf("get logs error: %s", err)
- return
- }
- outBytes, err := json.Marshal(out)
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- fmt.Printf("marshal error: %s", err)
- return
- }
- w.Write(outBytes)
-}
diff --git a/pkg/observability/handlers_test.go b/pkg/observability/handlers_test.go
deleted file mode 100644
index dc212a4..0000000
--- a/pkg/observability/handlers_test.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package observability
-
-import (
- "bytes"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
-
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestIngestionHandler(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
- o := NewWithoutMonitor(storage, 20)
- o.Storage = storage
- payload := IncomingData{
- {
- "date": 1720613813.197045,
- "log": "this is a string",
- },
- }
-
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- req := httptest.NewRequest(http.MethodPost, "/api/observability/ingestion/json", bytes.NewReader(payloadBytes))
- w := httptest.NewRecorder()
- o.ingestionHandler(w, req)
- res := w.Result()
-
- if res.StatusCode != http.StatusOK {
- t.Fatalf("expected status code OK. Got: %d", res.StatusCode)
- }
-
- // wait until all data is flushed
- o.ActiveBufferWriters.Wait()
-
- // flush remaining data that hasn't been flushed
- if n := o.Buffer.Len(); n >= 0 {
- err := o.WriteBufferToStorage(int64(n))
- if err != nil {
- t.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err)
- }
- }
-
- dirlist, err := storage.ReadDir("")
- if err != nil {
- t.Fatalf("read dir error: %s", err)
- }
- if len(dirlist) == 0 {
- t.Fatalf("dir is empty")
- }
- messages, err := storage.ReadFile(dirlist[0])
- if err != nil {
- t.Fatalf("read file error: %s", err)
- }
- decodedMessages := decodeMessages(messages)
- if decodedMessages[0].Date != 1720613813.197045 {
- t.Fatalf("unexpected date. Got %f, expected: %f", decodedMessages[0].Date, 1720613813.197045)
- }
- if decodedMessages[0].Data["log"] != "this is a string" {
- t.Fatalf("unexpected log data")
- }
-}
diff --git a/pkg/observability/helpers.go b/pkg/observability/helpers.go
deleted file mode 100644
index 2c98ee3..0000000
--- a/pkg/observability/helpers.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package observability
-
-import (
- "fmt"
- "math"
- "net/http"
- "strings"
- "time"
-)
-
-func (o *Observability) returnError(w http.ResponseWriter, err error, statusCode int) {
- fmt.Println("========= ERROR =========")
- fmt.Printf("Error: %s\n", err)
- fmt.Println("=========================")
- w.WriteHeader(statusCode)
- w.Write([]byte(`{"error": "` + strings.Replace(err.Error(), `"`, `\"`, -1) + `"}`))
-}
-
-func FloatToDate(datetime float64) time.Time {
- datetimeInt := int64(datetime)
- decimals := datetime - float64(datetimeInt)
- nsecs := int64(math.Round(decimals * 1_000_000)) // precision to match golang's time.Time
- return time.Unix(datetimeInt, nsecs*1000)
-}
-
-func DateToFloat(datetime time.Time) float64 {
- seconds := float64(datetime.Unix())
- nanoseconds := float64(datetime.Nanosecond()) / 1e9
- fmt.Printf("nanosec: %f", nanoseconds)
- return seconds + nanoseconds
-}
diff --git a/pkg/observability/helpers_test.go b/pkg/observability/helpers_test.go
deleted file mode 100644
index 195156d..0000000
--- a/pkg/observability/helpers_test.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package observability
-
-import (
- "testing"
- "time"
-)
-
-func TestFloatToDate2Way(t *testing.T) {
- now := time.Now()
- float := DateToFloat(now)
- date := FloatToDate(float)
- if date.Format(TIMESTAMP_FORMAT) != now.Format(TIMESTAMP_FORMAT) {
- t.Fatalf("got: %s, expected: %s", date.Format(TIMESTAMP_FORMAT), now.Format(TIMESTAMP_FORMAT))
- }
-}
diff --git a/pkg/observability/logs.go b/pkg/observability/logs.go
deleted file mode 100644
index 772770b..0000000
--- a/pkg/observability/logs.go
+++ /dev/null
@@ -1,116 +0,0 @@
-package observability
-
-import (
- "bufio"
- "fmt"
- "sort"
- "strings"
- "time"
-)
-
-func (o *Observability) getLogs(fromDate, endDate time.Time, pos int64, maxLogLines, offset int, search string, displayTags []string, filterTags []KeyValue) (LogEntryResponse, error) {
- logEntryResponse := LogEntryResponse{
- Enabled: true,
- LogEntries: []LogEntry{},
- Tags: KeyValueInt{},
- }
-
- keys := make(map[KeyValue]int)
-
- logFiles := []string{}
-
- if maxLogLines == 0 {
- maxLogLines = 100
- }
-
- for d := fromDate; d.Before(endDate) || d.Equal(endDate); d = d.AddDate(0, 0, 1) {
- fileList, err := o.Storage.ReadDir(d.Format(DATE_PREFIX))
- if err != nil {
- logEntryResponse.NextPos = -1
- return logEntryResponse, nil // can't read directory, return empty response
- }
- for _, filename := range fileList {
- logFiles = append(logFiles, d.Format(DATE_PREFIX)+"/"+filename)
- }
- }
-
- fileReaders, err := o.Storage.OpenFilesFromPos(logFiles, pos)
- if err != nil {
- return logEntryResponse, fmt.Errorf("error while reading files: %s", err)
- }
- for _, fileReader := range fileReaders {
- defer fileReader.Close()
- }
-
- for _, logInputData := range fileReaders { // read multiple files
- if len(logEntryResponse.LogEntries) >= maxLogLines {
- break
- }
- scanner := bufio.NewScanner(logInputData)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- advance, token, err = scanMessage(data, atEOF)
- pos += int64(advance)
- return
- })
- for scanner.Scan() && len(logEntryResponse.LogEntries) < maxLogLines { // read multiple lines
- // decode, store as logentry
- logMessage := decodeMessage(scanner.Bytes())
- logline, ok := logMessage.Data["log"]
- if ok {
- timestamp := FloatToDate(logMessage.Date).Add(time.Duration(offset) * time.Minute)
- if search == "" || strings.Contains(logline, search) {
- tags := []KeyValue{}
- for _, tag := range displayTags {
- if tagValue, ok := logMessage.Data[tag]; ok {
- tags = append(tags, KeyValue{Key: tag, Value: tagValue})
- }
- }
- filterMessage := true
- if len(filterTags) == 0 {
- filterMessage = false
- } else {
- for _, filter := range filterTags {
- if tagValue, ok := logMessage.Data[filter.Key]; ok {
- if tagValue == filter.Value {
- filterMessage = false
- }
- }
- }
- }
- if !filterMessage {
- logEntry := LogEntry{
- Timestamp: timestamp.Format(TIMESTAMP_FORMAT),
- Data: logline,
- Tags: tags,
- }
- logEntryResponse.LogEntries = append(logEntryResponse.LogEntries, logEntry)
- for k, v := range logMessage.Data {
- if k != "log" {
- keys[KeyValue{Key: k, Value: v}] += 1
- }
- }
- }
- }
- }
- }
- if err := scanner.Err(); err != nil {
- return logEntryResponse, fmt.Errorf("log file read (scanner) error: %s", err)
- }
- }
- if len(logEntryResponse.LogEntries) < maxLogLines {
- logEntryResponse.NextPos = -1 // no more records
- } else {
- logEntryResponse.NextPos = pos
- }
-
- for k, v := range keys {
- logEntryResponse.Tags = append(logEntryResponse.Tags, KeyValueTotal{
- Key: k.Key,
- Value: k.Value,
- Total: v,
- })
- }
- sort.Sort(logEntryResponse.Tags)
-
- return logEntryResponse, nil
-}
diff --git a/pkg/observability/logs_test.go b/pkg/observability/logs_test.go
deleted file mode 100644
index 1f4e34f..0000000
--- a/pkg/observability/logs_test.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package observability
-
-import (
- "bytes"
- "encoding/json"
- "io"
- "strconv"
- "strings"
- "testing"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestGetLogs(t *testing.T) {
- logging.Loglevel = logging.LOG_DEBUG
- totalMessagesToGenerate := 100
- storage := &memorystorage.MockMemoryStorage{}
- o := NewWithoutMonitor(storage, 20)
- timestamp := DateToFloat(time.Now())
- payload := IncomingData{
- {
- "date": timestamp,
- "log": "this is string: ",
- },
- }
-
- for i := 0; i < totalMessagesToGenerate; i++ {
- payload[0]["log"] = "this is string: " + strconv.Itoa(i)
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- data := io.NopCloser(bytes.NewReader(payloadBytes))
- err = o.Ingest(data)
- if err != nil {
- t.Fatalf("ingest error: %s", err)
- }
- }
-
- // wait until all data is flushed
- o.ActiveBufferWriters.Wait()
-
- // flush remaining data that hasn't been flushed
- if n := o.Buffer.Len(); n >= 0 {
- err := o.WriteBufferToStorage(int64(n))
- if err != nil {
- t.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err)
- }
- }
-
- now := time.Now()
- maxLogLines := 100
- search := ""
-
- logEntryResponse, err := o.getLogs(now, now, 0, maxLogLines, 0, search, []string{}, []KeyValue{})
- if err != nil {
- t.Fatalf("get logs error: %s", err)
- }
- if len(logEntryResponse.LogEntries) != totalMessagesToGenerate {
- t.Fatalf("didn't get the same log entries as messaged we generated: got: %d, expected: %d", len(logEntryResponse.LogEntries), totalMessagesToGenerate)
- }
- if logEntryResponse.LogEntries[0].Timestamp != FloatToDate(timestamp).Format(TIMESTAMP_FORMAT) {
- t.Fatalf("unexpected timestamp: %s vs %s", logEntryResponse.LogEntries[0].Timestamp, FloatToDate(timestamp).Format(TIMESTAMP_FORMAT))
- }
-}
-
-func TestFloatToDate(t *testing.T) {
- for i := 0; i < 10; i++ {
- now := time.Now()
- floatDate := float64(now.Unix()) + float64(now.Nanosecond())/1e9
- floatToDate := FloatToDate(floatDate)
- if now.Unix() != floatToDate.Unix() {
- t.Fatalf("times are not equal. Got: %v, expected: %v", floatToDate, now)
- }
- /*if now.UnixNano() != floatToDate.UnixNano() {
- t.Fatalf("times are not equal. Got: %v, expected: %v", floatToDate, now)
- }*/
- }
-}
-
-func TestKeyValue(t *testing.T) {
- logEntryResponse := LogEntryResponse{
- Tags: KeyValueInt{
- {Key: "k", Value: "v", Total: 4},
- },
- }
- out, err := json.Marshal(logEntryResponse)
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- if !strings.Contains(string(out), `"tags":[{"key":"k","value":"v","total":4}]`) {
- t.Fatalf("wrong output: %s", out)
- }
-}
diff --git a/pkg/observability/new.go b/pkg/observability/new.go
deleted file mode 100644
index 7bf9f0b..0000000
--- a/pkg/observability/new.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package observability
-
-import (
- "net/http"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-func New(defaultStorage storage.Iface) *Observability {
- o := NewWithoutMonitor(defaultStorage, MAX_BUFFER_SIZE)
- go o.monitorBuffer()
- return o
-}
-func NewWithoutMonitor(storage storage.Iface, maxBufferSize int) *Observability {
- o := &Observability{
- Buffer: &ConcurrentRWBuffer{},
- MaxBufferSize: maxBufferSize,
- Storage: storage,
- }
- return o
-}
-
-type Iface interface {
- GetRouter() *http.ServeMux
-}
diff --git a/pkg/observability/router.go b/pkg/observability/router.go
deleted file mode 100644
index 7085fca..0000000
--- a/pkg/observability/router.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package observability
-
-import "net/http"
-
-func (o *Observability) GetRouter() *http.ServeMux {
- mux := http.NewServeMux()
- mux.Handle("/api/observability/", http.HandlerFunc(o.observabilityHandler))
- mux.Handle("/api/observability/ingestion/json", http.HandlerFunc(o.ingestionHandler))
- mux.Handle("/api/observability/logs", http.HandlerFunc(o.logsHandler))
-
- return mux
-}
diff --git a/pkg/observability/types.go b/pkg/observability/types.go
deleted file mode 100644
index b718bac..0000000
--- a/pkg/observability/types.go
+++ /dev/null
@@ -1,89 +0,0 @@
-package observability
-
-import (
- "bytes"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-type IncomingData []map[string]any
-
-type FluentBitMessage struct {
- Date float64 `json:"date"`
- Data map[string]string `json:"data"`
-}
-
-type Observability struct {
- Storage storage.Iface
- Buffer *ConcurrentRWBuffer
- LastFlushed time.Time
- FlushOverflow atomic.Bool
- FlushOverflowSequence atomic.Uint64
- ActiveBufferWriters sync.WaitGroup
- WriteLock sync.Mutex
- MaxBufferSize int
-}
-
-type ConcurrentRWBuffer struct {
- buffer bytes.Buffer
- prefix []BufferPosAndPrefix
- mu sync.Mutex
-}
-
-type BufferPosAndPrefix struct {
- prefix string
- offset int
-}
-
-type LogEntryResponse struct {
- Enabled bool `json:"enabled"`
- LogEntries []LogEntry `json:"logEntries"`
- Tags KeyValueInt `json:"tags"`
- NextPos int64 `json:"nextPos"`
-}
-
-type LogEntry struct {
- Timestamp string `json:"timestamp"`
- Data string `json:"data"`
- Tags []KeyValue `json:"tags"`
-}
-
-type KeyValueInt []KeyValueTotal
-
-type KeyValueTotal struct {
- Key string
- Value string
- Total int
-}
-type KeyValue struct {
- Key string `json:"key"`
- Value string `json:"value"`
-}
-
-func (kv KeyValueInt) MarshalJSON() ([]byte, error) {
- res := "["
- for _, v := range kv {
- res += `{ "key" : "` + v.Key + `", "value": "` + v.Value + `", "total": ` + strconv.Itoa(v.Total) + ` },`
- }
- res = strings.TrimRight(res, ",")
- res += "]"
- return []byte(res), nil
-}
-
-func (kv KeyValueInt) Len() int {
- return len(kv)
-}
-func (kv KeyValueInt) Less(i, j int) bool {
- if kv[i].Key == kv[j].Key {
- return kv[i].Value < kv[j].Value
- }
- return kv[i].Key < kv[j].Key
-}
-func (kv KeyValueInt) Swap(i, j int) {
- kv[i], kv[j] = kv[j], kv[i]
-}
diff --git a/pkg/rest/auditlog/logentry.go b/pkg/rest/auditlog/logentry.go
deleted file mode 100644
index 6a09cb0..0000000
--- a/pkg/rest/auditlog/logentry.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package auditlog
-
-import (
- "encoding/json"
- "fmt"
- "path"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-const TIMESTAMP_FORMAT = "2006-01-02T15:04:05"
-const AUDITLOG_STATS_DIR = "stats"
-
-type LogEntry struct {
- Timestamp LogTimestamp `json:"timestamp"`
- UserID string `json:"userID"`
- Action string `json:"action"`
-}
-type LogTimestamp time.Time
-
-func (t LogTimestamp) MarshalJSON() ([]byte, error) {
- timestamp := fmt.Sprintf("\"%s\"", time.Time(t).Format(TIMESTAMP_FORMAT))
- return []byte(timestamp), nil
-}
-
-func Write(storage storage.Iface, logEntry LogEntry) error {
- statsPath := path.Join(AUDITLOG_STATS_DIR, "logins-"+time.Now().Format("2006-01-02")) + ".log"
- logEntryBytes, err := json.Marshal(logEntry)
- if err != nil {
- return fmt.Errorf("could not parse log entry: %s", err)
- }
- err = storage.AppendFile(statsPath, logEntryBytes)
- if err != nil {
- return fmt.Errorf("could not append stats to file (%s): %s", statsPath, err)
- }
-
- return nil
-}
diff --git a/pkg/rest/auth.go b/pkg/rest/auth.go
deleted file mode 100644
index 5196ab5..0000000
--- a/pkg/rest/auth.go
+++ /dev/null
@@ -1,419 +0,0 @@
-package rest
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "strings"
- "time"
-
- "github.com/google/uuid"
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store"
- "github.com/in4it/wireguard-server/pkg/auth/saml"
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/rest/login"
-)
-
-func (c *Context) authHandler(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost {
- c.returnError(w, fmt.Errorf("not a post request"), http.StatusBadRequest)
- return
- }
-
- if c.LocalAuthDisabled {
- c.returnError(w, fmt.Errorf("local auth is disabled in settings"), http.StatusForbidden)
- return
- }
-
- decoder := json.NewDecoder(r.Body)
- var loginReq login.LoginRequest
- err := decoder.Decode(&loginReq)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
-
- // check login attempts
- tooManyLogins := login.CheckTooManyLogins(c.LoginAttempts, loginReq.Login)
- if tooManyLogins {
- c.returnError(w, fmt.Errorf("too many login failures, try again later"), http.StatusTooManyRequests)
- return
- }
-
- loginResponse, user, err := login.Authenticate(loginReq, c.UserStore, c.JWTKeys.PrivateKey, c.JWTKeysKID)
- if err != nil {
- c.returnError(w, fmt.Errorf("authentication error: %s", err), http.StatusBadRequest)
- return
- }
- out, err := json.Marshal(loginResponse)
- if err != nil {
- c.returnError(w, fmt.Errorf("unable to marshal response: %s", err), http.StatusBadRequest)
- return
- }
- if loginResponse.MFARequired {
- c.write(w, out) // status ok, but unauthorized, because we need a second call with MFA code
- return
- } else if loginResponse.Authenticated {
- login.ClearAttemptsForLogin(c.LoginAttempts, loginReq.Login)
- user.LastLogin = time.Now()
- err = c.UserStore.UpdateUser(user)
- if err != nil {
- logging.ErrorLog(fmt.Errorf("last login update error: %s", err))
- }
- c.write(w, out)
- } else {
- // log login attempts
- login.RecordAttempt(c.LoginAttempts, loginReq.Login)
- // return Unauthorized
- c.writeWithStatus(w, out, http.StatusUnauthorized)
- }
-}
-
-func (c *Context) oidcProviderHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- oidcProviders := make([]oidc.OIDCProvider, len(c.OIDCProviders))
- copy(oidcProviders, c.OIDCProviders)
- for k := range oidcProviders {
- oidcProviders[k].LoginURL = fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, strings.Replace(oidcProviders[k].RedirectURI, "/callback/", "/login/", -1))
- oidcProviders[k].RedirectURI = fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProviders[k].RedirectURI)
- }
- out, err := json.Marshal(oidcProviders)
- if err != nil {
- c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- var oidcProvider oidc.OIDCProvider
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&oidcProvider)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- oidcProvider.ID = uuid.New().String()
- if oidcProvider.Name == "" {
- c.returnError(w, fmt.Errorf("name not set"), http.StatusBadRequest)
- return
- }
- if oidcProvider.ClientID == "" {
- c.returnError(w, fmt.Errorf("clientID not set"), http.StatusBadRequest)
- return
- }
- if oidcProvider.ClientSecret == "" {
- c.returnError(w, fmt.Errorf("clientSecret not set"), http.StatusBadRequest)
- return
- }
- if oidcProvider.Scope == "" {
- c.returnError(w, fmt.Errorf("scope not set"), http.StatusBadRequest)
- return
- }
- if oidcProvider.DiscoveryURI == "" {
- c.returnError(w, fmt.Errorf("discovery URL not set"), http.StatusBadRequest)
- return
- }
- oidcProvider.RedirectURI = "/callback/oidc/" + oidcProvider.ID
- c.OIDCProviders = append(c.OIDCProviders, oidcProvider)
- out, err := json.Marshal(oidcProvider)
- if err != nil {
- c.returnError(w, fmt.Errorf("oidcProvider marshal error: %s", err), http.StatusBadRequest)
- return
- }
- err = SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- return
- }
-}
-
-func (c *Context) oidcProviderElementHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodDelete:
- match := -1
- for k, oidcProvider := range c.OIDCProviders {
- if oidcProvider.ID == r.PathValue("id") {
- match = k
- }
- }
- if match == -1 {
- c.returnError(w, fmt.Errorf("oidc provider not found"), http.StatusBadRequest)
- return
- }
- c.OIDCProviders = append(c.OIDCProviders[:match], c.OIDCProviders[match+1:]...)
- // save config (changed providers)
- err := SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, []byte(`{ "deleted": "`+r.PathValue("id")+`" }`))
- }
-}
-
-func (c *Context) authMethods(w http.ResponseWriter, r *http.Request) {
- response := AuthMethodsResponse{
- LocalAuthDisabled: c.LocalAuthDisabled,
- OIDCProviders: make([]AuthMethodsProvider, len(c.OIDCProviders)),
- }
- for k, oidcProvider := range c.OIDCProviders {
- response.OIDCProviders[k] = AuthMethodsProvider{
- ID: oidcProvider.ID,
- Name: oidcProvider.Name,
- }
- }
-
- out, err := json.Marshal(response)
- if err != nil {
- c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
-}
-func (c *Context) authMethodsByID(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodPost:
- switch r.PathValue("method") {
- case "saml":
- loginResponse := login.LoginResponse{}
- var samlCallback SAMLCallback
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&samlCallback)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- if samlCallback.Code == "" {
- c.returnError(w, fmt.Errorf("no code provided"), http.StatusBadRequest)
- return
- }
- var samlProvider saml.Provider
- for k := range *c.SAML.Providers {
- if r.PathValue("id") == (*c.SAML.Providers)[k].ID {
- samlProvider = (*c.SAML.Providers)[k]
- }
- }
- if samlProvider.ID == "" {
- c.returnError(w, fmt.Errorf("saml provider not found"), http.StatusBadRequest)
- return
- }
-
- samlSession, err := c.SAML.Client.GetAuthenticatedUser(samlProvider, samlCallback.Code)
- if err != nil {
- c.returnError(w, fmt.Errorf("saml session not found"), http.StatusBadRequest)
- return
- }
-
- // add user to the user database (or modify existing one)
- user, err := addOrModifyExternalUser(c.Storage.Client, c.UserStore, samlSession.Login, "saml", samlSession.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("couldn't add/modify user in database: %s", err), http.StatusBadRequest)
- return
- }
-
- if user.Suspended {
- loginResponse.Suspended = true
- }
-
- token, err := login.GetJWTTokenWithExpiration(user.Login, user.Role, c.JWTKeys.PrivateKey, c.JWTKeysKID, samlSession.ExpiresAt)
- if err != nil {
- c.returnError(w, fmt.Errorf("token generation failed: %s", err), http.StatusBadRequest)
- return
- }
- loginResponse.Authenticated = true
- loginResponse.Token = token
-
- out, err := json.Marshal(loginResponse)
- if err != nil {
- c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- return
- default: // oidc is default
- loginResponse := login.LoginResponse{}
- var oidcCallback OIDCCallback
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&oidcCallback)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- for _, oidcProvider := range c.OIDCProviders {
- if r.PathValue("id") == oidcProvider.ID && oidcCallback.Code != "" { // we got the code back
- oidcstore.RetrieveTokenLock.Lock()
- defer oidcstore.RetrieveTokenLock.Unlock()
- oauth2data, err := oidc.RetrieveOAUth2DataUsingState(c.OIDCStore.OAuth2Data, oidcCallback.State) // get the oauth2 struct based on the state (key)
- if err != nil {
- c.returnError(w, fmt.Errorf("cannot find oauth2 data using state provided: %s", err), http.StatusBadRequest)
- return
- }
- if oauth2data.Token.AccessToken != "" {
- if oauth2data.Suspended {
- loginResponse.Suspended = true
- } else if c.LicenseUserCount >= c.UserStore.UserCount() {
- loginResponse.NoLicense = true
- } else {
- loginResponse.Authenticated = true
- loginResponse.Token = oauth2data.Token.AccessToken
- }
- out, err := json.Marshal(loginResponse)
- if err != nil {
- c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- return
- }
- // no token, let's generate a new one
- discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI)
- if err != nil {
- c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest)
- return
- }
- jwks, err := c.OIDCStore.GetJwks(discovery.JwksURI)
- if err != nil {
- c.returnError(w, fmt.Errorf("get jwks error: %s", err), http.StatusBadRequest)
- return
- }
- updatedOauth2data, err := oidc.UpdateOAuth2DataWithToken(jwks, discovery, oidcProvider.ClientID, oidcProvider.ClientSecret, c.Protocol+"://"+c.Hostname+oidcCallback.RedirectURI, oidcCallback.Code, oidcCallback.State, oauth2data)
- if err != nil {
- c.returnError(w, fmt.Errorf("GetTokenFromCode error: %s", err), http.StatusBadRequest)
- return
- }
- // add user to the user database (or modify existing one)
- user, err := addOrModifyExternalUser(c.Storage.Client, c.UserStore, updatedOauth2data.UserInfo.Email, "oidc", updatedOauth2data.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("couldn't add/modify user in database: %s", err), http.StatusBadRequest)
- return
- }
- if user.Suspended {
- loginResponse.Suspended = true
- updatedOauth2data.Suspended = true
- } else {
- updatedOauth2data.Suspended = false
- }
- // save oauth data (only when we're sure it's not a suspended user)
- err = c.OIDCStore.SaveOAuth2Data(updatedOauth2data, oidcCallback.State)
- if err != nil {
- c.returnError(w, fmt.Errorf("oidc store save failed: %s", err), http.StatusBadRequest)
- return
- }
- // cleanup oauth2 data
- c.OIDCStore.CleanupOAuth2Data(updatedOauth2data)
-
- // save config (changed user info)
- err = SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest)
- return
- }
-
- // set loginResponse
- if !loginResponse.Suspended {
- loginResponse.Authenticated = true
- loginResponse.Token = updatedOauth2data.Token.AccessToken
- }
- out, err := json.Marshal(loginResponse)
- if err != nil {
- c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- return
- }
- }
- }
- c.returnError(w, fmt.Errorf("oidc provider not found"), http.StatusBadRequest)
- case http.MethodGet:
- switch r.PathValue("method") {
- case "saml":
- id := r.PathValue("id")
- samlProviderId := -1
- for k := range *c.SAML.Providers {
- if (*c.SAML.Providers)[k].ID == id {
- samlProviderId = k
- }
- }
- if samlProviderId == -1 {
- c.returnError(w, fmt.Errorf("cannot find saml provider"), http.StatusBadRequest)
- return
- }
- redirectURI, err := c.SAML.Client.GetAuthURL((*c.SAML.Providers)[samlProviderId])
- if err != nil {
- c.returnError(w, fmt.Errorf("cannot get auth url"), http.StatusBadRequest)
- return
- }
- response := AuthMethodsProvider{
- ID: (*c.SAML.Providers)[samlProviderId].ID,
- Name: (*c.SAML.Providers)[samlProviderId].Name,
- RedirectURI: redirectURI,
- }
- out, err := json.Marshal(response)
- if err != nil {
- c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- return
- default:
- id := r.PathValue("id")
- for _, oidcProvider := range c.OIDCProviders {
- if id == oidcProvider.ID {
- callback := fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProvider.RedirectURI)
- discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI)
- if err != nil {
- c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest)
- return
- }
- redirectURI, state, err := oidc.GetRedirectURI(discovery, oidcProvider.ClientID, oidcProvider.Scope, callback, c.EnableOIDCTokenRenewal)
- if err != nil {
- c.returnError(w, fmt.Errorf("GetRedirectURI error: %s", err), http.StatusBadRequest)
- return
- }
- response := AuthMethodsProvider{
- ID: oidcProvider.ID,
- Name: oidcProvider.Name,
- RedirectURI: redirectURI,
- }
- out, err := json.Marshal(response)
- if err != nil {
- c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest)
- return
- }
- newOAuthEntry := oidc.OAuthData{
- ID: uuid.NewString(),
- OIDCProviderID: response.ID,
- CreatedAt: time.Now(),
- }
- err = c.OIDCStore.SaveOAuth2Data(newOAuthEntry, state)
- if err != nil {
- c.returnError(w, fmt.Errorf("unable to save state to oidc store: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- return
- }
- }
- c.returnError(w, fmt.Errorf("element not found"), http.StatusBadRequest)
- }
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) oidcRenewTokensHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodPost:
- c.OIDCRenewal.RenewAllOIDCConnections()
- c.write(w, []byte(`{"status": "done"}`))
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
diff --git a/pkg/rest/auth_test.go b/pkg/rest/auth_test.go
deleted file mode 100644
index 848e8f3..0000000
--- a/pkg/rest/auth_test.go
+++ /dev/null
@@ -1,946 +0,0 @@
-package rest
-
-import (
- "bytes"
- "compress/flate"
- "crypto/rand"
- "crypto/rsa"
- "encoding/base64"
- "encoding/json"
- "encoding/xml"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- "github.com/in4it/wireguard-server/pkg/auth/saml"
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/rest/login"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/russellhaering/gosaml2/types"
- dsigtypes "github.com/russellhaering/goxmldsig/types"
-)
-
-func getSAMLCertWithCustomCert(singleSignOnURL string, cert string) *types.EntityDescriptor {
- return &types.EntityDescriptor{
- EntityID: "https://www.idp.inv/metadata",
- IDPSSODescriptor: &types.IDPSSODescriptor{
- SingleSignOnServices: []types.SingleSignOnService{
- {
- Location: singleSignOnURL,
- },
- },
- KeyDescriptors: []types.KeyDescriptor{
- {
- KeyInfo: dsigtypes.KeyInfo{
- X509Data: dsigtypes.X509Data{
- X509Certificates: []dsigtypes.X509Certificate{
- {
- Data: cert,
- },
- },
- },
- },
- },
- },
- },
- }
-}
-func getSAMLCert(singleSignOnURL string) *types.EntityDescriptor {
- cert := `MIID2jCCA0MCAg39MA0GCSqGSIb3DQEBBQUAMIGbMQswCQYDVQQGEwJKUDEOMAwG
-A1UECBMFVG9reW8xEDAOBgNVBAcTB0NodW8ta3UxETAPBgNVBAoTCEZyYW5rNERE
-MRgwFgYDVQQLEw9XZWJDZXJ0IFN1cHBvcnQxGDAWBgNVBAMTD0ZyYW5rNEREIFdl
-YiBDQTEjMCEGCSqGSIb3DQEJARYUc3VwcG9ydEBmcmFuazRkZC5jb20wHhcNMTIw
-ODIyMDUyODAwWhcNMTcwODIxMDUyODAwWjBKMQswCQYDVQQGEwJKUDEOMAwGA1UE
-CAwFVG9reW8xETAPBgNVBAoMCEZyYW5rNEREMRgwFgYDVQQDDA93d3cuZXhhbXBs
-ZS5jb20wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCwvWITOLeyTbS1
-Q/UacqeILIK16UHLvSymIlbbiT7mpD4SMwB343xpIlXN64fC0Y1ylT6LLeX4St7A
-cJrGIV3AMmJcsDsNzgo577LqtNvnOkLH0GojisFEKQiREX6gOgq9tWSqwaENccTE
-sAXuV6AQ1ST+G16s00iN92hjX9V/V66snRwTsJ/p4WRpLSdAj4272hiM19qIg9zr
-h92e2rQy7E/UShW4gpOrhg2f6fcCBm+aXIga+qxaSLchcDUvPXrpIxTd/OWQ23Qh
-vIEzkGbPlBA8J7Nw9KCyaxbYMBFb1i0lBjwKLjmcoihiI7PVthAOu/B71D2hKcFj
-Kpfv4D1Uam/0VumKwhwuhZVNjLq1BR1FKRJ1CioLG4wCTr0LVgtvvUyhFrS+3PdU
-R0T5HlAQWPMyQDHgCpbOHW0wc0hbuNeO/lS82LjieGNFxKmMBFF9lsN2zsA6Qw32
-Xkb2/EFltXCtpuOwVztdk4MDrnaDXy9zMZuqFHpv5lWTbDVwDdyEQNclYlbAEbDe
-vEQo/rAOZFl94Mu63rAgLiPeZN4IdS/48or5KaQaCOe0DuAb4GWNIQ42cYQ5TsEH
-Wt+FIOAMSpf9hNPjDeu1uff40DOtsiyGeX9NViqKtttaHpvd7rb2zsasbcAGUl+f
-NQJj4qImPSB9ThqZqPTukEcM/NtbeQIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAIAi
-gU3My8kYYniDuKEXSJmbVB+K1upHxWDA8R6KMZGXfbe5BRd8s40cY6JBYL52Tgqd
-l8z5Ek8dC4NNpfpcZc/teT1WqiO2wnpGHjgMDuDL1mxCZNL422jHpiPWkWp3AuDI
-c7tL1QjbfAUHAQYwmHkWgPP+T2wAv0pOt36GgMCM`
- return getSAMLCertWithCustomCert(singleSignOnURL, cert)
-}
-
-func TestAuthHandler(t *testing.T) {
- c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context: %s", err)
- }
- c.UserStore.Empty()
- _, err = c.UserStore.AddUser(users.User{
- Login: "john",
- Password: "mypass",
- })
- if err != nil {
- t.Fatalf("Cannot create user")
- }
-
- loginReq := login.LoginRequest{
- Login: "john",
- Password: "mypass",
- }
-
- payload, err := json.Marshal(loginReq)
- if err != nil {
- t.Fatal(err)
- }
-
- req := httptest.NewRequest("POST", "http://example.com/api/auth", bytes.NewBuffer(payload))
- w := httptest.NewRecorder()
- c.authHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- var loginResponse login.LoginResponse
-
- err = json.NewDecoder(resp.Body).Decode(&loginResponse)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
- if !loginResponse.Authenticated {
- t.Fatalf("expected to be authenticated")
- }
-
-}
-
-func TestNewSAMLConnection(t *testing.T) {
- // generate new keypair
- kp := saml.NewKeyPair(&memorystorage.MockMemoryStorage{}, "www.idp.inv")
- _, cert, err := kp.GetKeyPair()
- if err != nil {
- t.Fatalf("Can't generate new keypair: %s", err)
- }
- certBase64 := base64.StdEncoding.EncodeToString(cert)
-
- testUrl := "127.0.0.1:12347"
- l, err := net.Listen("tcp", testUrl)
- if err != nil {
- t.Fatal(err)
- }
-
- singleSignOnURL := "http://" + testUrl + "/auth"
- audienceURL := "http://" + testUrl + "/aud"
- login := "john@example.inv"
-
- ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- requestURIParsed, _ := url.Parse(r.RequestURI)
- if requestURIParsed.Path == "/auth" {
- compressedSAMLReq, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest"))
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte(fmt.Sprintf("saml base64 decode error: %s", err)))
- return
- }
- samlRequest := new(bytes.Buffer)
- decompressor := flate.NewReader(bytes.NewReader(compressedSAMLReq))
- io.Copy(samlRequest, decompressor)
- decompressor.Close()
-
- var authnReq saml.AuthnRequest
- err = xml.Unmarshal(samlRequest.Bytes(), &authnReq)
- if err != nil {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte(fmt.Sprintf("saml authn request decode error: %s", err)))
- return
- }
- w.Write([]byte("OK"))
- return
- }
- if r.RequestURI == "/metadata" {
- out, _ := xml.Marshal(getSAMLCertWithCustomCert(singleSignOnURL, certBase64))
- w.Write(out)
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- default:
- w.WriteHeader(http.StatusBadRequest)
- }
- }))
-
- ts.Listener.Close()
- ts.Listener = l
- ts.Start()
- defer ts.Close()
- defer l.Close()
-
- // first create a new user
- c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context")
- }
-
- // create a new SAML connection
- samlProvider := saml.Provider{
- Name: "testProvider",
- MetadataURL: fmt.Sprintf("%s/metadata", ts.URL),
- }
-
- payload, err := json.Marshal(samlProvider)
- if err != nil {
- t.Fatal(err)
- }
-
- req := httptest.NewRequest("POST", "http://example.inv/api/saml-setup", bytes.NewBuffer(payload))
- w := httptest.NewRecorder()
- c.samlSetupHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- err = json.NewDecoder(resp.Body).Decode(&samlProvider)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
- if samlProvider.ID == "" {
- t.Fatalf("Was expecting saml provider to have an ID")
- }
-
- authURL, err := c.SAML.Client.GetAuthURL(samlProvider)
- if err != nil {
- t.Fatalf("cannot get Auth URL from saml: %s", err)
- }
-
- if authURL == "" {
- t.Fatalf("authURL is empty")
- }
-
- resp, err = http.Get(authURL)
- if err != nil {
- t.Fatalf("http get auth url error: %s", err)
- }
- if resp.StatusCode != 200 {
- t.Errorf("auth url get not status 200: %d", resp.StatusCode)
- }
- _, err = io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("body read error: %s", err)
- }
-
- // check SAML POST flow
- tsSAML := httptest.NewServer(c.SAML.Client.GetRouter())
- defer tsSAML.Close()
-
- // build the SAML response
- // example
- /*
-
-
- https://app.onelogin.com/saml/metadata/onelogin-id
-
-
-
-
-
-
-
-
-
-
- 5eB3C+2/vwdigestvalue
-
-
- sigvalue
-
-
- MIIGETCCA/mgAcertt
-
-
-
-
-
-
-
- https://app.onelogin.com/saml/metadata/onelogin-id
-
- ward@in4it.io
-
-
-
-
-
-
- https://vpn-server.in4it.io/saml/aud/provider-id
-
-
-
-
- urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport
-
-
-
-
- */
- samlResponse := saml.Response{
- Saml: "urn:oasis:names:tc:SAML:2.0:assertion",
- Samlp: "urn:oasis:names:tc:SAML:2.0:protocol",
- ID: "pfx181c1b93-cce6-a6c4-f1f7-99e539374d15",
- Version: "2.0",
- IssueInstant: time.Now().Format(time.RFC3339),
- Destination: "http://example.inv/saml/acs/" + samlProvider.ID,
- Issuer: ts.URL + "/metadata",
- Signature: saml.ResponseSignature{
- Ds: "http://www.w3.org/2000/09/xmldsig#",
- SignedInfo: saml.ResponseSignatureSignedInfo{
- CanonicalizationMethod: struct {
- Text string "xml:\",chardata\""
- Algorithm string "xml:\"Algorithm,attr\""
- }{
- Algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#",
- },
- SignatureMethod: saml.ResponseSignatureSignedInfoSignatureMethod{
- Algorithm: "http://www.w3.org/2000/09/xmldsig#rsa-sha1",
- },
- Reference: saml.ResponseSignatureSignedInfoReference{
- Transforms: struct {
- Text string "xml:\",chardata\""
- Transform []struct {
- Text string "xml:\",chardata\""
- Algorithm string "xml:\"Algorithm,attr\""
- } "xml:\"ds:Transform\""
- }{
- Transform: []struct {
- Text string "xml:\",chardata\""
- Algorithm string "xml:\"Algorithm,attr\""
- }{
- {
- Algorithm: "http://www.w3.org/2000/09/xmldsig#enveloped-signature",
- },
- {
- Algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#",
- },
- },
- },
- DigestMethod: struct {
- Text string "xml:\",chardata\""
- Algorithm string "xml:\"Algorithm,attr\""
- }{
- Algorithm: "http://www.w3.org/2000/09/xmldsig#sha1",
- },
- DigestValue: "thisisthesignature",
- },
- },
- KeyInfo: saml.ResponseSignatureKeyInfo{
- X509Data: struct {
- Text string "xml:\",chardata\""
- X509Certificate string "xml:\"ds:X509Certificate\""
- }{
- X509Certificate: certBase64,
- },
- },
- },
- Assertion: saml.ResponseAssertion{
- Subject: saml.ResponseSubject{
- NameID: struct {
- Text string "xml:\",chardata\""
- Format string "xml:\"Format,attr\""
- }{
- Text: login,
- Format: "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
- },
- },
- Conditions: saml.ResponseConditions{
- NotBefore: time.Now().Format(time.RFC3339),
- NotOnOrAfter: time.Now().Add(10 * time.Minute).Format(time.RFC3339),
- AudienceRestriction: saml.ResponseConditionsAdienceRestriction{
- Audience: audienceURL,
- },
- },
- },
- }
- samlResponseBytes, err := xml.Marshal(samlResponse)
- if err != nil {
- t.Fatalf("xml marshal error: %s", err)
- }
- //fmt.Printf("saml respons bytes: %s\n", samlResponseBytes)
- samlResponseBytesDeflated := new(bytes.Buffer)
- compressor, err := flate.NewWriter(samlResponseBytesDeflated, 1)
- if err != nil {
- t.Fatalf("deflate error: %s", err)
- }
- io.Copy(compressor, bytes.NewBuffer(samlResponseBytes))
- compressor.Close()
-
- samlResponseEncoded := base64.StdEncoding.EncodeToString(samlResponseBytesDeflated.Bytes())
-
- form := url.Values{}
- form.Add("SAMLResponse", samlResponseEncoded)
-
- resp, err = http.Post(tsSAML.URL+"/saml/acs/"+samlProvider.ID, "application/x-www-form-urlencoded", strings.NewReader(form.Encode()))
- if err != nil {
- t.Fatalf("http post acs url error: %s", err)
- }
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("body read error: %s", err)
- }
-
- if strings.Contains(string(body), "provider not found") {
- t.Fatalf("provider not found: body output: %s", body)
- }
-
- // currently does't authenticate because of missing signatures, but we checked if the auth process kicked off
- /*if resp.StatusCode != 200 {
- t.Errorf("auth url get not status 200: %d", resp.StatusCode)
- }*/
-
-}
-func TestAddModifyDeleteNewSAMLConnection(t *testing.T) {
- c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context")
- }
- c.Hostname = "example.inv"
- c.Protocol = "https"
-
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata" {
- out, err := xml.Marshal(getSAMLCert("http://localhost.inv"))
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- w.Write(out)
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- }))
-
- samlProvider := saml.Provider{
- Name: "testProvider",
- MetadataURL: fmt.Sprintf("%s/metadata", ts.URL),
- }
-
- payload, err := json.Marshal(samlProvider)
- if err != nil {
- t.Fatal(err)
- }
-
- req := httptest.NewRequest("POST", "http://example.com/api/saml-setup", bytes.NewBuffer(payload))
- w := httptest.NewRecorder()
- c.samlSetupHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- err = json.NewDecoder(resp.Body).Decode(&samlProvider)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
- if samlProvider.ID == "" {
- t.Fatalf("samlprovider id is empty")
- }
-
- // GET authmethods and see if provider exists
- req = httptest.NewRequest("GET", "http://example.com/authmethods/saml/"+samlProvider.ID, nil)
- req.SetPathValue("method", "saml")
- req.SetPathValue("id", samlProvider.ID)
- w = httptest.NewRecorder()
- c.authMethodsByID(w, req)
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- var authMethodsProvider AuthMethodsProvider
-
- err = json.NewDecoder(resp.Body).Decode(&authMethodsProvider)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
- if authMethodsProvider.ID != samlProvider.ID {
- t.Fatalf("authmethods provider id is different than saml provider id: %s vs %s. authMethodsProvider: %+v", authMethodsProvider.ID, samlProvider.ID, authMethodsProvider)
- }
-
- // PUT req
- samlProvider.AllowMissingAttributes = true
- payload, err = json.Marshal(samlProvider)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- req = httptest.NewRequest("PUT", "http://example.com/saml-setup/"+samlProvider.ID, bytes.NewBuffer(payload))
- req.SetPathValue("id", samlProvider.ID)
- w = httptest.NewRecorder()
- c.samlSetupElementHandler(w, req)
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- err = json.NewDecoder(resp.Body).Decode(&samlProvider)
- if err != nil {
- t.Fatalf("marshal decode error: %s", err)
- }
-
- if samlProvider.AllowMissingAttributes == false {
- t.Fatalf("allow missing attributes is false")
- }
-
- // GET on the saml endpoint to see if we can return it
- req = httptest.NewRequest("GET", "http://example.com/saml-setup", nil)
- w = httptest.NewRecorder()
- c.samlSetupHandler(w, req)
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- var samlProviders []saml.Provider
- err = json.NewDecoder(resp.Body).Decode(&samlProviders)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
- if len(samlProviders) == 0 {
- t.Fatalf("samlProviders is zero length")
- }
- if samlProviders[len(samlProviders)-1].ID != samlProvider.ID {
- t.Fatalf("ID doesn't match: %s vs %s ", samlProviders[len(samlProviders)-1].ID, samlProvider.ID)
- }
- if samlProviders[len(samlProviders)-1].Acs != fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ACS_URL, samlProvider.ID) {
- t.Fatalf("ACS doesn't match")
- }
- if samlProviders[len(samlProviders)-1].AllowMissingAttributes == false {
- t.Fatalf("allow missing attributes is false when getting all samlproviders")
- }
-
- // delete req
- req = httptest.NewRequest("DELETE", "http://example.com/saml-setup/"+samlProvider.ID, bytes.NewBuffer(payload))
- req.SetPathValue("id", samlProvider.ID)
- w = httptest.NewRecorder()
- c.samlSetupElementHandler(w, req)
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- // list to see if really deleted
- req = httptest.NewRequest("GET", "http://example.com/saml-setup", nil)
- w = httptest.NewRecorder()
- c.samlSetupHandler(w, req)
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- var samlProviders2 []saml.Provider
- err = json.NewDecoder(resp.Body).Decode(&samlProviders2)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
- if len(samlProviders)-1 != len(samlProviders2) {
- t.Fatalf("samlProviders has wrong length")
- }
-
-}
-
-func TestSAMLCallback(t *testing.T) {
- c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context")
- }
- c.Hostname = "example.inv"
- c.Protocol = "https"
-
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata" {
- out, err := xml.Marshal(getSAMLCert("http://localhost.inv"))
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- w.Write(out)
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- }))
-
- samlProvider := saml.Provider{
- Name: "testProvider",
- MetadataURL: fmt.Sprintf("%s/metadata", ts.URL),
- }
-
- payload, err := json.Marshal(samlProvider)
- if err != nil {
- t.Fatal(err)
- }
-
- req := httptest.NewRequest("POST", "http://example.com/api/saml-setup", bytes.NewBuffer(payload))
- w := httptest.NewRecorder()
- c.samlSetupHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- err = json.NewDecoder(resp.Body).Decode(&samlProvider)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
- if samlProvider.ID == "" {
- t.Fatalf("samlprovider id is empty")
- }
- samlCallback := SAMLCallback{
- Code: "abc",
- RedirectURI: "https://localhost.inv/something",
- }
- payload, err = json.Marshal(samlCallback)
- if err != nil {
- t.Fatal(err)
- }
-
- c.SAML.Client.CreateSession(saml.SessionKey{ProviderID: samlProvider.ID, SessionID: "abc"}, saml.AuthenticatedUser{ID: "123", Login: "john@example.com", ExpiresAt: time.Now().AddDate(0, 0, 1)})
-
- req = httptest.NewRequest("POST", "http://example.com/api/authmethods/saml/"+samlProvider.ID, bytes.NewBuffer(payload))
- req.SetPathValue("method", "saml")
- req.SetPathValue("id", samlProvider.ID)
- w = httptest.NewRecorder()
- c.authMethodsByID(w, req)
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- var loginResponse login.LoginResponse
- err = json.NewDecoder(resp.Body).Decode(&loginResponse)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
- if !loginResponse.Authenticated {
- t.Fatalf("Expected to be authenticated")
- }
-
-}
-
-func TestOIDCFlow(t *testing.T) {
- testUrl := "127.0.0.1:12346"
- l, err := net.Listen("tcp", testUrl)
- if err != nil {
- t.Fatal(err)
- }
-
- authURL := "http://" + testUrl + "/auth"
-
- // create a new OIDC connection
- oidcProvider := oidc.OIDCProvider{
- Name: "test-oidc",
- ClientID: "1-2-3-4",
- ClientSecret: "9-9-9-9",
- Scope: "openid",
- DiscoveryURI: "http://" + testUrl + "/discovery.json",
- }
- jwtPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096)
- if err != nil {
- t.Fatalf("can't generate jwt key: %s", err)
- }
-
- // first create a new user
- c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context")
- }
- c.Hostname = "example.inv"
- c.Protocol = "http"
- logging.Loglevel = 17
-
- ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- code := "thisisthecode"
-
- switch r.Method {
- case http.MethodGet:
- parsedURI, _ := url.Parse(r.RequestURI)
- switch parsedURI.Path {
- case "/discovery.json":
- discovery := oidc.Discovery{
- Issuer: "test-issuer",
- AuthorizationEndpoint: authURL,
- TokenEndpoint: "http://" + testUrl + "/token",
- JwksURI: "http://" + testUrl + "/jwks.json",
- }
- out, err := json.Marshal(discovery)
- if err != nil {
- t.Fatalf("json marshal error: %s", err)
- }
- w.Write(out)
- return
- case "/auth":
- if oidcProvider.ClientID != r.URL.Query().Get("client_id") {
- w.Write([]byte("client id mismatch"))
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- if oidcProvider.Scope != r.URL.Query().Get("scope") {
- w.Write([]byte("scope mismatch"))
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- w.Write([]byte(code))
- case "/jwks.json":
- publicKey := jwtPrivateKey.PublicKey
-
- jwks := oidc.Jwks{
- Keys: []oidc.JwksKey{
- {
- Kid: "kid-id-1234",
- Alg: "RS256",
- Kty: "RSA",
- Use: "sig",
- N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()),
- E: "AQAB",
- },
- },
- }
- out, err := json.Marshal(jwks)
- if err != nil {
- w.Write([]byte("jwks marshal error"))
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- w.Write(out)
- default:
- w.WriteHeader(http.StatusNotFound)
- }
- case http.MethodPost:
- parsedURI, _ := url.Parse(r.RequestURI)
- switch parsedURI.Path {
- case "/token":
- if r.FormValue("grant_type") != "authorization_code" {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte("wrong grant type"))
- return
- }
- if r.FormValue("code") != code {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte("wrong code"))
- return
- }
- if oidcProvider.ClientID != r.FormValue("client_id") {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte("client id mismatch"))
- return
- }
- if oidcProvider.ClientSecret != r.FormValue("client_secret") {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte("client secret mismatch"))
- return
- }
- if c.Protocol+"://"+c.Hostname+oidcProvider.RedirectURI != r.FormValue("redirect_uri") {
- w.WriteHeader(http.StatusBadRequest)
- w.Write([]byte(fmt.Sprintf("redirect uri mismatch: %s vs %s", oidcProvider.RedirectURI, r.FormValue("redirect_uri"))))
- return
- }
- token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), jwt.MapClaims{
- "iss": "test-server",
- "sub": "john",
- "email": "john@example.inv",
- "role": "user",
- "exp": time.Now().AddDate(0, 0, 1).Unix(),
- "iat": time.Now().Unix(),
- })
- token.Header["kid"] = "kid-id-1234"
-
- tokenString, err := token.SignedString(jwtPrivateKey)
- if err != nil {
- t.Fatalf("can't generate jwt token: %s", err)
- w.WriteHeader(http.StatusBadRequest)
- }
- tokenRes := oidc.Token{
- AccessToken: tokenString,
- IDToken: tokenString,
- ExpiresIn: 180,
- }
- tokenBytes, _ := json.Marshal(tokenRes)
-
- w.Write([]byte(tokenBytes))
- default:
- w.WriteHeader(http.StatusNotFound)
- }
- default:
- w.WriteHeader(http.StatusBadRequest)
- }
- }))
-
- ts.Listener.Close()
- ts.Listener = l
- ts.Start()
- defer ts.Close()
- defer l.Close()
-
- payload, err := json.Marshal(oidcProvider)
- if err != nil {
- t.Fatal(err)
- }
-
- // create new oidc provider
- req := httptest.NewRequest("POST", "http://example.inv/api/oidc", bytes.NewBuffer(payload))
- w := httptest.NewRecorder()
- c.oidcProviderHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- err = json.NewDecoder(resp.Body).Decode(&oidcProvider)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
- if oidcProvider.ID == "" {
- t.Fatalf("Was expecting oidc provider to have an ID")
- }
-
- // get redirect URL
- req = httptest.NewRequest("GET", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID, nil)
- req.SetPathValue("id", oidcProvider.ID)
- w = httptest.NewRecorder()
- c.authMethodsByID(w, req)
-
- resp = w.Result()
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- var authmethodsResponse AuthMethodsProvider
-
- err = json.NewDecoder(resp.Body).Decode(&authmethodsResponse)
-
- if err != nil {
- t.Fatalf("cannot decode authmethodsresponse: %s", err)
- }
- if !strings.HasPrefix(authmethodsResponse.RedirectURI, authURL) {
- t.Fatalf("expected authURL as prefix of redirect url. Redirect URL: %s", authmethodsResponse.RedirectURI)
- }
-
- redirectURIParsed, err := url.Parse(authmethodsResponse.RedirectURI)
- if err != nil {
- t.Fatalf("could not parse redirect URI: %s", err)
- }
- state := redirectURIParsed.Query().Get("state")
- if state == "" {
- t.Fatalf("could not obtain state")
- }
- res, err := http.Get(authmethodsResponse.RedirectURI)
- if err != nil {
- t.Fatalf("http get redirect uri error: %s", err)
- }
- if res.StatusCode != 200 {
- t.Fatalf("redirect uri statuscode not 200: %d", res.StatusCode)
- }
- code, err := io.ReadAll(res.Body)
- if err != nil {
- t.Fatalf("body read error: %s", err)
- }
-
- callback := OIDCCallback{
- Code: string(code),
- State: state,
- RedirectURI: oidcProvider.RedirectURI,
- }
- callbackPayload, err := json.Marshal(callback)
- if err != nil {
- t.Fatalf("callback marshal error: %s", err)
- }
- // execute callback
- req = httptest.NewRequest("POST", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID, bytes.NewBuffer(callbackPayload))
- req.SetPathValue("id", oidcProvider.ID)
- req.SetPathValue("method", "oidc")
- w = httptest.NewRecorder()
- c.authMethodsByID(w, req)
-
- resp = w.Result()
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- errorMessage, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("body read error after statuscode not 200 (%d): %s", resp.StatusCode, err)
- }
- t.Fatalf("status code is not 200: %d, errormessage: %s", resp.StatusCode, errorMessage)
- }
-
- var loginResponse login.LoginResponse
-
- err = json.NewDecoder(resp.Body).Decode(&loginResponse)
- if err != nil {
- t.Fatalf("cannot decode login response: %s", err)
- }
-
- if !loginResponse.Authenticated {
- t.Fatalf("not authenticated: %+v", loginResponse)
- }
- if loginResponse.Token == "" {
- t.Fatalf("no token received: %+v", loginResponse)
- }
-}
diff --git a/pkg/rest/config.go b/pkg/rest/config.go
deleted file mode 100644
index 18b6d68..0000000
--- a/pkg/rest/config.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package rest
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "os/user"
- "sync"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-var mu sync.Mutex
-
-func SaveConfig(c *Context) error {
- mu.Lock()
- defer mu.Unlock()
- cCopy := *c
- cCopy.SCIM = &SCIM{ // we don't save the client, but we want the token and enabled
- EnableSCIM: c.SCIM.EnableSCIM,
- Token: c.SCIM.Token,
- }
- cCopy.SAML = &SAML{ // we don't save the client, but we want the config
- Providers: c.SAML.Providers,
- }
- cCopy.JWTKeys = nil // we retrieve JWTKeys from pem files at startup
- cCopy.OIDCStore = nil // we save this separately
- cCopy.UserStore = nil // we save this separately
- cCopy.OIDCRenewal = nil // we don't save this
- cCopy.LoginAttempts = nil // no need to save this
- cCopy.Observability = nil // no need to save the client
- cCopy.Storage = nil // no need to save storage
- out, err := json.Marshal(cCopy)
- if err != nil {
- return fmt.Errorf("context marshal error: %s", err)
- }
- err = c.Storage.Client.WriteFile(c.Storage.Client.ConfigPath("config.json"), out)
- if err != nil {
- return fmt.Errorf("config write error: %s", err)
- }
- // fix permissions
- currentUser, err := user.Current()
- if err != nil {
- return fmt.Errorf("could not get current user: %s", err)
- }
- if currentUser.Username != "vpn" {
- err = c.Storage.Client.EnsureOwnership(c.Storage.Client.ConfigPath("config.json"), "vpn")
- if err != nil {
- return fmt.Errorf("config write error: %s", err)
- }
- }
-
- return nil
-}
-
-func GetConfig(storage storage.Iface) (*Context, error) {
- var c *Context
-
- appDir := storage.GetPath()
-
- // check if config exists
- if !storage.FileExists(storage.ConfigPath("config.json")) {
- return getEmptyContext(appDir)
- }
-
- data, err := storage.ReadFile(storage.ConfigPath("config.json"))
- if err != nil {
- return c, fmt.Errorf("config read error: %s", err)
- }
- decoder := json.NewDecoder(bytes.NewBuffer(data))
- err = decoder.Decode(&c)
- if err != nil {
- return c, fmt.Errorf("decode input error: %s", err)
- }
-
- c.AppDir = appDir
-
- return c, nil
-}
diff --git a/pkg/rest/constants.go b/pkg/rest/constants.go
deleted file mode 100644
index 48fe858..0000000
--- a/pkg/rest/constants.go
+++ /dev/null
@@ -1,4 +0,0 @@
-package rest
-
-const SERVER_TYPE_OBSERVABILITY = "observability"
-const SERVER_TYPE_VPN = "vpn"
diff --git a/pkg/rest/context.go b/pkg/rest/context.go
deleted file mode 100644
index ba7a2a4..0000000
--- a/pkg/rest/context.go
+++ /dev/null
@@ -1,116 +0,0 @@
-package rest
-
-import (
- "fmt"
- "sync"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store"
- oidcrenewal "github.com/in4it/wireguard-server/pkg/auth/oidc/store/renewal"
- "github.com/in4it/wireguard-server/pkg/auth/provisioning/scim"
- "github.com/in4it/wireguard-server/pkg/auth/saml"
- "github.com/in4it/wireguard-server/pkg/license"
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/observability"
- "github.com/in4it/wireguard-server/pkg/rest/login"
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-var muClientDownload sync.Mutex
-
-func newContext(storage storage.Iface, serverType string) (*Context, error) {
- c, err := GetConfig(storage)
- if err != nil {
- return c, fmt.Errorf("getConfig error: %s", err)
- }
- c.ServerType = serverType
-
- c.Storage = &Storage{
- Client: storage,
- }
-
- c.JWTKeys, err = getJWTKeys(storage)
- if err != nil {
- return c, fmt.Errorf("getJWTKeys error: %s", err)
- }
- c.OIDCStore, err = oidcstore.NewStore(storage)
- if err != nil {
- return c, fmt.Errorf("getOIDCStore error: %s", err)
- }
- if c.OIDCProviders == nil {
- c.OIDCProviders = []oidc.OIDCProvider{}
- }
-
- c.LicenseUserCount, c.CloudType = license.GetMaxUsers(c.Storage.Client)
- go func() { // run license refresh
- logging.DebugLog(fmt.Errorf("starting license refresh in background (current licenses: %d, cloud type: %s)", c.LicenseUserCount, c.CloudType))
- for {
- time.Sleep(time.Hour * 24)
- newLicenseCount := license.RefreshLicense(storage, c.CloudType, c.LicenseUserCount)
- if newLicenseCount != c.LicenseUserCount {
- logging.InfoLog(fmt.Sprintf("License changed from %d users to %d users", c.LicenseUserCount, newLicenseCount))
- c.LicenseUserCount = newLicenseCount
- }
- }
- }()
- c.UserStore, err = users.NewUserStore(c.Storage.Client, c.LicenseUserCount)
- if err != nil {
- return c, fmt.Errorf("userstore initialization error: %s", err)
- }
-
- c.OIDCRenewal, err = oidcrenewal.NewRenewal(storage, c.TokenRenewalTimeMinutes, c.LogLevel, c.EnableOIDCTokenRenewal, c.OIDCStore, c.OIDCProviders, c.UserStore)
- if err != nil {
- return c, fmt.Errorf("oidcrenewal init error: %s", err)
- }
-
- if c.LoginAttempts == nil {
- c.LoginAttempts = make(login.Attempts)
- }
-
- if c.SCIM == nil {
- c.SCIM = &SCIM{
- Client: scim.New(storage, c.UserStore, ""),
- Token: "",
- EnableSCIM: false,
- }
- } else {
- c.SCIM.Client = scim.New(storage, c.UserStore, c.SCIM.Token)
- }
- if c.SAML == nil {
- providers := []saml.Provider{}
- c.SAML = &SAML{
- Client: saml.New(&providers, storage, &c.Protocol, &c.Hostname),
- Providers: &providers,
- }
- } else {
- c.SAML.Client = saml.New(c.SAML.Providers, storage, &c.Protocol, &c.Hostname)
- }
-
- if c.Observability == nil {
- c.Observability = &Observability{
- Client: observability.New(storage),
- }
- } else {
- c.Observability.Client = observability.New(storage)
- }
-
- return c, nil
-}
-
-func getEmptyContext(appDir string) (*Context, error) {
- randomString, err := oidc.GetRandomString(64)
- if err != nil {
- return nil, fmt.Errorf("couldn't generate random string for local kid")
- }
- c := &Context{
- AppDir: appDir,
- JWTKeysKID: randomString,
- TokenRenewalTimeMinutes: oidcrenewal.DEFAULT_RENEWAL_TIME_MINUTES,
- LogLevel: logging.LOG_ERROR,
- SCIM: &SCIM{EnableSCIM: false},
- SAML: &SAML{Providers: &[]saml.Provider{}},
- }
- return c, nil
-}
diff --git a/pkg/rest/helpers.go b/pkg/rest/helpers.go
deleted file mode 100644
index 8bfab34..0000000
--- a/pkg/rest/helpers.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package rest
-
-import (
- "encoding/base64"
- "encoding/json"
- "fmt"
- "net/http"
- "strings"
-)
-
-func (c *Context) returnError(w http.ResponseWriter, err error, statusCode int) {
- fmt.Println("========= ERROR =========")
- fmt.Printf("Error: %s\n", err)
- fmt.Println("=========================")
- sendCorsHeaders(w, "", c.Hostname, c.Protocol)
- w.WriteHeader(statusCode)
- w.Write([]byte(`{"error": "` + strings.Replace(err.Error(), `"`, `\"`, -1) + `"}`))
-}
-
-func (c *Context) write(w http.ResponseWriter, res []byte) {
- sendCorsHeaders(w, "", c.Hostname, c.Protocol)
- w.WriteHeader(http.StatusOK)
- w.Write(res)
-}
-func (c *Context) writeWithStatus(w http.ResponseWriter, res []byte, status int) {
- sendCorsHeaders(w, "", c.Hostname, c.Protocol)
- w.WriteHeader(status)
- w.Write(res)
-}
-
-func sendCorsHeaders(w http.ResponseWriter, headers string, hostname string, protocol string) {
- if hostname == "" {
- w.Header().Add("Access-Control-Allow-Origin", "*")
- } else {
- w.Header().Add("Access-Control-Allow-Origin", fmt.Sprintf("%s://%s", protocol, hostname))
- }
- w.Header().Add("Access-Control-allow-methods", "GET,HEAD,POST,PUT,OPTIONS,DELETE,PATCH")
- if headers != "" {
- w.Header().Add("Access-Control-Allow-Headers", headers)
- }
-}
-
-func isAlphaNumeric(str string) bool {
- for _, c := range str {
- if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) {
- return false
- }
- }
- return true
-}
-
-func getKidFromToken(token string) (string, error) {
- jwtSplit := strings.Split(token, ".")
- if len(jwtSplit) < 1 {
- return "", fmt.Errorf("token split < 1")
- }
- data, err := base64.RawURLEncoding.DecodeString(jwtSplit[0])
- if err != nil {
- return "", fmt.Errorf("could not base64 decode data part of jwt")
- }
- var header JwtHeader
- err = json.Unmarshal(data, &header)
- if err != nil {
- return "", fmt.Errorf("could not unmarshal jwt data")
- }
- return header.Kid, nil
-}
-
-func returnIndexOrNotFound(contents []byte) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if !strings.HasPrefix(r.URL.Path, "/api") {
- w.WriteHeader(http.StatusOK)
- w.Write(contents)
- } else {
- w.WriteHeader(http.StatusNotFound)
- w.Write([]byte("404 page not found\n"))
- }
- })
-}
diff --git a/pkg/rest/helpers_test.go b/pkg/rest/helpers_test.go
deleted file mode 100644
index cfe500d..0000000
--- a/pkg/rest/helpers_test.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package rest
-
-import "testing"
-
-func TestIsAlphaNumeric(t *testing.T) {
- if !isAlphaNumeric("abc123") {
- t.Errorf("expected alphanumeric")
- }
- if isAlphaNumeric("abc@") {
- t.Errorf("expected alphanumeric")
- }
-}
diff --git a/pkg/rest/license.go b/pkg/rest/license.go
deleted file mode 100644
index b353bbf..0000000
--- a/pkg/rest/license.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package rest
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
-
- "github.com/in4it/wireguard-server/pkg/license"
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-func (c *Context) licenseHandler(w http.ResponseWriter, r *http.Request) {
- if r.PathValue("action") == "get-more" {
- c.LicenseUserCount = license.RefreshLicense(c.Storage.Client, c.CloudType, c.LicenseUserCount)
- }
-
- currentUserCount := c.UserStore.UserCount()
- licenseResponse := LicenseResponse{LicenseUserCount: c.LicenseUserCount, CurrentUserCount: currentUserCount, CloudType: c.CloudType}
-
- if r.PathValue("action") == "get-more" {
- licenseResponse.Key = license.GetLicenseKey(c.Storage.Client, c.CloudType)
- }
-
- out, err := json.Marshal(licenseResponse)
- if err != nil {
- c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest)
- return
- }
- c.write(w, out)
-}
-func (c *Context) connectionLicenseHandler(w http.ResponseWriter, r *http.Request) {
- user := r.Context().Value(CustomValue("user")).(users.User)
- totalConnections, err := wireguard.GetConfigNumbers(c.Storage.Client, user.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("can't determine total connections: %s", err), http.StatusBadRequest)
- return
-
- }
- out, err := json.Marshal(ConnectionLicenseResponse{LicenseUserCount: c.LicenseUserCount, ConnectionCount: len(totalConnections)})
- if err != nil {
- c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest)
- return
- }
- c.write(w, out)
-}
diff --git a/pkg/rest/login/attempt.go b/pkg/rest/login/attempt.go
deleted file mode 100644
index 2fe2f0c..0000000
--- a/pkg/rest/login/attempt.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package login
-
-import (
- "sync"
- "time"
-)
-
-var mu sync.Mutex
-
-type Attempts map[string][]Attempt
-
-type Attempt struct {
- Timestamp time.Time
-}
-
-func ClearAttemptsForLogin(attempts Attempts, login string) {
- mu.Lock()
- defer mu.Unlock()
- attempts[login] = []Attempt{}
-}
-
-func RecordAttempt(attempts Attempts, login string) {
- mu.Lock()
- defer mu.Unlock()
- _, ok := attempts[login]
- if !ok {
- attempts[login] = []Attempt{}
- }
- attempts[login] = append(attempts[login], Attempt{Timestamp: time.Now()})
-}
-
-func CheckTooManyLogins(attempts Attempts, login string) bool {
- threeMinutes := 3 * time.Minute
- _, ok := attempts[login]
- if ok {
- loginAttempts := 0
- for _, loginAttempt := range attempts[login] {
- if time.Since(loginAttempt.Timestamp) <= threeMinutes {
- loginAttempts++
- }
- }
- if loginAttempts >= 3 {
- if len(attempts[login]) > 3 {
- index := len(attempts[login]) - 3
- attempts[login] = attempts[login][index:]
- }
- return true
- }
- }
- return false
-}
diff --git a/pkg/rest/login/auth.go b/pkg/rest/login/auth.go
deleted file mode 100644
index 5d0288e..0000000
--- a/pkg/rest/login/auth.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package login
-
-import (
- "crypto/rsa"
- "fmt"
-
- "github.com/in4it/wireguard-server/pkg/mfa/totp"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-func Authenticate(loginReq LoginRequest, authIface AuthIface, jwtPrivateKey *rsa.PrivateKey, jwtKeyID string) (LoginResponse, users.User, error) {
- loginResponse := LoginResponse{}
- user, auth := authIface.AuthUser(loginReq.Login, loginReq.Password)
- if auth && !user.Suspended {
- if len(user.Factors) == 0 { // authentication without MFA
- token, err := GetJWTToken(user.Login, user.Role, jwtPrivateKey, jwtKeyID)
- if err != nil {
- return loginResponse, user, fmt.Errorf("token generation failed: %s", err)
- }
- loginResponse.Authenticated = true
- loginResponse.Token = token
- } else {
- if loginReq.FactorResponse.Name == "" {
- loginResponse.Authenticated = false
- loginResponse.MFARequired = true
- for _, factor := range user.Factors {
- loginResponse.Factors = append(loginResponse.Factors, factor.Name)
- }
- } else {
- for _, factor := range user.Factors {
- if factor.Name == loginReq.FactorResponse.Name {
- ok, err := totp.Verify(factor.Secret, loginReq.FactorResponse.Code)
- if err != nil {
- return loginResponse, user, fmt.Errorf("MFA (totp) verify failed: %s", err)
- }
- if ok { // authentication with MFA
- token, err := GetJWTToken(user.Login, user.Role, jwtPrivateKey, jwtKeyID)
- if err != nil {
- return loginResponse, user, fmt.Errorf("token generation failed: %s", err)
- }
- loginResponse.Authenticated = true
- loginResponse.Token = token
- }
- }
- }
- }
- }
- }
- return loginResponse, user, nil
-}
diff --git a/pkg/rest/login/auth_test.go b/pkg/rest/login/auth_test.go
deleted file mode 100644
index 34e61be..0000000
--- a/pkg/rest/login/auth_test.go
+++ /dev/null
@@ -1,123 +0,0 @@
-package login
-
-import (
- "crypto/rand"
- "crypto/rsa"
- "encoding/base32"
- "testing"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/mfa/totp"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-type MockAuth struct {
- AuthUserUser users.User
- AuthUserResult bool
-}
-
-func (m *MockAuth) AuthUser(login string, password string) (users.User, bool) {
- return m.AuthUserUser, m.AuthUserResult
-}
-
-func TestAuthenticate(t *testing.T) {
- m := MockAuth{
- AuthUserUser: users.User{
- Login: "john",
- },
- AuthUserResult: true,
- }
- loginReq := LoginRequest{
- Login: "john",
- Password: "mypass",
- }
- privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
- if err != nil {
- t.Fatalf("private key error: %s", err)
- }
-
- loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID")
- if err != nil {
- t.Fatalf("authentication error: %s", err)
- }
- if !loginResp.Authenticated {
- t.Fatalf("expected to be authenticated")
- }
- if loginResp.Token == "" {
- t.Fatalf("no token")
- }
-}
-func TestAuthenticateMFANoToken(t *testing.T) {
- m := MockAuth{
- AuthUserUser: users.User{
- Login: "john",
- Factors: []users.Factor{
- {
- Name: "test-factor",
- Type: "test",
- Secret: "secret",
- },
- },
- },
- AuthUserResult: true,
- }
- loginReq := LoginRequest{
- Login: "john",
- Password: "mypass",
- }
- privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
- if err != nil {
- t.Fatalf("private key error: %s", err)
- }
-
- loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID")
- if err != nil {
- t.Fatalf("authentication error: %s", err)
- }
- if loginResp.Authenticated {
- t.Fatalf("expected not to be authenticated")
- }
- if len(loginResp.Factors) == 0 {
- t.Fatalf("expected to get factors")
- }
-}
-func TestAuthenticateMFAWithToken(t *testing.T) {
- secret := base32.StdEncoding.EncodeToString([]byte("secret"))
- m := MockAuth{
- AuthUserUser: users.User{
- Login: "john",
- Factors: []users.Factor{
- {
- Name: "test-factor",
- Type: "test",
- Secret: secret,
- },
- },
- },
- AuthUserResult: true,
- }
- token, err := totp.GetToken(secret, time.Now().Unix()/30)
- if err != nil {
- t.Fatalf("GetToken error: %s", err)
- }
- loginReq := LoginRequest{
- Login: "john",
- Password: "mypass",
- FactorResponse: FactorResponse{
- Name: "test-factor",
- Code: token,
- },
- }
- privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
- if err != nil {
- t.Fatalf("private key error: %s", err)
- }
-
- loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID")
- if err != nil {
- t.Fatalf("authentication error: %s", err)
- }
- if !loginResp.Authenticated {
- t.Fatalf("expected to be authenticated")
- }
-}
diff --git a/pkg/rest/login/jwt.go b/pkg/rest/login/jwt.go
deleted file mode 100644
index bba4e09..0000000
--- a/pkg/rest/login/jwt.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package login
-
-import (
- "crypto/rsa"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
-)
-
-func GetJWTToken(login, role string, signKey *rsa.PrivateKey, kid string) (string, error) {
- return GetJWTTokenWithExpiration(login, role, signKey, kid, time.Now().Add(time.Hour*72))
-}
-
-func GetJWTTokenWithExpiration(login, role string, signKey *rsa.PrivateKey, kid string, expiration time.Time) (string, error) {
- token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), jwt.MapClaims{
- "iss": "wireguard-server",
- "sub": login,
- "role": role,
- "exp": expiration.Unix(),
- "iat": time.Now().Unix(),
- })
- token.Header["kid"] = kid
-
- tokenString, err := token.SignedString(signKey)
-
- return tokenString, err
-}
diff --git a/pkg/rest/login/types.go b/pkg/rest/login/types.go
deleted file mode 100644
index bf8b76d..0000000
--- a/pkg/rest/login/types.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package login
-
-import "github.com/in4it/wireguard-server/pkg/users"
-
-type AuthIface interface {
- AuthUser(login string, password string) (users.User, bool)
-}
-
-type LoginRequest struct {
- Login string `json:"login"`
- Password string `json:"password"`
- FactorResponse FactorResponse `json:"factorResponse"`
-}
-
-type FactorResponse struct {
- Name string `json:"name"`
- Code string `json:"code"`
-}
-
-type LoginResponse struct {
- Authenticated bool `json:"authenticated"`
- Suspended bool `json:"suspended"`
- NoLicense bool `json:"noLicense"`
- Token string `json:"token,omitempty"`
- MFARequired bool `json:"mfaRequired"`
- Factors []string `json:"factors"`
-}
diff --git a/pkg/rest/middleware.go b/pkg/rest/middleware.go
deleted file mode 100644
index 72b5b94..0000000
--- a/pkg/rest/middleware.go
+++ /dev/null
@@ -1,182 +0,0 @@
-package rest
-
-import (
- "context"
- "fmt"
- "log"
- "net/http"
- "strings"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-type CustomValue string
-
-// auth middleware
-
-func (c *Context) authMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if !c.SetupCompleted {
- c.returnError(w, fmt.Errorf("setup not completed"), http.StatusUnauthorized)
- return
- }
- if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
- c.writeWithStatus(w, []byte(`{"error": "token not found"}`), http.StatusUnauthorized)
- return
- }
- tokenString := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", -1)
- if len(tokenString) == 0 {
- c.returnError(w, fmt.Errorf("empty token"), http.StatusUnauthorized)
- return
- }
-
- // determine token to parse
- var tokenToParse string
- // is token an access token or a jwt from local auth?
- kid, _ := getKidFromToken(tokenString)
- if kid == c.JWTKeysKID { // local auth token
- tokenToParse = tokenString
- } else {
- for _, oauth2Data := range c.OIDCStore.OAuth2Data {
- if oauth2Data.Token.AccessToken == tokenString {
- tokenToParse = oauth2Data.Token.IDToken
- }
- }
- if tokenToParse == "" {
- c.returnError(w, fmt.Errorf("token error: access token not found (wrong token or token expired)"), http.StatusUnauthorized)
- return
- }
- }
- token, err := jwt.Parse(tokenToParse, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Header["kid"]; !ok {
- return nil, fmt.Errorf("no kid header found in token")
- }
- if token.Header["kid"] == c.JWTKeysKID {
- if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
- return nil, fmt.Errorf("local kid: unexpected signing method: %v", token.Header["alg"])
- }
- return c.JWTKeys.PublicKey, nil
- }
- discoveryProviders := make([]oidc.Discovery, len(c.OIDCProviders))
- for k, oidcProvider := range c.OIDCProviders {
- discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI)
- if err != nil {
- return nil, fmt.Errorf("couldn't retrieve discoveryURI from OIDC Provider (check discovery URI in OIDC settings). Error: %s", err)
- }
- discoveryProviders[k] = discovery
- }
- allJwks, err := c.OIDCStore.GetAllJwks(discoveryProviders)
- if err != nil {
- return nil, fmt.Errorf("couldn't retrieve JWKS URL from OIDC Provider (check discovery URI in OIDC settings). Error: %s", err)
- }
- publicKey, err := oidc.GetPublicKeyForToken(allJwks, discoveryProviders, token)
- if err != nil {
- return nil, fmt.Errorf("GetPublicKeyForToken error: %s", err)
- }
- return publicKey, nil
- })
- if err != nil {
- c.returnError(w, fmt.Errorf("token error: %s", err), http.StatusUnauthorized)
- return
- }
- token.Claims.(jwt.MapClaims)["kid"] = token.Header["kid"]
- ctx := context.WithValue(r.Context(), CustomValue("claims"), token.Claims.(jwt.MapClaims))
- next.ServeHTTP(w, r.WithContext(ctx))
- })
-}
-
-// logging middleware
-
-// responseWriter is a minimal wrapper for http.ResponseWriter that allows the
-// written HTTP status code to be captured for logging.
-// MIT licensed
-type responseWriter struct {
- http.ResponseWriter
- status int
- wroteHeader bool
-}
-
-func (rw *responseWriter) Status() int {
- return rw.status
-}
-
-func (rw *responseWriter) WriteHeader(code int) {
- if rw.wroteHeader {
- return
- }
-
- rw.status = code
- rw.ResponseWriter.WriteHeader(code)
- rw.wroteHeader = true
-}
-
-func (c *Context) loggingMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- start := time.Now()
- wrappedResponse := &responseWriter{ResponseWriter: w}
- next.ServeHTTP(wrappedResponse, r)
- log.Printf("req=%s res=%d method=%s src=%s duration=%s", r.RequestURI, wrappedResponse.status, r.Method, r.RemoteAddr, time.Since(start))
- })
-}
-
-func (c *Context) corsMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.Method == http.MethodOptions {
- sendCorsHeaders(w, r.Header.Get("Access-Control-Request-Headers"), c.Hostname, c.Protocol)
- w.WriteHeader(http.StatusNoContent)
- } else {
- next.ServeHTTP(w, r)
- }
- })
-}
-func (c *Context) injectUserMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user, err := c.GetUserFromRequest(r)
- if err != nil {
- c.returnError(w, fmt.Errorf("token error: %s", err), http.StatusUnauthorized)
- return
- }
-
- ctx := context.WithValue(r.Context(), CustomValue("user"), user)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
-}
-
-func (c *Context) isAdminMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user := r.Context().Value(CustomValue("user")).(users.User)
- if user.Role != "admin" {
- c.returnError(w, fmt.Errorf("endpoint forbidden"), http.StatusForbidden)
- return
- }
- next.ServeHTTP(w, r)
- })
-}
-
-func (c *Context) httpsRedirectMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if c.RedirectToHttps && r.TLS == nil {
- if strings.HasPrefix(r.URL.Path, "/api") {
- c.returnError(w, fmt.Errorf("non-tls requests disabled"), http.StatusForbidden)
- return
- }
- http.Redirect(w, r, fmt.Sprintf("https://%s%s", r.Host, r.RequestURI), http.StatusMovedPermanently)
- return
- }
- next.ServeHTTP(w, r)
- })
-}
-
-func (c *Context) isSCIMEnabled(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if !c.SCIM.EnableSCIM {
- w.WriteHeader(http.StatusNotFound)
- w.Write([]byte(`{ "error": "SCIM Not Enabled" }`))
- return
- }
- next.ServeHTTP(w, r)
- })
-}
diff --git a/pkg/rest/profile.go b/pkg/rest/profile.go
deleted file mode 100644
index e035bce..0000000
--- a/pkg/rest/profile.go
+++ /dev/null
@@ -1,139 +0,0 @@
-package rest
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
-
- "github.com/in4it/wireguard-server/pkg/mfa/totp"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-func (c *Context) profilePasswordHandler(w http.ResponseWriter, r *http.Request) {
- user := r.Context().Value(CustomValue("user")).(users.User)
- switch r.Method {
- case http.MethodPost:
- var userInput users.User
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&userInput)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- if userInput.Password == "" {
- c.returnError(w, fmt.Errorf("no password supplied"), http.StatusBadRequest)
- return
- }
- err = c.UserStore.UpdatePassword(user.ID, userInput.Password)
- if err != nil {
- c.returnError(w, fmt.Errorf("update password error: %s", err), http.StatusBadRequest)
- return
- }
-
- c.write(w, []byte(`{"result": "OK"}`))
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
-
- }
-}
-
-func (c *Context) profileFactorsHandler(w http.ResponseWriter, r *http.Request) {
- user := r.Context().Value(CustomValue("user")).(users.User)
- switch r.Method {
- case http.MethodGet:
- factors := make([]users.Factor, len(user.Factors))
- copy(factors, user.Factors)
- for k := range factors {
- factors[k].Secret = "" // remove secret when outputting
- }
- out, err := json.Marshal(factors)
- if err != nil {
- c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- var factor FactorRequest
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&factor)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode factor error: %s", err), http.StatusBadRequest)
- return
- }
- if factor.Secret == "" {
- c.returnError(w, fmt.Errorf("no factor secret supplied"), http.StatusBadRequest)
- return
- }
- if factor.Name == "" {
- c.returnError(w, fmt.Errorf("no factor name supplied"), http.StatusBadRequest)
- return
- }
- if len(factor.Name) > 16 {
- c.returnError(w, fmt.Errorf("factor name too long"), http.StatusBadRequest)
- return
- }
- if factor.Type == "" {
- c.returnError(w, fmt.Errorf("no factor type supplied"), http.StatusBadRequest)
- return
- }
- if factor.Code == "" {
- c.returnError(w, fmt.Errorf("no factor code supplied"), http.StatusBadRequest)
- return
- }
-
- ok, err := totp.VerifyMultipleIntervals(factor.Secret, factor.Code, 20)
- if err != nil {
- c.returnError(w, fmt.Errorf("totp verify error: %s", err), http.StatusBadRequest)
- return
- }
-
- if !ok {
- c.returnError(w, fmt.Errorf("code doesn't match. Try entering code again or try with a new QR code"), http.StatusBadRequest)
- return
- }
-
- user.Factors = append(user.Factors, users.Factor{Type: factor.Type, Secret: factor.Secret, Name: factor.Name})
- out, err := json.Marshal(user.Factors)
- if err != nil {
- c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest)
- return
- }
- err = c.UserStore.UpdateUser(user)
- if err != nil {
- c.returnError(w, fmt.Errorf("coudn't update user: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodDelete:
- factorName := r.PathValue("name")
- if factorName == "" {
- c.returnError(w, fmt.Errorf("no factor name supplied"), http.StatusBadRequest)
- return
- }
- toDelete := -1
- for k := range user.Factors {
- if user.Factors[k].Name == factorName {
- toDelete = k
- }
- }
- if toDelete == -1 {
- c.returnError(w, fmt.Errorf("factor not found"), http.StatusBadRequest)
- return
- }
- user.Factors = append(user.Factors[:toDelete], user.Factors[toDelete+1:]...)
- err := c.UserStore.UpdateUser(user)
- if err != nil {
- c.returnError(w, fmt.Errorf("coudn't update user: %s", err), http.StatusBadRequest)
- return
- }
- out, err := json.Marshal(user.Factors)
- if err != nil {
- c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
-
- }
-}
diff --git a/pkg/rest/router.go b/pkg/rest/router.go
deleted file mode 100644
index 98a1e8d..0000000
--- a/pkg/rest/router.go
+++ /dev/null
@@ -1,69 +0,0 @@
-package rest
-
-import (
- "io/fs"
- "net/http"
-)
-
-func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux {
- mux := http.NewServeMux()
-
- // static files
- mux.Handle("/assets/{filename}", http.FileServer(http.FS(assets)))
- mux.Handle("/index.html", returnIndexOrNotFound(indexHtml))
- mux.Handle("/favicon.ico", http.FileServer(http.FS(assets)))
-
- // saml authentication
- mux.Handle("/saml/", c.SAML.Client.GetRouter())
-
- // endpoints with no authentication
- mux.Handle("/api/context", http.HandlerFunc(c.contextHandler))
- mux.Handle("/api/auth", http.HandlerFunc(c.authHandler))
- mux.Handle("/api/authmethods", http.HandlerFunc(c.authMethods))
- mux.Handle("/api/authmethods/{method}/{id}", http.HandlerFunc(c.authMethodsByID))
- mux.Handle("/api/authmethods/{id}", http.HandlerFunc(c.authMethodsByID))
- mux.Handle("/api/version", http.HandlerFunc(c.version))
- mux.Handle("/api/upgrade", http.HandlerFunc(c.upgrade))
- mux.Handle("/", returnIndexOrNotFound(indexHtml))
-
- // endpoints with no authentication (observability)
- if c.ServerType == SERVER_TYPE_OBSERVABILITY {
- mux.Handle("/api/observability/", c.Observability.Client.GetRouter())
- }
-
- // scim
- mux.Handle("/api/scim/", c.isSCIMEnabled(c.SCIM.Client.GetRouter()))
-
- // endpoints with authentication
- mux.Handle("/api/userinfo", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.userinfoHandler))))
- mux.Handle("/api/profile/password", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profilePasswordHandler))))
- mux.Handle("/api/profile/factors", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profileFactorsHandler))))
- mux.Handle("/api/profile/factors/{name}", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profileFactorsHandler))))
-
- // endpoint with authentication (VPN)
- if c.ServerType == SERVER_TYPE_VPN {
- mux.Handle("/api/connections", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.connectionsHandler))))
- mux.Handle("/api/connection/{id}", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.connectionsElementHandler))))
- mux.Handle("/api/connectionlicense", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.connectionLicenseHandler))))
- }
-
- // endpoints with authentication, with admin role
- mux.Handle("/api/license", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.licenseHandler)))))
- mux.Handle("/api/license/{action}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.licenseHandler)))))
- mux.Handle("/api/oidc", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderHandler)))))
- mux.Handle("/api/oidc-renew-tokens", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcRenewTokensHandler)))))
- mux.Handle("/api/oidc/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderElementHandler)))))
- mux.Handle("/api/setup/general", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.setupHandler)))))
- mux.Handle("/api/setup/vpn", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.vpnSetupHandler)))))
- mux.Handle("/api/setup/templates", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.templateSetupHandler)))))
- mux.Handle("/api/setup/restart-vpn", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.restartVPNHandler)))))
- mux.Handle("/api/scim-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.scimSetupHandler)))))
- mux.Handle("/api/saml-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupHandler)))))
- mux.Handle("/api/saml-setup/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupElementHandler)))))
- mux.Handle("/api/users", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.usersHandler)))))
- mux.Handle("/api/user/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.userHandler)))))
- mux.Handle("/api/stats/user/{date}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.userStatsHandler)))))
- mux.Handle("/api/stats/packetlogs/{user}/{date}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.packetLogsHandler)))))
-
- return mux
-}
diff --git a/pkg/rest/rsa.go b/pkg/rest/rsa.go
deleted file mode 100644
index 3f8fd4b..0000000
--- a/pkg/rest/rsa.go
+++ /dev/null
@@ -1,126 +0,0 @@
-package rest
-
-/*
- * Genarate rsa keys. (https://github.com/wardviaene/http-echo/blob/master/rsa.go)
- */
-
-import (
- "bytes"
- "crypto/rand"
- "crypto/rsa"
- "crypto/x509"
- "encoding/pem"
- "fmt"
- "path"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-type JWTKeys struct {
- PrivateKey *rsa.PrivateKey `json:"privateKey,omitempty"`
- PublicKey *rsa.PublicKey `json:"publicKey,omitempty"`
-}
-
-func getJWTKeys(storage storage.Iface) (*JWTKeys, error) {
-
- filename := storage.ConfigPath("pki/private.pem")
- filenamePublicKey := storage.ConfigPath("pki/public.pem")
-
- if !storage.FileExists(filename) {
- err := storage.EnsurePath(path.Dir(filename))
- if err != nil {
- return nil, fmt.Errorf("ensure path error: %s", err)
- }
- err = createJWTKeys(storage, storage.ConfigPath("pki"))
- if err != nil {
- return nil, fmt.Errorf("createJWTKeys error: %s", err)
- }
- }
-
- signBytes, err := storage.ReadFile(filename)
- if err != nil {
- return nil, fmt.Errorf("private key read error: %s", err)
- }
- publicBytes, err := storage.ReadFile(filenamePublicKey)
- if err != nil {
- return nil, fmt.Errorf("private key read error: %s", err)
- }
-
- signKey, err := jwt.ParseRSAPrivateKeyFromPEM(signBytes)
- if err != nil {
- return nil, fmt.Errorf("can't parse private key: %s", err)
- }
- publicKey, err := jwt.ParseRSAPublicKeyFromPEM(publicBytes)
- if err != nil {
- return nil, fmt.Errorf("can't parse public key: %s", err)
- }
- return &JWTKeys{PrivateKey: signKey, PublicKey: publicKey}, nil
-}
-
-func createJWTKeys(storage storage.Iface, path string) error {
- reader := rand.Reader
- bitSize := 4096
-
- key, err := rsa.GenerateKey(reader, bitSize)
- if err != nil {
- return err
- }
-
- publicKey := key.PublicKey
-
- err = savePEMKey(storage, path+"/private.pem", key)
- if err != nil {
- return err
- }
- err = savePublicPEMKey(storage, path+"/public.pem", publicKey)
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func savePEMKey(storage storage.Iface, fileName string, key *rsa.PrivateKey) error {
- var buf bytes.Buffer
-
- var privateKey = &pem.Block{
- Type: "RSA PRIVATE KEY",
- Bytes: x509.MarshalPKCS1PrivateKey(key),
- }
-
- err := pem.Encode(&buf, privateKey)
- if err != nil {
- return err
- }
-
- err = storage.WriteFile(fileName, buf.Bytes())
- if err != nil {
- return fmt.Errorf("WriteFile error: %s", err)
- }
- return nil
-}
-
-func savePublicPEMKey(storage storage.Iface, fileName string, pubkey rsa.PublicKey) error {
- var buf bytes.Buffer
-
- asn1Bytes, err := x509.MarshalPKIXPublicKey(&pubkey)
- if err != nil {
- return err
- }
-
- var pemkey = &pem.Block{
- Type: "PUBLIC KEY",
- Bytes: asn1Bytes,
- }
-
- err = pem.Encode(&buf, pemkey)
- if err != nil {
- return err
- }
- err = storage.WriteFile(fileName, buf.Bytes())
- if err != nil {
- return fmt.Errorf("WriteFile error: %s", err)
- }
- return nil
-}
diff --git a/pkg/rest/rsa_test.go b/pkg/rest/rsa_test.go
deleted file mode 100644
index ce23943..0000000
--- a/pkg/rest/rsa_test.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package rest
-
-import (
- "bytes"
- "crypto/x509"
- "encoding/pem"
- "testing"
-
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
-)
-
-func TestGetJWTKeys(t *testing.T) {
- mockStorage := memorystorage.MockMemoryStorage{}
- keys, err := getJWTKeys(&mockStorage)
- if err != nil {
- t.Fatalf("error: %s", err)
- }
- privateKeyFromFile, err := mockStorage.ReadFile(mockStorage.ConfigPath("pki/private.pem"))
- if err != nil {
- t.Fatalf("read error: %s", err)
- }
- _, err = mockStorage.ReadFile(mockStorage.ConfigPath("pki/public.pem"))
- if err != nil {
- t.Fatalf("read error: %s", err)
- }
-
- var buf bytes.Buffer
- var privateKey = &pem.Block{
- Type: "RSA PRIVATE KEY",
- Bytes: x509.MarshalPKCS1PrivateKey(keys.PrivateKey),
- }
- err = pem.Encode(&buf, privateKey)
- if err != nil {
- t.Fatalf("pem encode error: %s", err)
- }
-
- if !bytes.Equal(privateKeyFromFile, buf.Bytes()) {
- t.Fatalf("private keys don't match")
- }
-}
diff --git a/pkg/rest/server.go b/pkg/rest/server.go
deleted file mode 100644
index a223511..0000000
--- a/pkg/rest/server.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package rest
-
-import (
- "crypto/tls"
- "embed"
- "fmt"
- "io"
- "io/fs"
- "log"
- "net/http"
-
- "github.com/in4it/wireguard-server/pkg/logging"
- localstorage "github.com/in4it/wireguard-server/pkg/storage/local"
- "golang.org/x/crypto/acme/autocert"
-)
-
-var (
- //go:embed static
- assets embed.FS
- enableTLSWaiter chan (bool) = make(chan bool)
- TLSWaiterCompleted bool
-)
-
-func StartServer(httpPort, httpsPort int, serverType string) {
- localStorage, err := localstorage.New()
- if err != nil {
- log.Fatalf("couldn't initialize storage: %s", err)
- }
- c, err := newContext(localStorage, serverType)
- if err != nil {
- log.Fatalf("startup failed: %s", err)
- }
-
- go handleSignals(c)
-
- assetsFS, err := fs.Sub(assets, "static")
- if err != nil {
- log.Fatalf("could not load static web assets")
- }
-
- indexHtml, err := assetsFS.Open("index.html")
- if err != nil {
- log.Fatalf("could not load static web assets (index.html)")
- }
- indexHtmlBody, err := io.ReadAll(indexHtml)
- if err != nil {
- log.Fatalf("could not read static web assets (index.html)")
- }
-
- certManager := autocert.Manager{}
-
- // HTTP Configuration
- go func() { // start http server
- log.Printf("Start http server on port %d", httpPort)
- log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", httpPort), certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody)))))))
- }()
-
- // TLS Configuration
- if !c.EnableTLS || !canEnableTLS(c.Hostname) {
- <-enableTLSWaiter
- }
- // only enable when TLS is enabled
-
- logging.DebugLog(fmt.Errorf("enabling TLS endpoint with let's encrypt for hostname '%s'", c.Hostname))
- certManager.Prompt = autocert.AcceptTOS
- certManager.HostPolicy = autocert.HostWhitelist(c.Hostname)
- certManager.Cache = autocert.DirCache("tls-certs")
- tlsServer := &http.Server{
- Addr: fmt.Sprintf(":%d", httpsPort),
- TLSConfig: &tls.Config{
- GetCertificate: certManager.GetCertificate,
- },
- Handler: c.loggingMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))),
- }
- c.Protocol = "https"
- TLSWaiterCompleted = true
- log.Fatal(tlsServer.ListenAndServeTLS("", ""))
-}
diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go
deleted file mode 100644
index 719ae8f..0000000
--- a/pkg/rest/setup.go
+++ /dev/null
@@ -1,669 +0,0 @@
-package rest
-
-import (
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/netip"
- "reflect"
- "slices"
- "sort"
- "strconv"
- "strings"
- "time"
-
- "github.com/google/uuid"
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- "github.com/in4it/wireguard-server/pkg/auth/saml"
- "github.com/in4it/wireguard-server/pkg/license"
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-func (c *Context) contextHandler(w http.ResponseWriter, r *http.Request) {
- if r.Method == http.MethodPost {
- decoder := json.NewDecoder(r.Body)
- var contextReq ContextRequest
- err := decoder.Decode(&contextReq)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- if !c.Storage.Client.FileExists(SETUP_CODE_FILE) {
- c.SetupCompleted = true
- }
- if !c.SetupCompleted {
- // check if tag hash is chosen
- accessGranted := false
- switch c.CloudType {
- case "digitalocean": // check if the hashtag is set
- if contextReq.TagHash != "" {
- if !strings.HasPrefix(contextReq.TagHash, "vpnsecret-") {
- c.returnError(w, fmt.Errorf("tag doesn't have the correct prefix. The tag needs to start with 'vpnsecret-'"), http.StatusUnauthorized)
- return
- }
- accessGranted, err = license.HasDigitalOceanTagSet(http.Client{Timeout: 5 * time.Second}, contextReq.TagHash)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not retrieve tags at this time: %s", err), http.StatusUnauthorized)
- return
- }
- if !accessGranted {
- c.returnError(w, fmt.Errorf("tag not found. Make sure the correct tag is attached to the droplet"), http.StatusUnauthorized)
- return
- }
- }
- case "aws": // check if the instance id is set
- if contextReq.InstanceID != "" {
- instanceID, err := license.GetAWSInstanceID(http.Client{Timeout: 5 * time.Second})
- if err != nil {
- c.returnError(w, fmt.Errorf("could not retrieve instance id at this time: %s", err), http.StatusUnauthorized)
- return
- }
- if strings.TrimPrefix(instanceID, "i-") == strings.TrimPrefix(contextReq.InstanceID, "i-") {
- accessGranted = true
- } else {
- c.returnError(w, fmt.Errorf("instance id doesn't match"), http.StatusUnauthorized)
- return
- }
- }
- }
- // check secret
- if !accessGranted {
- localSecret, err := c.Storage.Client.ReadFile(SETUP_CODE_FILE)
- if err != nil {
- c.returnError(w, fmt.Errorf("secret file read error: %s", err), http.StatusBadRequest)
- return
- }
- if strings.TrimSpace(string(localSecret)) != contextReq.Secret {
- c.returnError(w, fmt.Errorf("wrong secret provided"), http.StatusUnauthorized)
- return
- }
- }
- if contextReq.AdminPassword != "" {
- adminUser := users.User{
- Login: "admin",
- Password: contextReq.AdminPassword,
- Role: "admin",
- }
- if c.UserStore.LoginExists("admin") {
- err = c.UserStore.UpdateUser(adminUser)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not update user: %s", err), http.StatusBadRequest)
- return
- }
- } else {
- _, err = c.UserStore.AddUser(adminUser)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not add user: %s", err), http.StatusBadRequest)
- return
- }
- }
-
- c.SetupCompleted = true
- c.Hostname = contextReq.Hostname
- protocol := contextReq.Protocol
- protocol = strings.Replace(protocol, "http:", "http", -1)
- protocol = strings.Replace(protocol, "https:", "https", -1)
- c.Protocol = protocol
-
- err = SaveConfig(c)
- if err != nil {
- c.SetupCompleted = false
- c.returnError(w, fmt.Errorf("unable to save file: %s", err), http.StatusBadRequest)
- return
- }
-
- // update hostname in vpn config
- vpnconfig, err := wireguard.GetVPNConfig(c.Storage.Client)
- if err != nil {
- c.SetupCompleted = false
- c.returnError(w, fmt.Errorf("unable to get vpn-config: %s", err), http.StatusBadRequest)
- return
- }
- vpnconfig.Endpoint = c.Hostname
- err = wireguard.WriteVPNConfig(c.Storage.Client, vpnconfig)
- if err != nil {
- c.SetupCompleted = false
- c.returnError(w, fmt.Errorf("unable to write vpn-config: %s", err), http.StatusBadRequest)
- return
- }
- }
- }
- }
-
- out, err := json.Marshal(ContextSetupResponse{SetupCompleted: c.SetupCompleted, CloudType: c.CloudType, ServerType: c.ServerType})
- if err != nil {
- c.returnError(w, err, http.StatusBadRequest)
- return
- }
- c.write(w, out)
-}
-
-func (c *Context) setupHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- setupRequest := GeneralSetupRequest{
- Hostname: c.Hostname,
- EnableTLS: c.EnableTLS,
- RedirectToHttps: c.RedirectToHttps,
- DisableLocalAuth: c.LocalAuthDisabled,
- EnableOIDCTokenRenewal: c.EnableOIDCTokenRenewal,
- }
- out, err := json.Marshal(setupRequest)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- var setupRequest GeneralSetupRequest
- decoder := json.NewDecoder(r.Body)
- decoder.Decode(&setupRequest)
- if c.Hostname != setupRequest.Hostname {
- c.Hostname = setupRequest.Hostname
- }
- if c.RedirectToHttps != setupRequest.RedirectToHttps {
- c.RedirectToHttps = setupRequest.RedirectToHttps
- }
- if c.EnableTLS != setupRequest.EnableTLS {
- if !c.EnableTLS && setupRequest.EnableTLS && !TLSWaiterCompleted && canEnableTLS(c.Hostname) {
- enableTLSWaiter <- true
- }
- c.EnableTLS = setupRequest.EnableTLS
- }
- if c.LocalAuthDisabled != setupRequest.DisableLocalAuth {
- c.LocalAuthDisabled = setupRequest.DisableLocalAuth
- }
- if c.EnableOIDCTokenRenewal != setupRequest.EnableOIDCTokenRenewal {
- c.EnableOIDCTokenRenewal = setupRequest.EnableOIDCTokenRenewal
- c.OIDCRenewal.SetEnabled(c.EnableOIDCTokenRenewal)
- }
- err := SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not save config to disk: %s", err), http.StatusBadRequest)
- return
- }
- out, err := json.Marshal(setupRequest)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
- vpnConfig, err := wireguard.GetVPNConfig(c.Storage.Client)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not get vpn config: %s", err), http.StatusBadRequest)
- return
- }
- switch r.Method {
- case http.MethodGet:
- packetLogTypes := []string{}
- for k, enabled := range vpnConfig.PacketLogsTypes {
- if enabled {
- packetLogTypes = append(packetLogTypes, k)
- }
- }
- if vpnConfig.PacketLogsRetention == 0 {
- vpnConfig.PacketLogsRetention = 7
- }
- setupRequest := VPNSetupRequest{
- Routes: strings.Join(vpnConfig.ClientRoutes, ", "),
- VPNEndpoint: vpnConfig.Endpoint,
- AddressRange: vpnConfig.AddressRange.String(),
- ClientAddressPrefix: vpnConfig.ClientAddressPrefix,
- Port: strconv.Itoa(vpnConfig.Port),
- ExternalInterface: vpnConfig.ExternalInterface,
- Nameservers: strings.Join(vpnConfig.Nameservers, ","),
- DisableNAT: vpnConfig.DisableNAT,
- EnablePacketLogs: vpnConfig.EnablePacketLogs,
- PacketLogsTypes: packetLogTypes,
- PacketLogsRetention: strconv.Itoa(vpnConfig.PacketLogsRetention),
- }
- out, err := json.Marshal(setupRequest)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- var (
- writeVPNConfig bool
- rewriteClientConfigs bool
- setupRequest VPNSetupRequest
- )
- decoder := json.NewDecoder(r.Body)
- decoder.Decode(&setupRequest)
- if strings.Join(vpnConfig.ClientRoutes, ", ") != setupRequest.Routes {
- networks := strings.Split(setupRequest.Routes, ",")
- validatedNetworks := []string{}
- for _, network := range networks {
- if strings.TrimSpace(network) == "::/0" {
- validatedNetworks = append(validatedNetworks, "::/0")
- } else {
- _, ipnet, err := net.ParseCIDR(strings.TrimSpace(network))
- if err != nil {
- c.returnError(w, fmt.Errorf("client route %s in wrong format: %s", strings.TrimSpace(network), err), http.StatusBadRequest)
- return
- }
- validatedNetworks = append(validatedNetworks, ipnet.String())
- }
- }
- vpnConfig.ClientRoutes = validatedNetworks
- writeVPNConfig = true
- rewriteClientConfigs = true
- }
- if vpnConfig.Endpoint != setupRequest.VPNEndpoint {
- vpnConfig.Endpoint = setupRequest.VPNEndpoint
- writeVPNConfig = true
- rewriteClientConfigs = true
- }
- addressRangeParsed, err := netip.ParsePrefix(setupRequest.AddressRange)
- if err != nil {
- c.returnError(w, fmt.Errorf("AddressRange in wrong format: %s", err), http.StatusBadRequest)
- return
- }
- if addressRangeParsed.String() != vpnConfig.AddressRange.String() {
- vpnConfig.AddressRange = addressRangeParsed
- writeVPNConfig = true
- rewriteClientConfigs = true
- }
- if setupRequest.ClientAddressPrefix != vpnConfig.ClientAddressPrefix {
- vpnConfig.ClientAddressPrefix = setupRequest.ClientAddressPrefix
- writeVPNConfig = true
- rewriteClientConfigs = true
- }
- port, err := strconv.Atoi(setupRequest.Port)
- if err != nil {
- c.returnError(w, fmt.Errorf("port in wrong format: %s", err), http.StatusBadRequest)
- return
- }
- if port != vpnConfig.Port {
- vpnConfig.Port = port
- writeVPNConfig = true
- rewriteClientConfigs = true
- }
-
- nameservers := strings.Split(setupRequest.Nameservers, ",")
- for k := range nameservers {
- nameservers[k] = strings.TrimSpace(nameservers[k])
- }
- if !reflect.DeepEqual(nameservers, vpnConfig.Nameservers) {
- vpnConfig.Nameservers = nameservers
- writeVPNConfig = true
- rewriteClientConfigs = true
- }
- if setupRequest.ExternalInterface != vpnConfig.ExternalInterface { // don't rewrite client config
- vpnConfig.ExternalInterface = setupRequest.ExternalInterface
- writeVPNConfig = true
- }
- if setupRequest.DisableNAT != vpnConfig.DisableNAT { // don't rewrite client config
- vpnConfig.DisableNAT = setupRequest.DisableNAT
- writeVPNConfig = true
- }
- if setupRequest.EnablePacketLogs != vpnConfig.EnablePacketLogs {
- vpnConfig.EnablePacketLogs = setupRequest.EnablePacketLogs
- writeVPNConfig = true
- }
- packetLogsRention, err := strconv.Atoi(setupRequest.PacketLogsRetention)
- if err != nil || packetLogsRention < 1 {
- c.returnError(w, fmt.Errorf("incorrect packet log retention. Enter a number of days the logs must be kept (minimum 1)"), http.StatusBadRequest)
- return
- }
- if packetLogsRention != vpnConfig.PacketLogsRetention {
- vpnConfig.PacketLogsRetention = packetLogsRention
- writeVPNConfig = true
- }
-
- // packetlogtypes
- packetLogTypes := []string{}
- for k, enabled := range vpnConfig.PacketLogsTypes {
- if enabled {
- packetLogTypes = append(packetLogTypes, k)
- }
- }
- sort.Strings(setupRequest.PacketLogsTypes)
- sort.Strings(packetLogTypes)
- if !slices.Equal(setupRequest.PacketLogsTypes, packetLogTypes) {
- vpnConfig.PacketLogsTypes = make(map[string]bool)
- for _, v := range setupRequest.PacketLogsTypes {
- if v == "http+https" || v == "dns" || v == "tcp" {
- vpnConfig.PacketLogsTypes[v] = true
- }
- }
- writeVPNConfig = true
- }
-
- // write vpn config if config has changed
- if writeVPNConfig {
- err = wireguard.WriteVPNConfig(c.Storage.Client, vpnConfig)
- if err != nil {
- c.returnError(w, fmt.Errorf("could write vpn config: %s", err), http.StatusBadRequest)
- return
- }
- err = wireguard.ReloadVPNServerConfig()
- if err != nil {
- c.returnError(w, fmt.Errorf("unable to reload server config: %s", err), http.StatusBadRequest)
- return
- }
- }
- if rewriteClientConfigs {
- // rewrite client configs
- err = wireguard.UpdateClientsConfig(c.Storage.Client)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest)
- return
- }
- }
- out, err := json.Marshal(setupRequest)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) templateSetupHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest)
- return
- }
- serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest)
- return
- }
- setupRequest := TemplateSetupRequest{
- ClientTemplate: string(clientTemplate),
- ServerTemplate: string(serverTemplate),
- }
- out, err := json.Marshal(setupRequest)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- var templateSetupRequest TemplateSetupRequest
- decoder := json.NewDecoder(r.Body)
- decoder.Decode(&templateSetupRequest)
- clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest)
- return
- }
- serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest)
- return
- }
- if string(clientTemplate) != templateSetupRequest.ClientTemplate {
- err = wireguard.WriteClientTemplate(c.Storage.Client, []byte(templateSetupRequest.ClientTemplate))
- if err != nil {
- c.returnError(w, fmt.Errorf("WriteClientTemplate error: %s", err), http.StatusBadRequest)
- return
- }
- // rewrite client configs
- err = wireguard.UpdateClientsConfig(c.Storage.Client)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest)
- return
- }
- }
- if string(serverTemplate) != templateSetupRequest.ServerTemplate {
- err = wireguard.WriteServerTemplate(c.Storage.Client, []byte(templateSetupRequest.ServerTemplate))
- if err != nil {
- c.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest)
- return
- }
- }
- out, err := json.Marshal(templateSetupRequest)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) restartVPNHandler(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost {
- c.returnError(w, fmt.Errorf("unsupported method"), http.StatusBadRequest)
- return
- }
- client := http.Client{
- Timeout: 10 * time.Second,
- }
- req, err := http.NewRequest(r.Method, "http://"+wireguard.CONFIGMANAGER_URI+"/restart-vpn", nil)
- if err != nil {
- c.returnError(w, fmt.Errorf("restart request error: %s", err), http.StatusBadRequest)
- return
- }
- resp, err := client.Do(req)
- if err != nil {
- c.returnError(w, fmt.Errorf("restart error: %s", err), http.StatusBadRequest)
- return
- }
- if resp.StatusCode != http.StatusAccepted {
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.returnError(w, fmt.Errorf("restart error: got status code: %d. Response: %s", resp.StatusCode, bodyBytes), http.StatusBadRequest)
- return
- }
- c.returnError(w, fmt.Errorf("restart error: got status code: %d. Couldn't get response", resp.StatusCode), http.StatusBadRequest)
- return
- }
-
- defer resp.Body.Close()
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.returnError(w, fmt.Errorf("body read error: %s", err), http.StatusBadRequest)
- return
- }
-
- c.write(w, bodyBytes)
-}
-
-func (c *Context) scimSetupHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- scimSetup := SCIMSetup{
- Enabled: c.SCIM.EnableSCIM,
- }
- if c.SCIM.EnableSCIM {
- scimSetup.Token = c.SCIM.Token
- scimSetup.BaseURL = fmt.Sprintf("%s://%s/%s", c.Protocol, c.Hostname, "api/scim/v2/")
- }
- out, err := json.Marshal(scimSetup)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal scim setup: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- saveConfig := false
- var scimSetupRequest SCIMSetup
- decoder := json.NewDecoder(r.Body)
- decoder.Decode(&scimSetupRequest)
- if scimSetupRequest.Enabled && !c.SCIM.EnableSCIM {
- c.SCIM.EnableSCIM = true
- saveConfig = true
- }
- if !scimSetupRequest.Enabled && c.SCIM.EnableSCIM {
- c.SCIM.EnableSCIM = false
- saveConfig = true
- }
- if scimSetupRequest.RegenerateToken || (scimSetupRequest.Enabled && c.SCIM.Token == "") {
- // Generate new token
- randomString, err := oidc.GetRandomString(64)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not enable scim: %s", err), http.StatusBadRequest)
- return
- }
- token := base64.StdEncoding.EncodeToString([]byte(randomString))
- scimSetupRequest.Token = token
- c.SCIM.Token = token
- c.SCIM.Client.UpdateToken(token)
- saveConfig = true
- }
- if saveConfig {
- // save config
- err := SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not save config to disk: %s", err), http.StatusBadRequest)
- return
- }
- }
- out, err := json.Marshal(scimSetupRequest)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal scim setup: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) samlSetupHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- samlProviders := make([]saml.Provider, len(*c.SAML.Providers))
- copy(samlProviders, *c.SAML.Providers)
- for k := range samlProviders {
- samlProviders[k].Issuer = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ISSUER_URL, samlProviders[k].ID)
- samlProviders[k].Audience = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.AUDIENCE_URL, samlProviders[k].ID)
- samlProviders[k].Acs = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ACS_URL, samlProviders[k].ID)
- }
- out, err := json.Marshal(samlProviders)
- if err != nil {
- c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- var samlProvider saml.Provider
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&samlProvider)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- samlProvider.ID = uuid.New().String()
- if samlProvider.Name == "" {
- c.returnError(w, fmt.Errorf("name not set"), http.StatusBadRequest)
- return
- }
- if samlProvider.MetadataURL == "" {
- c.returnError(w, fmt.Errorf("metadata URL not set"), http.StatusBadRequest)
- return
- }
- _, err = c.SAML.Client.HasValidMetadataURL(samlProvider.MetadataURL)
- if err != nil {
- c.returnError(w, fmt.Errorf("metadata error: %s", err), http.StatusBadRequest)
- return
- }
-
- *c.SAML.Providers = append(*c.SAML.Providers, samlProvider)
- out, err := json.Marshal(samlProvider)
- if err != nil {
- c.returnError(w, fmt.Errorf("samlProvider marshal error: %s", err), http.StatusBadRequest)
- return
- }
- err = SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
-
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) samlSetupElementHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodDelete:
- match := -1
- for k, samlProvider := range *c.SAML.Providers {
- if samlProvider.ID == r.PathValue("id") {
- match = k
- }
- }
- if match == -1 {
- c.returnError(w, fmt.Errorf("saml provider not found"), http.StatusBadRequest)
- return
- }
- *c.SAML.Providers = append((*c.SAML.Providers)[:match], (*c.SAML.Providers)[match+1:]...)
- // save config (changed providers)
- err := SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, []byte(`{ "deleted": "`+r.PathValue("id")+`" }`))
- case http.MethodPut:
- var samlProvider saml.Provider
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&samlProvider)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- samlProviderID := -1
- for k := range *c.SAML.Providers {
- if (*c.SAML.Providers)[k].ID == r.PathValue("id") {
- samlProviderID = k
- }
- }
- if samlProviderID == -1 {
- c.returnError(w, fmt.Errorf("cannot find saml provider: %s", err), http.StatusBadRequest)
- return
- }
- saveConfig := false
- if (*c.SAML.Providers)[samlProviderID].AllowMissingAttributes != samlProvider.AllowMissingAttributes {
- (*c.SAML.Providers)[samlProviderID].AllowMissingAttributes = samlProvider.AllowMissingAttributes
- saveConfig = true
- }
- if (*c.SAML.Providers)[samlProviderID].MetadataURL != samlProvider.MetadataURL {
- _, err := c.SAML.Client.HasValidMetadataURL(samlProvider.MetadataURL)
- if err != nil {
- c.returnError(w, fmt.Errorf("metadata error: %s", err), http.StatusBadRequest)
- return
- }
- (*c.SAML.Providers)[samlProviderID].MetadataURL = samlProvider.MetadataURL
- saveConfig = true
- }
- out, err := json.Marshal(samlProvider)
- if err != nil {
- c.returnError(w, fmt.Errorf("samlProvider marshal error: %s", err), http.StatusBadRequest)
- return
- }
- if saveConfig {
- err = SaveConfig(c)
- if err != nil {
- c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest)
- return
- }
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
diff --git a/pkg/rest/setup_test.go b/pkg/rest/setup_test.go
deleted file mode 100644
index 2e14bca..0000000
--- a/pkg/rest/setup_test.go
+++ /dev/null
@@ -1,293 +0,0 @@
-package rest
-
-import (
- "bytes"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/in4it/wireguard-server/pkg/license"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-func TestContextHandlerSetupSecret(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
-
- storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`))
-
- userStore, err := users.NewUserStore(storage, -1)
- if err != nil {
- t.Fatalf("new user store error")
- }
- c, err := getEmptyContext("appdir")
- if err != nil {
- t.Fatalf("cannot create empty context")
- }
- c.Storage = &Storage{Client: storage}
- c.UserStore = userStore
-
- payload := ContextRequest{
- Secret: "secret setup code",
- AdminPassword: "adminPassword",
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes))
- w := httptest.NewRecorder()
- c.contextHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("got read error: %s", err)
- }
-
- var contextSetupResponse ContextSetupResponse
- err = json.Unmarshal(body, &contextSetupResponse)
- if err != nil {
- t.Fatalf("unmarshal error: %s", err)
- }
- if !contextSetupResponse.SetupCompleted {
- t.Fatalf("expected setup to be completed")
- }
-}
-
-func TestContextHandlerSetupWrongSecret(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
-
- storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`))
-
- userStore, err := users.NewUserStore(storage, -1)
- if err != nil {
- t.Fatalf("new user store error")
- }
- c, err := getEmptyContext("appdir")
- if err != nil {
- t.Fatalf("cannot create empty context")
- }
- c.Storage = &Storage{Client: storage}
- c.UserStore = userStore
-
- payload := ContextRequest{
- AdminPassword: "adminPassword",
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes))
- w := httptest.NewRecorder()
- c.contextHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 401 {
- t.Fatalf("status code is not 401: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("got read error: %s", err)
- }
-
- var contextSetupResponse ContextSetupResponse
- err = json.Unmarshal(body, &contextSetupResponse)
- if err != nil {
- t.Fatalf("unmarshal error: %s", err)
- }
- if contextSetupResponse.SetupCompleted {
- t.Fatalf("expected setup to not be completed")
- }
-}
-func TestContextHandlerSetupWrongSecretPartial(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
-
- storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`))
-
- userStore, err := users.NewUserStore(storage, -1)
- if err != nil {
- t.Fatalf("new user store error")
- }
- c, err := getEmptyContext("appdir")
- if err != nil {
- t.Fatalf("cannot create empty context")
- }
- c.Storage = &Storage{Client: storage}
- c.UserStore = userStore
-
- payload := ContextRequest{
- Secret: "secret setup cod",
- AdminPassword: "adminPassword",
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes))
- w := httptest.NewRecorder()
- c.contextHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 401 {
- t.Fatalf("status code is not 401: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("got read error: %s", err)
- }
-
- var contextSetupResponse ContextSetupResponse
- err = json.Unmarshal(body, &contextSetupResponse)
- if err != nil {
- t.Fatalf("unmarshal error: %s", err)
- }
- if contextSetupResponse.SetupCompleted {
- t.Fatalf("expected setup to not be completed")
- }
-}
-
-func TestContextHandlerSetupAWSInstanceID(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/latest/api/token" {
- w.Write([]byte("this is a test token"))
- return
- }
- if r.RequestURI == "/latest/meta-data/instance-id" {
- w.Write([]byte("i-012aaaaaaaaaaaaa1"))
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- }))
- defer ts.Close()
- license.MetadataIP = strings.TrimPrefix(ts.URL, "http://")
-
- storage := &memorystorage.MockMemoryStorage{}
-
- storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`))
-
- userStore, err := users.NewUserStore(storage, -1)
- if err != nil {
- t.Fatalf("new user store error")
- }
- c, err := getEmptyContext("appdir")
- if err != nil {
- t.Fatalf("cannot create empty context")
- }
- c.Storage = &Storage{Client: storage}
- c.UserStore = userStore
- c.CloudType = "aws"
-
- payload := ContextRequest{
- InstanceID: "i-012aaaaaaaaaaaaa1",
- AdminPassword: "adminPassword",
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes))
- w := httptest.NewRecorder()
- c.contextHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("got read error: %s", err)
- }
-
- var contextSetupResponse ContextSetupResponse
- err = json.Unmarshal(body, &contextSetupResponse)
- if err != nil {
- t.Fatalf("unmarshal error: %s", err)
- }
- if !contextSetupResponse.SetupCompleted {
- t.Fatalf("expected setup to be completed")
- }
-}
-func TestContextHandlerSetupDigitalOceanTag(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI == "/metadata/v1/tags" {
- w.Write([]byte("vpnsecret-this-is-a-secret-tag"))
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- }))
- defer ts.Close()
- license.MetadataIP = strings.TrimPrefix(ts.URL, "http://")
-
- storage := &memorystorage.MockMemoryStorage{}
-
- storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`))
-
- userStore, err := users.NewUserStore(storage, -1)
- if err != nil {
- t.Fatalf("new user store error")
- }
- c, err := getEmptyContext("appdir")
- if err != nil {
- t.Fatalf("cannot create empty context")
- }
- c.Storage = &Storage{Client: storage}
- c.UserStore = userStore
- c.CloudType = "digitalocean"
-
- payload := ContextRequest{
- TagHash: "vpnsecret-this-is-a-secret-tag",
- AdminPassword: "adminPassword",
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("marshal error: %s", err)
- }
- req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes))
- w := httptest.NewRecorder()
- c.contextHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("got read error: %s", err)
- }
-
- var contextSetupResponse ContextSetupResponse
- err = json.Unmarshal(body, &contextSetupResponse)
- if err != nil {
- t.Fatalf("unmarshal error: %s", err)
- }
- if !contextSetupResponse.SetupCompleted {
- t.Fatalf("expected setup to be completed")
- }
-}
diff --git a/pkg/rest/signals.go b/pkg/rest/signals.go
deleted file mode 100644
index 8e3dee5..0000000
--- a/pkg/rest/signals.go
+++ /dev/null
@@ -1,38 +0,0 @@
-package rest
-
-import (
- "fmt"
- "log"
- "os"
- "os/signal"
- "path"
- "syscall"
-)
-
-func handleSignals(c *Context) {
- // write pid file so other process can find it
- err := os.WriteFile(path.Join(c.AppDir, "rest-server.pid"), []byte(fmt.Sprintf("%d", os.Getpid())), 0664)
- if err != nil {
- log.Printf("Could not write pid file\n")
- }
- signalChannel := make(chan os.Signal, 1)
- signal.Notify(signalChannel, syscall.SIGHUP)
- for sig := range signalChannel {
- switch sig {
- case syscall.SIGHUP:
- c.ReloadConfig()
- }
- }
-}
-
-func (c *Context) ReloadConfig() {
- newC, err := newContext(c.Storage.Client, SERVER_TYPE_VPN)
- if err != nil {
- log.Printf("ReloadConfig failed: %s\n", err)
- }
- c.AppDir = newC.AppDir
- c.Hostname = newC.Hostname
- c.SetupCompleted = newC.SetupCompleted
- c.UserStore = newC.UserStore
- log.Printf("Config Reloaded!\n")
-}
diff --git a/pkg/rest/tls.go b/pkg/rest/tls.go
deleted file mode 100644
index 5eb94b0..0000000
--- a/pkg/rest/tls.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package rest
-
-import (
- "log"
- "strings"
-)
-
-func canEnableTLS(hostname string) bool {
- hostnameSplit := strings.Split(hostname, ":")
- if hostnameSplit[0] != "localhost" {
- return true
- } else {
- log.Printf("Not enabling TLS with lets encrypt. Hostname is localhost")
- }
-
- return false
-}
diff --git a/pkg/rest/types.go b/pkg/rest/types.go
deleted file mode 100644
index d3c54c1..0000000
--- a/pkg/rest/types.go
+++ /dev/null
@@ -1,237 +0,0 @@
-package rest
-
-import (
- "time"
-
- "github.com/in4it/wireguard-server/pkg/auth/oidc"
- oidcstore "github.com/in4it/wireguard-server/pkg/auth/oidc/store"
- oidcrenewal "github.com/in4it/wireguard-server/pkg/auth/oidc/store/renewal"
- "github.com/in4it/wireguard-server/pkg/auth/provisioning/scim"
- "github.com/in4it/wireguard-server/pkg/auth/saml"
- "github.com/in4it/wireguard-server/pkg/observability"
- "github.com/in4it/wireguard-server/pkg/rest/login"
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
-)
-
-const SETUP_CODE_FILE = "setup-code.txt"
-const ADMIN_USER = "admin"
-
-type Context struct {
- AppDir string `json:"appDir,omitempty"`
- ServerType string `json:"serverType,omitempty"`
- SetupCompleted bool `json:"setupCompleted"`
- Hostname string `json:"hostname,omitempty"`
- Protocol string `json:"protocol,omitempty"`
- JWTKeys *JWTKeys `json:"jwtKeys,omitempty"`
- JWTKeysKID string `json:"jwtKeysKid,omitempty"`
- OIDCProviders []oidc.OIDCProvider `json:"oidcProviders,omitempty"`
- LocalAuthDisabled bool `json:"disableLocalAuth,omitempty"`
- EnableTLS bool `json:"enableTLS,omitempty"`
- RedirectToHttps bool `json:"redirectToHttps,omitempty"`
- EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal,omitempty"`
- OIDCStore *oidcstore.Store `json:"oidcStore,omitempty"`
- UserStore *users.UserStore `json:"users,omitempty"`
- OIDCRenewal *oidcrenewal.Renewal `json:"oidcRenewal,omitempty"`
- LoginAttempts login.Attempts `json:"loginAttempts,omitempty"`
- LicenseUserCount int `json:"licenseUserCount,omitempty"`
- CloudType string `json:"cloudType,omitempty"`
- TokenRenewalTimeMinutes int `json:"tokenRenewalTimeMinutes,omitempty"`
- LogLevel int `json:"loglevel,omitempty"`
- SCIM *SCIM `json:"scim,omitempty"`
- SAML *SAML `json:"saml,omitempty"`
- Observability *Observability `json:"observability,omitempty"`
- Storage *Storage `json:"storage,omitempty"`
-}
-type SCIM struct {
- EnableSCIM bool `json:"enableSCIM,omitempty"`
- Token string `json:"token"`
- Client scim.Iface `json:"client,omitempty"`
-}
-type SAML struct {
- Providers *[]saml.Provider `json:"providers"`
- Client saml.Iface `json:"client,omitempty"`
-}
-type Observability struct {
- Client observability.Iface `json:"client,omitempty"`
-}
-type Storage struct {
- Client storage.Iface `json:"client,omitempty"`
-}
-
-type ContextRequest struct {
- Secret string `json:"secret"`
- TagHash string `json:"tagHash"`
- InstanceID string `json:"instanceID"`
- AdminPassword string `json:"adminPassword"`
- Hostname string `json:"hostname"`
- Protocol string `json:"protocol"`
-}
-type ContextSetupResponse struct {
- SetupCompleted bool `json:"setupCompleted"`
- CloudType string `json:"cloudType"`
- ServerType string `json:"serverType"`
-}
-
-type AuthMethodsResponse struct {
- LocalAuthDisabled bool `json:"localAuthDisabled"`
- OIDCProviders []AuthMethodsProvider `json:"oidcProviders"`
-}
-
-type AuthMethodsProvider struct {
- ID string `json:"id"`
- Name string `json:"name"`
- RedirectURI string `json:"redirectURI,omitempty"`
-}
-
-type OIDCCallback struct {
- Code string `json:"code"`
- State string `json:"state"`
- RedirectURI string `json:"redirectURI"`
-}
-type SAMLCallback struct {
- Code string `json:"code"`
- RedirectURI string `json:"redirectURI"`
-}
-
-type UserInfoResponse struct {
- Login string `json:"login"`
- Role string `json:"role"`
- UserType string `json:"userType"`
-}
-
-type GeneralSetupRequest struct {
- Hostname string `json:"hostname"`
- EnableTLS bool `json:"enableTLS"`
- RedirectToHttps bool `json:"redirectToHttps"`
- DisableLocalAuth bool `json:"disableLocalAuth"`
- EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal"`
-}
-
-type VPNSetupRequest struct {
- Routes string `json:"routes"`
- VPNEndpoint string `json:"vpnEndpoint"`
- AddressRange string `json:"addressRange"`
- ClientAddressPrefix string `json:"clientAddressPrefix"`
- Port string `json:"port"`
- ExternalInterface string `json:"externalInterface"`
- Nameservers string `json:"nameservers"`
- DisableNAT bool `json:"disableNAT"`
- EnablePacketLogs bool `json:"enablePacketLogs"`
- PacketLogsTypes []string `json:"packetLogsTypes"`
- PacketLogsRetention string `json:"packetLogsRetention"`
-}
-
-type TemplateSetupRequest struct {
- ClientTemplate string `json:"clientTemplate"`
- ServerTemplate string `json:"serverTemplate"`
-}
-
-type NewConnectionResponse struct {
- Name string `json:"name"`
-}
-type Connection struct {
- ID string `json:"id"`
- Name string `json:"name"`
-}
-
-type LicenseResponse struct {
- LicenseUserCount int `json:"licenseUserCount"`
- CurrentUserCount int `json:"currentUserCount,omitempty"`
- CloudType string `json:"cloudType"`
- Key string `json:"key,omitempty"`
-}
-
-type ConnectionLicenseResponse struct {
- LicenseUserCount int `json:"licenseUserCount"`
- ConnectionCount int `json:"connectionCount"`
-}
-
-type JwtHeader struct {
- Alg string `json:"alg"`
- Typ string `json:"typ"`
- Kid string `json:"kid"`
-}
-
-type UsersResponse struct {
- ID string `json:"id"`
- Login string `json:"login"`
- Role string `json:"role"`
- OIDCID string `json:"oidcID"`
- SAMLID string `json:"samlID"`
- Provisioned bool `json:"provisioned"`
- Suspended bool `json:"suspended"`
- ConnectionsDisabledOnAuthFailure bool `json:"connectionsDisabledOnAuthFailure"`
- LastTokenRenewal time.Time `json:"lastTokenRenewal,omitempty"`
- LastLogin string `json:"lastLogin"`
-}
-
-type FactorRequest struct {
- Name string `json:"name"`
- Type string `json:"type"`
- Secret string `json:"secret"`
- Code string `json:"code"`
-}
-
-type SCIMSetup struct {
- Enabled bool `json:"enabled"`
- Token string `json:"token,omitempty"`
- RegenerateToken bool `json:"regenerateToken,omitempty"`
- BaseURL string `json:"baseURL,omitempty"`
-}
-
-type SAMLSetup struct {
- Enabled bool `json:"enabled"`
- MetadataURL string `json:"metadataURL,omitempty"`
- RegenerateCert bool `json:"regenerateCert,omitempty"`
-}
-
-type UserStatsResponse struct {
- ReceiveBytes UserStatsData `json:"receivedBytes"`
- TransmitBytes UserStatsData `json:"transmitBytes"`
- Handshakes UserStatsData `json:"handshakes"`
-}
-type UserStatsData struct {
- Datasets UserStatsDatasets `json:"datasets"`
-}
-type UserStatsDatasets []UserStatsDataset
-type UserStatsDataset struct {
- Label string `json:"label"`
- Data []UserStatsDataPoint `json:"data"`
- Fill bool `json:"fill"`
- BorderColor string `json:"borderColor"`
- BackgroundColor string `json:"backgroundColor"`
- Tension float64 `json:"tension"`
- ShowLine bool `json:"showLine"`
-}
-
-type UserStatsDataPoint struct {
- X string `json:"x"`
- Y float64 `json:"y"`
-}
-
-type NewUserRequest struct {
- Login string `json:"login"`
- Role string `json:"role"`
- Password string `json:"password,omitempty"`
-}
-
-type LogDataResponse struct {
- LogData LogData `json:"logData"`
- Enabled bool `json:"enabled"`
- LogTypes []string `json:"logTypes"`
- Users map[string]string `json:"users"`
-}
-
-type LogData struct {
- Schema LogSchema `json:"schema"`
- Data []LogRow `json:"rows"`
- NextPos int64 `json:"nextPos"`
-}
-type LogSchema struct {
- Columns map[string]string `json:"columns"`
-}
-type LogRow struct {
- Timestamp string `json:"t"`
- Data []string `json:"d"`
-}
diff --git a/pkg/rest/users.go b/pkg/rest/users.go
deleted file mode 100644
index 2dcdcc5..0000000
--- a/pkg/rest/users.go
+++ /dev/null
@@ -1,275 +0,0 @@
-package rest
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/in4it/wireguard-server/pkg/storage"
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-func (c *Context) GetUserFromRequest(r *http.Request) (users.User, error) {
- claims := r.Context().Value(CustomValue("claims")).(jwt.MapClaims)
- sub, ok := claims["sub"]
- if !ok {
- return users.User{}, fmt.Errorf("userinfoHandler: subject not found in token")
- }
- iss, ok := claims["iss"]
- if !ok {
- return users.User{}, fmt.Errorf("userinfoHandler: issuer not found in token")
- }
-
- kid, ok := claims["kid"]
- if !ok {
- return users.User{}, fmt.Errorf("userinfoHandler: kid not found in token")
- }
-
- if kid == c.JWTKeysKID {
- user, err := c.UserStore.GetUserByLogin(sub.(string))
- if err != nil {
- return users.User{}, fmt.Errorf("GetUserByLogin: user not found")
- }
- return user, nil
- } else { // user comes from oidc
- oauth2DataIDs := []string{}
- for _, oauth2Data := range c.OIDCStore.OAuth2Data {
- if oauth2Data.Issuer == iss && oauth2Data.Subject == sub {
- oauth2DataIDs = append(oauth2DataIDs, oauth2Data.ID)
- }
- }
- if len(oauth2DataIDs) == 0 {
- return users.User{}, fmt.Errorf("userinfoHandler: couldn't find user in oidc database")
- }
- user, err := c.UserStore.GetUserByOIDCIDs(oauth2DataIDs)
- if err != nil {
- return user, fmt.Errorf("get user by oidc id failed: %s", err)
- }
- return user, nil
- }
-}
-
-func (c *Context) usersHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- users := c.UserStore.ListUsers()
- userResponse := make([]UsersResponse, len(users))
- for k, user := range users {
- userResponse[k].ID = user.ID
- userResponse[k].Login = user.Login
- userResponse[k].Role = user.Role
- userResponse[k].OIDCID = user.OIDCID
- userResponse[k].SAMLID = user.SAMLID
- userResponse[k].Suspended = user.Suspended
- userResponse[k].Provisioned = user.Provisioned
- userResponse[k].ConnectionsDisabledOnAuthFailure = user.ConnectionsDisabledOnAuthFailure
- if !user.LastLogin.IsZero() {
- userResponse[k].LastLogin = user.LastLogin.UTC().Format(time.RFC3339)
- }
- for _, oauth2Data := range c.OIDCStore.OAuth2Data {
- if oauth2Data.ID == user.OIDCID {
- userResponse[k].LastTokenRenewal = oauth2Data.LastTokenRenewal
- }
- }
- }
- out, err := json.Marshal(userResponse)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- var user NewUserRequest
- decoder := json.NewDecoder(r.Body)
- err := decoder.Decode(&user)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- if !isAlphaNumeric(user.Login) {
- c.returnError(w, fmt.Errorf("login not valid"), http.StatusBadRequest)
- return
- }
- if user.Login == "" {
- c.returnError(w, fmt.Errorf("login is empty"), http.StatusBadRequest)
- return
- }
- if user.Password == "" {
- c.returnError(w, fmt.Errorf("password is empty"), http.StatusBadRequest)
- return
- }
- if user.Role != "user" && user.Role != "admin" {
- c.returnError(w, fmt.Errorf("invalid role"), http.StatusBadRequest)
- return
- }
- if c.UserStore.UserCount() >= c.LicenseUserCount {
- c.returnError(w, fmt.Errorf("no more licenses available"), http.StatusBadRequest)
- return
- }
-
- newUser, err := c.UserStore.AddUser(users.User{Login: user.Login, Password: user.Password, Role: user.Role})
- if err != nil {
- c.returnError(w, fmt.Errorf("add user error: %s", err), http.StatusBadRequest)
- return
- }
- out, err := json.Marshal(newUser)
- if err != nil {
- c.returnError(w, fmt.Errorf("new user marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) userHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodDelete:
- userID := r.PathValue("id")
- err := c.UserStore.DeleteUserByID(userID)
- if err != nil {
- c.returnError(w, fmt.Errorf("delete user error: %s", err), http.StatusBadRequest)
- return
- }
- err = wireguard.DeleteAllClientConfigs(c.Storage.Client, userID)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", userID, err), http.StatusBadRequest)
- return
- }
- c.write(w, []byte(`{"deleted": "`+userID+`"}`))
- case http.MethodPatch:
- dbUser, err := c.UserStore.GetUserByID(r.PathValue("id"))
- if err != nil {
- c.returnError(w, fmt.Errorf("user not found: %s", err), http.StatusBadRequest)
- return
- }
- var user users.User
- decoder := json.NewDecoder(r.Body)
- err = decoder.Decode(&user)
- if err != nil {
- c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest)
- return
- }
- updateUser := false
- if user.Role != "" && dbUser.Role != user.Role {
- dbUser.Role = user.Role
- updateUser = true
- }
- if dbUser.Suspended != user.Suspended {
- dbUser.Suspended = user.Suspended
- updateUser = true
- if user.Suspended { // user is now suspended
- err := wireguard.DisableAllClientConfigs(c.Storage.Client, user.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", user.ID, err), http.StatusBadRequest)
- return
- }
- } else { // user is now unsuspended
- err := wireguard.ReactivateAllClientConfigs(c.Storage.Client, user.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not reactivate all clients for user %s: %s", user.ID, err), http.StatusBadRequest)
- return
- }
- }
- }
- if updateUser {
- err = c.UserStore.UpdateUser(dbUser)
- if err != nil {
- c.returnError(w, fmt.Errorf("update user error: %s", err), http.StatusBadRequest)
- return
- }
- }
- if user.Password != "" {
- err = c.UserStore.UpdatePassword(user.ID, user.Password)
- if err != nil {
- c.returnError(w, fmt.Errorf("update password error: %s", err), http.StatusBadRequest)
- return
- }
- }
- out, err := json.Marshal(dbUser)
- if err != nil {
- c.returnError(w, fmt.Errorf("marshal dbuser error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func addOrModifyExternalUser(storage storage.Iface, userStore *users.UserStore, login, authType, externalAuthID string) (users.User, error) {
- if userStore.LoginExists(login) {
- existingUser, err := userStore.GetUserByLogin(login)
- if err != nil {
- return existingUser, fmt.Errorf("couldn't find existing user in database: %s", login)
- }
-
- if authType == "oidc" {
- existingUser.OIDCID = externalAuthID
- }
- if authType == "saml" {
- existingUser.SAMLID = externalAuthID
- }
-
- if existingUser.ConnectionsDisabledOnAuthFailure { // we can enable connections again after auth
- err := wireguard.ReactivateAllClientConfigs(storage, existingUser.ID)
- if err != nil {
- return existingUser, fmt.Errorf("could not reactivate all clients for user %s: %s", existingUser.ID, err)
- }
- existingUser.ConnectionsDisabledOnAuthFailure = false
- }
-
- existingUser.LastLogin = time.Now()
-
- err = userStore.UpdateUser(existingUser)
- if err != nil {
- return existingUser, fmt.Errorf("couldn't update user: %s", login)
- }
- return existingUser, nil
- } else {
- newUser := users.User{
- Login: login,
- Role: "user",
- }
- if authType == "oidc" {
- newUser.OIDCID = externalAuthID
- }
- if authType == "saml" {
- newUser.SAMLID = externalAuthID
- }
-
- newUser.LastLogin = time.Now()
-
- newUserAdded, err := userStore.AddUser(newUser)
- if err != nil {
- return newUserAdded, fmt.Errorf("could not add user: %s", err)
- }
- return newUserAdded, nil
- }
-}
-
-func (c *Context) userinfoHandler(w http.ResponseWriter, r *http.Request) {
- var response UserInfoResponse
-
- user := r.Context().Value(CustomValue("user")).(users.User)
-
- response.Login = user.Login
- response.Role = user.Role
- if user.OIDCID == "" {
- response.UserType = "local"
- } else {
- response.UserType = "oidc"
- }
-
- out, err := json.Marshal(response)
- if err != nil {
- c.returnError(w, fmt.Errorf("cannot marshal userinfo response: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
-
-}
diff --git a/pkg/rest/users_test.go b/pkg/rest/users_test.go
deleted file mode 100644
index c516691..0000000
--- a/pkg/rest/users_test.go
+++ /dev/null
@@ -1,196 +0,0 @@
-package rest
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/http/httptest"
- "path"
- "strings"
- "testing"
-
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
- l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI)
- if err != nil {
- t.Fatal(err)
- }
-
- ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodPost:
- if r.RequestURI == "/refresh-clients" {
- w.WriteHeader(http.StatusAccepted)
- w.Write([]byte("OK"))
- return
- }
- if r.RequestURI == "/refresh-server-config" {
- w.WriteHeader(http.StatusAccepted)
- w.Write([]byte("OK"))
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- default:
- w.WriteHeader(http.StatusBadRequest)
- }
- }))
-
- ts.Listener.Close()
- ts.Listener = l
- ts.Start()
- defer ts.Close()
- defer l.Close()
-
- // first create a new user
- storage := &memorystorage.MockMemoryStorage{}
-
- c, err := newContext(storage, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context")
- }
-
- err = c.UserStore.Empty()
- if err != nil {
- t.Fatalf("Cannot create context")
- }
-
- // create a user
- user := users.User{
- Login: "john",
- Role: "user",
- Password: "xyz",
- }
- payload, err := json.Marshal(user)
- if err != nil {
- t.Fatalf("Cannot create payload: %s", err)
- }
- req := httptest.NewRequest("POST", "http://example.com/users", bytes.NewBuffer(payload))
- w := httptest.NewRecorder()
- c.usersHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- err = json.NewDecoder(resp.Body).Decode(&user)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
- // generate VPN config
- _, err = wireguard.CreateNewVPNConfig(c.Storage.Client)
- if err != nil {
- t.Fatalf("Cannot create vpn config: %s", err)
- }
-
- req = httptest.NewRequest("POST", "http://example.com/connections", nil)
- w = httptest.NewRecorder()
- c.connectionsHandler(w, req.WithContext(context.WithValue(context.Background(), CustomValue("user"), user)))
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- connectionID := fmt.Sprintf("%s-1", user.ID)
-
- userConfigFilename := storage.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connectionID+".json"))
- configBytes, err := storage.ReadFile(userConfigFilename)
- if err != nil {
- t.Fatalf("could not read user config file")
- }
-
- var config wireguard.PeerConfig
- err = json.Unmarshal(configBytes, &config)
- if err != nil {
- t.Fatalf("could not parse config: %s", err)
- }
- if config.Disabled {
- t.Fatalf("VPN connection is disabled. Expected not disabled")
- }
-
- req = httptest.NewRequest("GET", "http://example.com/connection/"+connectionID, nil)
- req.SetPathValue("id", connectionID)
- w = httptest.NewRecorder()
- c.connectionsElementHandler(w, req.WithContext(context.WithValue(context.Background(), CustomValue("user"), user)))
-
- resp = w.Result()
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("readall error: %s", err)
- }
- if !strings.Contains(string(body), "[Interface]") {
- t.Fatalf("output doesn't look like a wireguard client config: %s", body)
- }
-
- req = httptest.NewRequest("DELETE", "http://example.com/user/"+user.ID, nil)
- req.SetPathValue("id", user.ID)
- w = httptest.NewRecorder()
- c.userHandler(w, req)
-
- resp = w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- _, err = storage.ReadFile(userConfigFilename)
- if err == nil {
- t.Fatalf("could read user config file, expected not to")
- }
-}
-
-func TestCreateUser(t *testing.T) {
- // first create a new user
- storage := &memorystorage.MockMemoryStorage{}
-
- c, err := newContext(storage, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context")
- }
-
- err = c.UserStore.Empty()
- if err != nil {
- t.Fatalf("Cannot create context")
- }
-
- // create a user
- payload := []byte(`{"id": "", "login": "testuser", "password": "tttt213", "role": "user", "oidcID": "", "samlID": "", "lastLogin": "", "provisioned": false, "role":"user","samlID":"","suspended":false}`)
- req := httptest.NewRequest("POST", "http://example.com/users", bytes.NewBuffer(payload))
- w := httptest.NewRecorder()
- c.usersHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 200 {
- t.Fatalf("status code is not 200: %d", resp.StatusCode)
- }
-
- defer resp.Body.Close()
-
- var user users.User
- err = json.NewDecoder(resp.Body).Decode(&user)
- if err != nil {
- t.Fatalf("Cannot decode response from create user: %s", err)
- }
-
-}
diff --git a/pkg/rest/version.go b/pkg/rest/version.go
deleted file mode 100644
index c4b9e70..0000000
--- a/pkg/rest/version.go
+++ /dev/null
@@ -1,67 +0,0 @@
-package rest
-
-import (
- _ "embed"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "time"
-
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-//go:generate cp -r ../../latest ./resources/version
-//go:embed resources/version
-
-var version string
-
-func (c *Context) version(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- out, err := json.Marshal(map[string]string{"version": strings.TrimSpace(version)})
- if err != nil {
- c.returnError(w, fmt.Errorf("version marshal error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-
-func (c *Context) upgrade(w http.ResponseWriter, r *http.Request) {
- client := http.Client{
- Timeout: 10 * time.Second,
- }
- req, err := http.NewRequest(r.Method, "http://"+wireguard.CONFIGMANAGER_URI+"/upgrade", nil)
- if err != nil {
- c.returnError(w, fmt.Errorf("upgrade request error: %s", err), http.StatusBadRequest)
- return
- }
- resp, err := client.Do(req)
- if err != nil {
- c.returnError(w, fmt.Errorf("upgrade error: %s", err), http.StatusBadRequest)
- return
- }
- if resp.StatusCode != http.StatusOK {
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.returnError(w, fmt.Errorf("upgrade error: got status code: %d. Respons: %s", resp.StatusCode, bodyBytes), http.StatusBadRequest)
- return
- }
- c.returnError(w, fmt.Errorf("upgrade error: got status code: %d. Couldn't get response", resp.StatusCode), http.StatusBadRequest)
- return
- }
-
- defer resp.Body.Close()
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.returnError(w, fmt.Errorf("body read error: %s", err), http.StatusBadRequest)
- return
- }
-
- c.write(w, bodyBytes)
-
-}
diff --git a/pkg/rest/vpn.go b/pkg/rest/vpn.go
deleted file mode 100644
index afdd92d..0000000
--- a/pkg/rest/vpn.go
+++ /dev/null
@@ -1,118 +0,0 @@
-package rest
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "path"
- "strings"
-
- "github.com/in4it/wireguard-server/pkg/users"
- "github.com/in4it/wireguard-server/pkg/wireguard"
-)
-
-func (c *Context) connectionsHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- user := r.Context().Value(CustomValue("user")).(users.User)
-
- clients, err := c.Storage.Client.ReadDir(c.Storage.Client.ConfigPath(wireguard.VPN_CLIENTS_DIR))
- if err != nil {
- c.returnError(w, fmt.Errorf("cannot list connections for user: %s", err), http.StatusBadRequest)
- return
- }
-
- connectionList := []string{}
- for _, clientFilename := range clients {
- if wireguard.HasClientUserID(clientFilename, user.ID) {
- connectionList = append(connectionList, clientFilename)
- }
- }
- peerConfigs := make([]wireguard.PeerConfig, len(connectionList))
- for k, connection := range connectionList {
- var peerConfig wireguard.PeerConfig
- filename := c.Storage.Client.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connection))
- toDeleteFileContents, err := c.Storage.Client.ReadFile(filename)
- if err != nil {
- c.returnError(w, fmt.Errorf("can't read file %s: %s", filename, err), http.StatusBadRequest)
- return
- }
- err = json.Unmarshal(toDeleteFileContents, &peerConfig)
- if err != nil {
- c.returnError(w, fmt.Errorf("can't unmarshal file %s: %s", filename, err), http.StatusBadRequest)
- return
- }
- peerConfigs[k] = peerConfig
- }
- connections := make([]Connection, len(peerConfigs))
- for k := range peerConfigs {
- connections[k] = Connection{
- ID: peerConfigs[k].ID,
- Name: peerConfigs[k].Name,
- }
- }
- out, err := json.Marshal(connections)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal list connection response: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodPost:
- muClientDownload.Lock()
- defer muClientDownload.Unlock()
- user := r.Context().Value(CustomValue("user")).(users.User)
- peerConfig, err := wireguard.NewEmptyClientConfig(c.Storage.Client, user.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not generate client vpn config: %s", err), http.StatusBadRequest)
- return
- }
- newConnectionResponse := NewConnectionResponse{Name: peerConfig.Name}
- out, err := json.Marshal(newConnectionResponse)
- if err != nil {
- c.returnError(w, fmt.Errorf("could not marshal new connection response: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
-func (c *Context) connectionsElementHandler(w http.ResponseWriter, r *http.Request) {
- switch r.Method {
- case http.MethodGet:
- user := r.Context().Value(CustomValue("user")).(users.User)
- if !strings.HasPrefix(r.PathValue("id"), user.ID) {
- c.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest)
- return
- }
- if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") {
- c.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest)
- return
- }
- out, err := wireguard.GenerateNewClientConfig(c.Storage.Client, r.PathValue("id"), user.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("GetClientConfig error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, out)
- case http.MethodDelete:
- user := r.Context().Value(CustomValue("user")).(users.User)
- if !strings.HasPrefix(r.PathValue("id"), user.ID) {
- c.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest)
- return
- }
- if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") {
- c.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest)
- return
- }
- err := wireguard.DeleteClientConfig(c.Storage.Client, r.PathValue("id"), user.ID)
- if err != nil {
- c.returnError(w, fmt.Errorf("DeleteClientConfig error: %s", err), http.StatusBadRequest)
- return
- }
- c.write(w, []byte(`{"deleted": "`+r.PathValue("id")+`"}`))
-
- default:
- c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
- }
-}
diff --git a/pkg/storage/iface.go b/pkg/storage/iface.go
deleted file mode 100644
index f412fff..0000000
--- a/pkg/storage/iface.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package storage
-
-import (
- "io"
- "io/fs"
-)
-
-type Iface interface {
- GetPath() string
- EnsurePath(path string) error
- EnsureOwnership(filename, login string) error
- ReadDir(name string) ([]string, error)
- Remove(name string) error
- Rename(oldName, newName string) error
- AppendFile(name string, data []byte) error
- EnsurePermissions(name string, mode fs.FileMode) error
- FileInfo(name string) (fs.FileInfo, error)
- ReadWriter
- Seeker
-}
-
-type ReadWriter interface {
- ReadFile(name string) ([]byte, error)
- WriteFile(name string, data []byte) error
- FileExists(filename string) bool
- ConfigPath(filename string) string
- OpenFile(name string) (io.ReadCloser, error)
- OpenFileForWriting(name string) (io.WriteCloser, error)
- OpenFileForAppending(name string) (io.WriteCloser, error)
-}
-
-type Seeker interface {
- OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error)
-}
diff --git a/pkg/storage/local/constants.go b/pkg/storage/local/constants.go
deleted file mode 100644
index 938df52..0000000
--- a/pkg/storage/local/constants.go
+++ /dev/null
@@ -1,4 +0,0 @@
-package localstorage
-
-const CONFIG_PATH = "config"
-const VPN_CLIENTS_DIR = "clients"
diff --git a/pkg/storage/local/new.go b/pkg/storage/local/new.go
deleted file mode 100644
index be9b35c..0000000
--- a/pkg/storage/local/new.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package localstorage
-
-import (
- "fmt"
- "os"
- "path"
-)
-
-func New() (*LocalStorage, error) {
- pwd, err := os.Executable()
- if err != nil {
- return nil, fmt.Errorf("os Executable error: %s", err)
- }
-
- pathname := path.Dir(pwd)
- storage, err := NewWithPath(pathname)
- if err != nil {
- return storage, err
- }
- err = storage.EnsurePath(CONFIG_PATH)
- if err != nil {
- return nil, fmt.Errorf("cannot create storage directories: %s", err)
- }
- err = storage.EnsurePath(path.Join(CONFIG_PATH, VPN_CLIENTS_DIR))
- if err != nil {
- return nil, fmt.Errorf("cannot create storage directories: %s", err)
- }
- err = storage.EnsureOwnership(CONFIG_PATH, "vpn")
- if err != nil {
- return nil, fmt.Errorf("cannot ensure vpn ownership of config directory: %s", err)
- }
- err = storage.EnsureOwnership(path.Join(CONFIG_PATH, VPN_CLIENTS_DIR), "vpn")
- if err != nil {
- return nil, fmt.Errorf("cannot ensure vpn ownership of config directory: %s", err)
- }
-
- return NewWithPath(pathname)
-}
-
-func NewWithPath(pathname string) (*LocalStorage, error) {
- return &LocalStorage{
- path: pathname,
- }, nil
-}
diff --git a/pkg/storage/local/path.go b/pkg/storage/local/path.go
deleted file mode 100644
index 214b341..0000000
--- a/pkg/storage/local/path.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package localstorage
-
-import (
- "errors"
- "fmt"
- "io/fs"
- "os"
- "os/user"
- "path"
- "strconv"
-
- "github.com/in4it/wireguard-server/pkg/logging"
-)
-
-func (l *LocalStorage) FileExists(filename string) bool {
- if _, err := os.Stat(path.Join(l.path, filename)); errors.Is(err, os.ErrNotExist) {
- return false
- }
- return true
-}
-
-func (l *LocalStorage) ConfigPath(filename string) string {
- return path.Join(CONFIG_PATH, filename)
-}
-
-func (l *LocalStorage) GetPath() string {
- return l.path
-}
-
-func (l *LocalStorage) EnsurePath(pathname string) error {
- fullPathname := path.Join(l.path, pathname)
- if _, err := os.Stat(fullPathname); errors.Is(err, os.ErrNotExist) {
- err := os.Mkdir(fullPathname, 0700)
- if err != nil {
- return fmt.Errorf("create directory error: %s", err)
- }
- }
- return nil
-}
-
-func (l *LocalStorage) EnsureOwnership(filename, login string) error {
- currentUser, err := user.Current()
- if err != nil {
- return fmt.Errorf("could not get current user: %s", err)
- }
- if currentUser.Username != "root" {
- logging.DebugLog(fmt.Errorf("cannot ensure ownership of file %s when not user root (current user: %s)", filename, currentUser.Username))
- return nil
- }
- vpnUser, err := user.Lookup(login)
- if err != nil {
- return fmt.Errorf("user lookup error (vpn): %s", err)
- }
- vpnUserUid, err := strconv.Atoi(vpnUser.Uid)
- if err != nil {
- return fmt.Errorf("user lookup error (uid): %s", err)
- }
- vpnUserGid, err := strconv.Atoi(vpnUser.Gid)
- if err != nil {
- return fmt.Errorf("user lookup error (gid): %s", err)
- }
-
- err = os.Chown(path.Join(l.path, filename), vpnUserUid, vpnUserGid)
- if err != nil {
- return fmt.Errorf("vpn chown error: %s", err)
- }
- return nil
-}
-
-func (l *LocalStorage) ReadDir(pathname string) ([]string, error) {
- res, err := os.ReadDir(path.Join(l.path, pathname))
- if err != nil {
- return []string{}, err
- }
- resNames := make([]string, len(res))
- for k, v := range res {
- resNames[k] = v.Name()
- }
- return resNames, nil
-}
-
-func (l *LocalStorage) Remove(name string) error {
- return os.Remove(path.Join(l.path, name))
-}
-
-func (l *LocalStorage) Rename(oldName, newName string) error {
- return os.Rename(path.Join(l.path, oldName), path.Join(l.path, newName))
-}
-
-func (l *LocalStorage) EnsurePermissions(name string, mode fs.FileMode) error {
- return os.Chmod(path.Join(l.path, name), mode)
-}
-
-func (l *LocalStorage) FileInfo(name string) (fs.FileInfo, error) {
- return os.Stat(name)
-}
diff --git a/pkg/storage/local/read.go b/pkg/storage/local/read.go
deleted file mode 100644
index 3a2e6ce..0000000
--- a/pkg/storage/local/read.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package localstorage
-
-import (
- "fmt"
- "io"
- "os"
- "path"
-)
-
-func (l *LocalStorage) ReadFile(name string) ([]byte, error) {
- return os.ReadFile(path.Join(l.path, name))
-}
-
-func (l *LocalStorage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) {
- readers := []io.ReadCloser{}
- if pos < 0 {
- return readers, nil
- }
- for _, name := range names {
- file, err := os.Open(path.Join(l.path, name))
- if err != nil {
- return nil, fmt.Errorf("cannot open file (%s): %s", name, err)
- }
- stat, err := file.Stat()
- if err != nil {
- return nil, fmt.Errorf("cannot get file stat (%s): %s", name, err)
- }
- if stat.Size() <= pos {
- pos -= stat.Size()
- } else {
- _, err := file.Seek(pos, 0)
- if err != nil {
- return nil, fmt.Errorf("could not seek to pos (file: %s): %s", name, err)
- }
- pos = 0
- readers = append(readers, file)
- }
- }
- return readers, nil
-}
-
-func (l *LocalStorage) OpenFile(name string) (io.ReadCloser, error) {
- file, err := os.Open(path.Join(l.path, name))
- if err != nil {
- return nil, fmt.Errorf("cannot open file (%s): %s", name, err)
- }
- return file, nil
-}
diff --git a/pkg/storage/local/read_test.go b/pkg/storage/local/read_test.go
deleted file mode 100644
index c40e6b4..0000000
--- a/pkg/storage/local/read_test.go
+++ /dev/null
@@ -1,91 +0,0 @@
-package localstorage
-
-import (
- "bytes"
- "io"
- "os"
- "path"
- "testing"
-)
-
-func TestOpenFilesFromPos(t *testing.T) {
- pwd, err := os.Executable()
- if err != nil {
- t.Fatalf("os Executable error: %s", err)
- }
- l := LocalStorage{
- path: path.Dir(pwd),
- }
- contents1 := []byte(`this is the first file`)
- contents2 := []byte(`this is the second file`)
- err = l.WriteFile("1.txt", contents1)
- if err != nil {
- t.Fatalf("write file error: %s", err)
- }
- err = l.WriteFile("2.txt", contents2)
- if err != nil {
- t.Fatalf("write file error: %s", err)
- }
- t.Cleanup(func() {
- err = os.Remove(path.Join(l.path, "1.txt"))
- if err != nil {
- t.Fatalf("file delete error: %s", err)
- }
- err = os.Remove(path.Join(l.path, "2.txt"))
- if err != nil {
- t.Fatalf("file delete error: %s", err)
- }
- })
- expected := []string{
- "this is the first filethis is the second file",
- "is the first filethis is the second file",
- "this is the second file",
- "ethis is the second file",
- "his is the second file",
- "",
- "",
- "",
- }
- expextedOpenFiles := []int{
- 2,
- 2,
- 1,
- 2,
- 1,
- 0,
- 0,
- 0,
- }
- tests := []int64{
- 0,
- 5,
- int64(len(contents1)),
- int64(len(contents1) - 1),
- int64(len(contents1) + 1),
- int64(len(contents1) + len(contents2)),
- int64(len(contents1) + len(contents2) + 1),
- -5,
- }
- for k, pos := range tests {
- files, err := l.OpenFilesFromPos([]string{"1.txt", "2.txt"}, pos)
- if err != nil {
- t.Fatalf("open file error: %s", err)
- }
- contents := bytes.NewBuffer([]byte{})
- for _, file := range files {
- defer file.Close()
- body, err := io.ReadAll(file)
- if err != nil {
- t.Fatalf("could not read file: %s", err)
- }
- contents.Write(body)
- }
- if expected[k] != contents.String() {
- t.Fatalf("unexpected output: expected '%s' got '%s'", expected[k], contents.String())
- }
- if expextedOpenFiles[k] != len(files) {
- t.Fatalf("unexpected open files: expected %d got %d", expextedOpenFiles[k], len(files))
- }
- }
-
-}
diff --git a/pkg/storage/local/types.go b/pkg/storage/local/types.go
deleted file mode 100644
index 103b254..0000000
--- a/pkg/storage/local/types.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package localstorage
-
-type LocalStorage struct {
- path string
-}
diff --git a/pkg/storage/local/write.go b/pkg/storage/local/write.go
deleted file mode 100644
index 5eb3d19..0000000
--- a/pkg/storage/local/write.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package localstorage
-
-import (
- "fmt"
- "io"
- "os"
- "path"
-)
-
-func (l *LocalStorage) WriteFile(name string, data []byte) error {
- return os.WriteFile(path.Join(l.path, name), data, 0600)
-}
-
-func (l *LocalStorage) AppendFile(name string, data []byte) error {
- f, err := os.OpenFile(path.Join(l.path, name), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0660)
- if err != nil {
- return err
- }
- defer f.Close()
- if _, err := f.Write(data); err != nil {
- return err
- }
-
- return nil
-}
-
-func (l *LocalStorage) OpenFileForWriting(name string) (io.WriteCloser, error) {
- file, err := os.Create(path.Join(l.path, name))
- if err != nil {
- return nil, fmt.Errorf("cannot open file (%s): %s", name, err)
- }
- return file, nil
-}
-
-func (l *LocalStorage) OpenFileForAppending(name string) (io.WriteCloser, error) {
- file, err := os.OpenFile(path.Join(l.path, name), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0660)
- if err != nil {
- return nil, err
- }
- return file, nil
-}
diff --git a/pkg/storage/memory/fileinfo.go b/pkg/storage/memory/fileinfo.go
deleted file mode 100644
index b9fc5dd..0000000
--- a/pkg/storage/memory/fileinfo.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package memorystorage
-
-import (
- "io/fs"
- "time"
-)
-
-type FileInfo struct {
- NameOut string // base name of the file
- SizeOut int64 // length in bytes for regular files; system-dependent for others
- ModeOut fs.FileMode // file mode bits
- ModTimeOut time.Time // modification time
- IsDirOut bool // abbreviation for Mode().IsDir()
- SysOut any // underlying data source (can return nil)
-}
-
-func (f FileInfo) Name() string {
- return f.NameOut
-}
-func (f FileInfo) Size() int64 {
- return f.SizeOut
-}
-func (f FileInfo) Mode() fs.FileMode {
- return f.ModeOut
-}
-func (f FileInfo) ModTime() time.Time {
- return f.ModTimeOut
-}
-func (f FileInfo) IsDir() bool {
- return f.IsDirOut
-}
-func (f FileInfo) Sys() any {
- return nil
-}
diff --git a/pkg/storage/memory/storage.go b/pkg/storage/memory/storage.go
deleted file mode 100644
index eb7a6c1..0000000
--- a/pkg/storage/memory/storage.go
+++ /dev/null
@@ -1,201 +0,0 @@
-package memorystorage
-
-import (
- "bytes"
- "fmt"
- "io"
- "io/fs"
- "os"
- "path"
- "strings"
- "sync"
-)
-
-type MockReadWriterData []byte
-
-func (m *MockReadWriterData) Close() error {
- return nil
-}
-func (m *MockReadWriterData) Write(p []byte) (nn int, err error) {
- *m = append(*m, p...)
- return len(p), nil
-}
-
-type MockMemoryStorage struct {
- FileInfoData map[string]*FileInfo
- Data map[string]*MockReadWriterData
- Mu sync.Mutex
-}
-
-func (m *MockMemoryStorage) ConfigPath(filename string) string {
- return path.Join("config", filename)
-}
-func (m *MockMemoryStorage) Rename(oldName, newName string) error {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- _, ok := m.Data[oldName]
- if !ok {
- return fmt.Errorf("file doesn't exist")
- }
- m.Data[newName] = m.Data[oldName]
- delete(m.Data, oldName)
- return nil
-}
-func (m *MockMemoryStorage) FileExists(name string) bool {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- _, ok := m.Data[name]
- return ok
-}
-
-func (m *MockMemoryStorage) ReadFile(name string) ([]byte, error) {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- val, ok := m.Data[name]
- if !ok {
- return nil, fmt.Errorf("file does not exist")
- }
- return *val, nil
-}
-func (m *MockMemoryStorage) WriteFile(name string, data []byte) error {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- m.Data[name] = (*MockReadWriterData)(&data)
- return nil
-}
-func (m *MockMemoryStorage) AppendFile(name string, data []byte) error {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- if m.Data[name] == nil {
- m.Data[name] = (*MockReadWriterData)(&data)
- } else {
- *m.Data[name] = append(*m.Data[name], data...)
- }
-
- return nil
-}
-
-func (m *MockMemoryStorage) GetPath() string {
- pwd, _ := os.Executable()
- return path.Dir(pwd)
-}
-
-func (m *MockMemoryStorage) EnsurePath(pathname string) error {
- return nil
-}
-
-func (m *MockMemoryStorage) EnsureOwnership(filename, login string) error {
- return nil
-}
-
-func (m *MockMemoryStorage) ReadDir(path string) ([]string, error) {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- res := []string{}
- for k := range m.Data {
- if path == "" {
- res = append(res, strings.TrimSuffix(k, "/"))
- } else if strings.HasPrefix(k, path+"/") {
- res = append(res, strings.ReplaceAll(k, path+"/", ""))
- }
- }
- return res, nil
-}
-
-func (m *MockMemoryStorage) Remove(name string) error {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- _, ok := m.Data[name]
- if !ok {
- return fmt.Errorf("file does not exist")
- }
- delete(m.Data, name)
- return nil
-}
-
-func (m *MockMemoryStorage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- if pos > 0 {
- return nil, fmt.Errorf("pos > 0 not implemented")
- }
- readClosers := []io.ReadCloser{}
- for _, name := range names {
- val, ok := m.Data[name]
- if !ok {
- return nil, fmt.Errorf("file does not exist")
- }
- readClosers = append(readClosers, io.NopCloser(bytes.NewBuffer(*val)))
- }
- return readClosers, nil
-}
-func (m *MockMemoryStorage) OpenFile(name string) (io.ReadCloser, error) {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- val, ok := m.Data[name]
- if !ok {
- return nil, fmt.Errorf("file does not exist")
- }
-
- return io.NopCloser(bytes.NewBuffer(*val)), nil
-}
-func (m *MockMemoryStorage) OpenFileForWriting(name string) (io.WriteCloser, error) {
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- m.Data[name] = (*MockReadWriterData)(&[]byte{})
- return m.Data[name], nil
-}
-func (m *MockMemoryStorage) OpenFileForAppending(name string) (io.WriteCloser, error) {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- if m.Data == nil {
- m.Data = make(map[string]*MockReadWriterData)
- }
- val, ok := m.Data[name]
- if !ok {
- m.Data[name] = (*MockReadWriterData)(&[]byte{})
- return m.Data[name], nil
- }
- m.Data[name] = (*MockReadWriterData)(val)
- return m.Data[name], nil
-}
-func (m *MockMemoryStorage) EnsurePermissions(name string, mode fs.FileMode) error {
- return nil
-}
-func (m *MockMemoryStorage) FileInfo(name string) (fs.FileInfo, error) {
- m.Mu.Lock()
- defer m.Mu.Unlock()
- val, ok := m.FileInfoData[name]
- if !ok {
- return FileInfo{}, fmt.Errorf("couldn't get file info for: %s", name)
- }
- return val, nil
-}
diff --git a/pkg/storage/s3/list.go b/pkg/storage/s3/list.go
deleted file mode 100644
index 5cadfdc..0000000
--- a/pkg/storage/s3/list.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package s3storage
-
-import (
- "context"
- "fmt"
- "strings"
-
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/s3"
-)
-
-func (s *S3Storage) ReadDir(pathname string) ([]string, error) {
- objectList, err := s.s3Client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{
- Bucket: aws.String(s.bucketname),
- Prefix: aws.String(s.prefix + "/" + strings.TrimLeft(pathname, "/")),
- })
- if err != nil {
- return []string{}, fmt.Errorf("list object error: %s", err)
- }
- res := make([]string, len(objectList.Contents))
- for k, object := range objectList.Contents {
- res[k] = *object.Key
- }
- return res, nil
-}
diff --git a/pkg/storage/s3/new.go b/pkg/storage/s3/new.go
deleted file mode 100644
index 7e0c7f7..0000000
--- a/pkg/storage/s3/new.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package s3storage
-
-import (
- "context"
- "fmt"
-
- "github.com/aws/aws-sdk-go-v2/config"
- "github.com/aws/aws-sdk-go-v2/service/s3"
-)
-
-func New(bucketname, prefix string) (*S3Storage, error) {
- sdkConfig, err := config.LoadDefaultConfig(context.TODO())
- if err != nil {
- return nil, fmt.Errorf("config load error: %s", err)
- }
- s3Client := s3.NewFromConfig(sdkConfig)
-
- return &S3Storage{
- bucketname: bucketname,
- prefix: prefix,
- s3Client: s3Client,
- }, nil
-}
diff --git a/pkg/storage/s3/path.go b/pkg/storage/s3/path.go
deleted file mode 100644
index b8db5ce..0000000
--- a/pkg/storage/s3/path.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package s3storage
-
-import (
- "io/fs"
- "strings"
-)
-
-func (l *S3Storage) FileExists(filename string) bool {
- return false
-}
-
-func (l *S3Storage) ConfigPath(filename string) string {
- return CONFIG_PATH + "/" + strings.TrimLeft(filename, "/")
-}
-
-func (s *S3Storage) GetPath() string {
- return s.prefix
-}
-
-func (l *S3Storage) EnsurePath(pathname string) error {
- return nil
-}
-
-func (l *S3Storage) EnsureOwnership(filename, login string) error {
- return nil
-}
-
-func (l *S3Storage) Remove(name string) error {
- return nil
-}
-
-func (l *S3Storage) Rename(oldName, newName string) error {
- return nil
-}
-
-func (l *S3Storage) EnsurePermissions(name string, mode fs.FileMode) error {
- return nil
-}
-
-func (l *S3Storage) FileInfo(name string) (fs.FileInfo, error) {
- return nil, nil
-}
diff --git a/pkg/storage/s3/read.go b/pkg/storage/s3/read.go
deleted file mode 100644
index 91b0ffe..0000000
--- a/pkg/storage/s3/read.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package s3storage
-
-import (
- "io"
-)
-
-func (l *S3Storage) ReadFile(name string) ([]byte, error) {
- return nil, nil
-}
-
-func (l *S3Storage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) {
- return nil, nil
-}
-
-func (l *S3Storage) OpenFile(name string) (io.ReadCloser, error) {
- return nil, nil
-}
diff --git a/pkg/storage/s3/types.go b/pkg/storage/s3/types.go
deleted file mode 100644
index 3ff10eb..0000000
--- a/pkg/storage/s3/types.go
+++ /dev/null
@@ -1,11 +0,0 @@
-package s3storage
-
-import "github.com/aws/aws-sdk-go-v2/service/s3"
-
-const CONFIG_PATH = "config"
-
-type S3Storage struct {
- bucketname string
- prefix string
- s3Client *s3.Client
-}
diff --git a/pkg/storage/s3/write.go b/pkg/storage/s3/write.go
deleted file mode 100644
index 815b819..0000000
--- a/pkg/storage/s3/write.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package s3storage
-
-import (
- "bytes"
- "context"
- "fmt"
- "io"
-
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/s3"
-)
-
-func (s *S3Storage) WriteFile(name string, data []byte) error {
- _, err := s.s3Client.PutObject(context.TODO(), &s3.PutObjectInput{
- Bucket: aws.String(s.bucketname),
- Key: aws.String(name),
- Body: bytes.NewReader(data),
- })
- if err != nil {
- return fmt.Errorf("put object error: %s", err)
- }
- return nil
-}
-
-func (s *S3Storage) AppendFile(name string, data []byte) error {
- return nil
-}
-
-func (s *S3Storage) OpenFileForWriting(name string) (io.WriteCloser, error) {
- return nil, nil
-}
-
-func (s *S3Storage) OpenFileForAppending(name string) (io.WriteCloser, error) {
- return nil, nil
-}
diff --git a/pkg/users/localusers.go b/pkg/users/localusers.go
deleted file mode 100644
index 55be548..0000000
--- a/pkg/users/localusers.go
+++ /dev/null
@@ -1,185 +0,0 @@
-package users
-
-import (
- "fmt"
-
- "github.com/google/uuid"
- "golang.org/x/crypto/bcrypt"
-)
-
-func (u *UserStore) AddUser(user User) (User, error) {
- if user.Login == "" {
- return user, fmt.Errorf("login cannot be empty")
- }
- existingUsers := u.ListUsers()
- for _, existingUser := range existingUsers {
- if existingUser.Login == user.Login {
- return User{}, fmt.Errorf("user with login '%s' already exists", user.Login)
- }
- }
- user.ID = uuid.NewString()
- if user.Password != "" {
- hashedPassword, err := HashPassword(user.Password)
- if err != nil {
- return user, fmt.Errorf("HashPassword error: %s", err)
- }
- user.Password = hashedPassword
- }
- u.Users = append(u.Users, user)
- if u.autoSave {
- return user, u.SaveUsers()
- }
- return user, nil
-}
-
-func (u *UserStore) AddUsers(users []User) ([]User, error) {
- createdUsers := []User{}
- existingUsers := u.ListUsers()
- for k := range users {
- for _, existingUser := range existingUsers {
- if existingUser.Login == users[k].Login {
- return createdUsers, fmt.Errorf("user with login '%s' already exists", users[k].Login)
- }
- }
- users[k].ID = uuid.NewString()
- hashedPassword, err := HashPassword(users[k].Password)
- if err != nil {
- return createdUsers, fmt.Errorf("HashPassword error: %s", err)
- }
- users[k].Password = hashedPassword
- u.Users = append(u.Users, users[k])
- existingUsers = append(existingUsers, users[k])
- createdUsers = append(createdUsers, users[k])
- }
- if u.autoSave {
- return createdUsers, u.SaveUsers()
- }
- return createdUsers, nil
-}
-
-func (u *UserStore) GetUserByID(id string) (User, error) {
- for _, user := range u.Users {
- if user.ID == id {
- user.Password = ""
- return user, nil
- }
- }
- return User{}, fmt.Errorf("User not found")
-}
-
-func (u *UserStore) GetUserByLogin(login string) (User, error) {
- for _, user := range u.Users {
- if user.Login == login {
- user.Password = ""
- return user, nil
- }
- }
- return User{}, fmt.Errorf("User not found")
-}
-func (u *UserStore) DeleteUserByLogin(login string) error {
- for k, user := range u.Users {
- if user.Login == login {
- u.Users = append(u.Users[:k], u.Users[k+1:]...)
- if u.autoSave {
- return u.SaveUsers()
- }
- return nil
- }
- }
- return fmt.Errorf("User not found")
-}
-
-func (u *UserStore) DeleteUserByID(id string) error {
- for k, user := range u.Users {
- if user.ID == id {
- u.Users = append(u.Users[:k], u.Users[k+1:]...)
- if u.autoSave {
- return u.SaveUsers()
- }
- return nil
- }
- }
- return fmt.Errorf("User not found")
-}
-
-func (u *UserStore) AuthUser(login, password string) (User, bool) {
- for _, user := range u.Users {
- if user.Login == login {
- passwordMatch := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
- if passwordMatch == nil {
- return user, true
- }
- }
- }
- return User{}, false
-}
-
-func (u *UserStore) LoginExists(login string) bool {
- for _, user := range u.Users {
- if user.Login == login {
- return true
- }
- }
- return false
-}
-func (u *UserStore) UpdateUser(user User) error {
- for k, existingUser := range u.Users {
- if existingUser.Login == user.Login {
- password := existingUser.Password // we keep the password
- user.Password = password
- u.Users[k] = user
- if u.autoSave {
- return u.SaveUsers()
- } else {
- return nil
- }
- }
- }
- return fmt.Errorf("user not found in database: %s", user.Login)
-}
-func (u *UserStore) UpdatePassword(userID string, password string) error {
- for k, existingUser := range u.Users {
- if existingUser.ID == userID {
- hashedPassword, err := HashPassword(password)
- if err != nil {
- return fmt.Errorf("HashPassword error: %s", err)
- }
- u.Users[k].Password = hashedPassword
- if u.autoSave {
- return u.SaveUsers()
- } else {
- return nil
- }
- }
- }
- return fmt.Errorf("user not found in database: userID %s", userID)
-}
-
-func HashPassword(password string) (string, error) {
- adminPasswordHashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
- if err != nil {
- return "", fmt.Errorf("unable to set password: %s", err)
- }
- return string(adminPasswordHashed), nil
-}
-
-func (u *UserStore) ListUsers() []User {
- users := make([]User, len(u.Users))
- for k, user := range u.Users {
- user.Password = ""
- users[k] = user
- }
- return users
-}
-
-func (u *UserStore) UserCount() int {
- return len(u.Users)
-}
-
-func (u *UserStore) Empty() error {
- u.Users = []User{}
- if u.autoSave {
- return u.SaveUsers()
- }
- return nil
-}
diff --git a/pkg/users/maxusers.go b/pkg/users/maxusers.go
deleted file mode 100644
index ffc0221..0000000
--- a/pkg/users/maxusers.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package users
-
-func (u *UserStore) GetMaxUsers() int {
- return u.maxUsers
-}
diff --git a/pkg/users/oidcusers.go b/pkg/users/oidcusers.go
deleted file mode 100644
index 66a6bfa..0000000
--- a/pkg/users/oidcusers.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package users
-
-import "fmt"
-
-func (u *UserStore) GetUserByOIDCIDs(oidcIDs []string) (User, error) {
- for _, user := range u.Users {
- for _, oidcID := range oidcIDs {
- if user.OIDCID == oidcID {
- return user, nil
- }
- }
-
- }
- return User{}, fmt.Errorf("User not found")
-}
diff --git a/pkg/users/save.go b/pkg/users/save.go
deleted file mode 100644
index f35085f..0000000
--- a/pkg/users/save.go
+++ /dev/null
@@ -1,36 +0,0 @@
-package users
-
-import (
- "encoding/json"
- "fmt"
- "os/user"
- "sync"
-)
-
-var UserStoreMu sync.Mutex
-
-func (u *UserStore) SaveUsers() error {
- UserStoreMu.Lock()
- defer UserStoreMu.Unlock()
- out, err := json.Marshal(u.Users)
- if err != nil {
- return fmt.Errorf("user store marshal error: %s", err)
- }
- err = u.storage.WriteFile(u.storage.ConfigPath(USERSTORE_FILENAME), out)
- if err != nil {
- return fmt.Errorf("user store write error: %s", err)
- }
- // fix permissions
- currentUser, err := user.Current()
- if err != nil {
- return fmt.Errorf("could not get current user: %s", err)
- }
- if currentUser.Username != "vpn" {
- err = u.storage.EnsureOwnership(u.storage.ConfigPath(USERSTORE_FILENAME), "vpn")
- if err != nil {
- return fmt.Errorf("ensure ownership error (userstore): %s", err)
- }
- }
-
- return nil
-}
diff --git a/pkg/users/store.go b/pkg/users/store.go
deleted file mode 100644
index ca0ee62..0000000
--- a/pkg/users/store.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package users
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-const USERSTORE_FILENAME = "users.json"
-
-func NewUserStore(storage storage.Iface, maxUsers int) (*UserStore, error) {
- userStore := &UserStore{
- autoSave: true,
- maxUsers: maxUsers,
- storage: storage,
- }
-
- if !userStore.storage.FileExists(userStore.storage.ConfigPath(USERSTORE_FILENAME)) {
- userStore.Users = []User{}
- return userStore, nil
- }
-
- data, err := userStore.storage.ReadFile(userStore.storage.ConfigPath(USERSTORE_FILENAME))
- if err != nil {
- return userStore, fmt.Errorf("config read error: %s", err)
- }
- decoder := json.NewDecoder(bytes.NewBuffer(data))
- err = decoder.Decode(&userStore.Users)
- if err != nil {
- return userStore, fmt.Errorf("decode input error: %s", err)
- }
- return userStore, nil
-}
diff --git a/pkg/users/types.go b/pkg/users/types.go
deleted file mode 100644
index 0c2bc65..0000000
--- a/pkg/users/types.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package users
-
-import (
- "time"
-
- "github.com/in4it/wireguard-server/pkg/storage"
-)
-
-type UserStore struct {
- Users []User `json:"users"`
- autoSave bool
- maxUsers int
- storage storage.Iface
-}
-
-type User struct {
- ID string `json:"id"`
- Login string `json:"login"`
- Role string `json:"role"`
- OIDCID string `json:"oidcID,omitempty"`
- SAMLID string `json:"samlID,omitempty"`
- Provisioned bool `json:"provisioned,omitempty"`
- Password string `json:"password,omitempty"`
- Suspended bool `json:"suspended"`
- ConnectionsDisabledOnAuthFailure bool `json:"connectionsDisabledOnAuthFailure"`
- Factors []Factor `json:"factors"`
- ExternalID string `json:"externalID,omitempty"`
- LastLogin time.Time `json:"lastLogin"`
-}
-
-type Factor struct {
- Name string `json:"name"`
- Type string `json:"type"`
- Secret string `json:"secret"`
-}
diff --git a/pkg/utils/date/compare.go b/pkg/utils/date/compare.go
deleted file mode 100644
index b6a901e..0000000
--- a/pkg/utils/date/compare.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package dateutils
-
-import "time"
-
-func DateEqual(date1, date2 time.Time) bool {
- y1, m1, d1 := date1.Date()
- y2, m2, d2 := date2.Date()
- return y1 == y2 && m1 == m2 && d1 == d2
-}
diff --git a/pkg/utils/random/rand.go b/pkg/utils/random/rand.go
deleted file mode 100644
index 7645ba3..0000000
--- a/pkg/utils/random/rand.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package randomutils
-
-import (
- "crypto/rand"
- "encoding/base64"
- "fmt"
- "io"
-)
-
-func GetRandomString(n int) (string, error) {
- buf := make([]byte, n)
-
- _, err := io.ReadFull(rand.Reader, buf)
- if err != nil {
- return "", fmt.Errorf("crypto/rand Reader error: %s", err)
- }
-
- return base64.RawURLEncoding.EncodeToString(buf), nil
-}
diff --git a/pkg/vpn/helpers.go b/pkg/vpn/helpers.go
new file mode 100644
index 0000000..92c1340
--- /dev/null
+++ b/pkg/vpn/helpers.go
@@ -0,0 +1,38 @@
+package vpn
+
+import (
+ "fmt"
+ "net/http"
+ "strings"
+)
+
+func (v *VPN) returnError(w http.ResponseWriter, err error, statusCode int) {
+ fmt.Println("========= ERROR =========")
+ fmt.Printf("Error: %s\n", err)
+ fmt.Println("=========================")
+ w.WriteHeader(statusCode)
+ w.Write([]byte(`{"error": "` + strings.Replace(err.Error(), `"`, `\"`, -1) + `"}`))
+}
+
+func (v *VPN) write(w http.ResponseWriter, res []byte) {
+ sendCorsHeaders(w, "", v.Hostname, v.Protocol)
+ w.WriteHeader(http.StatusOK)
+ w.Write(res)
+}
+func (v *VPN) writeWithStatus(w http.ResponseWriter, res []byte, status int) {
+ sendCorsHeaders(w, "", v.Hostname, v.Protocol)
+ w.WriteHeader(status)
+ w.Write(res)
+}
+
+func sendCorsHeaders(w http.ResponseWriter, headers string, hostname string, protocol string) {
+ if hostname == "" {
+ w.Header().Add("Access-Control-Allow-Origin", "*")
+ } else {
+ w.Header().Add("Access-Control-Allow-Origin", fmt.Sprintf("%s://%s", protocol, hostname))
+ }
+ w.Header().Add("Access-Control-allow-methods", "GET,HEAD,POST,PUT,OPTIONS,DELETE,PATCH")
+ if headers != "" {
+ w.Header().Add("Access-Control-Allow-Headers", headers)
+ }
+}
diff --git a/pkg/vpn/new.go b/pkg/vpn/new.go
new file mode 100644
index 0000000..d7bef01
--- /dev/null
+++ b/pkg/vpn/new.go
@@ -0,0 +1,13 @@
+package vpn
+
+import (
+ "github.com/in4it/go-devops-platform/storage"
+ "github.com/in4it/go-devops-platform/users"
+)
+
+func New(defaultStorage storage.Iface, userStore *users.UserStore) *VPN {
+ return &VPN{
+ Storage: defaultStorage,
+ UserStore: userStore,
+ }
+}
diff --git a/pkg/rest/static/.gitignore b/pkg/vpn/resources/.gitignore
similarity index 100%
rename from pkg/rest/static/.gitignore
rename to pkg/vpn/resources/.gitignore
diff --git a/pkg/vpn/router.go b/pkg/vpn/router.go
new file mode 100644
index 0000000..c1134c0
--- /dev/null
+++ b/pkg/vpn/router.go
@@ -0,0 +1,26 @@
+package vpn
+
+import (
+ "net/http"
+
+ "github.com/in4it/go-devops-platform/rest"
+)
+
+func (v *VPN) GetRouter() *http.ServeMux {
+ mux := http.NewServeMux()
+
+ mux.Handle("/api/vpn/connections", http.HandlerFunc(v.connectionsHandler))
+ mux.Handle("/api/vpn/connection/{id}", http.HandlerFunc(v.connectionsElementHandler))
+ mux.Handle("/api/vpn/connectionlicense", http.HandlerFunc(v.connectionLicenseHandler))
+
+ mux.Handle("/api/vpn/stats/user/{date}", rest.IsAdminMiddleware(http.HandlerFunc(v.userStatsHandler)))
+ mux.Handle("/api/vpn/stats/packetlogs/{user}/{date}", rest.IsAdminMiddleware(http.HandlerFunc(v.packetLogsHandler)))
+
+ mux.Handle("/api/vpn/setup/vpn", rest.IsAdminMiddleware(http.HandlerFunc(v.vpnSetupHandler)))
+ mux.Handle("/api/vpn/setup/templates", rest.IsAdminMiddleware(http.HandlerFunc(v.templateSetupHandler)))
+ mux.Handle("/api/vpn/setup/restart-vpn", rest.IsAdminMiddleware(http.HandlerFunc(v.restartVPNHandler)))
+
+ mux.Handle("/api/vpn/version", http.HandlerFunc(v.version))
+
+ return mux
+}
diff --git a/pkg/vpn/setup.go b/pkg/vpn/setup.go
new file mode 100644
index 0000000..470dcf1
--- /dev/null
+++ b/pkg/vpn/setup.go
@@ -0,0 +1,300 @@
+package vpn
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/netip"
+ "reflect"
+ "slices"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/in4it/wireguard-server/pkg/wireguard"
+)
+
+func (v *VPN) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
+ vpnConfig, err := wireguard.GetVPNConfig(v.Storage)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not get vpn config: %s", err), http.StatusBadRequest)
+ return
+ }
+ switch r.Method {
+ case http.MethodGet:
+ packetLogTypes := []string{}
+ for k, enabled := range vpnConfig.PacketLogsTypes {
+ if enabled {
+ packetLogTypes = append(packetLogTypes, k)
+ }
+ }
+ if vpnConfig.PacketLogsRetention == 0 {
+ vpnConfig.PacketLogsRetention = 7
+ }
+ setupRequest := VPNSetupRequest{
+ Routes: strings.Join(vpnConfig.ClientRoutes, ", "),
+ VPNEndpoint: vpnConfig.Endpoint,
+ AddressRange: vpnConfig.AddressRange.String(),
+ ClientAddressPrefix: vpnConfig.ClientAddressPrefix,
+ Port: strconv.Itoa(vpnConfig.Port),
+ ExternalInterface: vpnConfig.ExternalInterface,
+ Nameservers: strings.Join(vpnConfig.Nameservers, ","),
+ DisableNAT: vpnConfig.DisableNAT,
+ EnablePacketLogs: vpnConfig.EnablePacketLogs,
+ PacketLogsTypes: packetLogTypes,
+ PacketLogsRetention: strconv.Itoa(vpnConfig.PacketLogsRetention),
+ }
+ out, err := json.Marshal(setupRequest)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ case http.MethodPost:
+ var (
+ writeVPNConfig bool
+ rewriteClientConfigs bool
+ setupRequest VPNSetupRequest
+ )
+ decoder := json.NewDecoder(r.Body)
+ decoder.Decode(&setupRequest)
+ if strings.Join(vpnConfig.ClientRoutes, ", ") != setupRequest.Routes {
+ networks := strings.Split(setupRequest.Routes, ",")
+ validatedNetworks := []string{}
+ for _, network := range networks {
+ if strings.TrimSpace(network) == "::/0" {
+ validatedNetworks = append(validatedNetworks, "::/0")
+ } else {
+ _, ipnet, err := net.ParseCIDR(strings.TrimSpace(network))
+ if err != nil {
+ v.returnError(w, fmt.Errorf("client route %s in wrong format: %s", strings.TrimSpace(network), err), http.StatusBadRequest)
+ return
+ }
+ validatedNetworks = append(validatedNetworks, ipnet.String())
+ }
+ }
+ vpnConfig.ClientRoutes = validatedNetworks
+ writeVPNConfig = true
+ rewriteClientConfigs = true
+ }
+ if vpnConfig.Endpoint != setupRequest.VPNEndpoint {
+ vpnConfig.Endpoint = setupRequest.VPNEndpoint
+ writeVPNConfig = true
+ rewriteClientConfigs = true
+ }
+ addressRangeParsed, err := netip.ParsePrefix(setupRequest.AddressRange)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("AddressRange in wrong format: %s", err), http.StatusBadRequest)
+ return
+ }
+ if addressRangeParsed.String() != vpnConfig.AddressRange.String() {
+ vpnConfig.AddressRange = addressRangeParsed
+ writeVPNConfig = true
+ rewriteClientConfigs = true
+ }
+ if setupRequest.ClientAddressPrefix != vpnConfig.ClientAddressPrefix {
+ vpnConfig.ClientAddressPrefix = setupRequest.ClientAddressPrefix
+ writeVPNConfig = true
+ rewriteClientConfigs = true
+ }
+ port, err := strconv.Atoi(setupRequest.Port)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("port in wrong format: %s", err), http.StatusBadRequest)
+ return
+ }
+ if port != vpnConfig.Port {
+ vpnConfig.Port = port
+ writeVPNConfig = true
+ rewriteClientConfigs = true
+ }
+
+ nameservers := strings.Split(setupRequest.Nameservers, ",")
+ for k := range nameservers {
+ nameservers[k] = strings.TrimSpace(nameservers[k])
+ }
+ if !reflect.DeepEqual(nameservers, vpnConfig.Nameservers) {
+ vpnConfig.Nameservers = nameservers
+ writeVPNConfig = true
+ rewriteClientConfigs = true
+ }
+ if setupRequest.ExternalInterface != vpnConfig.ExternalInterface { // don't rewrite client config
+ vpnConfig.ExternalInterface = setupRequest.ExternalInterface
+ writeVPNConfig = true
+ }
+ if setupRequest.DisableNAT != vpnConfig.DisableNAT { // don't rewrite client config
+ vpnConfig.DisableNAT = setupRequest.DisableNAT
+ writeVPNConfig = true
+ }
+ if setupRequest.EnablePacketLogs != vpnConfig.EnablePacketLogs {
+ vpnConfig.EnablePacketLogs = setupRequest.EnablePacketLogs
+ writeVPNConfig = true
+ }
+ packetLogsRention, err := strconv.Atoi(setupRequest.PacketLogsRetention)
+ if err != nil || packetLogsRention < 1 {
+ v.returnError(w, fmt.Errorf("incorrect packet log retention. Enter a number of days the logs must be kept (minimum 1)"), http.StatusBadRequest)
+ return
+ }
+ if packetLogsRention != vpnConfig.PacketLogsRetention {
+ vpnConfig.PacketLogsRetention = packetLogsRention
+ writeVPNConfig = true
+ }
+
+ // packetlogtypes
+ packetLogTypes := []string{}
+ for k, enabled := range vpnConfig.PacketLogsTypes {
+ if enabled {
+ packetLogTypes = append(packetLogTypes, k)
+ }
+ }
+ sort.Strings(setupRequest.PacketLogsTypes)
+ sort.Strings(packetLogTypes)
+ if !slices.Equal(setupRequest.PacketLogsTypes, packetLogTypes) {
+ vpnConfig.PacketLogsTypes = make(map[string]bool)
+ for _, v := range setupRequest.PacketLogsTypes {
+ if v == "http+https" || v == "dns" || v == "tcp" {
+ vpnConfig.PacketLogsTypes[v] = true
+ }
+ }
+ writeVPNConfig = true
+ }
+
+ // write vpn config if config has changed
+ if writeVPNConfig {
+ err = wireguard.WriteVPNConfig(v.Storage, vpnConfig)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could write vpn config: %s", err), http.StatusBadRequest)
+ return
+ }
+ err = wireguard.ReloadVPNServerConfig()
+ if err != nil {
+ v.returnError(w, fmt.Errorf("unable to reload server config: %s", err), http.StatusBadRequest)
+ return
+ }
+ }
+ if rewriteClientConfigs {
+ // rewrite client configs
+ err = wireguard.UpdateClientsConfig(v.Storage)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest)
+ return
+ }
+ }
+ out, err := json.Marshal(setupRequest)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ default:
+ v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
+ }
+}
+
+func (v *VPN) templateSetupHandler(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ clientTemplate, err := wireguard.GetClientTemplate(v.Storage)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest)
+ return
+ }
+ serverTemplate, err := wireguard.GetServerTemplate(v.Storage)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest)
+ return
+ }
+ setupRequest := TemplateSetupRequest{
+ ClientTemplate: string(clientTemplate),
+ ServerTemplate: string(serverTemplate),
+ }
+ out, err := json.Marshal(setupRequest)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ case http.MethodPost:
+ var templateSetupRequest TemplateSetupRequest
+ decoder := json.NewDecoder(r.Body)
+ decoder.Decode(&templateSetupRequest)
+ clientTemplate, err := wireguard.GetClientTemplate(v.Storage)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest)
+ return
+ }
+ serverTemplate, err := wireguard.GetServerTemplate(v.Storage)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest)
+ return
+ }
+ if string(clientTemplate) != templateSetupRequest.ClientTemplate {
+ err = wireguard.WriteClientTemplate(v.Storage, []byte(templateSetupRequest.ClientTemplate))
+ if err != nil {
+ v.returnError(w, fmt.Errorf("WriteClientTemplate error: %s", err), http.StatusBadRequest)
+ return
+ }
+ // rewrite client configs
+ err = wireguard.UpdateClientsConfig(v.Storage)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest)
+ return
+ }
+ }
+ if string(serverTemplate) != templateSetupRequest.ServerTemplate {
+ err = wireguard.WriteServerTemplate(v.Storage, []byte(templateSetupRequest.ServerTemplate))
+ if err != nil {
+ v.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest)
+ return
+ }
+ }
+ out, err := json.Marshal(templateSetupRequest)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ default:
+ v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
+ }
+}
+
+func (v *VPN) restartVPNHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ v.returnError(w, fmt.Errorf("unsupported method"), http.StatusBadRequest)
+ return
+ }
+ client := http.Client{
+ Timeout: 10 * time.Second,
+ }
+ req, err := http.NewRequest(r.Method, "http://"+wireguard.CONFIGMANAGER_URI+"/restart-vpn", nil)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("restart request error: %s", err), http.StatusBadRequest)
+ return
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("restart error: %s", err), http.StatusBadRequest)
+ return
+ }
+ if resp.StatusCode != http.StatusAccepted {
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("restart error: got status code: %d. Response: %s", resp.StatusCode, bodyBytes), http.StatusBadRequest)
+ return
+ }
+ v.returnError(w, fmt.Errorf("restart error: got status code: %d. Couldn't get response", resp.StatusCode), http.StatusBadRequest)
+ return
+ }
+
+ defer resp.Body.Close()
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("body read error: %s", err), http.StatusBadRequest)
+ return
+ }
+
+ v.write(w, bodyBytes)
+}
diff --git a/pkg/rest/stats.go b/pkg/vpn/stats.go
similarity index 88%
rename from pkg/rest/stats.go
rename to pkg/vpn/stats.go
index e7d6abc..2e47a1d 100644
--- a/pkg/rest/stats.go
+++ b/pkg/vpn/stats.go
@@ -1,4 +1,4 @@
-package rest
+package vpn
import (
"bufio"
@@ -15,21 +15,21 @@ import (
"strings"
"time"
- "github.com/in4it/wireguard-server/pkg/storage"
- dateutils "github.com/in4it/wireguard-server/pkg/utils/date"
+ "github.com/in4it/go-devops-platform/storage"
+ dateutils "github.com/in4it/go-devops-platform/utils/date"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
const MAX_LOG_OUTPUT_LINES = 100
-func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) {
+func (v *VPN) userStatsHandler(w http.ResponseWriter, r *http.Request) {
if r.PathValue("date") == "" {
- c.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest)
return
}
date, err := time.Parse("2006-01-02", r.PathValue("date"))
if err != nil {
- c.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest)
return
}
unitAdjustment := int64(1)
@@ -49,7 +49,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) {
}
}
// get all users
- users := c.UserStore.ListUsers()
+ users := v.UserStore.ListUsers()
userMap := make(map[string]string)
for _, user := range users {
userMap[user.ID] = user.Login
@@ -65,10 +65,10 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) {
}
logData := bytes.NewBuffer([]byte{})
for _, statsFile := range statsFiles {
- if c.Storage.Client.FileExists(statsFile) {
- fileLogData, err := c.Storage.Client.ReadFile(statsFile)
+ if v.Storage.FileExists(statsFile) {
+ fileLogData, err := v.Storage.ReadFile(statsFile)
if err != nil {
- c.returnError(w, fmt.Errorf("readfile error: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("readfile error: %s", err), http.StatusBadRequest)
return
}
logData.Write(fileLogData)
@@ -149,7 +149,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) {
}
if err := scanner.Err(); err != nil {
- c.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest)
return
}
userStatsResponse.ReceiveBytes = UserStatsData{
@@ -210,39 +210,39 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) {
out, err := json.Marshal(userStatsResponse)
if err != nil {
- c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest)
return
}
- c.write(w, out)
+ v.write(w, out)
}
-func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) {
- vpnConfig, err := wireguard.GetVPNConfig(c.Storage.Client)
+func (v *VPN) packetLogsHandler(w http.ResponseWriter, r *http.Request) {
+ vpnConfig, err := wireguard.GetVPNConfig(v.Storage)
if err != nil {
- c.returnError(w, fmt.Errorf("get vpn config error: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("get vpn config error: %s", err), http.StatusBadRequest)
return
}
if !vpnConfig.EnablePacketLogs { // packet logs is disabled
out, err := json.Marshal(LogDataResponse{Enabled: false})
if err != nil {
- c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest)
return
}
- c.write(w, out)
+ v.write(w, out)
return
}
userID := r.PathValue("user")
if userID == "" {
- c.returnError(w, fmt.Errorf("no user supplied"), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("no user supplied"), http.StatusBadRequest)
return
}
if r.PathValue("date") == "" {
- c.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest)
return
}
date, err := time.Parse("2006-01-02", r.PathValue("date"))
if err != nil {
- c.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest)
return
}
offset := 0
@@ -261,7 +261,7 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) {
}
search := r.FormValue("search")
// get all users
- users := c.UserStore.ListUsers()
+ users := v.UserStore.ListUsers()
userMap := make(map[string]string)
for _, user := range users {
userMap[user.ID] = user.Login
@@ -292,14 +292,14 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) {
if !dateutils.DateEqual(time.Now(), date) { // date is in local timezone, and we are UTC, so also read next file
statsFiles = append(statsFiles, path.Join(wireguard.VPN_STATS_DIR, wireguard.VPN_PACKETLOGGER_DIR, userID+"-"+date.AddDate(0, 0, 1).Format("2006-01-02")+".log"))
}
- statsFiles, err = getCompressedFilesAndRemoveNonExistent(c.Storage.Client, statsFiles)
+ statsFiles, err = getCompressedFilesAndRemoveNonExistent(v.Storage, statsFiles)
if err != nil {
- c.returnError(w, fmt.Errorf("unable to get files for reading: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("unable to get files for reading: %s", err), http.StatusBadRequest)
return
}
- fileReaders, err := c.Storage.Client.OpenFilesFromPos(statsFiles, pos)
+ fileReaders, err := v.Storage.OpenFilesFromPos(statsFiles, pos)
if err != nil {
- c.returnError(w, fmt.Errorf("error while reading files: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("error while reading files: %s", err), http.StatusBadRequest)
return
}
for _, fileReader := range fileReaders {
@@ -334,7 +334,7 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) {
}
}
if err := scanner.Err(); err != nil {
- c.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest)
return
}
}
@@ -355,10 +355,10 @@ func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) {
out, err := json.Marshal(LogDataResponse{Enabled: true, LogData: logData, LogTypes: packetLogTypes, Users: userMap})
if err != nil {
- c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest)
+ v.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest)
return
}
- c.write(w, out)
+ v.write(w, out)
}
func getCompressedFilesAndRemoveNonExistent(storage storage.Iface, files []string) ([]string, error) {
diff --git a/pkg/rest/stats_test.go b/pkg/vpn/stats_test.go
similarity index 92%
rename from pkg/rest/stats_test.go
rename to pkg/vpn/stats_test.go
index b484606..b04c0d0 100644
--- a/pkg/rest/stats_test.go
+++ b/pkg/vpn/stats_test.go
@@ -1,4 +1,4 @@
-package rest
+package vpn
import (
"bytes"
@@ -11,7 +11,8 @@ import (
"testing"
"time"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
+ memorystorage "github.com/in4it/go-devops-platform/storage/memory"
+ "github.com/in4it/go-devops-platform/users"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
@@ -19,10 +20,8 @@ func TestUserStatsHandler(t *testing.T) {
storage := &memorystorage.MockMemoryStorage{}
- c, err := newContext(storage, SERVER_TYPE_VPN)
- if err != nil {
- t.Fatalf("Cannot create context")
- }
+ v := New(storage, &users.UserStore{})
+
testData := `2024-08-23T19:29:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,12729136,24348520,2024-08-23T18:30:42
2024-08-23T19:34:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,13391716,25162108,2024-08-23T19:33:38
2024-08-23T19:39:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,14419152,27496068,2024-08-23T19:37:39
@@ -34,7 +33,7 @@ func TestUserStatsHandler(t *testing.T) {
2024-08-23T20:09:03,3df97301-5f73-407a-a26b-91829f1e7f48,1,39928520,85171728,2024-08-23T20:08:54`
statsFile := path.Join(wireguard.VPN_STATS_DIR, "user-"+time.Now().Format("2006-01-02")) + ".log"
- err = c.Storage.Client.WriteFile(statsFile, []byte(strings.ReplaceAll(testData, "2024-08-23", time.Now().Format("2006-01-02"))))
+ err := v.Storage.WriteFile(statsFile, []byte(strings.ReplaceAll(testData, "2024-08-23", time.Now().Format("2006-01-02"))))
if err != nil {
t.Fatalf("Cannot write test file")
}
@@ -42,7 +41,7 @@ func TestUserStatsHandler(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/stats/user", nil)
req.SetPathValue("date", time.Now().Format("2006-01-02"))
w := httptest.NewRecorder()
- c.userStatsHandler(w, req)
+ v.userStatsHandler(w, req)
resp := w.Result()
diff --git a/pkg/vpn/types.go b/pkg/vpn/types.go
new file mode 100644
index 0000000..d70f370
--- /dev/null
+++ b/pkg/vpn/types.go
@@ -0,0 +1,89 @@
+package vpn
+
+import (
+ "github.com/in4it/go-devops-platform/storage"
+ "github.com/in4it/go-devops-platform/users"
+)
+
+type VPN struct {
+ Storage storage.Iface
+ UserStore *users.UserStore
+ Hostname string
+ Protocol string
+}
+
+type NewConnectionResponse struct {
+ Name string `json:"name"`
+}
+type Connection struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+}
+
+type UserStatsResponse struct {
+ ReceiveBytes UserStatsData `json:"receivedBytes"`
+ TransmitBytes UserStatsData `json:"transmitBytes"`
+ Handshakes UserStatsData `json:"handshakes"`
+}
+type UserStatsData struct {
+ Datasets UserStatsDatasets `json:"datasets"`
+}
+type UserStatsDatasets []UserStatsDataset
+type UserStatsDataset struct {
+ Label string `json:"label"`
+ Data []UserStatsDataPoint `json:"data"`
+ Fill bool `json:"fill"`
+ BorderColor string `json:"borderColor"`
+ BackgroundColor string `json:"backgroundColor"`
+ Tension float64 `json:"tension"`
+ ShowLine bool `json:"showLine"`
+}
+
+type UserStatsDataPoint struct {
+ X string `json:"x"`
+ Y float64 `json:"y"`
+}
+
+type LogDataResponse struct {
+ LogData LogData `json:"logData"`
+ Enabled bool `json:"enabled"`
+ LogTypes []string `json:"logTypes"`
+ Users map[string]string `json:"users"`
+}
+
+type LogData struct {
+ Schema LogSchema `json:"schema"`
+ Data []LogRow `json:"rows"`
+ NextPos int64 `json:"nextPos"`
+}
+type LogSchema struct {
+ Columns map[string]string `json:"columns"`
+}
+type LogRow struct {
+ Timestamp string `json:"t"`
+ Data []string `json:"d"`
+}
+
+type VPNSetupRequest struct {
+ Routes string `json:"routes"`
+ VPNEndpoint string `json:"vpnEndpoint"`
+ AddressRange string `json:"addressRange"`
+ ClientAddressPrefix string `json:"clientAddressPrefix"`
+ Port string `json:"port"`
+ ExternalInterface string `json:"externalInterface"`
+ Nameservers string `json:"nameservers"`
+ DisableNAT bool `json:"disableNAT"`
+ EnablePacketLogs bool `json:"enablePacketLogs"`
+ PacketLogsTypes []string `json:"packetLogsTypes"`
+ PacketLogsRetention string `json:"packetLogsRetention"`
+}
+
+type TemplateSetupRequest struct {
+ ClientTemplate string `json:"clientTemplate"`
+ ServerTemplate string `json:"serverTemplate"`
+}
+
+type ConnectionLicenseResponse struct {
+ LicenseUserCount int `json:"licenseUserCount"`
+ ConnectionCount int `json:"connectionCount"`
+}
diff --git a/pkg/rest/types_sort.go b/pkg/vpn/types_sort.go
similarity index 94%
rename from pkg/rest/types_sort.go
rename to pkg/vpn/types_sort.go
index ea7e9f7..34e5744 100644
--- a/pkg/rest/types_sort.go
+++ b/pkg/vpn/types_sort.go
@@ -1,4 +1,4 @@
-package rest
+package vpn
func (u UserStatsDatasets) Len() int {
return len(u)
diff --git a/pkg/vpn/version.go b/pkg/vpn/version.go
new file mode 100644
index 0000000..d819010
--- /dev/null
+++ b/pkg/vpn/version.go
@@ -0,0 +1,28 @@
+package vpn
+
+import (
+ _ "embed"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+)
+
+//go:generate cp -r ../../latest ./resources/version
+//go:embed resources/version
+
+var version string
+
+func (v *VPN) version(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ out, err := json.Marshal(map[string]string{"version": strings.TrimSpace(version)})
+ if err != nil {
+ v.returnError(w, fmt.Errorf("version marshal error: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ default:
+ v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
+ }
+}
diff --git a/pkg/vpn/vpn.go b/pkg/vpn/vpn.go
new file mode 100644
index 0000000..42f8bcb
--- /dev/null
+++ b/pkg/vpn/vpn.go
@@ -0,0 +1,139 @@
+package vpn
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "path"
+ "strings"
+ "sync"
+
+ "github.com/in4it/go-devops-platform/rest"
+ "github.com/in4it/go-devops-platform/users"
+ "github.com/in4it/wireguard-server/pkg/wireguard"
+)
+
+var muClientDownload sync.Mutex
+
+func (v *VPN) connectionsHandler(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ user := r.Context().Value(rest.CustomValue("user")).(users.User)
+
+ clients, err := v.Storage.ReadDir(v.Storage.ConfigPath(wireguard.VPN_CLIENTS_DIR))
+ if err != nil {
+ v.returnError(w, fmt.Errorf("cannot list connections for user: %s", err), http.StatusBadRequest)
+ return
+ }
+
+ connectionList := []string{}
+ for _, clientFilename := range clients {
+ if wireguard.HasClientUserID(clientFilename, user.ID) {
+ connectionList = append(connectionList, clientFilename)
+ }
+ }
+ peerConfigs := make([]wireguard.PeerConfig, len(connectionList))
+ for k, connection := range connectionList {
+ var peerConfig wireguard.PeerConfig
+ filename := v.Storage.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connection))
+ toDeleteFileContents, err := v.Storage.ReadFile(filename)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("can't read file %s: %s", filename, err), http.StatusBadRequest)
+ return
+ }
+ err = json.Unmarshal(toDeleteFileContents, &peerConfig)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("can't unmarshal file %s: %s", filename, err), http.StatusBadRequest)
+ return
+ }
+ peerConfigs[k] = peerConfig
+ }
+ connections := make([]Connection, len(peerConfigs))
+ for k := range peerConfigs {
+ connections[k] = Connection{
+ ID: peerConfigs[k].ID,
+ Name: peerConfigs[k].Name,
+ }
+ }
+ out, err := json.Marshal(connections)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not marshal list connection response: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ case http.MethodPost:
+ muClientDownload.Lock()
+ defer muClientDownload.Unlock()
+ user := r.Context().Value(rest.CustomValue("user")).(users.User)
+ peerConfig, err := wireguard.NewEmptyClientConfig(v.Storage, user.ID)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not generate client vpn config: %s", err), http.StatusBadRequest)
+ return
+ }
+ newConnectionResponse := NewConnectionResponse{Name: peerConfig.Name}
+ out, err := json.Marshal(newConnectionResponse)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("could not marshal new connection response: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ default:
+ v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
+ }
+}
+func (v *VPN) connectionsElementHandler(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ user := r.Context().Value(rest.CustomValue("user")).(users.User)
+ if !strings.HasPrefix(r.PathValue("id"), user.ID) {
+ v.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest)
+ return
+ }
+ if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") {
+ v.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest)
+ return
+ }
+ out, err := wireguard.GenerateNewClientConfig(v.Storage, r.PathValue("id"), user.ID)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("GetClientConfig error: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+ case http.MethodDelete:
+ user := r.Context().Value(rest.CustomValue("user")).(users.User)
+ if !strings.HasPrefix(r.PathValue("id"), user.ID) {
+ v.returnError(w, fmt.Errorf("connection id is in invalid format (needs to contain user id)"), http.StatusBadRequest)
+ return
+ }
+ if strings.Contains(r.PathValue("id"), ".") || strings.Contains(r.PathValue("id"), "/") {
+ v.returnError(w, fmt.Errorf("connection id contains invalid characters"), http.StatusBadRequest)
+ return
+ }
+ err := wireguard.DeleteClientConfig(v.Storage, r.PathValue("id"), user.ID)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("DeleteClientConfig error: %s", err), http.StatusBadRequest)
+ return
+ }
+ v.write(w, []byte(`{"deleted": "`+r.PathValue("id")+`"}`))
+
+ default:
+ v.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
+ }
+}
+
+func (v *VPN) connectionLicenseHandler(w http.ResponseWriter, r *http.Request) {
+ user := r.Context().Value(rest.CustomValue("user")).(users.User)
+ licenseUserCount := r.Context().Value(rest.CustomValue("licenseUserCount")).(int)
+ totalConnections, err := wireguard.GetConfigNumbers(v.Storage, user.ID)
+ if err != nil {
+ v.returnError(w, fmt.Errorf("can't determine total connections: %s", err), http.StatusBadRequest)
+ return
+
+ }
+ out, err := json.Marshal(ConnectionLicenseResponse{LicenseUserCount: licenseUserCount, ConnectionCount: len(totalConnections)})
+ if err != nil {
+ v.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest)
+ return
+ }
+ v.write(w, out)
+}
diff --git a/pkg/auth/provisioning/scim/users_test.go b/pkg/vpn/vpn_test.go
similarity index 53%
rename from pkg/auth/provisioning/scim/users_test.go
rename to pkg/vpn/vpn_test.go
index 2c399ac..a01d8e7 100644
--- a/pkg/auth/provisioning/scim/users_test.go
+++ b/pkg/vpn/vpn_test.go
@@ -1,7 +1,8 @@
-package scim
+package vpn
import (
"bytes"
+ "context"
"encoding/json"
"fmt"
"io"
@@ -9,160 +10,25 @@ import (
"net/http"
"net/http/httptest"
"path"
+ "strings"
"testing"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
- "github.com/in4it/wireguard-server/pkg/users"
+ "github.com/in4it/go-devops-platform/auth/provisioning/scim"
+ "github.com/in4it/go-devops-platform/rest"
+ memorystorage "github.com/in4it/go-devops-platform/storage/memory"
+ "github.com/in4it/go-devops-platform/users"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
const USERSTORE_MAX_USERS = 1000
-func TestUsersGetCount100EmptyResult(t *testing.T) {
+func TestSCIMCreateUserConnectionDeleteUserFlow(t *testing.T) {
storage := &memorystorage.MockMemoryStorage{}
-
- userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS)
- if err != nil {
- t.Fatalf("cannot create new user store")
- }
- userStore.Empty()
- if err != nil {
- t.Fatalf("cannot empty user store")
- }
-
- s := New(storage, userStore, "token")
- req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?count=100&startIndex=1&", nil)
- w := httptest.NewRecorder()
- s.getUsersHandler(w, req)
-
- resp := w.Result()
- body, _ := io.ReadAll(resp.Body)
-
- response, err := listUserResponse([]users.User{}, "", 100, 1)
- if err != nil {
- t.Fatalf("userResponse error: %s", err)
- }
- if string(body) != string(response) {
- t.Fatalf("expected empty input. Got %s\n", string(body))
- }
-}
-
-func TestUsersGetCount10(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
- userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS)
- if err != nil {
- t.Fatalf("cannot create new user store")
- }
- err = userStore.Empty()
- if err != nil {
- t.Fatalf("cannot empty user store")
- }
- totalUserCount := 150
- usersToCreate := make([]users.User, totalUserCount)
- for i := 0; i < totalUserCount; i++ {
- usersToCreate[i] = users.User{
- Login: fmt.Sprintf("user-%d@domain.inv", i),
- }
- }
- users, err := userStore.AddUsers(usersToCreate)
- if err != nil {
- t.Fatalf("cannot create users: %s", err)
- }
- s := New(storage, userStore, "token")
- req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?count=10&startIndex=1&", nil)
- w := httptest.NewRecorder()
- s.getUsersHandler(w, req)
-
- resp := w.Result()
- body, _ := io.ReadAll(resp.Body)
-
- response, err := listUserResponse(users, "", 10, 1)
- if err != nil {
- t.Fatalf("userResponse error: %s", err)
- }
- if string(body) != string(response) {
- t.Fatalf("Unexpected output: Got: %s\nExpected: %s\n\n", string(body), string(response))
- }
-}
-
-func TestUsersGetCount10Start5(t *testing.T) {
- count := 10
- start := 5
- storage := &memorystorage.MockMemoryStorage{}
- userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS)
- if err != nil {
- t.Fatalf("cannot create new user store")
- }
- err = userStore.Empty()
- if err != nil {
- t.Fatalf("cannot empty user store")
- }
- totalUserCount := 150
- usersToCreate := make([]users.User, totalUserCount)
- for i := 0; i < totalUserCount; i++ {
- usersToCreate[i] = users.User{
- Login: fmt.Sprintf("user-%d@domain.inv", i),
- }
- }
- users, err := userStore.AddUsers(usersToCreate)
- if err != nil {
- t.Fatalf("cannot create users: %s", err)
- }
- s := New(storage, userStore, "token")
- req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/api/scim/v2/Users?count=%d&startIndex=%d&", count, start), nil)
- w := httptest.NewRecorder()
- s.getUsersHandler(w, req)
-
- resp := w.Result()
-
- var userResponse UserResponse
- err = json.NewDecoder(resp.Body).Decode(&userResponse)
- if err != nil {
- t.Fatalf("Could not decode output: %s", err)
- }
- if userResponse.TotalResults != totalUserCount-start {
- t.Fatalf("Wrong user count: %d", userResponse.TotalResults)
- }
- if userResponse.ItemsPerPage != count {
- t.Fatalf("Wrong page count: %d", userResponse.TotalResults)
- }
- if userResponse.StartIndex != start {
- t.Fatalf("Wrong user start: %d", userResponse.StartIndex)
- }
- if len(userResponse.Resources) != count {
- t.Fatalf("Wrong response count: %d", len(userResponse.Resources))
- }
- if userResponse.Resources[0].UserName != users[5].Login {
- t.Fatalf("Wrong first login: %s (actual) vs %s (expected)", userResponse.Resources[0].UserName, users[5].Login)
- }
-}
-
-func TestUsersGetNonExistentUser(t *testing.T) {
- userStore, err := users.NewUserStore(&memorystorage.MockMemoryStorage{}, USERSTORE_MAX_USERS)
- if err != nil {
- t.Fatalf("cannot create new user stoer")
- }
-
- s := New(&memorystorage.MockMemoryStorage{}, userStore, "token")
- req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?filter=userName+eq+%22ward%40in4it.io%22&", nil)
- w := httptest.NewRecorder()
- s.getUsersHandler(w, req)
-
- resp := w.Result()
- body, _ := io.ReadAll(resp.Body)
-
- response, err := listUserResponse([]users.User{}, "", -1, -1)
- if err != nil {
- t.Fatalf("userResponse error: %s", err)
- }
- if string(body) != string(response) {
- t.Fatalf("expected empty input. Got %s\n", string(body))
- }
-}
-
-func TestAddUser(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
- userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS)
+ userStore, err := users.NewUserStoreWithHooks(storage, USERSTORE_MAX_USERS, users.UserHooks{
+ DisableFunc: wireguard.DisableAllClientConfigs,
+ DeleteFunc: wireguard.DeleteAllClientConfigs,
+ ReactivateFunc: wireguard.ReactivateAllClientConfigs,
+ })
if err != nil {
t.Fatalf("cannot create new user store: %s", err)
}
@@ -170,54 +36,7 @@ func TestAddUser(t *testing.T) {
if err != nil {
t.Fatalf("cannot empty user store")
}
- s := New(storage, userStore, "token")
- payload := PostUserRequest{
- UserName: "john@domain.inv",
- Name: Name{
- GivenName: "John",
- FamilyName: "Doe",
- },
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- t.Fatalf("cannot marshal payload: %s", err)
- }
- req := httptest.NewRequest("POST", "http://example.com/api/scim/v2/Users?", bytes.NewBuffer(payloadBytes))
- w := httptest.NewRecorder()
- s.postUsersHandler(w, req)
-
- resp := w.Result()
-
- if resp.StatusCode != 201 {
- t.Fatalf("User not added. StatusCode: %d", resp.StatusCode)
- }
-
- var postUserRequest PostUserRequest
- err = json.NewDecoder(resp.Body).Decode(&postUserRequest)
- if err != nil {
- t.Fatalf("Could not decode output: %s", err)
- }
-
- if postUserRequest.Id == "" {
- t.Fatalf("id is empty: %s", err)
- }
- if postUserRequest.UserName != payload.UserName {
- t.Fatalf("username mismatch: %s (actual) vs %s (expected)", postUserRequest.UserName, payload.UserName)
- }
-
-}
-
-func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
- storage := &memorystorage.MockMemoryStorage{}
- userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS)
- if err != nil {
- t.Fatalf("cannot create new user store: %s", err)
- }
- userStore.Empty()
- if err != nil {
- t.Fatalf("cannot empty user store")
- }
- s := New(storage, userStore, "token")
+ s := scim.New(storage, userStore, "token")
l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI)
if err != nil {
@@ -250,9 +69,9 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
defer l.Close()
// create a user
- payload := PostUserRequest{
+ payload := scim.PostUserRequest{
UserName: "john@domain.inv",
- Name: Name{
+ Name: scim.Name{
GivenName: "John",
FamilyName: "Doe",
},
@@ -263,7 +82,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
}
req := httptest.NewRequest("POST", "http://example.com/api/scim/v2/Users?", bytes.NewBuffer(payloadBytes))
w := httptest.NewRecorder()
- s.postUsersHandler(w, req)
+ s.PostUsersHandler(w, req)
resp := w.Result()
@@ -273,7 +92,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
defer resp.Body.Close()
- var postUserRequest PostUserRequest
+ var postUserRequest scim.PostUserRequest
err = json.NewDecoder(resp.Body).Decode(&postUserRequest)
if err != nil {
t.Fatalf("Could not decode output: %s", err)
@@ -314,7 +133,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
req = httptest.NewRequest("DELETE", "http://example.com/api/scim/v2/Users/"+user.ID, nil)
req.SetPathValue("id", user.ID)
w = httptest.NewRecorder()
- s.deleteUserHandler(w, req)
+ s.DeleteUserHandler(w, req)
resp = w.Result()
@@ -330,7 +149,11 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
func TestCreateUserConnectionSuspendUserFlow(t *testing.T) {
storage := &memorystorage.MockMemoryStorage{}
- userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS)
+ userStore, err := users.NewUserStoreWithHooks(storage, USERSTORE_MAX_USERS, users.UserHooks{
+ DisableFunc: wireguard.DisableAllClientConfigs,
+ DeleteFunc: wireguard.DeleteAllClientConfigs,
+ ReactivateFunc: wireguard.ReactivateAllClientConfigs,
+ })
if err != nil {
t.Fatalf("cannot create new user store: %s", err)
}
@@ -338,7 +161,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) {
if err != nil {
t.Fatalf("cannot empty user store")
}
- s := New(storage, userStore, "token")
+ s := scim.New(storage, userStore, "token")
l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI)
if err != nil {
@@ -371,9 +194,9 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) {
defer l.Close()
// create a user
- payload := PostUserRequest{
+ payload := scim.PostUserRequest{
UserName: "john@domain.inv",
- Name: Name{
+ Name: scim.Name{
GivenName: "John",
FamilyName: "Doe",
},
@@ -384,7 +207,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) {
}
req := httptest.NewRequest("POST", "http://example.com/api/scim/v2/Users?", bytes.NewBuffer(payloadBytes))
w := httptest.NewRecorder()
- s.postUsersHandler(w, req)
+ s.PostUsersHandler(w, req)
resp := w.Result()
@@ -394,7 +217,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) {
defer resp.Body.Close()
- var postUserRequest PostUserRequest
+ var postUserRequest scim.PostUserRequest
err = json.NewDecoder(resp.Body).Decode(&postUserRequest)
if err != nil {
t.Fatalf("Could not decode output: %s", err)
@@ -444,7 +267,7 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) {
req = httptest.NewRequest("PUT", "http://example.com/api/scim/v2/Users/"+user.ID, bytes.NewBuffer(payloadBytes))
req.SetPathValue("id", user.ID)
w = httptest.NewRecorder()
- s.putUserHandler(w, req)
+ s.PutUserHandler(w, req)
resp = w.Result()
@@ -466,3 +289,134 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) {
t.Fatalf("VPN connection is enabled. Expected disabled")
}
}
+
+func TestCreateUserConnectionDeleteUserFlow(t *testing.T) {
+ l, err := net.Listen("tcp", wireguard.CONFIGMANAGER_URI)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodPost:
+ if r.RequestURI == "/refresh-clients" {
+ w.WriteHeader(http.StatusAccepted)
+ w.Write([]byte("OK"))
+ return
+ }
+ if r.RequestURI == "/refresh-server-config" {
+ w.WriteHeader(http.StatusAccepted)
+ w.Write([]byte("OK"))
+ return
+ }
+ w.WriteHeader(http.StatusBadRequest)
+ default:
+ w.WriteHeader(http.StatusBadRequest)
+ }
+ }))
+
+ ts.Listener.Close()
+ ts.Listener = l
+ ts.Start()
+ defer ts.Close()
+ defer l.Close()
+
+ // first create a new user
+ storage := &memorystorage.MockMemoryStorage{}
+
+ userStore, err := users.NewUserStoreWithHooks(storage, USERSTORE_MAX_USERS, users.UserHooks{
+ DisableFunc: wireguard.DisableAllClientConfigs,
+ DeleteFunc: wireguard.DeleteAllClientConfigs,
+ ReactivateFunc: wireguard.ReactivateAllClientConfigs,
+ })
+ if err != nil {
+ t.Fatalf("cannot create new user store: %s", err)
+ }
+
+ v := New(storage, userStore)
+
+ err = v.UserStore.Empty()
+ if err != nil {
+ t.Fatalf("Cannot create context")
+ }
+
+ // create a user
+ userToCreate := users.User{
+ Login: "john",
+ Role: "user",
+ Password: "xyz",
+ }
+ user, err := v.UserStore.AddUser(userToCreate)
+ if err != nil {
+ t.Fatalf("user creation error: %s", err)
+ }
+
+ // generate VPN config
+ _, err = wireguard.CreateNewVPNConfig(v.Storage)
+ if err != nil {
+ t.Fatalf("Cannot create vpn config: %s", err)
+ }
+
+ req := httptest.NewRequest("POST", "http://example.com/connections", nil)
+ w := httptest.NewRecorder()
+ v.connectionsHandler(w, req.WithContext(context.WithValue(context.Background(), rest.CustomValue("user"), user)))
+
+ resp := w.Result()
+
+ if resp.StatusCode != 200 {
+ t.Fatalf("status code is not 200: %d", resp.StatusCode)
+ }
+
+ connectionID := fmt.Sprintf("%s-1", user.ID)
+
+ userConfigFilename := storage.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, connectionID+".json"))
+ configBytes, err := storage.ReadFile(userConfigFilename)
+ if err != nil {
+ t.Fatalf("could not read user config file")
+ }
+
+ var config wireguard.PeerConfig
+ err = json.Unmarshal(configBytes, &config)
+ if err != nil {
+ t.Fatalf("could not parse config: %s", err)
+ }
+ if config.Disabled {
+ t.Fatalf("VPN connection is disabled. Expected not disabled")
+ }
+
+ req = httptest.NewRequest("GET", "http://example.com/connection/"+connectionID, nil)
+ req.SetPathValue("id", connectionID)
+ w = httptest.NewRecorder()
+ v.connectionsElementHandler(w, req.WithContext(context.WithValue(context.Background(), rest.CustomValue("user"), user)))
+
+ resp = w.Result()
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 200 {
+ t.Fatalf("status code is not 200: %d", resp.StatusCode)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("readall error: %s", err)
+ }
+ if !strings.Contains(string(body), "[Interface]") {
+ t.Fatalf("output doesn't look like a wireguard client config: %s", body)
+ }
+
+ err = v.UserStore.DeleteUserByID(user.ID)
+ if err != nil {
+ t.Fatalf("user deletion error: %s", err)
+ }
+
+ err = v.UserStore.UserHooks.DeleteFunc(v.Storage, users.User{ID: user.ID})
+ if err != nil {
+ t.Fatalf("could not delete all clients for user %s: %s", user.ID, err)
+
+ }
+
+ _, err = storage.ReadFile(userConfigFilename)
+ if err == nil {
+ t.Fatalf("could read user config file, expected not to")
+ }
+}
diff --git a/pkg/wireguard/ip.go b/pkg/wireguard/ip.go
index 3def8ad..6d3101d 100644
--- a/pkg/wireguard/ip.go
+++ b/pkg/wireguard/ip.go
@@ -8,7 +8,7 @@ import (
"path"
"strings"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
)
func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix, addressPrefix string) (net.IP, error) {
diff --git a/pkg/wireguard/linux/syncclients/cleanup.go b/pkg/wireguard/linux/syncclients/cleanup.go
index 85a14db..c877d1b 100644
--- a/pkg/wireguard/linux/syncclients/cleanup.go
+++ b/pkg/wireguard/linux/syncclients/cleanup.go
@@ -8,7 +8,7 @@ import (
"fmt"
"path"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
wireguardlinux "github.com/in4it/wireguard-server/pkg/wireguard/linux"
)
diff --git a/pkg/wireguard/linux/syncclients/process.go b/pkg/wireguard/linux/syncclients/process.go
index 0da83f2..8e5327f 100644
--- a/pkg/wireguard/linux/syncclients/process.go
+++ b/pkg/wireguard/linux/syncclients/process.go
@@ -7,7 +7,7 @@ import (
"fmt"
"log"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
wireguardlinux "github.com/in4it/wireguard-server/pkg/wireguard/linux"
)
diff --git a/pkg/wireguard/linux/syncclients/server.go b/pkg/wireguard/linux/syncclients/server.go
index 4675079..42082a8 100644
--- a/pkg/wireguard/linux/syncclients/server.go
+++ b/pkg/wireguard/linux/syncclients/server.go
@@ -8,7 +8,7 @@ import (
"path"
"strings"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
wireguardlinux "github.com/in4it/wireguard-server/pkg/wireguard/linux"
)
diff --git a/pkg/wireguard/linux/syncclients/wg.go b/pkg/wireguard/linux/syncclients/wg.go
index bc36826..567fedf 100644
--- a/pkg/wireguard/linux/syncclients/wg.go
+++ b/pkg/wireguard/linux/syncclients/wg.go
@@ -9,7 +9,7 @@ import (
"path"
"strings"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard"
)
diff --git a/pkg/wireguard/packetlogger.go b/pkg/wireguard/packetlogger.go
index 978acfd..b40069b 100644
--- a/pkg/wireguard/packetlogger.go
+++ b/pkg/wireguard/packetlogger.go
@@ -19,9 +19,9 @@ import (
"github.com/gopacket/gopacket"
"github.com/gopacket/gopacket/layers"
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
- dateutils "github.com/in4it/wireguard-server/pkg/utils/date"
+ "github.com/in4it/go-devops-platform/logging"
+ "github.com/in4it/go-devops-platform/storage"
+ dateutils "github.com/in4it/go-devops-platform/utils/date"
"github.com/packetcap/go-pcap"
"golang.org/x/sys/unix"
)
diff --git a/pkg/wireguard/packetlogger_test.go b/pkg/wireguard/packetlogger_test.go
index a89c047..62267e3 100644
--- a/pkg/wireguard/packetlogger_test.go
+++ b/pkg/wireguard/packetlogger_test.go
@@ -13,9 +13,9 @@ import (
"testing"
"time"
- localstorage "github.com/in4it/wireguard-server/pkg/storage/local"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
- dateutils "github.com/in4it/wireguard-server/pkg/utils/date"
+ localstorage "github.com/in4it/go-devops-platform/storage/local"
+ memorystorage "github.com/in4it/go-devops-platform/storage/memory"
+ dateutils "github.com/in4it/go-devops-platform/utils/date"
)
func TestParsePacket(t *testing.T) {
diff --git a/pkg/wireguard/stats_linux.go b/pkg/wireguard/stats_linux.go
index e4c0bfc..7963741 100644
--- a/pkg/wireguard/stats_linux.go
+++ b/pkg/wireguard/stats_linux.go
@@ -11,8 +11,8 @@ import (
"strings"
"time"
- "github.com/in4it/wireguard-server/pkg/logging"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/logging"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard/linux/stats"
)
diff --git a/pkg/wireguard/vpnconfig.go b/pkg/wireguard/vpnconfig.go
index 8b11516..21ce0f6 100644
--- a/pkg/wireguard/vpnconfig.go
+++ b/pkg/wireguard/vpnconfig.go
@@ -15,7 +15,7 @@ import (
"sync"
"time"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
"github.com/in4it/wireguard-server/pkg/wireguard/network"
)
diff --git a/pkg/wireguard/wireguardclientconfig.go b/pkg/wireguard/wireguardclientconfig.go
index e65d77c..7b56009 100644
--- a/pkg/wireguard/wireguardclientconfig.go
+++ b/pkg/wireguard/wireguardclientconfig.go
@@ -15,7 +15,8 @@ import (
"text/template"
"time"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
+ "github.com/in4it/go-devops-platform/users"
)
var clientConfigMutex sync.Mutex
@@ -313,14 +314,14 @@ func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string)
return out.Bytes(), nil
}
-func DeleteAllClientConfigs(storage storage.Iface, userID string) error {
+func DeleteAllClientConfigs(storage storage.Iface, user users.User) error {
clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR))
if err != nil {
return fmt.Errorf("cannot list files in users clients directory: %s", err)
}
for _, clientFilename := range clients {
- if HasClientUserID(clientFilename, userID) {
+ if HasClientUserID(clientFilename, user.ID) {
filename := storage.ConfigPath(path.Join(VPN_CLIENTS_DIR, clientFilename))
err = storage.Remove(filename)
if err != nil {
@@ -377,7 +378,7 @@ func DeleteClientConfig(storage storage.Iface, connectionID, userID string) erro
}
return nil
}
-func DisableAllClientConfigs(storage storage.Iface, userID string) error {
+func DisableAllClientConfigs(storage storage.Iface, user users.User) error {
clientConfigMutex.Lock()
defer clientConfigMutex.Unlock()
clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR))
@@ -387,7 +388,7 @@ func DisableAllClientConfigs(storage storage.Iface, userID string) error {
toDelete := []string{}
for _, clientFilename := range clients {
- if HasClientUserID(clientFilename, userID) {
+ if HasClientUserID(clientFilename, user.ID) {
toDelete = append(toDelete, clientFilename)
}
}
@@ -438,7 +439,7 @@ func DisableAllClientConfigs(storage storage.Iface, userID string) error {
}
return nil
}
-func ReactivateAllClientConfigs(storage storage.Iface, userID string) error {
+func ReactivateAllClientConfigs(storage storage.Iface, user users.User) error {
clientConfigMutex.Lock()
defer clientConfigMutex.Unlock()
clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR))
@@ -448,7 +449,7 @@ func ReactivateAllClientConfigs(storage storage.Iface, userID string) error {
toAdd := []string{}
for _, clientFilename := range clients {
- if HasClientUserID(clientFilename, userID) {
+ if HasClientUserID(clientFilename, user.ID) {
toAdd = append(toAdd, clientFilename)
}
}
diff --git a/pkg/wireguard/wireguardclientconfig_test.go b/pkg/wireguard/wireguardclientconfig_test.go
index 3a33628..d7019c4 100644
--- a/pkg/wireguard/wireguardclientconfig_test.go
+++ b/pkg/wireguard/wireguardclientconfig_test.go
@@ -12,7 +12,8 @@ import (
"testing"
"time"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
+ memorystorage "github.com/in4it/go-devops-platform/storage/memory"
+ "github.com/in4it/go-devops-platform/users"
)
func TestGetNextFreeIPFromList(t *testing.T) {
@@ -295,7 +296,7 @@ func TestCreateAndDeleteAllClientConfig(t *testing.T) {
t.Errorf("Public key not found in client config")
}
- err = DeleteAllClientConfigs(storage, "2-2-2-2")
+ err = DeleteAllClientConfigs(storage, users.User{ID: "2-2-2-2"})
if err != nil {
t.Fatalf("DeleteAllClientConfigs error: %s", err)
}
@@ -482,7 +483,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) {
t.Errorf("Peer config is disabled")
}
- err = DisableAllClientConfigs(storage, "2-2-2-2")
+ err = DisableAllClientConfigs(storage, users.User{ID: "2-2-2-2"})
if err != nil {
t.Fatalf("DisableAllClientConfigs error: %s", err)
}
@@ -504,7 +505,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) {
t.Errorf("peer config not disabled")
}
- err = ReactivateAllClientConfigs(storage, "2-2-2-2")
+ err = ReactivateAllClientConfigs(storage, users.User{ID: "2-2-2-2"})
if err != nil {
t.Fatalf("DisableAllClientConfigs error: %s", err)
}
diff --git a/pkg/wireguard/wireguardserverconfig.go b/pkg/wireguard/wireguardserverconfig.go
index be0155a..5b1f354 100644
--- a/pkg/wireguard/wireguardserverconfig.go
+++ b/pkg/wireguard/wireguardserverconfig.go
@@ -7,7 +7,7 @@ import (
"path"
"text/template"
- "github.com/in4it/wireguard-server/pkg/storage"
+ "github.com/in4it/go-devops-platform/storage"
)
func WriteWireGuardServerConfig(storage storage.Iface) error {
diff --git a/pkg/wireguard/wireguardserverconfig_test.go b/pkg/wireguard/wireguardserverconfig_test.go
index 3552bf6..dcde1ab 100644
--- a/pkg/wireguard/wireguardserverconfig_test.go
+++ b/pkg/wireguard/wireguardserverconfig_test.go
@@ -8,7 +8,7 @@ import (
"testing"
"time"
- memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory"
+ memorystorage "github.com/in4it/go-devops-platform/storage/memory"
)
func TestWriteWireGuardServerConfig(t *testing.T) {
diff --git a/webapp/package.json b/webapp/package.json
index 41d6964..8625e78 100644
--- a/webapp/package.json
+++ b/webapp/package.json
@@ -5,7 +5,7 @@
"type": "module",
"scripts": {
"dev": "vite",
- "build": "tsc && vite build && cp -r dist/* ../pkg/rest/static",
+ "build": "tsc && vite build && cp -r dist/* ../cmd/rest-server/static",
"lint": "eslint src --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
"preview": "vite preview",
"test": "vitest"
diff --git a/webapp/src/Routes/Connection/Download.tsx b/webapp/src/Routes/Connection/Download.tsx
index f88353b..70cf320 100644
--- a/webapp/src/Routes/Connection/Download.tsx
+++ b/webapp/src/Routes/Connection/Download.tsx
@@ -10,7 +10,7 @@ type Props = {
export function Download({id, name}:Props) {
const {authInfo} = useAuthContext();
const handleDownload = () => {
- fetch(AppSettings.url + '/connection/'+id, {
+ fetch(AppSettings.url + '/vpn/connection/'+id, {
headers: {
"Authorization": "Bearer " + authInfo.token
},
diff --git a/webapp/src/Routes/Connection/ListConnections.tsx b/webapp/src/Routes/Connection/ListConnections.tsx
index cf584cd..f7b4fbb 100644
--- a/webapp/src/Routes/Connection/ListConnections.tsx
+++ b/webapp/src/Routes/Connection/ListConnections.tsx
@@ -17,7 +17,7 @@ export function ListConnections() {
const { isPending, error, data } = useQuery({
queryKey: ['connections'],
queryFn: () =>
- fetch(AppSettings.url + '/connections', {
+ fetch(AppSettings.url + '/vpn/connections', {
headers: {
"Content-Type": "application/json",
"Authorization": "Bearer " + authInfo.token
@@ -30,7 +30,7 @@ export function ListConnections() {
})
const deleteConnection = useMutation({
mutationFn: (id:string) => {
- return axios.delete(AppSettings.url + '/connection/'+id, {
+ return axios.delete(AppSettings.url + '/vpn/connection/'+id, {
headers: {
"Authorization": "Bearer " + authInfo.token
},
diff --git a/webapp/src/Routes/Connection/NewConnection.tsx b/webapp/src/Routes/Connection/NewConnection.tsx
index 0197e50..d03a8e0 100644
--- a/webapp/src/Routes/Connection/NewConnection.tsx
+++ b/webapp/src/Routes/Connection/NewConnection.tsx
@@ -13,7 +13,7 @@ export function NewConnection() {
const alertIcon =
const newConnection = useMutation({
mutationFn: () => {
- return axios.post(AppSettings.url + '/connections', {}, {
+ return axios.post(AppSettings.url + '/vpn/connections', {}, {
headers: {
"Authorization": "Bearer " + authInfo.token
},
@@ -34,7 +34,7 @@ export function NewConnection() {
const { isPending, error, data } = useQuery({
queryKey: ['connectionlicense'],
queryFn: () =>
- fetch(AppSettings.url + '/connectionlicense', {
+ fetch(AppSettings.url + '/vpn/connectionlicense', {
headers: {
"Content-Type": "application/json",
"Authorization": "Bearer " + authInfo.token
diff --git a/webapp/src/Routes/Home/UserStats.tsx b/webapp/src/Routes/Home/UserStats.tsx
index 78ed63f..a00f594 100644
--- a/webapp/src/Routes/Home/UserStats.tsx
+++ b/webapp/src/Routes/Home/UserStats.tsx
@@ -18,7 +18,7 @@ export function UserStats() {
const { isPending, error, data } = useQuery({
queryKey: ['userstats', statsDate, unit],
queryFn: () =>
- fetch(AppSettings.url + '/stats/user/' + format(statsDate === null ? new Date() : statsDate, "yyyy-MM-dd") + "?offset="+timezoneOffset+"&unit=" +unit, {
+ fetch(AppSettings.url + '/vpn/stats/user/' + format(statsDate === null ? new Date() : statsDate, "yyyy-MM-dd") + "?offset="+timezoneOffset+"&unit=" +unit, {
headers: {
"Content-Type": "application/json",
"Authorization": "Bearer " + authInfo.token
diff --git a/webapp/src/Routes/Logs/Logs.tsx b/webapp/src/Routes/Logs/Logs.tsx
index dd8a0e5..375c000 100644
--- a/webapp/src/Routes/Logs/Logs.tsx
+++ b/webapp/src/Routes/Logs/Logs.tsx
@@ -1,9 +1,9 @@
-import { Card, Container, Text, Table, Title, Button, Grid, Popover, Group, TextInput, rem, ActionIcon, Checkbox, Highlight, MultiSelect} from "@mantine/core";
+import { Container, Table, Title, Button, Grid, Popover, Group, TextInput, rem, ActionIcon, Checkbox, Highlight} from "@mantine/core";
import { AppSettings } from "../../Constants/Constants";
import { useInfiniteQuery } from "@tanstack/react-query";
import { useAuthContext } from "../../Auth/Auth";
-import { Link, useSearchParams } from "react-router-dom";
-import { TbArrowRight, TbSearch, TbSettings } from "react-icons/tb";
+import { useSearchParams } from "react-router-dom";
+import { TbArrowRight, TbSearch } from "react-icons/tb";
import { DatePickerInput } from "@mantine/dates";
import { useEffect, useState } from "react";
import React from "react";
diff --git a/webapp/src/Routes/PacketLogs/PacketLogs.tsx b/webapp/src/Routes/PacketLogs/PacketLogs.tsx
index 8f2b681..4d923ea 100644
--- a/webapp/src/Routes/PacketLogs/PacketLogs.tsx
+++ b/webapp/src/Routes/PacketLogs/PacketLogs.tsx
@@ -51,7 +51,7 @@ export function PacketLogs() {
const { isPending, fetchNextPage, hasNextPage, error, data } = useInfiniteQuery({
queryKey: ['packetlogs', user, logsDate, logType, searchParam],
queryFn: async ({ pageParam }) =>
- fetch(AppSettings.url + '/stats/packetlogs/'+(user === undefined || user === "" ? "all" : user)+'/'+(logsDate == undefined ? getDate(new Date()) : getDate(logsDate)) + "?pos="+pageParam+"&offset="+timezoneOffset+"&logtype="+encodeURIComponent(logType.join(","))+"&search="+encodeURIComponent(searchParam), {
+ fetch(AppSettings.url + '/vpn/stats/packetlogs/'+(user === undefined || user === "" ? "all" : user)+'/'+(logsDate == undefined ? getDate(new Date()) : getDate(logsDate)) + "?pos="+pageParam+"&offset="+timezoneOffset+"&logtype="+encodeURIComponent(logType.join(","))+"&search="+encodeURIComponent(searchParam), {
headers: {
"Content-Type": "application/json",
"Authorization": "Bearer " + authInfo.token
diff --git a/webapp/src/Routes/Setup/Restart.tsx b/webapp/src/Routes/Setup/Restart.tsx
index aa017c5..cbbfd0f 100644
--- a/webapp/src/Routes/Setup/Restart.tsx
+++ b/webapp/src/Routes/Setup/Restart.tsx
@@ -18,7 +18,7 @@ export function Restart() {
const alertIcon = ;
const setupMutation = useMutation({
mutationFn: () => {
- return axios.post(AppSettings.url + '/setup/restart-vpn', {}, {
+ return axios.post(AppSettings.url + '/vpn/setup/restart-vpn', {}, {
headers: {
"Authorization": "Bearer " + authInfo.token
},
diff --git a/webapp/src/Routes/Setup/TemplateSetup.tsx b/webapp/src/Routes/Setup/TemplateSetup.tsx
index d48985b..a09a385 100644
--- a/webapp/src/Routes/Setup/TemplateSetup.tsx
+++ b/webapp/src/Routes/Setup/TemplateSetup.tsx
@@ -23,7 +23,7 @@ export function TemplateSetup() {
const { isPending, error, data, isSuccess } = useQuery({
queryKey: ['templates-setup'],
queryFn: () =>
- fetch(AppSettings.url + '/setup/templates', {
+ fetch(AppSettings.url + '/vpn/setup/templates', {
headers: {
"Content-Type": "application/json",
"Authorization": "Bearer " + authInfo.token
@@ -44,7 +44,7 @@ export function TemplateSetup() {
const alertIcon = ;
const setupMutation = useMutation({
mutationFn: (setupRequest: TemplateSetupRequest) => {
- return axios.post(AppSettings.url + '/setup/templates', setupRequest, {
+ return axios.post(AppSettings.url + '/vpn/setup/templates', setupRequest, {
headers: {
"Authorization": "Bearer " + authInfo.token
},
diff --git a/webapp/src/Routes/Setup/VPNSetup.tsx b/webapp/src/Routes/Setup/VPNSetup.tsx
index ff2a55b..62ba1ba 100644
--- a/webapp/src/Routes/Setup/VPNSetup.tsx
+++ b/webapp/src/Routes/Setup/VPNSetup.tsx
@@ -35,7 +35,7 @@ export function VPNSetup() {
const { isPending, error, data, isSuccess } = useQuery({
queryKey: ['vpn-setup'],
queryFn: () =>
- fetch(AppSettings.url + '/setup/vpn', {
+ fetch(AppSettings.url + '/vpn/setup/vpn', {
headers: {
"Content-Type": "application/json",
"Authorization": "Bearer " + authInfo.token
@@ -64,7 +64,7 @@ export function VPNSetup() {
});
const setupMutation = useMutation({
mutationFn: (setupRequest: VPNSetupRequest) => {
- return axios.post(AppSettings.url + '/setup/vpn', setupRequest, {
+ return axios.post(AppSettings.url + '/vpn/setup/vpn', setupRequest, {
headers: {
"Authorization": "Bearer " + authInfo.token
},