diff --git a/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java b/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java index 3d7f90d2f..9b0cc5cf2 100644 --- a/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java +++ b/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java @@ -32,7 +32,6 @@ import com.google.tsunami.common.net.http.HttpClient; import com.google.tsunami.common.net.http.HttpHeaders; import com.google.tsunami.common.net.http.HttpResponse; -import com.google.tsunami.common.net.http.HttpStatus; import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.provider.TestCredential; import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.tester.CredentialTester; import com.google.tsunami.proto.NetworkService; @@ -44,6 +43,8 @@ /** Credential tester specifically for mlflow. */ public final class MlFlowCredentialTester extends CredentialTester { private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + private static final String MLFLOW_SERVICE = "mlflow"; + private final HttpClient httpClient; @Inject @@ -63,39 +64,7 @@ public String description() { @Override public boolean canAccept(NetworkService networkService) { - if (!NetworkServiceUtils.isWebService(networkService)) { - return false; - } - - boolean canAcceptByCustomFingerprint = false; - logger.atInfo().log("probing Mlflow ping - custom fingerprint phase"); - - // We want to test weak credentials against mlflow versions above 2.5 which has basic - // authentication module.these versions return a 401 status code and a link to documentation - // about how to authenticate. - var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint()); - var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping"); - try { - HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build()); - - if (apiPingResponse.status() == HttpStatus.UNAUTHORIZED - && apiPingResponse.bodyString().isPresent()) { - canAcceptByCustomFingerprint = - apiPingResponse - .bodyString() - .get() - .contains( - "You are not authenticated. Please see " - + "https://www.mlflow.org/docs/latest/auth/index.html" - + "#authenticating-to-mlflow " - + "on how to authenticate"); - } - } catch (IOException e) { - logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl); - return false; - } - - return canAcceptByCustomFingerprint; + return NetworkServiceUtils.getWebServiceName(networkService).equals(MLFLOW_SERVICE); } @Override diff --git a/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java b/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java index d613df46d..e6342a5c1 100644 --- a/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java +++ b/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java @@ -79,7 +79,7 @@ public void detect_weakCredentialsExists_returnsWeakCredentials() throws Excepti NetworkService.newBuilder() .setNetworkEndpoint( forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") + .setServiceName("mlflow") .build(); assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1))) @@ -94,7 +94,7 @@ public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exc NetworkService.newBuilder() .setNetworkEndpoint( forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") + .setServiceName("mlflow") .build(); assertThat( @@ -104,13 +104,13 @@ public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exc } @Test - public void detect_canAccept() throws Exception { + public void detect_mlflowService_canAccept() throws Exception { startMockWebServer(); NetworkService targetNetworkService = NetworkService.newBuilder() .setNetworkEndpoint( forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") + .setServiceName("mlflow") .build(); assertThat(tester.canAccept(targetNetworkService)).isTrue(); @@ -124,7 +124,7 @@ public void detect_weakCredentialsExistAndMlflowInForeignLanguage_returnsFirstWe NetworkService.newBuilder() .setNetworkEndpoint( forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") + .setServiceName("mlflow") .build(); assertThat( @@ -140,7 +140,7 @@ public void detect_noWeakCredentials_returnsNoCredentials() throws Exception { NetworkService.newBuilder() .setNetworkEndpoint( forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") + .setServiceName("mlflow") .build(); assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WRONG_CRED_1))) .isEmpty(); diff --git a/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinter.java b/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinter.java index b6430a19d..835fe8f88 100644 --- a/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinter.java +++ b/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinter.java @@ -18,12 +18,17 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.tsunami.common.net.http.HttpRequest.get; import static java.util.stream.Collectors.joining; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.flogger.GoogleLogger; +import com.google.tsunami.common.data.NetworkEndpointUtils; import com.google.tsunami.common.data.NetworkServiceUtils; +import com.google.tsunami.common.net.http.HttpClient; +import com.google.tsunami.common.net.http.HttpResponse; +import com.google.tsunami.common.net.http.HttpStatus; import com.google.tsunami.plugin.PluginType; import com.google.tsunami.plugin.ServiceFingerprinter; import com.google.tsunami.plugin.annotations.ForWebService; @@ -36,6 +41,7 @@ import com.google.tsunami.plugins.fingerprinters.web.detection.SoftwareDetector.DetectedSoftware; import com.google.tsunami.plugins.fingerprinters.web.detection.VersionDetector; import com.google.tsunami.plugins.fingerprinters.web.detection.VersionDetector.DetectedVersion; +import com.google.tsunami.plugins.fingerprinters.web.proto.SoftwareIdentity; import com.google.tsunami.proto.CrawlConfig; import com.google.tsunami.proto.CrawlResult; import com.google.tsunami.proto.FingerprintingReport; @@ -47,8 +53,11 @@ import com.google.tsunami.proto.Version.VersionType; import com.google.tsunami.proto.VersionSet; import com.google.tsunami.proto.WebServiceContext; +import java.io.IOException; import java.util.Collection; +import java.util.HashSet; import java.util.Optional; +import java.util.Set; import javax.inject.Inject; /** A {@link ServiceFingerprinter} plugin that fingerprints web applications. */ @@ -69,6 +78,7 @@ public final class WebServiceFingerprinter implements ServiceFingerprinter { private final SoftwareDetector softwareDetector; private final VersionDetector.Factory versionDetectorFactory; private final WebServiceFingerprinterConfigs configs; + private final HttpClient httpClient; @Inject WebServiceFingerprinter( @@ -76,12 +86,14 @@ public final class WebServiceFingerprinter implements ServiceFingerprinter { Crawler crawler, SoftwareDetector softwareDetector, VersionDetector.Factory versionDetectorFactory, - WebServiceFingerprinterConfigs configs) { + WebServiceFingerprinterConfigs configs, + HttpClient httpClient) { this.fingerprintRegistry = checkNotNull(fingerprintRegistry); this.crawler = checkNotNull(crawler); this.softwareDetector = checkNotNull(softwareDetector); this.versionDetectorFactory = checkNotNull(versionDetectorFactory); this.configs = checkNotNull(configs); + this.httpClient = checkNotNull(httpClient); } @Override @@ -119,16 +131,11 @@ public FingerprintingReport fingerprint(TargetInfo targetInfo, NetworkService ne if (versionsBySoftware.isEmpty()) { logger.atInfo().log( - "WebServiceFingerprinter failed to confirm running web application on '%s'.", + "WebServiceFingerprinter failed to confirm running web application on '%s' using existing" + + " hashes. Try custom heuristics instead", startingUrl); - return FingerprintingReport.newBuilder() - .addNetworkServices( - addWebServiceContext( - networkService, - Optional.empty(), - Optional.empty(), - crawlResultsUnderRecordingLimit)) - .build(); + return fingerprintWithCustomHeuristics( + networkService, startingUrl, crawlResultsUnderRecordingLimit); } else { logger.atInfo().log( "WebServiceFingerprinter identified %d results for '%s'.", @@ -148,6 +155,48 @@ public FingerprintingReport fingerprint(TargetInfo targetInfo, NetworkService ne } } + private FingerprintingReport fingerprintWithCustomHeuristics( + NetworkService networkService, String startingUrl, ImmutableSet crawlResults) { + ImmutableSet detectedSoftware = + detectSoftwareByCustomHeuristics(networkService, startingUrl); + + if (detectedSoftware.isEmpty()) { + logger.atInfo().log( + "WebServiceFingerprinter failed to confirm running web application on '%s' using custom" + + " heuristics either.", + startingUrl); + return FingerprintingReport.newBuilder() + .addNetworkServices( + addWebServiceContext( + networkService, Optional.empty(), Optional.empty(), crawlResults)) + .build(); + } + + logger.atInfo().log( + "WebServiceFingerprinter discovered %d potential applications for '%s': [%s] using custom" + + " heuristics.", + detectedSoftware.size(), + startingUrl, + detectedSoftware.stream() + .map(software -> software.softwareIdentity().getSoftware()) + .collect(joining(","))); + return FingerprintingReport.newBuilder() + .addAllNetworkServices( + detectedSoftware.stream() + .map( + software -> + addWebServiceContext( + // Overwrite service name + networkService.toBuilder() + .setServiceName(software.softwareIdentity().getSoftware()) + .build(), + Optional.of(software), + Optional.empty(), + crawlResults)) + .collect(toImmutableList())) + .build(); + } + private ImmutableMap detectSoftwareVersions( Collection detectedSoftware, NetworkService networkService) { ImmutableMap.Builder versionsBySoftwareBuilder = @@ -222,4 +271,49 @@ private ImmutableSet crawlNetworkService( .build(); return crawler.crawl(crawlConfig); } + + private ImmutableSet detectSoftwareByCustomHeuristics( + NetworkService networkService, String startingUrl) { + HashSet detectedSoftware = new HashSet<>(); + + checkForMlflow(detectedSoftware, networkService, startingUrl); + return ImmutableSet.copyOf(detectedSoftware); + } + + private void checkForMlflow( + Set software, NetworkService networkService, String startingUrl) { + logger.atInfo().log("probing Mlflow ping - custom fingerprint phase"); + + // We want to test weak credentials against mlflow versions above 2.5 which has basic + // authentication module.these versions return a 401 status code and a link to documentation + // about how to authenticate. + var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint()); + var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping"); + try { + HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build()); + + if (apiPingResponse.status() != HttpStatus.UNAUTHORIZED + || apiPingResponse.bodyString().isEmpty()) { + return; + } + + if (apiPingResponse + .bodyString() + .get() + .contains( + "You are not authenticated. Please see " + + "https://www.mlflow.org/docs/latest/auth/index.html" + + "#authenticating-to-mlflow " + + "on how to authenticate")) { + software.add( + DetectedSoftware.builder() + .setSoftwareIdentity(SoftwareIdentity.newBuilder().setSoftware("mlflow").build()) + .setRootPath(startingUrl) + .setContentHashes(ImmutableMap.of()) + .build()); + } + } catch (IOException e) { + logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl); + } + } } diff --git a/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/detection/SoftwareDetector.java b/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/detection/SoftwareDetector.java index e0e6741bc..c40a2c4c7 100644 --- a/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/detection/SoftwareDetector.java +++ b/google/fingerprinters/web/src/main/java/com/google/tsunami/plugins/fingerprinters/web/detection/SoftwareDetector.java @@ -240,7 +240,7 @@ public static Builder builder() { /** Builder for {@link DetectedSoftware}. */ @AutoValue.Builder - abstract static class Builder { + public abstract static class Builder { public abstract Builder setSoftwareIdentity(SoftwareIdentity value); public abstract Builder setRootPath(String value); public abstract Builder setContentHashes(ImmutableMap value); diff --git a/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/CommonTestData.java b/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/CommonTestData.java index 5fbbbd7d7..c875d4ea1 100644 --- a/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/CommonTestData.java +++ b/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/CommonTestData.java @@ -145,6 +145,15 @@ private CommonTestData() {} .build(); public static final Hash SOFTWARE_3_CSS_HASH = Hash.newBuilder().setHexString("1ebae34d06fc5a9be81b852a7c354041").build(); + + public static final CrawlResult SOFTWARE_4_MLFLOW = + CrawlResult.newBuilder() + .setCrawlTarget( + CrawlTarget.newBuilder().setUrl(fakeUrl("/login?from")).setHttpMethod("GET")) + .setResponseCode(200) + .setContent(ByteString.copyFromUtf8("MLFLOW")) + .build(); + public static final CrawlResult UNKNOWN_CONTENT = CrawlResult.newBuilder() .setCrawlTarget(CrawlTarget.newBuilder().setUrl(fakeUrl("/unknown")).setHttpMethod("GET")) @@ -157,6 +166,9 @@ private CommonTestData() {} SoftwareIdentity.newBuilder().setSoftware("Software2").build(); public static final SoftwareIdentity SOFTWARE_IDENTITY_3 = SoftwareIdentity.newBuilder().setSoftware("Software3").build(); + + public static final SoftwareIdentity SOFTWARE_IDENTITY_4 = + SoftwareIdentity.newBuilder().setSoftware("mlflow").build(); public static final FingerprintData FINGERPRINT_DATA_1 = FingerprintData.fromProto( Fingerprints.newBuilder() diff --git a/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinterTest.java b/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinterTest.java index 8681041df..86b669f26 100644 --- a/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinterTest.java +++ b/google/fingerprinters/web/src/test/java/com/google/tsunami/plugins/fingerprinters/web/WebServiceFingerprinterTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostname; +import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostnameAndPort; import static com.google.tsunami.common.data.NetworkEndpointUtils.forIp; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.COMMON_LIB; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.FINGERPRINT_DATA_1; @@ -29,9 +30,11 @@ import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_2_ICON; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_3_CSS; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_3_ZIP; +import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_4_MLFLOW; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_IDENTITY_1; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_IDENTITY_2; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_IDENTITY_3; +import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.SOFTWARE_IDENTITY_4; import static com.google.tsunami.plugins.fingerprinters.web.CommonTestData.fakeUrl; import com.google.common.collect.ImmutableList; @@ -42,6 +45,7 @@ import com.google.inject.Guice; import com.google.inject.Provides; import com.google.inject.assistedinject.FactoryModuleBuilder; +import com.google.tsunami.common.data.NetworkEndpointUtils; import com.google.tsunami.common.net.http.HttpClientModule; import com.google.tsunami.plugins.fingerprinters.web.WebServiceFingerprinterConfigs.WebServiceFingerprinterCliOptions; import com.google.tsunami.plugins.fingerprinters.web.crawl.Crawler; @@ -52,6 +56,7 @@ import com.google.tsunami.proto.CrawlResult; import com.google.tsunami.proto.CrawlTarget; import com.google.tsunami.proto.FingerprintingReport; +import com.google.tsunami.proto.NetworkEndpoint; import com.google.tsunami.proto.NetworkService; import com.google.tsunami.proto.ServiceContext; import com.google.tsunami.proto.Software; @@ -60,9 +65,14 @@ import com.google.tsunami.proto.Version.VersionType; import com.google.tsunami.proto.VersionSet; import com.google.tsunami.proto.WebServiceContext; +import java.io.IOException; import java.util.Collection; import java.util.List; import javax.inject.Inject; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -74,12 +84,14 @@ public final class WebServiceFingerprinterTest { private final FakeCrawler fakeCrawler = new FakeCrawler(); private WebServiceFingerprinterCliOptions cliOptions; + private MockWebServer mockWebServer; @Inject WebServiceFingerprinter fingerprinter; @Before public void setUp() { cliOptions = new WebServiceFingerprinterCliOptions(); + mockWebServer = new MockWebServer(); Guice.createInjector( new AbstractModule() { @Override @@ -326,6 +338,62 @@ public void fingerprint_whenLimitContentSize_doNotRecordLargeCrawlResult() { .doesNotContain(SOFTWARE_3_CSS); } + @Test + public void fingerprint_mlflowServiceWithBasicAuth_fillsServiceContextWithApplication() + throws Exception { + fakeCrawler.setCrawlResults(ImmutableSet.of(SOFTWARE_4_MLFLOW)); + startMockMlflowWebServer(); + NetworkEndpoint endpoint = + forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()); + NetworkService networkService = + NetworkService.newBuilder().setNetworkEndpoint(endpoint).setServiceName("http").build(); + + FingerprintingReport fingerprintingReport = + fingerprinter.fingerprint(TargetInfo.getDefaultInstance(), networkService); + + assertThat(fingerprintingReport) + .comparingExpectedFieldsOnly() + .isEqualTo( + FingerprintingReport.newBuilder() + .addNetworkServices( + networkService.toBuilder() + .setServiceName(SOFTWARE_IDENTITY_4.getSoftware()) + .setServiceContext( + ServiceContext.newBuilder() + .setWebServiceContext( + WebServiceContext.newBuilder() + .setApplicationRoot( + String.format( + "http://%s/", + NetworkEndpointUtils.toUriAuthority(endpoint))) + .setSoftware( + Software.newBuilder() + .setName(SOFTWARE_IDENTITY_4.getSoftware()))))) + .build()); + } + + private void startMockMlflowWebServer() throws IOException { + final Dispatcher dispatcher = + new Dispatcher() { + final MockResponse unauthorizedResponse = + new MockResponse() + .setResponseCode(401) + .setBody( + "You are not authenticated. " + + "Please see https://www.mlflow.org/docs/latest/auth/index.html" + + "#authenticating-to-mlflow " + + "on how to authenticate"); + + @Override + public MockResponse dispatch(RecordedRequest request) { + return unauthorizedResponse; + } + }; + mockWebServer.setDispatcher(dispatcher); + mockWebServer.start(); + mockWebServer.url("/"); + } + private static NetworkService addServiceContext( NetworkService networkService, String appRoot,