Skip to content

Commit

Permalink
apply google format
Browse files Browse the repository at this point in the history
  • Loading branch information
lanced00m committed Apr 30, 2024
1 parent 0de8e9a commit 2576c7f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.testers.mlflow;

import static com.google.common.base.Preconditions.checkNotNull;
Expand All @@ -26,7 +27,6 @@
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonSyntaxException;
import com.google.protobuf.ByteString;
import com.google.tsunami.common.data.NetworkEndpointUtils;
import com.google.tsunami.common.data.NetworkServiceUtils;
import com.google.tsunami.common.net.http.HttpClient;
Expand All @@ -36,152 +36,149 @@
import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.provider.TestCredential;
import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.tester.CredentialTester;
import com.google.tsunami.proto.NetworkService;

import java.io.IOException;
import java.util.Base64;
import java.util.List;
import javax.inject.Inject;

/**
* Credential tester specifically for mlflow.
*/
/** Credential tester specifically for mlflow. */
public final class MlFlowCredentialTester extends CredentialTester {
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
private final HttpClient httpClient;

@Inject
MlFlowCredentialTester(HttpClient httpClient) {
this.httpClient = checkNotNull(httpClient);
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
private final HttpClient httpClient;

@Inject
MlFlowCredentialTester(HttpClient httpClient) {
this.httpClient = checkNotNull(httpClient);
}

@Override
public String name() {
return "MlFlowCredentialTester";
}

@Override
public String description() {
return "MlFlow credential tester.";
}

@Override
public boolean canAccept(NetworkService networkService) {
if (!NetworkServiceUtils.isWebService(networkService)) {
return false;
}

@Override
public String name() {
return "MlFlowCredentialTester";
boolean canAcceptByCustomFingerprint = false;
logger.atInfo().log("probing Mlflow ping - custom fingerprint phase");

// we want to test mlflow versions above 2.5 which has basic authentication module
// these versions returned a 401 status code and a link to documentation about how to
// authenticate.
var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint());
var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping");
try {
HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build());

if (apiPingResponse.status() == HttpStatus.UNAUTHORIZED
&& apiPingResponse.bodyString().isPresent()) {
canAcceptByCustomFingerprint =
apiPingResponse
.bodyString()
.get()
.contains(
"You are not authenticated. "
+ "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow "
+ "on how to authenticate");
}
} catch (IOException e) {
logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl);
return false;
}

@Override
public String description() {
return "MlFlow credential tester.";
return canAcceptByCustomFingerprint;
}

@Override
public boolean batched() {
return true;
}

@Override
public ImmutableList<TestCredential> testValidCredentials(
NetworkService networkService, List<TestCredential> credentials) {
// Always return 1st weak credential to gracefully handle no auth configured case, where we
// return empty credential instead of all the weak credentials
return credentials.stream()
.filter(cred -> isMlFlowAccessible(networkService, cred))
.findFirst()
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

private boolean isMlFlowAccessible(NetworkService networkService, TestCredential credential) {
var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint());
var url =
String.format(
"http://%s/%s?username=%s",
uriAuthority, "api/2.0/mlflow/users/get", credential.username());
try {
logger.atInfo().log(
"url: %s, username: %s, password: %s",
url, credential.username(), credential.password().orElse(""));
HttpResponse response = sendRequestWithCredentials(url, credential);
return response.status().isSuccess()
&& response
.bodyString()
.map(MlFlowCredentialTester::bodyContainsSuccessfulUserInfo)
.orElse(false);
} catch (IOException e) {
logger.atWarning().withCause(e).log("Unable to query '%s'.", url);
return false;
}

@Override
public boolean canAccept(NetworkService networkService) {
if (!NetworkServiceUtils.isWebService(networkService)) {
return false;
}

boolean canAcceptByCustomFingerprint = false;
logger.atInfo().log("probing Mlflow ping - custom fingerprint phase");

// we want to test mlflow versions above 2.5 which has basic authentication module
// these versions returned a 401 status code and a link to documentation about how to
// authenticate.
var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint());
var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping");
try {
HttpResponse apiPingResponse = httpClient.send(get(pingApiUrl).withEmptyHeaders().build());

if (apiPingResponse.status() == HttpStatus.UNAUTHORIZED
&& apiPingResponse.bodyString().isPresent()) {
canAcceptByCustomFingerprint =
apiPingResponse
.bodyString()
.get()
.contains(
"You are not authenticated. "
+ "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow "
+ "on how to authenticate");
}
} catch (IOException e) {
logger.atWarning().withCause(e).log("Unable to query '%s'.", pingApiUrl);
return false;
}

return canAcceptByCustomFingerprint;
}

private HttpResponse sendRequestWithCredentials(String url, TestCredential credential)
throws IOException {
// For testing no-auth configured case, no auth header is passed in
if (Strings.isNullOrEmpty(credential.username())
&& Strings.isNullOrEmpty(credential.password().orElse(""))) {
return httpClient.send(post(url).withEmptyHeaders().build());
}

@Override
public boolean batched() {
return httpClient.send(
get(url)
.setHeaders(
HttpHeaders.builder()
.addHeader(
"Authorization",
"basic "
+ Base64.getEncoder()
.encodeToString(
(credential.username() + ":" + credential.password().orElse(""))
.getBytes(UTF_8)))
.build())
.build());
}

/**
* A successful authenticated request to the /api/2.0/mlflow/users/get?username=admin endpoint
* returns a JSON with a root key like the following:
* {"user":{"experiment_permissions":[],"id":1,"is_admin":true,"registered_model_permissions":[],
* "username":"admin"}}
*/
private static boolean bodyContainsSuccessfulUserInfo(String responseBody) {
try {
JsonObject response = JsonParser.parseString(responseBody).getAsJsonObject();

if (response.has("user")) {
logger.atInfo().log("Successfully received a mlflow user info");
return true;
} else {
return false;
}
} catch (JsonSyntaxException e) {
logger.atWarning().withCause(e).log(
"An error occurred while parsing the json response: %s", responseBody);
return false;
}

@Override
public ImmutableList<TestCredential> testValidCredentials(
NetworkService networkService, List<TestCredential> credentials) {
// Always return 1st weak credential to gracefully handle no auth configured case, where we
// return empty credential instead of all the weak credentials
return credentials.stream()
.filter(cred -> isMlFlowAccessible(networkService, cred))
.findFirst()
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

private boolean isMlFlowAccessible(NetworkService networkService, TestCredential credential) {
var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint());
var url =
String.format(
"http://%s/%s?username=%s",
uriAuthority, "api/2.0/mlflow/users/get", credential.username());
try {
logger.atInfo().log(
"url: %s, username: %s, password: %s",
url, credential.username(), credential.password().orElse(""));
HttpResponse response = sendRequestWithCredentials(url, credential);
return response.status().isSuccess()
&& response
.bodyString()
.map(MlFlowCredentialTester::bodyContainsSuccessfulUserInfo)
.orElse(false);
} catch (IOException e) {
logger.atWarning().withCause(e).log("Unable to query '%s'.", url);
return false;
}
}

private HttpResponse sendRequestWithCredentials(String url, TestCredential credential)
throws IOException {
// For testing no-auth configured case, no auth header is passed in
if (Strings.isNullOrEmpty(credential.username())
&& Strings.isNullOrEmpty(credential.password().orElse(""))) {
return httpClient.send(post(url).withEmptyHeaders().build());
}

return httpClient.send(
get(url)
.setHeaders(
HttpHeaders.builder()
.addHeader(
"Authorization",
"basic "
+ Base64.getEncoder()
.encodeToString(
(credential.username() + ":" + credential.password().orElse(""))
.getBytes(UTF_8)))
.build())
.build());
}

/**
* A successful authenticated request to the /api/2.0/mlflow/users/get?username=admin endpoint
* returns a JSON with a root key like the following:
* {"user":{"experiment_permissions":[],"id":1,"is_admin":true,"registered_model_permissions":[],
* "username":"admin"}}
*/
private static boolean bodyContainsSuccessfulUserInfo(String responseBody) {
try {
JsonObject response = JsonParser.parseString(responseBody).getAsJsonObject();

if (response.has("user")) {
logger.atInfo().log("Successfully received a mlflow user info");
return true;
} else {
return false;
}
} catch (JsonSyntaxException e) {
logger.atWarning().withCause(e).log(
"An error occurred while parsing the json response: %s", responseBody);
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.testers.mlflow;

import static com.google.common.truth.Truth.assertThat;
import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostnameAndPort;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.flogger.GoogleLogger;
import com.google.inject.Guice;
Expand All @@ -34,19 +41,12 @@
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

import javax.inject.Inject;
import java.io.IOException;
import java.sql.Connection;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.truth.Truth.assertThat;
import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostnameAndPort;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

/**
* Tests for {@link MlFlowCredentialTester}.
*/
Expand Down

0 comments on commit 2576c7f

Please sign in to comment.