Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the result of notification hooks #616

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ func runActionWithTimeout(

// RunAll run all the actions in the outputs and returns the end result.
func (r *Registry) RunAll(result map[string]any) map[string]any {
if len(result) == 0 {
return result
}

if _, exists := result[sdkAct.Outputs]; !exists {
r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is")
return result
Expand Down
33 changes: 27 additions & 6 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func StopGracefully(
defer cancel()

//nolint:contextcheck
_, err := pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"signal": currentSignal},
v1.HookName_HOOK_NAME_ON_SIGNAL,
Expand All @@ -118,6 +118,9 @@ func StopGracefully(
logger.Error().Err(err).Msg("Failed to run OnSignal hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
}
}

logger.Info().Msg("GatewayD is shutting down")
Expand Down Expand Up @@ -434,6 +437,9 @@ var runCmd = &cobra.Command{
logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks")
span.RecordError(err)
}
if updatedGlobalConfig != nil {
updatedGlobalConfig = pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig)
}

// If the config was modified by the plugins, merge it with the one loaded from the file.
// Only global configuration is merged, which means that plugins cannot modify the plugin
Expand Down Expand Up @@ -606,12 +612,15 @@ var runCmd = &cobra.Command{
defer cancel()

if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]any); ok {
_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
} else {
logger.Error().Msg("Failed to get loggers from config")
}
Expand Down Expand Up @@ -767,12 +776,15 @@ var runCmd = &cobra.Command{
"backoffMultiplier": clientConfig.BackoffMultiplier,
"disableBackoffCaps": clientConfig.DisableBackoffCaps,
}
_, err := pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewClient hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}

err = pools[configGroupName][configBlockName].Put(client.ID, client)
if err != nil {
Expand Down Expand Up @@ -822,14 +834,17 @@ var runCmd = &cobra.Command{
context.Background(), conf.Plugin.Timeout)
defer cancel()

_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"name": configBlockName, "size": currentPoolSize},
v1.HookName_HOOK_NAME_ON_NEW_POOL)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewPool hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
}
}

Expand Down Expand Up @@ -877,12 +892,15 @@ var runCmd = &cobra.Command{
defer cancel()

if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]any); ok {
_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
} else {
logger.Error().Msg("Failed to get proxy from config")
}
Expand Down Expand Up @@ -948,12 +966,15 @@ var runCmd = &cobra.Command{
defer cancel()

if data, ok := conf.GlobalKoanf.Get("servers").(map[string]any); ok {
_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewServer hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
} else {
logger.Error().Msg("Failed to get the servers configuration")
}
Expand Down
16 changes: 13 additions & 3 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate
defer cancel()

// Run the OnTrafficToServer hooks.
_, err = pr.PluginRegistry.Run(
result, err = pr.PluginRegistry.Run(
pluginTimeoutCtx,
trafficData(
conn.Conn(),
Expand All @@ -463,8 +463,11 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate
pr.Logger.Error().Err(err).Msg("Error running hook")
span.RecordError(err)
}
span.AddEvent("Ran the OnTrafficToServer hooks")
if result != nil {
_ = pr.PluginRegistry.ActRegistry.RunAll(result)
}

span.AddEvent("Ran the OnTrafficToServer hooks")
metrics.ProxyPassThroughsToServer.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()

return nil
Expand Down Expand Up @@ -558,6 +561,9 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
pr.Logger.Error().Err(err).Msg("Error running hook")
span.RecordError(err)
}
if result != nil {
result = pr.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTrafficFromServer hooks")

// If the hook modified the response, use the modified response.
Expand All @@ -575,7 +581,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout)
defer cancel()

_, err = pr.PluginRegistry.Run(
result, err = pr.PluginRegistry.Run(
pluginTimeoutCtx,
trafficData(
conn.Conn(),
Expand All @@ -597,6 +603,10 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
pr.Logger.Error().Err(err).Msg("Error running hook")
span.RecordError(err)
}
if result != nil {
_ = pr.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTrafficToClient hooks")

if errVerdict != nil {
span.RecordError(errVerdict)
Expand Down
47 changes: 38 additions & 9 deletions network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,17 @@ func (s *Server) OnBoot() Action {
pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()
// Run the OnBooting hooks.
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"status": fmt.Sprint(s.Status)},
v1.HookName_HOOK_NAME_ON_BOOTING)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnBooting hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnBooting hooks")

// Set the server status to running.
Expand All @@ -117,14 +120,17 @@ func (s *Server) OnBoot() Action {
pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()

_, err = s.PluginRegistry.Run(
result, err = s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"status": fmt.Sprint(s.Status)},
v1.HookName_HOOK_NAME_ON_BOOTED)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnBooted hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnBooted hooks")

s.Logger.Debug().Msg("GatewayD booted")
Expand All @@ -150,12 +156,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
"remote": RemoteAddr(conn.Conn()),
},
}
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx, onOpeningData, v1.HookName_HOOK_NAME_ON_OPENING)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnOpening hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnOpening hooks")

// Attempt to retrieve the next proxy.
Expand Down Expand Up @@ -195,12 +204,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
"remote": RemoteAddr(conn.Conn()),
},
}
_, err = s.PluginRegistry.Run(
result, err = s.PluginRegistry.Run(
pluginTimeoutCtx, onOpenedData, v1.HookName_HOOK_NAME_ON_OPENED)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnOpened hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnOpened hooks")

metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Inc()
Expand Down Expand Up @@ -231,12 +243,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action {
if err != nil {
data["error"] = err.Error()
}
_, gatewaydErr := s.PluginRegistry.Run(
result, gatewaydErr := s.PluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSING)
if gatewaydErr != nil {
s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook")
span.RecordError(gatewaydErr)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnClosing hooks")

// Shutdown the server if there are no more connections and the server is stopped.
Expand Down Expand Up @@ -291,12 +306,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action {
if err != nil {
data["error"] = err.Error()
}
_, gatewaydErr = s.PluginRegistry.Run(
result, gatewaydErr = s.PluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSED)
if gatewaydErr != nil {
s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook")
span.RecordError(gatewaydErr)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnClosed hooks")

metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Dec()
Expand All @@ -320,12 +338,15 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti
"remote": RemoteAddr(conn.Conn()),
},
}
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx, onTrafficData, v1.HookName_HOOK_NAME_ON_TRAFFIC)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnTraffic hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTraffic hooks")

stack := NewStack()
Expand Down Expand Up @@ -391,14 +412,17 @@ func (s *Server) OnShutdown() {
pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()
// Run the OnShutdown hooks.
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"connections": s.CountConnections()},
v1.HookName_HOOK_NAME_ON_SHUTDOWN)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnShutdown hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnShutdown hooks")

// Shutdown proxies.
Expand All @@ -424,14 +448,17 @@ func (s *Server) OnTick() (time.Duration, Action) {
pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()
// Run the OnTick hooks.
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"connections": s.CountConnections()},
v1.HookName_HOOK_NAME_ON_TICK)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnTick hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTick hooks")

// TODO: Investigate whether to move schedulers here or not
Expand Down Expand Up @@ -474,6 +501,8 @@ func (s *Server) Run() *gerr.GatewayDError {
span.AddEvent("Ran the OnRun hooks")

if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)

if errMsg, ok := result["error"].(string); ok && errMsg != "" {
s.Logger.Error().Str("error", errMsg).Msg("Error in hook")
}
Expand Down