Skip to content

Commit

Permalink
Use the result of notification hooks by running actions
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Oct 7, 2024
1 parent 0b56d1e commit 2e498e9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 18 deletions.
33 changes: 27 additions & 6 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func StopGracefully(
defer cancel()

//nolint:contextcheck
_, err := pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"signal": currentSignal},
v1.HookName_HOOK_NAME_ON_SIGNAL,
Expand All @@ -118,6 +118,9 @@ func StopGracefully(
logger.Error().Err(err).Msg("Failed to run OnSignal hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)

Check failure on line 122 in cmd/run.go

View workflow job for this annotation

GitHub Actions / Test GatewayD

Function `RunAll->Run->publishTask` should pass the context parameter (contextcheck)
}
}

logger.Info().Msg("GatewayD is shutting down")
Expand Down Expand Up @@ -434,6 +437,9 @@ var runCmd = &cobra.Command{
logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks")
span.RecordError(err)
}
if updatedGlobalConfig != nil {
updatedGlobalConfig = pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig)
}

// If the config was modified by the plugins, merge it with the one loaded from the file.
// Only global configuration is merged, which means that plugins cannot modify the plugin
Expand Down Expand Up @@ -606,12 +612,15 @@ var runCmd = &cobra.Command{
defer cancel()

if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]any); ok {
_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
} else {
logger.Error().Msg("Failed to get loggers from config")
}
Expand Down Expand Up @@ -767,12 +776,15 @@ var runCmd = &cobra.Command{
"backoffMultiplier": clientConfig.BackoffMultiplier,
"disableBackoffCaps": clientConfig.DisableBackoffCaps,
}
_, err := pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewClient hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}

err = pools[configGroupName][configBlockName].Put(client.ID, client)
if err != nil {
Expand Down Expand Up @@ -822,14 +834,17 @@ var runCmd = &cobra.Command{
context.Background(), conf.Plugin.Timeout)
defer cancel()

_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"name": configBlockName, "size": currentPoolSize},
v1.HookName_HOOK_NAME_ON_NEW_POOL)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewPool hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
}
}

Expand Down Expand Up @@ -877,12 +892,15 @@ var runCmd = &cobra.Command{
defer cancel()

if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]any); ok {
_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
} else {
logger.Error().Msg("Failed to get proxy from config")
}
Expand Down Expand Up @@ -948,12 +966,15 @@ var runCmd = &cobra.Command{
defer cancel()

if data, ok := conf.GlobalKoanf.Get("servers").(map[string]any); ok {
_, err = pluginRegistry.Run(
result, err := pluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER)
if err != nil {
logger.Error().Err(err).Msg("Failed to run OnNewServer hooks")
span.RecordError(err)
}
if result != nil {
_ = pluginRegistry.ActRegistry.RunAll(result)
}
} else {
logger.Error().Msg("Failed to get the servers configuration")
}
Expand Down
16 changes: 13 additions & 3 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate
defer cancel()

// Run the OnTrafficToServer hooks.
_, err = pr.PluginRegistry.Run(
result, err = pr.PluginRegistry.Run(
pluginTimeoutCtx,
trafficData(
conn.Conn(),
Expand All @@ -463,8 +463,11 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate
pr.Logger.Error().Err(err).Msg("Error running hook")
span.RecordError(err)
}
span.AddEvent("Ran the OnTrafficToServer hooks")
if result != nil {
_ = pr.PluginRegistry.ActRegistry.RunAll(result)
}

span.AddEvent("Ran the OnTrafficToServer hooks")
metrics.ProxyPassThroughsToServer.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()

return nil
Expand Down Expand Up @@ -558,6 +561,9 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
pr.Logger.Error().Err(err).Msg("Error running hook")
span.RecordError(err)
}
if result != nil {
result = pr.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTrafficFromServer hooks")

// If the hook modified the response, use the modified response.
Expand All @@ -575,7 +581,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout)
defer cancel()

_, err = pr.PluginRegistry.Run(
result, err = pr.PluginRegistry.Run(
pluginTimeoutCtx,
trafficData(
conn.Conn(),
Expand All @@ -597,6 +603,10 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
pr.Logger.Error().Err(err).Msg("Error running hook")
span.RecordError(err)
}
if result != nil {
_ = pr.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTrafficToClient hooks")

if errVerdict != nil {
span.RecordError(errVerdict)
Expand Down
47 changes: 38 additions & 9 deletions network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,17 @@ func (s *Server) OnBoot() Action {
pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()
// Run the OnBooting hooks.
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"status": fmt.Sprint(s.Status)},
v1.HookName_HOOK_NAME_ON_BOOTING)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnBooting hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnBooting hooks")

// Set the server status to running.
Expand All @@ -117,14 +120,17 @@ func (s *Server) OnBoot() Action {
pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()

_, err = s.PluginRegistry.Run(
result, err = s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"status": fmt.Sprint(s.Status)},
v1.HookName_HOOK_NAME_ON_BOOTED)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnBooted hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnBooted hooks")

s.Logger.Debug().Msg("GatewayD booted")
Expand All @@ -150,12 +156,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
"remote": RemoteAddr(conn.Conn()),
},
}
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx, onOpeningData, v1.HookName_HOOK_NAME_ON_OPENING)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnOpening hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnOpening hooks")

// Attempt to retrieve the next proxy.
Expand Down Expand Up @@ -195,12 +204,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
"remote": RemoteAddr(conn.Conn()),
},
}
_, err = s.PluginRegistry.Run(
result, err = s.PluginRegistry.Run(
pluginTimeoutCtx, onOpenedData, v1.HookName_HOOK_NAME_ON_OPENED)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnOpened hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnOpened hooks")

metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Inc()
Expand Down Expand Up @@ -231,12 +243,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action {
if err != nil {
data["error"] = err.Error()
}
_, gatewaydErr := s.PluginRegistry.Run(
result, gatewaydErr := s.PluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSING)
if gatewaydErr != nil {
s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook")
span.RecordError(gatewaydErr)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnClosing hooks")

// Shutdown the server if there are no more connections and the server is stopped.
Expand Down Expand Up @@ -291,12 +306,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action {
if err != nil {
data["error"] = err.Error()
}
_, gatewaydErr = s.PluginRegistry.Run(
result, gatewaydErr = s.PluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSED)
if gatewaydErr != nil {
s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook")
span.RecordError(gatewaydErr)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnClosed hooks")

metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Dec()
Expand All @@ -320,12 +338,15 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti
"remote": RemoteAddr(conn.Conn()),
},
}
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx, onTrafficData, v1.HookName_HOOK_NAME_ON_TRAFFIC)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnTraffic hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTraffic hooks")

stack := NewStack()
Expand Down Expand Up @@ -391,14 +412,17 @@ func (s *Server) OnShutdown() {
pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()
// Run the OnShutdown hooks.
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"connections": s.CountConnections()},
v1.HookName_HOOK_NAME_ON_SHUTDOWN)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnShutdown hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnShutdown hooks")

// Shutdown proxies.
Expand All @@ -424,14 +448,17 @@ func (s *Server) OnTick() (time.Duration, Action) {
pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
defer cancel()
// Run the OnTick hooks.
_, err := s.PluginRegistry.Run(
result, err := s.PluginRegistry.Run(
pluginTimeoutCtx,
map[string]any{"connections": s.CountConnections()},
v1.HookName_HOOK_NAME_ON_TICK)
if err != nil {
s.Logger.Error().Err(err).Msg("Failed to run OnTick hook")
span.RecordError(err)
}
if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)
}
span.AddEvent("Ran the OnTick hooks")

// TODO: Investigate whether to move schedulers here or not
Expand Down Expand Up @@ -474,6 +501,8 @@ func (s *Server) Run() *gerr.GatewayDError {
span.AddEvent("Ran the OnRun hooks")

if result != nil {
_ = s.PluginRegistry.ActRegistry.RunAll(result)

if errMsg, ok := result["error"].(string); ok && errMsg != "" {
s.Logger.Error().Str("error", errMsg).Msg("Error in hook")
}
Expand Down

0 comments on commit 2e498e9

Please sign in to comment.