From 0b56d1ec09ddd56e61a18cf79990797b6cbf77f5 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 21:24:24 +0200 Subject: [PATCH 1/3] Check for nil or empty result --- act/registry.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/act/registry.go b/act/registry.go index 68fe97ab..e3a0b9c6 100644 --- a/act/registry.go +++ b/act/registry.go @@ -407,6 +407,10 @@ func runActionWithTimeout( // RunAll run all the actions in the outputs and returns the end result. func (r *Registry) RunAll(result map[string]any) map[string]any { + if result == nil || len(result) == 0 { + return result + } + if _, exists := result[sdkAct.Outputs]; !exists { r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is") return result From 2e498e91b434d7a0c3ad0ef3518f10b20f338633 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 21:25:09 +0200 Subject: [PATCH 2/3] Use the result of notification hooks by running actions --- cmd/run.go | 33 +++++++++++++++++++++++++++------ network/proxy.go | 16 +++++++++++++--- network/server.go | 47 ++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 78 insertions(+), 18 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 98a96061..bf992467 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -109,7 +109,7 @@ func StopGracefully( defer cancel() //nolint:contextcheck - _, err := pluginRegistry.Run( + result, err := pluginRegistry.Run( pluginTimeoutCtx, map[string]any{"signal": currentSignal}, v1.HookName_HOOK_NAME_ON_SIGNAL, @@ -118,6 +118,9 @@ func StopGracefully( logger.Error().Err(err).Msg("Failed to run OnSignal hooks") span.RecordError(err) } + if result != nil { + _ = pluginRegistry.ActRegistry.RunAll(result) + } } logger.Info().Msg("GatewayD is shutting down") @@ -434,6 +437,9 @@ var runCmd = &cobra.Command{ logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") span.RecordError(err) } + if updatedGlobalConfig != nil { + updatedGlobalConfig = pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig) + } // If the config was modified by the plugins, merge it with the one loaded from the file. // Only global configuration is merged, which means that plugins cannot modify the plugin @@ -606,12 +612,15 @@ var runCmd = &cobra.Command{ defer cancel() if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]any); ok { - _, err = pluginRegistry.Run( + result, err := pluginRegistry.Run( pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks") span.RecordError(err) } + if result != nil { + _ = pluginRegistry.ActRegistry.RunAll(result) + } } else { logger.Error().Msg("Failed to get loggers from config") } @@ -767,12 +776,15 @@ var runCmd = &cobra.Command{ "backoffMultiplier": clientConfig.BackoffMultiplier, "disableBackoffCaps": clientConfig.DisableBackoffCaps, } - _, err := pluginRegistry.Run( + result, err := pluginRegistry.Run( pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") span.RecordError(err) } + if result != nil { + _ = pluginRegistry.ActRegistry.RunAll(result) + } err = pools[configGroupName][configBlockName].Put(client.ID, client) if err != nil { @@ -822,7 +834,7 @@ var runCmd = &cobra.Command{ context.Background(), conf.Plugin.Timeout) defer cancel() - _, err = pluginRegistry.Run( + result, err := pluginRegistry.Run( pluginTimeoutCtx, map[string]any{"name": configBlockName, "size": currentPoolSize}, v1.HookName_HOOK_NAME_ON_NEW_POOL) @@ -830,6 +842,9 @@ var runCmd = &cobra.Command{ logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") span.RecordError(err) } + if result != nil { + _ = pluginRegistry.ActRegistry.RunAll(result) + } } } @@ -877,12 +892,15 @@ var runCmd = &cobra.Command{ defer cancel() if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]any); ok { - _, err = pluginRegistry.Run( + result, err := pluginRegistry.Run( pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") span.RecordError(err) } + if result != nil { + _ = pluginRegistry.ActRegistry.RunAll(result) + } } else { logger.Error().Msg("Failed to get proxy from config") } @@ -948,12 +966,15 @@ var runCmd = &cobra.Command{ defer cancel() if data, ok := conf.GlobalKoanf.Get("servers").(map[string]any); ok { - _, err = pluginRegistry.Run( + result, err := pluginRegistry.Run( pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewServer hooks") span.RecordError(err) } + if result != nil { + _ = pluginRegistry.ActRegistry.RunAll(result) + } } else { logger.Error().Msg("Failed to get the servers configuration") } diff --git a/network/proxy.go b/network/proxy.go index cdf9b71f..6d60b310 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -446,7 +446,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate defer cancel() // Run the OnTrafficToServer hooks. - _, err = pr.PluginRegistry.Run( + result, err = pr.PluginRegistry.Run( pluginTimeoutCtx, trafficData( conn.Conn(), @@ -463,8 +463,11 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate pr.Logger.Error().Err(err).Msg("Error running hook") span.RecordError(err) } - span.AddEvent("Ran the OnTrafficToServer hooks") + if result != nil { + _ = pr.PluginRegistry.ActRegistry.RunAll(result) + } + span.AddEvent("Ran the OnTrafficToServer hooks") metrics.ProxyPassThroughsToServer.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc() return nil @@ -558,6 +561,9 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate pr.Logger.Error().Err(err).Msg("Error running hook") span.RecordError(err) } + if result != nil { + result = pr.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnTrafficFromServer hooks") // If the hook modified the response, use the modified response. @@ -575,7 +581,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout) defer cancel() - _, err = pr.PluginRegistry.Run( + result, err = pr.PluginRegistry.Run( pluginTimeoutCtx, trafficData( conn.Conn(), @@ -597,6 +603,10 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate pr.Logger.Error().Err(err).Msg("Error running hook") span.RecordError(err) } + if result != nil { + _ = pr.PluginRegistry.ActRegistry.RunAll(result) + } + span.AddEvent("Ran the OnTrafficToClient hooks") if errVerdict != nil { span.RecordError(errVerdict) diff --git a/network/server.go b/network/server.go index 60af37ac..a6e98b88 100644 --- a/network/server.go +++ b/network/server.go @@ -98,7 +98,7 @@ func (s *Server) OnBoot() Action { pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnBooting hooks. - _, err := s.PluginRegistry.Run( + result, err := s.PluginRegistry.Run( pluginTimeoutCtx, map[string]any{"status": fmt.Sprint(s.Status)}, v1.HookName_HOOK_NAME_ON_BOOTING) @@ -106,6 +106,9 @@ func (s *Server) OnBoot() Action { s.Logger.Error().Err(err).Msg("Failed to run OnBooting hook") span.RecordError(err) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnBooting hooks") // Set the server status to running. @@ -117,7 +120,7 @@ func (s *Server) OnBoot() Action { pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() - _, err = s.PluginRegistry.Run( + result, err = s.PluginRegistry.Run( pluginTimeoutCtx, map[string]any{"status": fmt.Sprint(s.Status)}, v1.HookName_HOOK_NAME_ON_BOOTED) @@ -125,6 +128,9 @@ func (s *Server) OnBoot() Action { s.Logger.Error().Err(err).Msg("Failed to run OnBooted hook") span.RecordError(err) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnBooted hooks") s.Logger.Debug().Msg("GatewayD booted") @@ -150,12 +156,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { "remote": RemoteAddr(conn.Conn()), }, } - _, err := s.PluginRegistry.Run( + result, err := s.PluginRegistry.Run( pluginTimeoutCtx, onOpeningData, v1.HookName_HOOK_NAME_ON_OPENING) if err != nil { s.Logger.Error().Err(err).Msg("Failed to run OnOpening hook") span.RecordError(err) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnOpening hooks") // Attempt to retrieve the next proxy. @@ -195,12 +204,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { "remote": RemoteAddr(conn.Conn()), }, } - _, err = s.PluginRegistry.Run( + result, err = s.PluginRegistry.Run( pluginTimeoutCtx, onOpenedData, v1.HookName_HOOK_NAME_ON_OPENED) if err != nil { s.Logger.Error().Err(err).Msg("Failed to run OnOpened hook") span.RecordError(err) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnOpened hooks") metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Inc() @@ -231,12 +243,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr := s.PluginRegistry.Run( + result, gatewaydErr := s.PluginRegistry.Run( pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSING) if gatewaydErr != nil { s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook") span.RecordError(gatewaydErr) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnClosing hooks") // Shutdown the server if there are no more connections and the server is stopped. @@ -291,12 +306,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr = s.PluginRegistry.Run( + result, gatewaydErr = s.PluginRegistry.Run( pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSED) if gatewaydErr != nil { s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook") span.RecordError(gatewaydErr) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnClosed hooks") metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Dec() @@ -320,12 +338,15 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti "remote": RemoteAddr(conn.Conn()), }, } - _, err := s.PluginRegistry.Run( + result, err := s.PluginRegistry.Run( pluginTimeoutCtx, onTrafficData, v1.HookName_HOOK_NAME_ON_TRAFFIC) if err != nil { s.Logger.Error().Err(err).Msg("Failed to run OnTraffic hook") span.RecordError(err) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnTraffic hooks") stack := NewStack() @@ -391,7 +412,7 @@ func (s *Server) OnShutdown() { pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnShutdown hooks. - _, err := s.PluginRegistry.Run( + result, err := s.PluginRegistry.Run( pluginTimeoutCtx, map[string]any{"connections": s.CountConnections()}, v1.HookName_HOOK_NAME_ON_SHUTDOWN) @@ -399,6 +420,9 @@ func (s *Server) OnShutdown() { s.Logger.Error().Err(err).Msg("Failed to run OnShutdown hook") span.RecordError(err) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnShutdown hooks") // Shutdown proxies. @@ -424,7 +448,7 @@ func (s *Server) OnTick() (time.Duration, Action) { pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnTick hooks. - _, err := s.PluginRegistry.Run( + result, err := s.PluginRegistry.Run( pluginTimeoutCtx, map[string]any{"connections": s.CountConnections()}, v1.HookName_HOOK_NAME_ON_TICK) @@ -432,6 +456,9 @@ func (s *Server) OnTick() (time.Duration, Action) { s.Logger.Error().Err(err).Msg("Failed to run OnTick hook") span.RecordError(err) } + if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + } span.AddEvent("Ran the OnTick hooks") // TODO: Investigate whether to move schedulers here or not @@ -474,6 +501,8 @@ func (s *Server) Run() *gerr.GatewayDError { span.AddEvent("Ran the OnRun hooks") if result != nil { + _ = s.PluginRegistry.ActRegistry.RunAll(result) + if errMsg, ok := result["error"].(string); ok && errMsg != "" { s.Logger.Error().Str("error", errMsg).Msg("Error in hook") } From dc5759a4893a5182c0b8bd474da07249a0a80e7c Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 21:51:55 +0200 Subject: [PATCH 3/3] Fix linter errors --- act/registry.go | 2 +- cmd/run.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/act/registry.go b/act/registry.go index e3a0b9c6..04b093a9 100644 --- a/act/registry.go +++ b/act/registry.go @@ -407,7 +407,7 @@ func runActionWithTimeout( // RunAll run all the actions in the outputs and returns the end result. func (r *Registry) RunAll(result map[string]any) map[string]any { - if result == nil || len(result) == 0 { + if len(result) == 0 { return result } diff --git a/cmd/run.go b/cmd/run.go index bf992467..e8007f01 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -119,7 +119,7 @@ func StopGracefully( span.RecordError(err) } if result != nil { - _ = pluginRegistry.ActRegistry.RunAll(result) + _ = pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck } }