diff --git a/.gitignore b/.gitignore index 7b517219..3893cecb 100644 --- a/.gitignore +++ b/.gitignore @@ -34,7 +34,9 @@ libtensorflow* # Test generated files cmd/test_*.yaml.bak cmd/test_*.yaml +cmd/docs/* # docker-compose gatewayd-files/ cmd/gatewayd-plugin-cache-linux-amd64-* + diff --git a/.golangci.yaml b/.golangci.yaml index ecb52445..87f9503d 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -87,6 +87,7 @@ linters-settings: - "github.com/spf13/cobra" - "github.com/knadh/koanf" - "github.com/spf13/cast" + - "github.com/jackc/pgx/v5/pgproto3" tagalign: align: false sort: false diff --git a/act/builtins_test.go b/act/builtins_test.go new file mode 100644 index 00000000..cacd6c7e --- /dev/null +++ b/act/builtins_test.go @@ -0,0 +1,131 @@ +package act + +import ( + "testing" + + sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" + "github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres" + gerr "github.com/gatewayd-io/gatewayd/errors" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Terminate_Action(t *testing.T) { + response, err := (&pgproto3.Terminate{}).Encode( + postgres.ErrorResponse( + "Request terminated", + "ERROR", + "42000", + "Policy terminated the request", + ), + ) + require.NoError(t, err) + + tests := []struct { + params []sdkAct.Parameter + result any + err error + }{ + { + params: []sdkAct.Parameter{}, + result: nil, + err: gerr.ErrLoggerRequired, + }, + { + params: []sdkAct.Parameter{ + { + Key: LoggerKey, + Value: nil, + }, + }, + result: nil, + err: gerr.ErrLoggerRequired, + }, + { + params: []sdkAct.Parameter{ + { + Key: LoggerKey, + Value: zerolog.New(nil), + }, + }, + result: true, + }, + { + params: []sdkAct.Parameter{ + { + Key: LoggerKey, + Value: zerolog.New(nil), + }, + { + Key: ResultKey, + Value: nil, + }, + }, + result: true, + }, + { + params: []sdkAct.Parameter{ + { + Key: LoggerKey, + Value: zerolog.New(nil), + }, + { + Key: ResultKey, + Value: map[string]any{}, + }, + }, + result: map[string]any{"response": response}, + }, + } + + for _, test := range tests { + t.Run("Test_Terminate_Action", func(t *testing.T) { + result, err := Terminate(nil, test.params...) + assert.ErrorIs(t, err, test.err) + assert.Equal(t, result, test.result) + }) + } +} + +func Test_Log_Action(t *testing.T) { + tests := []struct { + params []sdkAct.Parameter + result any + err error + }{ + { + params: []sdkAct.Parameter{}, + result: nil, + err: gerr.ErrLoggerRequired, + }, + { + params: []sdkAct.Parameter{ + { + Key: LoggerKey, + Value: nil, + }, + }, + result: nil, + err: gerr.ErrLoggerRequired, + }, + { + params: []sdkAct.Parameter{ + { + Key: LoggerKey, + Value: zerolog.New(nil), + }, + }, + result: true, + }, + } + + for _, test := range tests { + t.Run("Test_Log_Action", func(t *testing.T) { + result, err := Log(nil, test.params...) + assert.ErrorIs(t, err, test.err) + assert.Equal(t, result, test.result) + }) + } +} diff --git a/api/api_test.go b/api/api_test.go index cc885164..35b0f5dd 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -6,6 +6,7 @@ import ( "testing" sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin" + pluginV1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" "github.com/gatewayd-io/gatewayd/act" v1 "github.com/gatewayd-io/gatewayd/api/v1" "github.com/gatewayd-io/gatewayd/config" @@ -30,7 +31,8 @@ func TestGetGlobalConfig(t *testing.T) { // Load config from the default config file. conf := config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.TODO()) + require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) api := API{ @@ -53,13 +55,18 @@ func TestGetGlobalConfigWithGroupName(t *testing.T) { // Load config from the default config file. conf := config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.TODO()) + require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) api := API{ Config: conf, } - globalConfig, err := api.GetGlobalConfig(context.Background(), &v1.Group{GroupName: nil}) + defaultGroup := config.Default + globalConfig, err := api.GetGlobalConfig( + context.Background(), + &v1.Group{GroupName: &defaultGroup}, + ) require.NoError(t, err) globalconf := globalConfig.AsMap() assert.NotEmpty(t, globalconf) @@ -71,7 +78,7 @@ func TestGetGlobalConfigWithGroupName(t *testing.T) { assert.NotEmpty(t, globalconf["servers"]) assert.NotEmpty(t, globalconf["metrics"]) assert.NotEmpty(t, globalconf["api"]) - if _, ok := globalconf["loggers"].(map[string]interface{})["default"]; !ok { + if _, ok := globalconf["loggers"].(map[string]interface{})[config.Default]; !ok { t.Errorf("loggers.default is not found") } } @@ -80,7 +87,8 @@ func TestGetGlobalConfigWithNonExistingGroupName(t *testing.T) { // Load config from the default config file. conf := config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.TODO()) + require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) api := API{ @@ -96,7 +104,8 @@ func TestGetPluginConfig(t *testing.T) { // Load config from the default config file. conf := config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.TODO()) + require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) api := API{ @@ -136,6 +145,17 @@ func TestGetPlugins(t *testing.T) { RemoteURL: "plugin-url", Checksum: "plugin-checksum", }, + Requires: []sdkPlugin.Identifier{ + { + Name: "plugin1-name", + Version: "plugin1-version", + RemoteURL: "plugin1-url", + Checksum: "plugin1-checksum", + }, + }, + Hooks: []pluginV1.HookName{ + pluginV1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_CLIENT, + }, }) api := API{ @@ -145,6 +165,11 @@ func TestGetPlugins(t *testing.T) { require.NoError(t, err) assert.NotEmpty(t, plugins) assert.NotEmpty(t, plugins.GetConfigs()) + assert.NotEmpty(t, plugins.GetConfigs()[0].GetRequires()) + assert.Equal( + t, + int32(pluginV1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_CLIENT), + plugins.GetConfigs()[0].GetHooks()[0]) } func TestGetPluginsWithEmptyPluginRegistry(t *testing.T) { diff --git a/api/embed_swagger_test.go b/api/embed_swagger_test.go new file mode 100644 index 00000000..40065db9 --- /dev/null +++ b/api/embed_swagger_test.go @@ -0,0 +1,11 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_IsSwaggerEmbedded(t *testing.T) { + assert.False(t, IsSwaggerEmbedded()) +} diff --git a/api/healthcheck_test.go b/api/healthcheck_test.go new file mode 100644 index 00000000..2aa2a5ce --- /dev/null +++ b/api/healthcheck_test.go @@ -0,0 +1,94 @@ +package api + +import ( + "context" + "testing" + + "github.com/gatewayd-io/gatewayd/act" + "github.com/gatewayd-io/gatewayd/config" + "github.com/gatewayd-io/gatewayd/network" + "github.com/gatewayd-io/gatewayd/plugin" + "github.com/gatewayd-io/gatewayd/pool" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/health/grpc_health_v1" +) + +func Test_Healthchecker(t *testing.T) { + clientConfig := &config.Client{ + Network: config.DefaultNetwork, + Address: config.DefaultAddress, + } + client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil) + newPool := pool.NewPool(context.TODO(), 1) + require.NotNil(t, newPool) + assert.Nil(t, newPool.Put(client.ID, client)) + + proxy := network.NewProxy( + context.TODO(), + network.Proxy{ + AvailableConnections: newPool, + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &config.Client{ + Network: config.DefaultNetwork, + Address: config.DefaultAddress, + }, + Logger: zerolog.Logger{}, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) + + actRegistry := act.NewActRegistry( + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: zerolog.Logger{}, + }) + + pluginRegistry := plugin.NewRegistry( + context.TODO(), + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: zerolog.Logger{}, + DevMode: true, + }, + ) + + server := network.NewServer( + context.TODO(), + network.Server{ + Network: config.DefaultNetwork, + Address: config.DefaultAddress, + TickInterval: config.DefaultTickInterval, + Options: network.Option{ + EnableTicker: false, + }, + Proxy: proxy, + Logger: zerolog.Logger{}, + PluginRegistry: pluginRegistry, + PluginTimeout: config.DefaultPluginTimeout, + HandshakeTimeout: config.DefaultHandshakeTimeout, + }, + ) + + healthchecker := HealthChecker{ + Servers: map[string]*network.Server{ + config.Default: server, + }, + } + assert.NotNil(t, healthchecker) + hcr, err := healthchecker.Check(context.TODO(), &grpc_health_v1.HealthCheckRequest{}) + assert.NoError(t, err) + assert.NotNil(t, hcr) + assert.Equal(t, grpc_health_v1.HealthCheckResponse_NOT_SERVING, hcr.GetStatus()) + + err = healthchecker.Watch(&grpc_health_v1.HealthCheckRequest{}, nil) + assert.Error(t, err) + assert.Equal(t, "rpc error: code = Unimplemented desc = not implemented", err.Error()) +} diff --git a/cmd/configs.go b/cmd/configs.go index ef867725..3cc2d2b2 100644 --- a/cmd/configs.go +++ b/cmd/configs.go @@ -27,7 +27,9 @@ func generateConfig( GlobalKoanf: koanf.New("."), PluginKoanf: koanf.New("."), } - conf.LoadDefaults(context.TODO()) + if err := conf.LoadDefaults(context.TODO()); err != nil { + logger.Fatal(err) + } // Marshal the config file to YAML. var konfig *koanf.Koanf @@ -66,20 +68,32 @@ func generateConfig( } // lintConfig lints the given config file of the given type. -func lintConfig(fileType configFileType, configFile string) error { +func lintConfig(fileType configFileType, configFile string) *gerr.GatewayDError { // Load the config file and check it for errors. var conf *config.Config switch fileType { case Global: conf = config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: configFile}) - conf.LoadDefaults(context.TODO()) - conf.LoadGlobalConfigFile(context.TODO()) - conf.UnmarshalGlobalConfig(context.TODO()) + if err := conf.LoadDefaults(context.TODO()); err != nil { + return err + } + if err := conf.LoadGlobalConfigFile(context.TODO()); err != nil { + return err + } + if err := conf.UnmarshalGlobalConfig(context.TODO()); err != nil { + return err + } case Plugins: conf = config.NewConfig(context.TODO(), config.Config{PluginConfigFile: configFile}) - conf.LoadDefaults(context.TODO()) - conf.LoadPluginConfigFile(context.TODO()) - conf.UnmarshalPluginConfig(context.TODO()) + if err := conf.LoadDefaults(context.TODO()); err != nil { + return err + } + if err := conf.LoadPluginConfigFile(context.TODO()); err != nil { + return err + } + if err := conf.UnmarshalPluginConfig(context.TODO()); err != nil { + return err + } default: return gerr.ErrLintingFailed } diff --git a/cmd/gen_docs.go b/cmd/gen_docs.go index 19a08e42..45593931 100644 --- a/cmd/gen_docs.go +++ b/cmd/gen_docs.go @@ -1,10 +1,14 @@ package cmd import ( + "os" + "github.com/spf13/cobra" "github.com/spf13/cobra/doc" ) +const outputDirPermissions = 0o755 + var docOutputDir string var genDocs = &cobra.Command{ @@ -12,6 +16,12 @@ var genDocs = &cobra.Command{ Short: "Generate markdown documentation", Hidden: true, Run: func(cmd *cobra.Command, _ []string) { + // Create the output directory if it doesn't exist + if err := os.MkdirAll(docOutputDir, outputDirPermissions); err != nil { + cmd.PrintErr(err) + return + } + // Generate the markdown files err := doc.GenMarkdownTree(rootCmd, docOutputDir) if err != nil { cmd.PrintErr(err) diff --git a/cmd/gen_docs_test.go b/cmd/gen_docs_test.go new file mode 100644 index 00000000..a5fc7631 --- /dev/null +++ b/cmd/gen_docs_test.go @@ -0,0 +1,17 @@ +package cmd + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_genDocs(t *testing.T) { + _, err := executeCommandC(rootCmd, "gen-docs", "--output-dir", "./docs") + require.NoError(t, err, "genDocs should not return an error") + assert.DirExists(t, "./docs", "genDocs should create the output directory") + assert.FileExists(t, "./docs/gatewayd.md", "genDocs should create the markdown file") + require.NoError(t, os.RemoveAll("./docs")) +} diff --git a/cmd/plugin_list.go b/cmd/plugin_list.go index 0384de91..000d9616 100644 --- a/cmd/plugin_list.go +++ b/cmd/plugin_list.go @@ -57,9 +57,18 @@ func init() { func listPlugins(cmd *cobra.Command, pluginConfigFile string, onlyEnabled bool) { // Load the plugin config file. conf := config.NewConfig(context.TODO(), config.Config{PluginConfigFile: pluginConfigFile}) - conf.LoadDefaults(context.TODO()) - conf.LoadPluginConfigFile(context.TODO()) - conf.UnmarshalPluginConfig(context.TODO()) + if err := conf.LoadDefaults(context.TODO()); err != nil { + cmd.PrintErr(err) + return + } + if err := conf.LoadPluginConfigFile(context.TODO()); err != nil { + cmd.PrintErr(err) + return + } + if err := conf.UnmarshalPluginConfig(context.TODO()); err != nil { + cmd.PrintErr(err) + return + } if len(conf.Plugin.Plugins) != 0 { cmd.Printf("Total plugins: %d\n", len(conf.Plugin.Plugins)) diff --git a/cmd/plugin_test.go b/cmd/plugin_test.go index ca8b9dad..e2b8b7bf 100644 --- a/cmd/plugin_test.go +++ b/cmd/plugin_test.go @@ -18,6 +18,7 @@ Usage: gatewayd plugin [command] Available Commands: + help Help about any command init Create or overwrite the GatewayD plugins config install Install a plugin from a local archive or a GitHub repository lint Lint the GatewayD plugins config diff --git a/cmd/root_test.go b/cmd/root_test.go index 951d589f..de1ab45b 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1,18 +1,15 @@ package cmd import ( + "bytes" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func Test_rootCmd(t *testing.T) { - output, err := executeCommandC(rootCmd) - require.NoError(t, err, "rootCmd should not return an error") - //nolint:lll - assert.Equal(t, - `GatewayD is a cloud-native database gateway and framework for building data-driven applications. It sits between your database servers and clients and proxies all their communication. +//nolint:lll +const rootHelp string = `GatewayD is a cloud-native database gateway and framework for building data-driven applications. It sits between your database servers and clients and proxies all their communication. Usage: gatewayd [command] @@ -29,7 +26,21 @@ Flags: -h, --help help for gatewayd Use "gatewayd [command] --help" for more information about a command. -`, +` + +func Test_rootCmd(t *testing.T) { + output, err := executeCommandC(rootCmd) + require.NoError(t, err, "rootCmd should not return an error") + assert.Equal(t, + rootHelp, output, "rootCmd should print the correct output") } + +func Test_Execute(t *testing.T) { + buf := new(bytes.Buffer) + rootCmd.SetOut(buf) + rootCmd.SetErr(buf) + Execute() + assert.Equal(t, rootHelp, buf.String(), "Execute should print the correct output") +} diff --git a/cmd/run.go b/cmd/run.go index c788c0e9..686e8103 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -235,7 +235,9 @@ var runCmd = &cobra.Command{ // Load global and plugin configuration. conf = config.NewConfig(runCtx, config.Config{GlobalConfigFile: globalConfigFile, PluginConfigFile: pluginConfigFile}) - conf.InitConfig(runCtx) + if err := conf.InitConfig(runCtx); err != nil { + log.Fatal(err) + } // Create and initialize loggers from the config. // Use cobra command cmd instead of os.Stdout for the console output. @@ -419,7 +421,9 @@ var runCmd = &cobra.Command{ if updatedGlobalConfig != nil { // Merge the config with the one loaded from the file (in memory). // The changes won't be persisted to disk. - conf.MergeGlobalConfig(runCtx, updatedGlobalConfig) + if err := conf.MergeGlobalConfig(runCtx, updatedGlobalConfig); err != nil { + log.Fatal(err) + } } // Start the metrics server if enabled. diff --git a/config/config.go b/config/config.go index a38af813..5e5a0330 100644 --- a/config/config.go +++ b/config/config.go @@ -23,13 +23,14 @@ import ( ) type IConfig interface { - InitConfig(ctx context.Context) - LoadDefaults(ctx context.Context) - LoadPluginEnvVars(ctx context.Context) - LoadGlobalEnvVars(ctx context.Context) - LoadGlobalConfigFile(ctx context.Context) - LoadPluginConfigFile(ctx context.Context) - MergeGlobalConfig(ctx context.Context, updatedGlobalConfig map[string]interface{}) + InitConfig(ctx context.Context) *gerr.GatewayDError + LoadDefaults(ctx context.Context) *gerr.GatewayDError + LoadPluginEnvVars(ctx context.Context) *gerr.GatewayDError + LoadGlobalEnvVars(ctx context.Context) *gerr.GatewayDError + LoadGlobalConfigFile(ctx context.Context) *gerr.GatewayDError + LoadPluginConfigFile(ctx context.Context) *gerr.GatewayDError + MergeGlobalConfig( + ctx context.Context, updatedGlobalConfig map[string]interface{}) *gerr.GatewayDError } type Config struct { @@ -64,24 +65,42 @@ func NewConfig(ctx context.Context, config Config) *Config { } } -func (c *Config) InitConfig(ctx context.Context) { +func (c *Config) InitConfig(ctx context.Context) *gerr.GatewayDError { newCtx, span := otel.Tracer(TracerName).Start(ctx, "Initialize config") defer span.End() - c.LoadDefaults(newCtx) + if err := c.LoadDefaults(newCtx); err != nil { + return err + } - c.LoadPluginConfigFile(newCtx) - c.LoadPluginEnvVars(newCtx) - c.UnmarshalPluginConfig(newCtx) + if err := c.LoadPluginConfigFile(newCtx); err != nil { + return err + } + if err := c.LoadPluginEnvVars(newCtx); err != nil { + return err + } + if err := c.UnmarshalPluginConfig(newCtx); err != nil { + return err + } - c.LoadGlobalConfigFile(newCtx) - c.ValidateGlobalConfig(newCtx) - c.LoadGlobalEnvVars(newCtx) - c.UnmarshalGlobalConfig(newCtx) + if err := c.LoadGlobalConfigFile(newCtx); err != nil { + return err + } + if err := c.ValidateGlobalConfig(newCtx); err != nil { + return err + } + if err := c.LoadGlobalEnvVars(newCtx); err != nil { + return err + } + if err := c.UnmarshalGlobalConfig(newCtx); err != nil { + return err + } + + return nil } // LoadDefaults loads the default configuration before loading the config files. -func (c *Config) LoadDefaults(ctx context.Context) { +func (c *Config) LoadDefaults(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Load defaults") defaultLogger := Logger{ @@ -164,7 +183,8 @@ func (c *Config) LoadDefaults(ctx context.Context) { if err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to unmarshal global configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to unmarshal global configuration: %w", err)) } for configObject, configMap := range gconf { @@ -193,7 +213,7 @@ func (c *Config) LoadDefaults(ctx context.Context) { err := fmt.Errorf("unknown config object: %s", configObject) span.RecordError(err) span.End() - log.Fatal(err) + return gerr.ErrConfigParseError.Wrap(err) } } } @@ -201,7 +221,8 @@ func (c *Config) LoadDefaults(ctx context.Context) { } else if !os.IsNotExist(err) { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to read global configuration file: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to read global configuration file: %w", err)) } c.pluginDefaults = PluginConfig{ @@ -222,7 +243,8 @@ func (c *Config) LoadDefaults(ctx context.Context) { if err := c.GlobalKoanf.Load(structs.Provider(c.globalDefaults, "json"), nil); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to load default global configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to load default global configuration: %w", err)) } } @@ -230,39 +252,48 @@ func (c *Config) LoadDefaults(ctx context.Context) { if err := c.PluginKoanf.Load(structs.Provider(c.pluginDefaults, "json"), nil); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to load default plugin configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to load default plugin configuration: %w", err)) } } span.End() + + return nil } // LoadGlobalEnvVars loads the environment variables into the global configuration with the // given prefix, "GATEWAYD_". -func (c *Config) LoadGlobalEnvVars(ctx context.Context) { +func (c *Config) LoadGlobalEnvVars(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Load global environment variables") if err := c.GlobalKoanf.Load(loadEnvVars(), nil); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to load environment variables: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to load environment variables: %w", err)) } span.End() + + return nil } // LoadPluginEnvVars loads the environment variables into the plugins configuration with the // given prefix, "GATEWAYD_". -func (c *Config) LoadPluginEnvVars(ctx context.Context) { +func (c *Config) LoadPluginEnvVars(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Load plugin environment variables") if err := c.PluginKoanf.Load(loadEnvVars(), nil); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to load environment variables: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to load environment variables: %w", err)) } span.End() + + return nil } func loadEnvVars() *env.Env { @@ -272,33 +303,39 @@ func loadEnvVars() *env.Env { } // LoadGlobalConfigFile loads the plugin configuration file. -func (c *Config) LoadGlobalConfigFile(ctx context.Context) { +func (c *Config) LoadGlobalConfigFile(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Load global config file") if err := c.GlobalKoanf.Load(file.Provider(c.GlobalConfigFile), yaml.Parser()); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to load global configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to load global configuration: %w", err)) } span.End() + + return nil } // LoadPluginConfigFile loads the plugin configuration file. -func (c *Config) LoadPluginConfigFile(ctx context.Context) { +func (c *Config) LoadPluginConfigFile(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Load plugin config file") if err := c.PluginKoanf.Load(file.Provider(c.PluginConfigFile), yaml.Parser()); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to load plugin configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to load plugin configuration: %w", err)) } span.End() + + return nil } // UnmarshalGlobalConfig unmarshals the global configuration for easier access. -func (c *Config) UnmarshalGlobalConfig(ctx context.Context) { +func (c *Config) UnmarshalGlobalConfig(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Unmarshal global config") if err := c.GlobalKoanf.UnmarshalWithConf("", &c.Global, koanf.UnmarshalConf{ @@ -306,14 +343,17 @@ func (c *Config) UnmarshalGlobalConfig(ctx context.Context) { }); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to unmarshal global configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to unmarshal global configuration: %w", err)) } span.End() + + return nil } // UnmarshalPluginConfig unmarshals the plugin configuration for easier access. -func (c *Config) UnmarshalPluginConfig(ctx context.Context) { +func (c *Config) UnmarshalPluginConfig(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Unmarshal plugin config") if err := c.PluginKoanf.UnmarshalWithConf("", &c.Plugin, koanf.UnmarshalConf{ @@ -321,21 +361,25 @@ func (c *Config) UnmarshalPluginConfig(ctx context.Context) { }); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to unmarshal plugin configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to unmarshal plugin configuration: %w", err)) } span.End() + + return nil } func (c *Config) MergeGlobalConfig( ctx context.Context, updatedGlobalConfig map[string]interface{}, -) { +) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Merge global config from plugins") if err := c.GlobalKoanf.Load(confmap.Provider(updatedGlobalConfig, "."), nil); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to merge global configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to merge global configuration: %w", err)) } if err := c.GlobalKoanf.UnmarshalWithConf("", &c.Global, koanf.UnmarshalConf{ @@ -343,23 +387,24 @@ func (c *Config) MergeGlobalConfig( }); err != nil { span.RecordError(err) span.End() - log.Fatal(fmt.Errorf("failed to unmarshal global configuration: %w", err)) + return gerr.ErrConfigParseError.Wrap( + fmt.Errorf("failed to unmarshal global configuration: %w", err)) } span.End() + + return nil } -func (c *Config) ValidateGlobalConfig(ctx context.Context) { +func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { _, span := otel.Tracer(TracerName).Start(ctx, "Validate global config") var globalConfig GlobalConfig if err := c.GlobalKoanf.Unmarshal("", &globalConfig); err != nil { span.RecordError(err) span.End() - log.Fatal( - gerr.ErrValidationFailed.Wrap( - fmt.Errorf("failed to unmarshal global configuration: %w", err)), - ) + return gerr.ErrValidationFailed.Wrap( + fmt.Errorf("failed to unmarshal global configuration: %w", err)) } var errors []*gerr.GatewayDError @@ -464,8 +509,13 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) { for _, err := range errors { log.Println(err) } - span.RecordError(goerrors.New("failed to validate global configuration")) + err := goerrors.New("failed to validate global configuration") + span.RecordError(err) span.End() - log.Fatal("failed to validate global configuration") + return gerr.ErrValidationFailed.Wrap(err) } + + span.End() + + return nil } diff --git a/config/config_test.go b/config/config_test.go index b89a9212..3f4d1b81 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -6,6 +6,7 @@ import ( "github.com/knadh/koanf" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var parentDir = "../" @@ -28,10 +29,43 @@ func TestNewConfig(t *testing.T) { // TestInitConfig tests the InitConfig function, which practically tests all // the other functions. func TestInitConfig(t *testing.T) { + ctx := context.Background() + config := NewConfig(ctx, + Config{ + GlobalConfigFile: parentDir + "cmd/testdata/gatewayd.yaml", + PluginConfigFile: parentDir + PluginsConfigFilename, + }, + ) + err := config.InitConfig(ctx) + require.Nil(t, err) + assert.NotNil(t, config.Global) + assert.NotEqual(t, GlobalConfig{}, config.Global) + assert.Contains(t, config.Global.Servers, Default) + assert.Contains(t, config.Global.Servers, "test") // Test the multi-tenant configuration. + assert.NotNil(t, config.Plugin) + assert.NotEqual(t, PluginConfig{}, config.Plugin) + assert.Len(t, config.Plugin.Plugins, 1) + assert.NotNil(t, config.GlobalKoanf) + assert.NotEqual(t, config.GlobalKoanf, koanf.New(".")) + assert.Equal(t, DefaultLogLevel, config.GlobalKoanf.String("loggers.default.level")) + assert.NotNil(t, config.PluginKoanf) + assert.NotEqual(t, config.PluginKoanf, koanf.New(".")) + assert.NotNil(t, config.globalDefaults) + assert.NotEqual(t, GlobalConfig{}, config.globalDefaults) + assert.Contains(t, config.globalDefaults.Servers, Default) + assert.Contains(t, config.globalDefaults.Servers, "test") + assert.NotNil(t, config.pluginDefaults) + assert.NotEqual(t, PluginConfig{}, config.pluginDefaults) + assert.Empty(t, config.pluginDefaults.Plugins) +} + +// TestInitConfigMultiTenant tests the InitConfig function with a multi-tenant configuration. +func TestInitConfigMultiTenant(t *testing.T) { ctx := context.Background() config := NewConfig(ctx, Config{GlobalConfigFile: parentDir + GlobalConfigFilename, PluginConfigFile: parentDir + PluginsConfigFilename}) - config.InitConfig(ctx) + err := config.InitConfig(ctx) + require.Nil(t, err) assert.NotNil(t, config.Global) assert.NotEqual(t, GlobalConfig{}, config.Global) assert.Contains(t, config.Global.Servers, Default) @@ -51,23 +85,58 @@ func TestInitConfig(t *testing.T) { assert.Empty(t, config.pluginDefaults.Plugins) } +// TestInitConfigMissingFile tests the InitConfig function with a missing file. +func TestInitConfigMissingKeys(t *testing.T) { + ctx := context.Background() + config := NewConfig(ctx, + Config{ + GlobalConfigFile: "./testdata/missing_keys.yaml", + PluginConfigFile: parentDir + PluginsConfigFilename, + }, + ) + err := config.InitConfig(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), + "validation failed, OriginalError: failed to validate global configuration") +} + +// TestInitConfigMissingFile tests the InitConfig function with a missing file. +func TestInitConfigMissingFile(t *testing.T) { + ctx := context.Background() + config := NewConfig(ctx, + Config{ + GlobalConfigFile: "./testdata/missing_file.yaml", + PluginConfigFile: parentDir + PluginsConfigFilename, + }, + ) + err := config.InitConfig(ctx) + assert.Error(t, err) + assert.Contains( + t, + err.Error(), + "error parsing config, OriginalError: failed to load global configuration: "+ + "open testdata/missing_file.yaml: no such file or directory") +} + // TestMergeGlobalConfig tests the MergeGlobalConfig function. func TestMergeGlobalConfig(t *testing.T) { ctx := context.Background() config := NewConfig(ctx, Config{GlobalConfigFile: parentDir + GlobalConfigFilename, PluginConfigFile: parentDir + PluginsConfigFilename}) - config.InitConfig(ctx) + err := config.InitConfig(ctx) + require.Nil(t, err) // The default log level is info. assert.Equal(t, DefaultLogLevel, config.Global.Loggers[Default].Level) // Merge a config that sets the log level to debug. - config.MergeGlobalConfig(ctx, map[string]interface{}{ + err = config.MergeGlobalConfig(ctx, map[string]interface{}{ "loggers": map[string]interface{}{ "default": map[string]interface{}{ "level": "debug", }, }, }) + require.Nil(t, err) assert.NotNil(t, config.Global) assert.NotEqual(t, GlobalConfig{}, config.Global) // The log level should now be debug. diff --git a/config/getters_test.go b/config/getters_test.go index f53a28a1..3f71eb37 100644 --- a/config/getters_test.go +++ b/config/getters_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestGetOutput tests the GetOutput function. @@ -13,6 +14,12 @@ func TestGetOutput(t *testing.T) { assert.Equal(t, []LogOutput{Console}, logger.GetOutput()) } +// TestGetOutputWithMultipleLoggers tests the GetOutput function with multiple loggers. +func TestGetOutputWithMultipleLoggers(t *testing.T) { + logger := Logger{Output: []string{"console", "file"}} + assert.Equal(t, []LogOutput{Console, File}, logger.GetOutput()) +} + // TestGetPlugins tests the GetPlugins function. func TestGetPlugins(t *testing.T) { plugin := Plugin{Name: "plugin1"} @@ -30,7 +37,8 @@ func TestFilter(t *testing.T) { // Load config from the default config file. conf := NewConfig(context.TODO(), Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - conf.InitConfig(context.TODO()) + err := conf.InitConfig(context.TODO()) + require.Nil(t, err) assert.NotEmpty(t, conf.Global) // Filter the config. @@ -43,3 +51,17 @@ func TestFilter(t *testing.T) { assert.Contains(t, defaultGroup.Metrics, Default) assert.Contains(t, defaultGroup.Loggers, Default) } + +// TestFilterWithMissingGroupName tests the Filter function with a missing group name. +func TestFilterWithMissingGroupName(t *testing.T) { + // Load config from the default config file. + conf := NewConfig(context.TODO(), + Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) + err := conf.InitConfig(context.TODO()) + require.Nil(t, err) + assert.NotEmpty(t, conf.Global) + + // Filter the config. + defaultGroup := conf.Global.Filter("missing") + assert.Empty(t, defaultGroup) +} diff --git a/config/getters_unix_test.go b/config/getters_unix_test.go new file mode 100644 index 00000000..7b073b48 --- /dev/null +++ b/config/getters_unix_test.go @@ -0,0 +1,24 @@ +//go:build !windows +// +build !windows + +package config + +import ( + "log/syslog" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestSyslogPriority tests the GetSyslogPriority function. +func TestSyslogPriority(t *testing.T) { + logger := Logger{SyslogPriority: "warning"} + assert.Equal(t, logger.GetSyslogPriority(), syslog.LOG_DAEMON|syslog.LOG_WARNING) +} + +// TestSyslogPriorityWithInvalidPriority tests the GetSyslogPriority function with +// an invalid priority, which should return the default priority. +func TestSyslogPriorityWithInvalidPriority(t *testing.T) { + logger := Logger{SyslogPriority: "invalid"} + assert.Equal(t, logger.GetSyslogPriority(), syslog.LOG_DAEMON|syslog.LOG_INFO) +} diff --git a/config/testdata/missing_keys.yaml b/config/testdata/missing_keys.yaml new file mode 100644 index 00000000..d2311a7d --- /dev/null +++ b/config/testdata/missing_keys.yaml @@ -0,0 +1,45 @@ +# GatewayD Global Configuration + +loggers: + default: + level: info + output: ["console"] + noColor: True + # The "test" key is missing in the testdata file to test validation + # test: + # level: info + # output: ["console"] + # noColor: True + +metrics: + default: + enabled: True + test: + enabled: True + +clients: + default: + address: localhost:5432 + test: + address: localhost:5433 + +pools: + default: + size: 10 + test: + size: 10 + +proxies: + default: + healthCheckPeriod: 60s # duration + test: + healthCheckPeriod: 60s # duration + +servers: + default: + address: 0.0.0.0:15432 + test: + address: 0.0.0.0:15433 + +api: + enabled: True diff --git a/errors/errors.go b/errors/errors.go index b2f2a214..8078a6eb 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -48,6 +48,7 @@ const ( ErrCodeAsyncAction ErrCodeEvalError ErrCodeMsgEncodeError + ErrCodeConfigParseError ) var ( @@ -173,6 +174,10 @@ var ( ErrCodeMsgEncodeError, "error encoding message", nil, } + ErrConfigParseError = &GatewayDError{ + ErrCodeConfigParseError, "error parsing config", nil, + } + // Unwrapped errors. ErrLoggerRequired = errors.New("terminate action requires a logger parameter") ) diff --git a/logging/hclog_adapter_test.go b/logging/hclog_adapter_test.go index eaf35aa9..c734b9aa 100644 --- a/logging/hclog_adapter_test.go +++ b/logging/hclog_adapter_test.go @@ -124,22 +124,37 @@ func TestNewHcLogAdapter_Log(t *testing.T) { } func TestNewHcLogAdapter_GetLevel(t *testing.T) { - logger := NewLogger( - context.Background(), - LoggerConfig{ - Output: []config.LogOutput{config.Console}, - Level: zerolog.TraceLevel, - TimeFormat: zerolog.TimeFormatUnix, - NoColor: true, - }, - ) - - hcLogAdapter := NewHcLogAdapter(&logger, "test") - hcLogAdapter.SetLevel(hclog.Trace) - assert.Equal(t, hclog.Trace, hcLogAdapter.GetLevel()) - - hcLogAdapter.SetLevel(hclog.Debug) - assert.Equal(t, hclog.Debug, hcLogAdapter.GetLevel()) - assert.NotEqual(t, zerolog.DebugLevel, logger.GetLevel(), - "The logger should not be affected by the hclog adapter's loggerLevel") + levels := map[zerolog.Level]hclog.Level{ + zerolog.NoLevel: hclog.NoLevel, + zerolog.TraceLevel: hclog.Trace, + zerolog.DebugLevel: hclog.Debug, + zerolog.InfoLevel: hclog.Info, + zerolog.WarnLevel: hclog.Warn, + zerolog.ErrorLevel: hclog.Error, + zerolog.FatalLevel: hclog.Error, + zerolog.PanicLevel: hclog.Error, + zerolog.Disabled: hclog.Off, + } + + for zerologLevel, hclogLevel := range levels { + logger := NewLogger( + context.Background(), + LoggerConfig{ + Output: []config.LogOutput{config.Console}, + Level: zerologLevel, + TimeFormat: zerolog.TimeFormatUnix, + NoColor: true, + }, + ) + + hcLogAdapter := NewHcLogAdapter(&logger, "test") + hcLogAdapter.SetLevel(hclogLevel) + hcLogAdapter.Log(hclogLevel, "This is a message", "key", "value") + assert.Equal(t, hclogLevel, hcLogAdapter.GetLevel()) + + hcLogAdapter.SetLevel(hclog.Debug) + assert.Equal(t, hclog.Debug, hcLogAdapter.GetLevel()) + assert.NotEqual(t, zerolog.DebugLevel, logger.GetLevel(), + "The logger should not be affected by the hclog adapter's loggerLevel") + } } diff --git a/logging/logger_unix_test.go b/logging/logger_unix_test.go new file mode 100644 index 00000000..7620c3cf --- /dev/null +++ b/logging/logger_unix_test.go @@ -0,0 +1,44 @@ +//go:build !windows +// +build !windows + +package logging + +import ( + "context" + "log/syslog" + "testing" + "time" + + "github.com/gatewayd-io/gatewayd/config" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +func TestSyslogAndRsyslog(t *testing.T) { + go func() { + testServer("tcp", "127.0.0.1:1514") + }() + + logger := NewLogger( + context.Background(), + LoggerConfig{ + Output: []config.LogOutput{config.Syslog, config.RSyslog}, + TimeFormat: zerolog.TimeFormatUnix, + Level: zerolog.WarnLevel, + NoColor: true, + ConsoleTimeFormat: time.RFC3339, + RSyslogNetwork: "tcp", + RSyslogAddress: "localhost:1514", + SyslogPriority: syslog.LOG_DAEMON | syslog.LOG_WARNING, + FileName: "", + MaxSize: 0, + MaxBackups: 0, + MaxAge: 0, + Compress: false, + LocalTime: false, + ConsoleOut: nil, + Name: config.Default, + }, + ) + assert.NotNil(t, logger) +} diff --git a/logging/logging_helpers_test.go b/logging/logging_helpers_test.go new file mode 100644 index 00000000..29ba6212 --- /dev/null +++ b/logging/logging_helpers_test.go @@ -0,0 +1,40 @@ +package logging + +import ( + "bufio" + "log" + "net" +) + +func testServer(network, address string) { + // Create a new listener + listener, err := net.Listen(network, address) + if err != nil { + log.Println(err) + return + } + + // Accept incoming connections + conn, err := listener.Accept() + if err != nil { + log.Println(err) + return + } + + // Handle the connection + handleConnection(conn) +} + +func handleConnection(conn net.Conn) { + // Close the connection + defer conn.Close() + + // Create a new scanner + scanner := bufio.NewScanner(conn) + + // Scan the connection + for scanner.Scan() { + // Print the scanned text + log.Println(scanner.Text()) + } +} diff --git a/metrics/merger_test.go b/metrics/merger_test.go index 4cce31e9..cc8e9130 100644 --- a/metrics/merger_test.go +++ b/metrics/merger_test.go @@ -53,4 +53,12 @@ func TestMerger(t *testing.T) { gatewayd_test_total{plugin="test"} 1` assert.Contains(t, string(merger.OutputMetrics), want) + + // Remove the plugin from the merger, thus stopping the metrics + // collection from the plugin. + merger.Remove("test") + + // Test start/stop of the merger scheduler. + go merger.Start() + go merger.Stop() } diff --git a/metrics/utils_test.go b/metrics/utils_test.go new file mode 100644 index 00000000..abb63d00 --- /dev/null +++ b/metrics/utils_test.go @@ -0,0 +1,42 @@ +package metrics + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_HeaderBypassResponseWriter(t *testing.T) { + testServer := httptest.NewServer( + http.HandlerFunc(func(writer http.ResponseWriter, _ *http.Request) { + respWriter := HeaderBypassResponseWriter{writer} + respWriter.WriteHeader(http.StatusBadGateway) // This is a no-op. + sent, err := respWriter.Write([]byte("Hello, World!")) + require.NoError(t, err) + assert.Equal(t, 13, sent) + }), + ) + defer testServer.Close() + + req, err := http.NewRequestWithContext( + context.Background(), http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + assert.NotNil(t, req) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + // The WriteHeader method intentionally does nothing, to prevent a bug + // in the merging metrics that causes the headers to be written twice, + // which results in an error: "http: superfluous response.WriteHeader call". + assert.NotEqual(t, http.StatusBadGateway, resp.StatusCode) + greeting, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, "Hello, World!", string(greeting)) +} diff --git a/network/conn_wrapper_test.go b/network/conn_wrapper_test.go new file mode 100644 index 00000000..96ad5bae --- /dev/null +++ b/network/conn_wrapper_test.go @@ -0,0 +1,49 @@ +package network + +import ( + "crypto/tls" + "net" + "testing" + + "github.com/gatewayd-io/gatewayd/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_ConnWrapper_NoTLS tests that the ConnWrapper correctly wraps a net.Conn. +func Test_ConnWrapper_NoTLS(t *testing.T) { + server, client := net.Pipe() + require.NotNil(t, server) + require.NotNil(t, client) + + serverWrapper := NewConnWrapper(ConnWrapper{ + NetConn: server, + HandshakeTimeout: config.DefaultHandshakeTimeout, + }) + assert.Equal(t, server, serverWrapper.Conn()) + defer serverWrapper.Close() + assert.False(t, serverWrapper.IsTLSEnabled()) + assert.Equal(t, serverWrapper.LocalAddr(), server.LocalAddr()) + assert.Equal(t, serverWrapper.RemoteAddr(), server.RemoteAddr()) + + clientWrapper := NewConnWrapper(ConnWrapper{ + NetConn: client, + HandshakeTimeout: config.DefaultHandshakeTimeout, + }) + assert.Equal(t, client, clientWrapper.Conn()) + defer clientWrapper.Close() + assert.False(t, clientWrapper.IsTLSEnabled()) + assert.Equal(t, clientWrapper.LocalAddr(), client.LocalAddr()) + assert.Equal(t, clientWrapper.RemoteAddr(), client.RemoteAddr()) +} + +// Test_ConnWrapper_TLS tests that the CreateTLSConfig function correctly +// creates a TLS config given a certificate and a private key. +func Test_CreateTLSConfig(t *testing.T) { + tlsConfig, err := CreateTLSConfig( + "../cmd/testdata/localhost.crt", "../cmd/testdata/localhost.key") + require.NoError(t, err) + assert.Equal(t, tlsConfig.ClientAuth, tls.VerifyClientCertIfGiven) + assert.NotEmpty(t, tlsConfig.Certificates[0].Certificate) + assert.NotEmpty(t, tlsConfig.Certificates[0].PrivateKey) +} diff --git a/plugin/plugin_registry_test.go b/plugin/plugin_registry_test.go index 9bb81801..4a33d910 100644 --- a/plugin/plugin_registry_test.go +++ b/plugin/plugin_registry_test.go @@ -68,6 +68,14 @@ func TestPluginRegistry(t *testing.T) { instance := reg.Get(ident) assert.Equal(t, instance, impl) + assert.Equal(t, reg.Size(), 1) + assert.True(t, reg.Exists(ident.Name, ident.Version, ident.RemoteURL)) + + reg.ForEach(func(i sdkPlugin.Identifier, p *Plugin) { + assert.Equal(t, i, ident) + assert.Equal(t, p, impl) + }) + reg.Remove(ident) assert.Empty(t, reg.List()) diff --git a/tracing/tracing_test.go b/tracing/tracing_test.go new file mode 100644 index 00000000..59369ed3 --- /dev/null +++ b/tracing/tracing_test.go @@ -0,0 +1,13 @@ +package tracing + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// Test_OTLPTracer tests the OTLPTracer function. +func Test_OTLPTracer(t *testing.T) { + shutdown := OTLPTracer(false, "localhost:4317", "gatewayd") + require.NotNil(t, shutdown) +}