diff --git a/gatewayd_plugin.yaml b/gatewayd_plugin.yaml index fc95bd0..80273c5 100644 --- a/gatewayd_plugin.yaml +++ b/gatewayd_plugin.yaml @@ -27,10 +27,7 @@ plugins: - METRICS_ENABLED=True - METRICS_UNIX_DOMAIN_SOCKET=/tmp/gatewayd-plugin-sql-ids-ips.sock - METRICS_PATH=/metrics - - TOKENIZER_API_ADDRESS=http://localhost:8000 - - SERVING_API_ADDRESS=http://localhost:8501 - - MODEL_NAME=sqli_model - - MODEL_VERSION=3 + - PREDICTION_API_ADDRESS=http://localhost:8000 # Threshold determine the minimum prediction confidence # required to detect an SQL injection attack. Any value # between 0 and 1 is valid, and it is inclusive. diff --git a/main.go b/main.go index 3820587..740ae6d 100644 --- a/main.go +++ b/main.go @@ -54,10 +54,7 @@ func main() { pluginInstance.Impl.EnableLibinjection = cast.ToBool(cfg["enableLibinjection"]) pluginInstance.Impl.LibinjectionPermissiveMode = cast.ToBool( cfg["libinjectionPermissiveMode"]) - pluginInstance.Impl.TokenizerAPIAddress = cast.ToString(cfg["tokenizerAPIAddress"]) - pluginInstance.Impl.ServingAPIAddress = cast.ToString(cfg["servingAPIAddress"]) - pluginInstance.Impl.ModelName = cast.ToString(cfg["modelName"]) - pluginInstance.Impl.ModelVersion = cast.ToString(cfg["modelVersion"]) + pluginInstance.Impl.PredictionAPIAddress = cast.ToString(cfg["predictionAPIAddress"]) pluginInstance.Impl.ResponseType = cast.ToString(cfg["responseType"]) pluginInstance.Impl.ErrorMessage = cast.ToString(cfg["errorMessage"]) diff --git a/plugin/constants.go b/plugin/constants.go index d2ff29b..2897925 100644 --- a/plugin/constants.go +++ b/plugin/constants.go @@ -3,12 +3,11 @@ package plugin const ( DecodedQueryField string = "decodedQuery" DetectorField string = "detector" - ScoreField string = "score" QueryField string = "query" ErrorField string = "error" IsInjectionField string = "is_injection" ResponseField string = "response" - OutputsField string = "outputs" + ConfidenceField string = "confidence" TokensField string = "tokens" StringField string = "String" ResponseTypeField string = "response_type" @@ -23,6 +22,5 @@ const ( ErrorDetail string = "Back off, you're not welcome here." LogLevel string = "error" - TokenizeAndSequencePath string = "/tokenize_and_sequence" - PredictPath string = "/v1/models/%s/versions/%s:predict" + PredictPath string = "/predict" ) diff --git a/plugin/module.go b/plugin/module.go index 0f88954..09de3d8 100644 --- a/plugin/module.go +++ b/plugin/module.go @@ -36,12 +36,8 @@ var ( "metricsUnixDomainSocket": sdkConfig.GetEnv( "METRICS_UNIX_DOMAIN_SOCKET", "/tmp/gatewayd-plugin-sql-ids-ips.sock"), "metricsEndpoint": sdkConfig.GetEnv("METRICS_ENDPOINT", "/metrics"), - "tokenizerAPIAddress": sdkConfig.GetEnv( - "TOKENIZER_API_ADDRESS", "http://localhost:8000"), - "servingAPIAddress": sdkConfig.GetEnv( - "SERVING_API_ADDRESS", "http://localhost:8501"), - "modelName": sdkConfig.GetEnv("MODEL_NAME", "sqli_model"), - "modelVersion": sdkConfig.GetEnv("MODEL_VERSION", "1"), + "predictionAPIAddress": sdkConfig.GetEnv( + "PREDICTION_API_ADDRESS", "http://localhost:8000"), "threshold": sdkConfig.GetEnv("THRESHOLD", "0.8"), "enableLibinjection": sdkConfig.GetEnv("ENABLE_LIBINJECTION", "true"), "libinjectionPermissiveMode": sdkConfig.GetEnv("LIBINJECTION_MODE", "true"), diff --git a/plugin/plugin.go b/plugin/plugin.go index 0cc72db..2be7da1 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/json" - "fmt" "github.com/carlmjohnson/requests" "github.com/corazawaf/libinjection-go" @@ -28,10 +27,7 @@ type Plugin struct { Threshold float32 EnableLibinjection bool LibinjectionPermissiveMode bool - TokenizerAPIAddress string - ServingAPIAddress string - ModelName string - ModelVersion string + PredictionAPIAddress string ResponseType string ErrorMessage string ErrorSeverity string @@ -111,36 +107,12 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S } queryString := cast.ToString(queryMap[StringField]) - var tokens map[string]any - err = requests. - URL(p.TokenizerAPIAddress). - Path(TokenizeAndSequencePath). - BodyJSON(map[string]any{ - QueryField: queryString, - }). - ToJSON(&tokens). - Fetch(context.Background()) - if err != nil { - p.Logger.Error("Failed to make POST request", ErrorField, err) - if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode { - return p.prepareResponse( - req, - map[string]any{ - QueryField: queryString, - DetectorField: Libinjection, - ErrorField: "Failed to make POST request to tokenizer API", - }, - ), nil - } - return req, nil - } - var output map[string]any err = requests. - URL(p.ServingAPIAddress). - Path(fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion)). + URL(p.PredictionAPIAddress). + Path(PredictPath). BodyJSON(map[string]any{ - "inputs": []any{cast.ToSlice(tokens[TokensField])}, + QueryField: queryString, }). ToJSON(&output). Fetch(context.Background()) @@ -152,34 +124,32 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S map[string]any{ QueryField: queryString, DetectorField: Libinjection, - ErrorField: "Failed to make POST request to serving API", + ErrorField: "Failed to make POST request to tokenizer API", }, ), nil } return req, nil } - predictions := cast.ToSlice(output[OutputsField]) - scores := cast.ToSlice(predictions[0]) - score := cast.ToFloat32(scores[0]) - p.Logger.Trace("Deep learning model prediction", ScoreField, score) + confidence := cast.ToFloat32(output[ConfidenceField]) + p.Logger.Trace("Deep learning model prediction", ConfidenceField, confidence) // Check the prediction against the threshold, // otherwise check if the query is an SQL injection using libinjection. injection := p.isSQLi(queryString) - if score >= p.Threshold { + if confidence >= p.Threshold { if p.EnableLibinjection && !injection { p.Logger.Debug("False positive detected", DetectorField, Libinjection) } Detections.With(map[string]string{DetectorField: DeepLearningModel}).Inc() - p.Logger.Warn(p.ErrorMessage, ScoreField, score, DetectorField, DeepLearningModel) + p.Logger.Warn(p.ErrorMessage, ConfidenceField, confidence, DetectorField, DeepLearningModel) return p.prepareResponse( req, map[string]any{ - QueryField: queryString, - ScoreField: score, - DetectorField: DeepLearningModel, + QueryField: queryString, + ConfidenceField: confidence, + DetectorField: DeepLearningModel, }, ), nil } else if p.EnableLibinjection && injection && !p.LibinjectionPermissiveMode { diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index 4ed0eec..fdcb19f 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -3,7 +3,6 @@ package plugin import ( "context" "encoding/json" - "fmt" "net/http" "net/http/httptest" "testing" @@ -71,28 +70,13 @@ func Test_errorResponse(t *testing.T) { func Test_OnTrafficFromClinet(t *testing.T) { p := &Plugin{ - Logger: hclog.NewNullLogger(), - ModelName: "sqli_model", - ModelVersion: "2", + Logger: hclog.NewNullLogger(), } server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case TokenizeAndSequencePath: - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") - // This is the tokenized query: - // {"query":"select * from users where id = 1 or 1=1"} - resp := map[string][]float32{ - "tokens": { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12, - }, - } - data, _ := json.Marshal(resp) - _, err := w.Write(data) - require.NoError(t, err) - case fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion): + case PredictPath: w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") // This is the output of the deep learning model. @@ -107,8 +91,7 @@ func Test_OnTrafficFromClinet(t *testing.T) { ) defer server.Close() - p.TokenizerAPIAddress = server.URL - p.ServingAPIAddress = server.URL + p.PredictionAPIAddress = server.URL query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"} queryBytes, err := query.Encode(nil) @@ -136,17 +119,13 @@ func Test_OnTrafficFromClinet(t *testing.T) { func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) { plugins := []*Plugin{ { - Logger: hclog.NewNullLogger(), - ModelName: "sqli_model", - ModelVersion: "2", + Logger: hclog.NewNullLogger(), // If libinjection is enabled, the response should contain the "response" field, // and the "signals" field, which means the plugin will terminate the request. EnableLibinjection: true, }, { - Logger: hclog.NewNullLogger(), - ModelName: "sqli_model", - ModelVersion: "2", + Logger: hclog.NewNullLogger(), // If libinjection is disabled, the response should not contain the "response" field, // and the "signals" field, which means the plugin will not terminate the request. EnableLibinjection: false, @@ -156,7 +135,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) { server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case TokenizeAndSequencePath: + case PredictPath: w.WriteHeader(http.StatusInternalServerError) default: w.WriteHeader(http.StatusNotFound) @@ -166,8 +145,7 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) { defer server.Close() for i := range plugins { - plugins[i].TokenizerAPIAddress = server.URL - plugins[i].ServingAPIAddress = server.URL + plugins[i].PredictionAPIAddress = server.URL query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"} queryBytes, err := query.Encode(nil) @@ -204,43 +182,22 @@ func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) { func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) { plugins := []*Plugin{ { - Logger: hclog.NewNullLogger(), - ModelName: "sqli_model", - ModelVersion: "2", + Logger: hclog.NewNullLogger(), // If libinjection is disabled, the response should not contain the "response" field, // and the "signals" field, which means the plugin will not terminate the request. EnableLibinjection: false, }, { - Logger: hclog.NewNullLogger(), - ModelName: "sqli_model", - ModelVersion: "2", + Logger: hclog.NewNullLogger(), // If libinjection is enabled, the response should contain the "response" field, // and the "signals" field, which means the plugin will terminate the request. EnableLibinjection: true, }, } - - // This is the same for both plugins. - predictPath := fmt.Sprintf(PredictPath, plugins[0].ModelName, plugins[1].ModelVersion) - server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case TokenizeAndSequencePath: - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") - // This is the tokenized query: - // {"query":"select * from users where id = 1 or 1=1"} - resp := map[string][]float32{ - "tokens": { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12, - }, - } - data, _ := json.Marshal(resp) - _, err := w.Write(data) - require.NoError(t, err) - case predictPath: + case PredictPath: w.WriteHeader(http.StatusInternalServerError) default: w.WriteHeader(http.StatusNotFound) @@ -250,8 +207,7 @@ func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) { defer server.Close() for i := range plugins { - plugins[i].TokenizerAPIAddress = server.URL - plugins[i].ServingAPIAddress = server.URL + plugins[i].PredictionAPIAddress = server.URL query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"} queryBytes, err := query.Encode(nil)