Skip to content

Commit

Permalink
Move custom fingerprinting logic from MLFlow credential tester to web…
Browse files Browse the repository at this point in the history
… service fingerprinter.

PiperOrigin-RevId: 634092796
Change-Id: I39c5d3c51254cee7222f09f9bcd77708aa5c15cd
  • Loading branch information
maoning authored and copybara-github committed May 15, 2024
1 parent cdcfeb9 commit fcbbcb8
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import com.google.tsunami.common.net.http.HttpClient;
import com.google.tsunami.common.net.http.HttpHeaders;
import com.google.tsunami.common.net.http.HttpResponse;
import com.google.tsunami.common.net.http.HttpStatus;
import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.provider.TestCredential;
import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.tester.CredentialTester;
import com.google.tsunami.proto.NetworkService;
Expand All @@ -44,6 +43,8 @@
/** Credential tester specifically for mlflow. */
public final class MlFlowCredentialTester extends CredentialTester {
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
private static final String MLFLOW_SERVICE = "mlflow";

private final HttpClient httpClient;

@Inject
Expand All @@ -63,39 +64,7 @@ public String description() {

@Override
public boolean canAccept(NetworkService networkService) {
if (!NetworkServiceUtils.isWebService(networkService)) {
return false;
}

boolean canAcceptByCustomFingerprint = false;
logger.atInfo().log("probing Mlflow ping - custom fingerprint phase");

// We want to test weak credentials against mlflow versions above 2.5 which has basic
// authentication module.these versions return a 401 status code and a link to documentation
// about how to authenticate.
var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint());
var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping");
try {
HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build());

if (apiPingResponse.status() == HttpStatus.UNAUTHORIZED
&& apiPingResponse.bodyString().isPresent()) {
canAcceptByCustomFingerprint =
apiPingResponse
.bodyString()
.get()
.contains(
"You are not authenticated. Please see "
+ "https://www.mlflow.org/docs/latest/auth/index.html"
+ "#authenticating-to-mlflow "
+ "on how to authenticate");
}
} catch (IOException e) {
logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl);
return false;
}

return canAcceptByCustomFingerprint;
return NetworkServiceUtils.getWebServiceName(networkService).equals(MLFLOW_SERVICE);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void detect_weakCredentialsExists_returnsWeakCredentials() throws Excepti
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.setServiceName("mlflow")
.build();

assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1)))
Expand All @@ -94,7 +94,7 @@ public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exc
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.setServiceName("mlflow")
.build();

assertThat(
Expand All @@ -104,13 +104,13 @@ public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exc
}

@Test
public void detect_canAccept() throws Exception {
public void detect_mlflowService_canAccept() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.setServiceName("mlflow")
.build();

assertThat(tester.canAccept(targetNetworkService)).isTrue();
Expand All @@ -124,7 +124,7 @@ public void detect_weakCredentialsExistAndMlflowInForeignLanguage_returnsFirstWe
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.setServiceName("mlflow")
.build();

assertThat(
Expand All @@ -140,7 +140,7 @@ public void detect_noWeakCredentials_returnsNoCredentials() throws Exception {
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.setServiceName("mlflow")
.build();
assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WRONG_CRED_1)))
.isEmpty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.tsunami.common.net.http.HttpRequest.get;
import static java.util.stream.Collectors.joining;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.GoogleLogger;
import com.google.tsunami.common.data.NetworkEndpointUtils;
import com.google.tsunami.common.data.NetworkServiceUtils;
import com.google.tsunami.common.net.http.HttpClient;
import com.google.tsunami.common.net.http.HttpResponse;
import com.google.tsunami.common.net.http.HttpStatus;
import com.google.tsunami.plugin.PluginType;
import com.google.tsunami.plugin.ServiceFingerprinter;
import com.google.tsunami.plugin.annotations.ForWebService;
Expand All @@ -36,6 +41,7 @@
import com.google.tsunami.plugins.fingerprinters.web.detection.SoftwareDetector.DetectedSoftware;
import com.google.tsunami.plugins.fingerprinters.web.detection.VersionDetector;
import com.google.tsunami.plugins.fingerprinters.web.detection.VersionDetector.DetectedVersion;
import com.google.tsunami.plugins.fingerprinters.web.proto.SoftwareIdentity;
import com.google.tsunami.proto.CrawlConfig;
import com.google.tsunami.proto.CrawlResult;
import com.google.tsunami.proto.FingerprintingReport;
Expand All @@ -47,8 +53,11 @@
import com.google.tsunami.proto.Version.VersionType;
import com.google.tsunami.proto.VersionSet;
import com.google.tsunami.proto.WebServiceContext;
import java.io.IOException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import javax.inject.Inject;

/** A {@link ServiceFingerprinter} plugin that fingerprints web applications. */
Expand All @@ -69,19 +78,22 @@ public final class WebServiceFingerprinter implements ServiceFingerprinter {
private final SoftwareDetector softwareDetector;
private final VersionDetector.Factory versionDetectorFactory;
private final WebServiceFingerprinterConfigs configs;
private final HttpClient httpClient;

@Inject
WebServiceFingerprinter(
FingerprintRegistry fingerprintRegistry,
Crawler crawler,
SoftwareDetector softwareDetector,
VersionDetector.Factory versionDetectorFactory,
WebServiceFingerprinterConfigs configs) {
WebServiceFingerprinterConfigs configs,
HttpClient httpClient) {
this.fingerprintRegistry = checkNotNull(fingerprintRegistry);
this.crawler = checkNotNull(crawler);
this.softwareDetector = checkNotNull(softwareDetector);
this.versionDetectorFactory = checkNotNull(versionDetectorFactory);
this.configs = checkNotNull(configs);
this.httpClient = checkNotNull(httpClient);
}

@Override
Expand Down Expand Up @@ -119,16 +131,11 @@ public FingerprintingReport fingerprint(TargetInfo targetInfo, NetworkService ne

if (versionsBySoftware.isEmpty()) {
logger.atInfo().log(
"WebServiceFingerprinter failed to confirm running web application on '%s'.",
"WebServiceFingerprinter failed to confirm running web application on '%s' using existing"
+ " hashes. Try custom heuristics instead",
startingUrl);
return FingerprintingReport.newBuilder()
.addNetworkServices(
addWebServiceContext(
networkService,
Optional.empty(),
Optional.empty(),
crawlResultsUnderRecordingLimit))
.build();
return fingerprintWithCustomHeuristics(
networkService, startingUrl, crawlResultsUnderRecordingLimit);
} else {
logger.atInfo().log(
"WebServiceFingerprinter identified %d results for '%s'.",
Expand All @@ -148,6 +155,48 @@ public FingerprintingReport fingerprint(TargetInfo targetInfo, NetworkService ne
}
}

private FingerprintingReport fingerprintWithCustomHeuristics(
NetworkService networkService, String startingUrl, ImmutableSet<CrawlResult> crawlResults) {
ImmutableSet<DetectedSoftware> detectedSoftware =
detectSoftwareByCustomHeuristics(networkService, startingUrl);

if (detectedSoftware.isEmpty()) {
logger.atInfo().log(
"WebServiceFingerprinter failed to confirm running web application on '%s' using custom"
+ " heuristics either.",
startingUrl);
return FingerprintingReport.newBuilder()
.addNetworkServices(
addWebServiceContext(
networkService, Optional.empty(), Optional.empty(), crawlResults))
.build();
}

logger.atInfo().log(
"WebServiceFingerprinter discovered %d potential applications for '%s': [%s] using custom"
+ " heuristics.",
detectedSoftware.size(),
startingUrl,
detectedSoftware.stream()
.map(software -> software.softwareIdentity().getSoftware())
.collect(joining(",")));
return FingerprintingReport.newBuilder()
.addAllNetworkServices(
detectedSoftware.stream()
.map(
software ->
addWebServiceContext(
// Overwrite service name
networkService.toBuilder()
.setServiceName(software.softwareIdentity().getSoftware())
.build(),
Optional.of(software),
Optional.empty(),
crawlResults))
.collect(toImmutableList()))
.build();
}

private ImmutableMap<DetectedSoftware, DetectedVersion> detectSoftwareVersions(
Collection<DetectedSoftware> detectedSoftware, NetworkService networkService) {
ImmutableMap.Builder<DetectedSoftware, DetectedVersion> versionsBySoftwareBuilder =
Expand Down Expand Up @@ -222,4 +271,49 @@ private ImmutableSet<CrawlResult> crawlNetworkService(
.build();
return crawler.crawl(crawlConfig);
}

private ImmutableSet<DetectedSoftware> detectSoftwareByCustomHeuristics(
NetworkService networkService, String startingUrl) {
HashSet<DetectedSoftware> detectedSoftware = new HashSet<>();

checkForMlflow(detectedSoftware, networkService, startingUrl);
return ImmutableSet.copyOf(detectedSoftware);
}

private void checkForMlflow(
Set<DetectedSoftware> software, NetworkService networkService, String startingUrl) {
logger.atInfo().log("probing Mlflow ping - custom fingerprint phase");

// We want to test weak credentials against mlflow versions above 2.5 which has basic
// authentication module.these versions return a 401 status code and a link to documentation
// about how to authenticate.
var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint());
var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping");
try {
HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build());

if (apiPingResponse.status() != HttpStatus.UNAUTHORIZED
|| apiPingResponse.bodyString().isEmpty()) {
return;
}

if (apiPingResponse
.bodyString()
.get()
.contains(
"You are not authenticated. Please see "
+ "https://www.mlflow.org/docs/latest/auth/index.html"
+ "#authenticating-to-mlflow "
+ "on how to authenticate")) {
software.add(
DetectedSoftware.builder()
.setSoftwareIdentity(SoftwareIdentity.newBuilder().setSoftware("mlflow").build())
.setRootPath(startingUrl)
.setContentHashes(ImmutableMap.of())
.build());
}
} catch (IOException e) {
logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ public static Builder builder() {

/** Builder for {@link DetectedSoftware}. */
@AutoValue.Builder
abstract static class Builder {
public abstract static class Builder {
public abstract Builder setSoftwareIdentity(SoftwareIdentity value);
public abstract Builder setRootPath(String value);
public abstract Builder setContentHashes(ImmutableMap<CrawlResult, Hash> value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ private CommonTestData() {}
.build();
public static final Hash SOFTWARE_3_CSS_HASH =
Hash.newBuilder().setHexString("1ebae34d06fc5a9be81b852a7c354041").build();

public static final CrawlResult SOFTWARE_4_MLFLOW =
CrawlResult.newBuilder()
.setCrawlTarget(
CrawlTarget.newBuilder().setUrl(fakeUrl("/login?from")).setHttpMethod("GET"))
.setResponseCode(200)
.setContent(ByteString.copyFromUtf8("MLFLOW"))
.build();

public static final CrawlResult UNKNOWN_CONTENT =
CrawlResult.newBuilder()
.setCrawlTarget(CrawlTarget.newBuilder().setUrl(fakeUrl("/unknown")).setHttpMethod("GET"))
Expand All @@ -157,6 +166,9 @@ private CommonTestData() {}
SoftwareIdentity.newBuilder().setSoftware("Software2").build();
public static final SoftwareIdentity SOFTWARE_IDENTITY_3 =
SoftwareIdentity.newBuilder().setSoftware("Software3").build();

public static final SoftwareIdentity SOFTWARE_IDENTITY_4 =
SoftwareIdentity.newBuilder().setSoftware("mlflow").build();
public static final FingerprintData FINGERPRINT_DATA_1 =
FingerprintData.fromProto(
Fingerprints.newBuilder()
Expand Down
Loading

0 comments on commit fcbbcb8

Please sign in to comment.