Skip to content

Commit

Permalink
feat: add AsGoGetterURL() to connection (#303)
Browse files Browse the repository at this point in the history
* feat: add AsGoGetterURL() to connection

* feat: return file content on AsEnv()

* added tests as well

* feat: env prep

* make envprep context aware

* chore: use lowercase for connection types

* fix: use new CmdEnv

* chore: fix test
  • Loading branch information
adityathebe authored Oct 19, 2023
1 parent 1e214bd commit 748be3f
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 1 deletion.
219 changes: 218 additions & 1 deletion models/connections.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,67 @@
package models

import (
"bytes"
"context"
"encoding/base64"
"fmt"
"math/rand"
"net/url"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"time"

"github.com/flanksource/duty/types"
"github.com/google/uuid"
)

// List of all connection types
const (
ConnectionTypeAWS = "aws"
ConnectionTypeAzure = "azure"
ConnectionTypeAzureDevops = "azure_devops"
ConnectionTypeDiscord = "discord"
ConnectionTypeDynatrace = "dynatrace"
ConnectionTypeElasticSearch = "elasticsearch"
ConnectionTypeEmail = "email"
ConnectionTypeGCP = "google_cloud"
ConnectionTypeGenericWebhook = "generic_webhook"
ConnectionTypeGit = "git"
ConnectionTypeGithub = "github"
ConnectionTypeGoogleChat = "google_chat"
ConnectionTypeHTTP = "http"
ConnectionTypeIFTTT = "ifttt"
ConnectionTypeJMeter = "jmeter"
ConnectionTypeKubernetes = "kubernetes"
ConnectionTypeLDAP = "ldap"
ConnectionTypeMatrix = "matrix"
ConnectionTypeMattermost = "mattermost"
ConnectionTypeMongo = "mongo"
ConnectionTypeMySQL = "mysql"
ConnectionTypeNtfy = "ntfy"
ConnectionTypeOpsGenie = "opsgenie"
ConnectionTypePostgres = "postgres"
ConnectionTypePrometheus = "prometheus"
ConnectionTypePushbullet = "pushbullet"
ConnectionTypePushover = "pushover"
ConnectionTypeRedis = "redis"
ConnectionTypeRestic = "restic"
ConnectionTypeRocketchat = "rocketchat"
ConnectionTypeSFTP = "sftp"
ConnectionTypeSlack = "slack"
ConnectionTypeSlackWebhook = "slackwebhook"
ConnectionTypeSMB = "smb"
ConnectionTypeSQLServer = "sql_server"
ConnectionTypeTeams = "teams"
ConnectionTypeTelegram = "telegram"
ConnectionTypeWebhook = "webhook"
ConnectionTypeWindows = "windows"
ConnectionTypeZulipChat = "zulip_chat"
)

type Connection struct {
ID uuid.UUID `gorm:"primaryKey;unique_index;not null;column:id" json:"id" faker:"uuid_hyphenated" `
Name string `gorm:"column:name" json:"name" faker:"name" `
Expand All @@ -25,9 +78,10 @@ type Connection struct {
}

func (c Connection) String() string {
if c.Type == "aws" {
if strings.ToLower(c.Type) == ConnectionTypeAWS {
return "AWS::" + c.Username
}

var connection string
// Obfuscate passwords of the form ' password=xxxxx ' from connectionString since
// connectionStrings are used as metric labels and we don't want to leak passwords
Expand All @@ -48,3 +102,166 @@ func (c Connection) String() string {
func (c Connection) AsMap(removeFields ...string) map[string]any {
return asMap(c, removeFields...)
}

// AsGoGetterURL returns the connection as a url that's supported by https://github.com/hashicorp/go-getter
// Connection details are added to the url as query params
func (c Connection) AsGoGetterURL() (string, error) {
parsedURL, err := url.Parse(c.URL)
if err != nil {
return "", err
}

var output string
switch strings.ReplaceAll(strings.ToLower(c.Type), " ", "_") {
case ConnectionTypeHTTP:
if c.Username != "" || c.Password != "" {
parsedURL.User = url.UserPassword(c.Username, c.Password)
}

output = parsedURL.String()

case ConnectionTypeGit:
q := parsedURL.Query()

if c.Certificate != "" {
q.Set("sshkey", base64.URLEncoding.EncodeToString([]byte(c.Certificate)))
}

if v, ok := c.Properties["ref"]; ok {
q.Set("ref", v)
}

if v, ok := c.Properties["depth"]; ok {
q.Set("depth", v)
}

parsedURL.RawQuery = q.Encode()
output = parsedURL.String()

case ConnectionTypeAWS:
q := parsedURL.Query()
q.Set("aws_access_key_id", c.Username)
q.Set("aws_access_key_secret", c.Password)

if v, ok := c.Properties["profile"]; ok {
q.Set("aws_profile", v)
}

if v, ok := c.Properties["region"]; ok {
q.Set("region", v)
}

// For S3
if v, ok := c.Properties["version"]; ok {
q.Set("version", v)
}

parsedURL.RawQuery = q.Encode()
output = parsedURL.String()
}

return output, nil
}

// AsEnv generates environment variables and a configuration file content based on the connection type.
func (c Connection) AsEnv(ctx context.Context) EnvPrep {
var envPrep = EnvPrep{
Files: make(map[string]bytes.Buffer),
}

switch strings.ReplaceAll(strings.ToLower(c.Type), " ", "_") {
case ConnectionTypeAWS:
envPrep.Env = append(envPrep.Env, fmt.Sprintf("AWS_ACCESS_KEY_ID=%s", c.Username))
envPrep.Env = append(envPrep.Env, fmt.Sprintf("AWS_SECRET_ACCESS_KEY=%s", c.Password))

// credentialFilePath :="$HOME/.aws/credentials"
credentialFilePath := filepath.Join(".creds", "aws", fmt.Sprintf("cred-%d", rand.Intn(100000000)))

var credentialFile bytes.Buffer
credentialFile.WriteString("[default]\n")
credentialFile.WriteString(fmt.Sprintf("aws_access_key_id = %s\n", c.Username))
credentialFile.WriteString(fmt.Sprintf("aws_secret_access_key = %s\n", c.Password))

if v, ok := c.Properties["profile"]; ok {
envPrep.Env = append(envPrep.Env, fmt.Sprintf("AWS_DEFAULT_PROFILE=%s", v))
}

if v, ok := c.Properties["region"]; ok {
envPrep.Env = append(envPrep.Env, fmt.Sprintf("AWS_DEFAULT_REGION=%s", v))

credentialFile.WriteString(fmt.Sprintf("region = %s\n", v))

envPrep.CmdEnvs = append(envPrep.CmdEnvs, fmt.Sprintf("AWS_DEFAULT_REGION=%s", v))
}

envPrep.Files[credentialFilePath] = credentialFile

envPrep.CmdEnvs = append(envPrep.CmdEnvs, "AWS_EC2_METADATA_DISABLED=true") // https://github.com/aws/aws-cli/issues/5262#issuecomment-705832151
envPrep.CmdEnvs = append(envPrep.CmdEnvs, fmt.Sprintf("AWS_SHARED_CREDENTIALS_FILE=%s", credentialFilePath))

case ConnectionTypeAzure:
args := []string{"login", "--service-principal", "--username", c.Username, "--password", c.Password}
if v, ok := c.Properties["tenant"]; ok {
args = append(args, "--tenant")
args = append(args, v)
}

// login with service principal
envPrep.PreRuns = append(envPrep.PreRuns, exec.CommandContext(ctx, "az", args...))

case ConnectionTypeGCP:
var credentialFile bytes.Buffer
credentialFile.WriteString(c.Certificate)

// credentialFilePath := "$HOME/.config/gcloud/credentials"
credentialFilePath := filepath.Join(".creds", "gcp", fmt.Sprintf("cred-%d", rand.Intn(100000000)))

// to configure gcloud CLI to use the service account specified in GOOGLE_APPLICATION_CREDENTIALS,
// we need to explicitly activate it
envPrep.PreRuns = append(envPrep.PreRuns, exec.CommandContext(ctx, "gcloud", "auth", "activate-service-account", "--key-file", credentialFilePath))
envPrep.Files[credentialFilePath] = credentialFile

envPrep.CmdEnvs = append(envPrep.CmdEnvs, fmt.Sprintf("GOOGLE_APPLICATION_CREDENTIALS=%s", credentialFilePath))
}

return envPrep
}

type EnvPrep struct {
// Env is the connection credentials in environment variables
Env []string

// CmdEnvs is a list of env vars that will be passed to the command
CmdEnvs []string

// List of commands that need to be run before the actual command.
// These commands will setup the connection.
PreRuns []*exec.Cmd

// File contains the content of the configuration file based on the connection
Files map[string]bytes.Buffer
}

// Inject creates the config file & injects the necessary environment variable into the command
func (c *EnvPrep) Inject(ctx context.Context, cmd *exec.Cmd) ([]*exec.Cmd, error) {
for path, file := range c.Files {
if err := saveConfig(file.Bytes(), path); err != nil {
return nil, fmt.Errorf("error saving config to %s: %w", path, err)
}
}

cmd.Env = append(cmd.Env, c.CmdEnvs...)

return c.PreRuns, nil
}

func saveConfig(content []byte, absPath string) error {
file, err := os.Create(absPath)
if err != nil {
return err
}
defer file.Close()

_, err = file.Write(content)
return err
}
106 changes: 106 additions & 0 deletions models/connections_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package models

import (
"context"
"testing"
)

func Test_Connection_AsGoGetterURL(t *testing.T) {
testCases := []struct {
name string
connection Connection
expectedURL string
expectedError error
}{
{
name: "HTTP Connection",
connection: Connection{
Type: ConnectionTypeHTTP,
URL: "http://example.com",
Username: "testuser",
Password: "testpassword",
},
expectedURL: "http://testuser:[email protected]",
expectedError: nil,
},
{
name: "Git Connection",
connection: Connection{
Type: ConnectionTypeGit,
URL: "https://github.com/repo.git",
Certificate: "cert123",
Properties: map[string]string{"ref": "main"},
},
expectedURL: "https://github.com/repo.git?ref=main&sshkey=Y2VydDEyMw%3D%3D",
expectedError: nil,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resultURL, err := tc.connection.AsGoGetterURL()

if resultURL != tc.expectedURL {
t.Errorf("Expected URL: %s, but got: %s", tc.expectedURL, resultURL)
}

if err != tc.expectedError {
t.Errorf("Expected error: %v, but got: %v", tc.expectedError, err)
}
})
}
}

func Test_Connection_AsEnv(t *testing.T) {
testCases := []struct {
name string
connection Connection
expectedEnv []string
expectedFileContent string
}{
{
name: "AWS Connection",
connection: Connection{
Type: ConnectionTypeAWS,
Username: "awsuser",
Password: "awssecret",
Properties: map[string]string{"profile": "awsprofile", "region": "us-east-1"},
},
expectedEnv: []string{
"AWS_ACCESS_KEY_ID=awsuser",
"AWS_SECRET_ACCESS_KEY=awssecret",
"AWS_DEFAULT_PROFILE=awsprofile",
"AWS_DEFAULT_REGION=us-east-1",
},
expectedFileContent: "[default]\naws_access_key_id = awsuser\naws_secret_access_key = awssecret\nregion = us-east-1\n",
},
{
name: "GCP Connection",
connection: Connection{
Type: ConnectionTypeGCP,
Username: "gcpuser",
Certificate: `{"account": "gcpuser"}`,
},
expectedEnv: []string{},
expectedFileContent: `{"account": "gcpuser"}`,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
envPrep := tc.connection.AsEnv(context.Background())

for i, expected := range tc.expectedEnv {
if envPrep.Env[i] != expected {
t.Errorf("Expected environment variable: %s, but got: %s", expected, envPrep.Env[i])
}
}

for _, content := range envPrep.Files {
if content.String() != tc.expectedFileContent {
t.Errorf("Expected file content: %s, but got: %s", tc.expectedFileContent, content.String())
}
}
})
}
}

0 comments on commit 748be3f

Please sign in to comment.