diff --git a/bindings/bindings.go b/bindings/bindings.go index a52692d..bb367d0 100644 --- a/bindings/bindings.go +++ b/bindings/bindings.go @@ -27,7 +27,7 @@ const UseCFAPI bool = true type BindingConfig interface { Validate() error - ToFileSystem() (afero.Fs, error) + ToFileSystem(zerolog.Logger) (afero.Fs, error) } const ( @@ -99,8 +99,13 @@ func ReadConfigFromRegistry(key registry.Key, config any) error { if os.IsNotExist(err) { continue } - if err != nil { - return err + if err == registry.ErrUnexpectedType { + // attempt to read as multi-string + values, _, err := key.GetStringsValue(tag) + if err != nil { + return err + } + value = strings.Join(values, "\n") } fieldvalue.SetString(value) case reflect.Bool: diff --git a/bindings/proxy/client/config.go b/bindings/proxy/client/config.go index 0dada50..3e436bd 100644 --- a/bindings/proxy/client/config.go +++ b/bindings/proxy/client/config.go @@ -4,6 +4,7 @@ import ( "errors" "net/http" + "github.com/rs/zerolog" "github.com/spf13/afero" ) @@ -37,7 +38,7 @@ func (a *authenticator) RoundTrip(r *http.Request) (*http.Response, error) { return a.delegate.RoundTrip(r) } -func (c *Config) ToFileSystem() (afero.Fs, error) { +func (c *Config) ToFileSystem(logger zerolog.Logger) (afero.Fs, error) { httpclient := &http.Client{} httpclient.Transport = &authenticator{ Config: c, diff --git a/bindings/s3/config.go b/bindings/s3/config.go index 91847a9..90a0d6f 100644 --- a/bindings/s3/config.go +++ b/bindings/s3/config.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/balazsgrill/potatodrive/bindings/utils" s3 "github.com/fclairamb/afero-s3" + "github.com/rs/zerolog" "github.com/spf13/afero" ) @@ -39,7 +40,7 @@ func (c *Config) Validate() error { return nil } -func (c *Config) ToFileSystem() (afero.Fs, error) { +func (c *Config) ToFileSystem(logger zerolog.Logger) (afero.Fs, error) { sess, err := session.NewSession(&aws.Config{ Region: aws.String(c.Region), Endpoint: aws.String(c.Endpoint), diff --git a/bindings/sftp/config.go b/bindings/sftp/config.go index acc1bf8..bf33588 100644 --- a/bindings/sftp/config.go +++ b/bindings/sftp/config.go @@ -5,16 +5,18 @@ import ( "github.com/balazsgrill/potatodrive/bindings/utils" sftpclient "github.com/pkg/sftp" + "github.com/rs/zerolog" "github.com/spf13/afero" "github.com/spf13/afero/sftpfs" "golang.org/x/crypto/ssh" ) type Config struct { - User string `flag:"user,User name" reg:"User"` - Password string `flag:"password,Password" reg:"Password"` - Host string `flag:"host,Host:port" reg:"Host"` - Basepath string `flag:"basepath,Base path on remote server" reg:"Basepath"` + User string `flag:"user,User name" reg:"User"` + Password string `flag:"password,Password" reg:"Password"` + PrivateKey string `flag:"privatekey,PrivateKey" reg:"PrivateKey"` + Host string `flag:"host,Host:port" reg:"Host"` + Basepath string `flag:"basepath,Base path on remote server" reg:"Basepath"` } func (c *Config) Validate() error { @@ -24,18 +26,50 @@ func (c *Config) Validate() error { if c.User == "" { return errors.New("user is mandatory") } - if c.Password == "" { - return errors.New("password is mandatory") + if c.Password == "" && c.PrivateKey == "" { + return errors.New("password or private key is mandatory") } return nil } -func (c *Config) Connect(onDisconnect func(error)) (afero.Fs, error) { +func (c *Config) authFromKey() (ssh.AuthMethod, error) { + key, err := ssh.ParseRawPrivateKey([]byte(c.PrivateKey)) + if err != nil { + return nil, err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + return ssh.PublicKeys(signer), nil +} + +type configWithLogger struct { + Config + Logger zerolog.Logger +} + +func (c *configWithLogger) Connect(onDisconnect func(error)) (afero.Fs, error) { + var authmetods []ssh.AuthMethod + if c.Password != "" { + authmetods = append(authmetods, ssh.Password(c.Password)) + } + if c.PrivateKey != "" { + auth, err := c.authFromKey() + if err != nil { + c.Logger.Error().Err(err).Msg("SSH key auth failed") + } else { + authmetods = append(authmetods, auth) + } + } + + if len(authmetods) == 0 { + return nil, errors.New("no valid authentication method is defined") + } + config := ssh.ClientConfig{ - User: c.User, - Auth: []ssh.AuthMethod{ - ssh.Password(c.Password), - }, + User: c.User, + Auth: authmetods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } conn, err := ssh.Dial("tcp", c.Host, &config) @@ -58,10 +92,14 @@ func (c *Config) Connect(onDisconnect func(error)) (afero.Fs, error) { return sftpfs.New(client), nil } -func (c *Config) ToFileSystem() (afero.Fs, error) { +func (c *Config) ToFileSystem(logger zerolog.Logger) (afero.Fs, error) { var remote afero.Fs + cwithlogger := &configWithLogger{ + Config: *c, + Logger: logger, + } remote = &utils.ConnectingFs{ - Connect: c.Connect, + Connect: cwithlogger.Connect, } if c.Basepath != "" { remote = utils.NewBasePathFs(remote, c.Basepath) diff --git a/cmd/main/mgr.go b/cmd/main/mgr.go index 2c6c386..8cb6c87 100644 --- a/cmd/main/mgr.go +++ b/cmd/main/mgr.go @@ -46,26 +46,30 @@ func initLogger() (string, zerolog.Logger, io.Closer) { func startInstance(parentkey registry.Key, keyname string, context bindings.InstanceContext) (io.Closer, error) { key, err := registry.OpenKey(parentkey, keyname, registry.QUERY_VALUE) if err != nil { - context.Logger.Printf("Open key: %v", err) + context.Logger.Error().Msgf("Open key: %v", err) return nil, err } var basec bindings.BaseConfig err = bindings.ReadConfigFromRegistry(key, &basec) if err != nil { - context.Logger.Printf("Get base config: %v", err) + context.Logger.Error().Msgf("Get base config: %v", err) return nil, err } config := bindings.CreateConfigByType(basec.Type) - bindings.ReadConfigFromRegistry(key, config) + err = bindings.ReadConfigFromRegistry(key, config) + if err != nil { + context.Logger.Error().Msgf("Read config: %v", err) + return nil, err + } err = config.Validate() if err != nil { - context.Logger.Printf("Validate config: %v", err) + context.Logger.Error().Msgf("Validate config: %v", err) return nil, err } - fs, err := config.ToFileSystem() + fs, err := config.ToFileSystem(context.Logger) if err != nil { - context.Logger.Printf("Create file system: %v", err) + context.Logger.Error().Msgf("Create file system: %v", err) return nil, err } diff --git a/test/minio/minio_test.go b/test/minio/minio_test.go index fec0f86..446d1e2 100644 --- a/test/minio/minio_test.go +++ b/test/minio/minio_test.go @@ -75,6 +75,9 @@ func setup(t *testing.T) *testInstance { } func (instance *testInstance) start(t *testing.T) { + instancecontext := bindings.InstanceContext{ + Logger: zerolog.New(zerolog.NewTestWriter(t)), + } config := s3.Config{ Endpoint: "localhost:9000", Region: "us-east-1", // default region @@ -83,13 +86,11 @@ func (instance *testInstance) start(t *testing.T) { KeySecret: MINIO_SECRET_KEY, UseSSL: false, } - fs, err := config.ToFileSystem() + fs, err := config.ToFileSystem(instancecontext.Logger) if err != nil { t.Fatal(err) } - instancecontext := bindings.InstanceContext{ - Logger: zerolog.New(zerolog.NewTestWriter(t)), - } + uid := uuid.NewMD5(uuid.UUID{}, []byte("test")) gid := core.BytesToGuid(uid[:]) err = filesystem.RegisterRootPathSimple(*gid, instance.fsdir) diff --git a/test/proxy/proxy_test.go b/test/proxy/proxy_test.go index 0be485b..53c8ab9 100644 --- a/test/proxy/proxy_test.go +++ b/test/proxy/proxy_test.go @@ -6,10 +6,12 @@ import ( "github.com/balazsgrill/potatodrive/bindings/proxy/client" "github.com/balazsgrill/potatodrive/bindings/proxy/server" + "github.com/rs/zerolog" "github.com/spf13/afero" ) func TestProxyConnection(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) fs := afero.NewMemMapFs() mux := http.NewServeMux() mux.HandleFunc("/", server.Handler(fs)) @@ -25,7 +27,7 @@ func TestProxyConnection(t *testing.T) { KeyId: "", KeySecret: "", } - fs2, err := clientconifg.ToFileSystem() + fs2, err := clientconifg.ToFileSystem(logger) if err != nil { t.Fatal(err) }