diff --git a/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java b/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java index d7c106c8f..e24b567f6 100644 --- a/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java +++ b/google/detectors/credentials/generic_weak_credential_detector/src/main/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTester.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.testers.mlflow; import static com.google.common.base.Preconditions.checkNotNull; @@ -26,7 +27,6 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.google.gson.JsonSyntaxException; -import com.google.protobuf.ByteString; import com.google.tsunami.common.data.NetworkEndpointUtils; import com.google.tsunami.common.data.NetworkServiceUtils; import com.google.tsunami.common.net.http.HttpClient; @@ -36,151 +36,149 @@ import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.provider.TestCredential; import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.tester.CredentialTester; import com.google.tsunami.proto.NetworkService; - import java.io.IOException; import java.util.Base64; import java.util.List; import javax.inject.Inject; -/** - * Credential tester specifically for mlflow. - */ +/** Credential tester specifically for mlflow. */ public final class MlFlowCredentialTester extends CredentialTester { - private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); - private final HttpClient httpClient; - - @Inject - MlFlowCredentialTester(HttpClient httpClient) { - this.httpClient = checkNotNull(httpClient); + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + private final HttpClient httpClient; + + @Inject + MlFlowCredentialTester(HttpClient httpClient) { + this.httpClient = checkNotNull(httpClient); + } + + @Override + public String name() { + return "MlFlowCredentialTester"; + } + + @Override + public String description() { + return "MlFlow credential tester."; + } + + @Override + public boolean canAccept(NetworkService networkService) { + if (!NetworkServiceUtils.isWebService(networkService)) { + return false; } - @Override - public String name() { - return "MlFlowCredentialTester"; + boolean canAcceptByCustomFingerprint = false; + logger.atInfo().log("probing Mlflow ping - custom fingerprint phase"); + + // we want to test mlflow versions above 2.5 which has basic authentication module + // these versions returned a 401 status code and a link to documentation about how to + // authenticate. + var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint()); + var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping"); + try { + HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build()); + + if (apiPingResponse.status() == HttpStatus.UNAUTHORIZED + && apiPingResponse.bodyString().isPresent()) { + canAcceptByCustomFingerprint = + apiPingResponse + .bodyString() + .get() + .contains( + "You are not authenticated. " + + "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow " + + "on how to authenticate"); + } + } catch (IOException e) { + logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl); + return false; } - @Override - public String description() { - return "MlFlow credential tester."; + return canAcceptByCustomFingerprint; + } + + @Override + public boolean batched() { + return true; + } + + @Override + public ImmutableList testValidCredentials( + NetworkService networkService, List credentials) { + // Always return 1st weak credential to gracefully handle no auth configured case, where we + // return empty credential instead of all the weak credentials + return credentials.stream() + .filter(cred -> isMlFlowAccessible(networkService, cred)) + .findFirst() + .map(ImmutableList::of) + .orElseGet(ImmutableList::of); + } + + private boolean isMlFlowAccessible(NetworkService networkService, TestCredential credential) { + var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint()); + var url = + String.format( + "http://%s/%s?username=%s", + uriAuthority, "api/2.0/mlflow/users/get", credential.username()); + try { + logger.atInfo().log( + "url: %s, username: %s, password: %s", + url, credential.username(), credential.password().orElse("")); + HttpResponse response = sendRequestWithCredentials(url, credential); + return response.status().isSuccess() + && response + .bodyString() + .map(MlFlowCredentialTester::bodyContainsSuccessfulUserInfo) + .orElse(false); + } catch (IOException e) { + logger.atWarning().withCause(e).log("Unable to query '%s'.", url); + return false; } - - @Override - public boolean canAccept(NetworkService networkService) { - if (!NetworkServiceUtils.isWebService(networkService)) { - return false; - } - - boolean canAcceptByCustomFingerprint = false; - logger.atInfo().log("probing Mlflow ping - custom fingerprint phase"); - - // we want to test mlflow versions above 2.5 which has basic authentication module - // these versions returned a 401 status code and a link to documentation about how to - // authenticate. - var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint()); - var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping"); - try { - HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build()); - - if (apiPingResponse.status() == HttpStatus.UNAUTHORIZED - && apiPingResponse.bodyString().isPresent()) { - canAcceptByCustomFingerprint = - apiPingResponse - .bodyString() - .get() - .contains( - "You are not authenticated. " - + "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow " - + "on how to authenticate"); - } - } catch (IOException e) { - logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl); - return false; - } - - return canAcceptByCustomFingerprint; + } + + private HttpResponse sendRequestWithCredentials(String url, TestCredential credential) + throws IOException { + // For testing no-auth configured case, no auth header is passed in + if (Strings.isNullOrEmpty(credential.username()) + && Strings.isNullOrEmpty(credential.password().orElse(""))) { + return httpClient.send(post(url).withEmptyHeaders().build()); } - @Override - public boolean batched() { + return httpClient.send( + get(url) + .setHeaders( + HttpHeaders.builder() + .addHeader( + "Authorization", + "basic " + + Base64.getEncoder() + .encodeToString( + (credential.username() + ":" + credential.password().orElse("")) + .getBytes(UTF_8))) + .build()) + .build()); + } + + /** + * A successful authenticated request to the /api/2.0/mlflow/users/get?username=admin endpoint + * returns a JSON with a root key like the following: + * {"user":{"experiment_permissions":[],"id":1,"is_admin":true,"registered_model_permissions":[], + * "username":"admin"}} + */ + private static boolean bodyContainsSuccessfulUserInfo(String responseBody) { + try { + JsonObject response = JsonParser.parseString(responseBody).getAsJsonObject(); + + if (response.has("user")) { + logger.atInfo().log("Successfully received a mlflow user info"); return true; + } else { + return false; + } + } catch (JsonSyntaxException e) { + logger.atWarning().withCause(e).log( + "An error occurred while parsing the json response: %s", responseBody); + return false; } - - @Override - public ImmutableList testValidCredentials( - NetworkService networkService, List credentials) { - // Always return 1st weak credential to gracefully handle no auth configured case, where we - // return empty credential instead of all the weak credentials - return credentials.stream() - .filter(cred -> isMlFlowAccessible(networkService, cred)) - .findFirst() - .map(ImmutableList::of) - .orElseGet(ImmutableList::of); - } - - private boolean isMlFlowAccessible(NetworkService networkService, TestCredential credential) { - var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint()); - var url = String.format("http://%s/%s", uriAuthority, "api/2.0/mlflow/users/create"); - try { - logger.atInfo().log( - "url: %s, username: %s, password: %s", - url, credential.username(), credential.password().orElse("")); - HttpResponse response = sendRequestWithCredentials(url, credential); - return response.status().isSuccess() - && response - .bodyString() - .map(MlFlowCredentialTester::bodyContainsSuccessfulUserRegistration) - .orElse(false); - } catch (IOException e) { - logger.atWarning().withCause(e).log("Unable to query '%s'.", url); - return false; - } - } - - private HttpResponse sendRequestWithCredentials(String url, TestCredential credential) - throws IOException { - // For testing no-auth configured case, no auth header is passed in - if (Strings.isNullOrEmpty(credential.username()) - && Strings.isNullOrEmpty(credential.password().orElse(""))) { - return httpClient.send(post(url).withEmptyHeaders().build()); - } - return httpClient.send( - post(url) - .setHeaders( - HttpHeaders.builder() - .addHeader( - "Authorization", - "basic " - + Base64.getEncoder() - .encodeToString( - (credential.username() + ":" + credential.password().orElse("")) - .getBytes(UTF_8))) - .build()) - .setRequestBody( - ByteString.copyFromUtf8( - "{\"username\": \"googleTsunamiSecurityScanner\", \"password\": \"googleTsunamiSecurityScanner\"}")) - .build()); - } - - /** - * A successful authenticated request to the /api/2.0/mlflow/users/create endpoint returns a JSON - * with a root key like the following: - * {"user":{"experiment_permissions":[],"id":4,"is_admin":false,"registered_model_permissions":[], - * "username":"googleTsunamiSecurityScanner"}} - */ - private static boolean bodyContainsSuccessfulUserRegistration(String responseBody) { - try { - JsonObject response = JsonParser.parseString(responseBody).getAsJsonObject(); - - if (response.has("user")) { - logger.atInfo().log("Successfully created a new mlflow user as an admin"); - return true; - } else { - return false; - } - } catch (JsonSyntaxException e) { - logger.atWarning().withCause(e).log( - "An error occurred while parsing the json response: %s", responseBody); - return false; - } - } + } } diff --git a/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java b/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java index b06f91e26..6c6424221 100644 --- a/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java +++ b/google/detectors/credentials/generic_weak_credential_detector/src/test/java/com/google/tsunami/plugins/detectors/credentials/genericweakcredentialdetector/testers/mlflow/MlFlowCredentialTesterTest.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.testers.mlflow; import com.google.common.collect.ImmutableList; @@ -34,166 +35,167 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; - import javax.inject.Inject; import java.io.IOException; import java.sql.Connection; import java.util.Objects; import java.util.Optional; - import static com.google.common.truth.Truth.assertThat; import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostnameAndPort; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -/** - * Tests for {@link MlFlowCredentialTester}. - */ +/** Tests for {@link MlFlowCredentialTester}. */ @RunWith(JUnit4.class) public class MlFlowCredentialTesterTest { - @Rule - public MockitoRule rule = MockitoJUnit.rule(); - @Mock - private ConnectionProviderInterface mockConnectionProvider; - @Mock - private Connection mockConnection; - @Inject - private MlFlowCredentialTester tester; - private MockWebServer mockWebServer; - private static final TestCredential WEAK_CRED_1 = - TestCredential.create("admin", Optional.of("password")); - private static final TestCredential WEAK_CRED_2 = - TestCredential.create("username", Optional.of("password")); - private static final TestCredential WRONG_CRED_1 = - TestCredential.create("wrong", Optional.of("wrong")); - - private static final String WEAK_CRED_AUTH_1 = "basic dXNlcm5hbWU6cGFzc3dvcmQ="; - private static final String WEAK_CRED_AUTH_2 = "basic YWRtaW46cGFzc3dvcmQ="; - private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); - - @Before - public void setup() { - mockWebServer = new MockWebServer(); - Guice.createInjector(new HttpClientModule.Builder().build()).injectMembers(this); - } - - @Test - public void detect_weakCredentialsExists_returnsWeakCredentials() throws Exception { - startMockWebServer(); - NetworkService targetNetworkService = - NetworkService.newBuilder() - .setNetworkEndpoint( - forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") - .build(); - - assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1))) - .containsExactly(WEAK_CRED_1); - mockWebServer.shutdown(); - } - - @Test - public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exception { - startMockWebServer(); - NetworkService targetNetworkService = - NetworkService.newBuilder() - .setNetworkEndpoint( - forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") - .build(); - - assertThat( - tester.testValidCredentials( - targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2))) - .containsExactly(WEAK_CRED_1); - } - - @Test - public void detect_canAccept() throws Exception { - startMockWebServer(); - NetworkService targetNetworkService = - NetworkService.newBuilder() - .setNetworkEndpoint( - forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") - .build(); - - assertThat(tester.canAccept(targetNetworkService)).isTrue(); - } - - @Test - public void detect_weakCredentialsExistAndMlflowInForeignLanguage_returnsFirstWeakCredentials() - throws Exception { - startMockWebServer(); - NetworkService targetNetworkService = - NetworkService.newBuilder() - .setNetworkEndpoint( - forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") - .build(); - - assertThat( - tester.testValidCredentials( - targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2))) - .containsExactly(WEAK_CRED_1); - } - - @Test - public void detect_noWeakCredentials_returnsNoCredentials() throws Exception { - startMockWebServer(); - NetworkService targetNetworkService = - NetworkService.newBuilder() - .setNetworkEndpoint( - forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) - .setServiceName("http") - .build(); - assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WRONG_CRED_1))) - .isEmpty(); - } - - @Test - public void detect_nonMlflowService_skips() throws Exception { - when(mockConnectionProvider.getConnection(any(), any(), any())).thenReturn(mockConnection); - NetworkService targetNetworkService = - NetworkService.newBuilder() - .setNetworkEndpoint(forHostnameAndPort("example.com", 8080)) - .setServiceName("http") - .build(); - - assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1))) - .isEmpty(); - verifyNoInteractions(mockConnectionProvider); - } - - private void startMockWebServer() - throws IOException { - final Dispatcher dispatcher = new Dispatcher() { - final MockResponse unauthorizedResponse = new MockResponse().setResponseCode(401).setBody("You are not authenticated. " - + "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow " - + "on how to authenticate"); - - @Override - public MockResponse dispatch(RecordedRequest request) { - String authorizationHeader = request.getHeaders().get("Authorization"); - if (authorizationHeader == null) { - return unauthorizedResponse; - } - if (Objects.equals(request.getPath(), "/api/2.0/mlflow/users/create") && Objects.equals(request.getMethod(), "POST")) { - boolean isDefaultCredentials = authorizationHeader.equals(WEAK_CRED_AUTH_1) || authorizationHeader.equals(WEAK_CRED_AUTH_2); - if (isDefaultCredentials) { - return new MockResponse().setResponseCode(200) - .setBody("{\"user\":{\"experiment_permissions\":[],\"id\":4,\"is_admin\":false,\"registered_model_permissions\":[],\n" + - " \"username\":\"googleTsunamiSecurityScanner\"}}"); - } else { - return unauthorizedResponse; - } - } - return new MockResponse().setResponseCode(404); + @Rule public MockitoRule rule = MockitoJUnit.rule(); + @Mock private ConnectionProviderInterface mockConnectionProvider; + @Mock private Connection mockConnection; + @Inject private MlFlowCredentialTester tester; + private MockWebServer mockWebServer; + private static final TestCredential WEAK_CRED_1 = + TestCredential.create("admin", Optional.of("password")); + private static final TestCredential WEAK_CRED_2 = + TestCredential.create("username", Optional.of("password")); + private static final TestCredential WRONG_CRED_1 = + TestCredential.create("wrong", Optional.of("wrong")); + + private static final String WEAK_CRED_AUTH_1 = "basic dXNlcm5hbWU6cGFzc3dvcmQ="; + private static final String WEAK_CRED_AUTH_2 = "basic YWRtaW46cGFzc3dvcmQ="; + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + + @Before + public void setup() { + mockWebServer = new MockWebServer(); + Guice.createInjector(new HttpClientModule.Builder().build()).injectMembers(this); + } + + @Test + public void detect_weakCredentialsExists_returnsWeakCredentials() throws Exception { + startMockWebServer(); + NetworkService targetNetworkService = + NetworkService.newBuilder() + .setNetworkEndpoint( + forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) + .setServiceName("http") + .build(); + + assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1))) + .containsExactly(WEAK_CRED_1); + mockWebServer.shutdown(); + } + + @Test + public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exception { + startMockWebServer(); + NetworkService targetNetworkService = + NetworkService.newBuilder() + .setNetworkEndpoint( + forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) + .setServiceName("http") + .build(); + + assertThat( + tester.testValidCredentials( + targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2))) + .containsExactly(WEAK_CRED_1); + } + + @Test + public void detect_canAccept() throws Exception { + startMockWebServer(); + NetworkService targetNetworkService = + NetworkService.newBuilder() + .setNetworkEndpoint( + forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) + .setServiceName("http") + .build(); + + assertThat(tester.canAccept(targetNetworkService)).isTrue(); + } + + @Test + public void detect_weakCredentialsExistAndMlflowInForeignLanguage_returnsFirstWeakCredentials() + throws Exception { + startMockWebServer(); + NetworkService targetNetworkService = + NetworkService.newBuilder() + .setNetworkEndpoint( + forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) + .setServiceName("http") + .build(); + + assertThat( + tester.testValidCredentials( + targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2))) + .containsExactly(WEAK_CRED_1); + } + + @Test + public void detect_noWeakCredentials_returnsNoCredentials() throws Exception { + startMockWebServer(); + NetworkService targetNetworkService = + NetworkService.newBuilder() + .setNetworkEndpoint( + forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort())) + .setServiceName("http") + .build(); + assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WRONG_CRED_1))) + .isEmpty(); + } + + @Test + public void detect_nonMlflowService_skips() throws Exception { + when(mockConnectionProvider.getConnection(any(), any(), any())).thenReturn(mockConnection); + NetworkService targetNetworkService = + NetworkService.newBuilder() + .setNetworkEndpoint(forHostnameAndPort("example.com", 8080)) + .setServiceName("http") + .build(); + + assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1))) + .isEmpty(); + verifyNoInteractions(mockConnectionProvider); + } + + private void startMockWebServer() throws IOException { + final Dispatcher dispatcher = + new Dispatcher() { + final MockResponse unauthorizedResponse = + new MockResponse() + .setResponseCode(401) + .setBody( + "You are not authenticated. " + + "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow " + + "on how to authenticate"); + + @Override + public MockResponse dispatch(RecordedRequest request) { + String authorizationHeader = request.getHeaders().get("Authorization"); + if (authorizationHeader == null) { + return unauthorizedResponse; + } + if (request.getPath().matches("/api/2.0/mlflow/users/get\\?.*") + && Objects.equals(request.getMethod(), "GET")) { + boolean isDefaultCredentials = + authorizationHeader.equals(WEAK_CRED_AUTH_1) + || authorizationHeader.equals(WEAK_CRED_AUTH_2); + if (isDefaultCredentials) { + return new MockResponse() + .setResponseCode(200) + .setBody( + "{\"user\":{\"experiment_permissions\":[],\"id\":1,\"is_admin\":true,\"registered_model_permissions\":[]," + + "\"username\":\"admin\"}}"); + } else { + return unauthorizedResponse; + } } + return new MockResponse().setResponseCode(404); + } }; - mockWebServer.setDispatcher(dispatcher); - mockWebServer.start(); - mockWebServer.url("/"); - } + mockWebServer.setDispatcher(dispatcher); + mockWebServer.start(); + mockWebServer.url("/"); + } }