Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlflow weak credential tester #455

Merged
merged 7 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ public boolean canAccept(NetworkService networkService) {
boolean canAcceptByCustomFingerprint = false;
logger.atInfo().log("probing Mlflow ping - custom fingerprint phase");

// we want to test mlflow versions above 2.5 which has basic authentication module
// these versions returned a 401 status code and a link to documentation about how to
// authenticate.
// We want to test weak credentials against mlflow versions above 2.5 which has basic
// authentication module.these versions return a 401 status code and a link to documentation
// about how to authenticate.
var uriAuthority = NetworkEndpointUtils.toUriAuthority(networkService.getNetworkEndpoint());
var pingApiUrl = String.format("http://%s/%s", uriAuthority, "ping");
try {
Expand All @@ -85,8 +85,9 @@ public boolean canAccept(NetworkService networkService) {
.bodyString()
.get()
.contains(
"You are not authenticated. "
+ "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow "
"You are not authenticated. Please see "
+ "https://www.mlflow.org/docs/latest/auth/index.html"
+ "#authenticating-to-mlflow"
pisqu4red marked this conversation as resolved.
Show resolved Hide resolved
+ "on how to authenticate");
}
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.flogger.GoogleLogger;
import com.google.inject.Guice;
import com.google.tsunami.common.net.db.ConnectionProviderInterface;
import com.google.tsunami.common.net.http.HttpClientModule;
import com.google.tsunami.plugins.detectors.credentials.genericweakcredentialdetector.provider.TestCredential;
import com.google.tsunami.proto.NetworkService;
import java.io.IOException;
import java.sql.Connection;
import java.util.Objects;
import java.util.Optional;
import javax.inject.Inject;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
Expand All @@ -41,168 +45,160 @@
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
import javax.inject.Inject;
import java.io.IOException;
import java.sql.Connection;
import java.util.Objects;
import java.util.Optional;

/**
* Tests for {@link MlFlowCredentialTester}.
*/
/** Tests for {@link MlFlowCredentialTester}. */
@RunWith(JUnit4.class)
public class MlFlowCredentialTesterTest {
@Rule
public MockitoRule rule = MockitoJUnit.rule();
@Mock
private ConnectionProviderInterface mockConnectionProvider;
@Mock
private Connection mockConnection;
@Inject
private MlFlowCredentialTester tester;
private MockWebServer mockWebServer;
private static final TestCredential WEAK_CRED_1 =
TestCredential.create("admin", Optional.of("password"));
private static final TestCredential WEAK_CRED_2 =
TestCredential.create("username", Optional.of("password"));
private static final TestCredential WRONG_CRED_1 =
TestCredential.create("wrong", Optional.of("wrong"));

private static final String WEAK_CRED_AUTH_1 = "basic dXNlcm5hbWU6cGFzc3dvcmQ=";
private static final String WEAK_CRED_AUTH_2 = "basic YWRtaW46cGFzc3dvcmQ=";
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();

@Before
public void setup() {
mockWebServer = new MockWebServer();
Guice.createInjector(new HttpClientModule.Builder().build()).injectMembers(this);
}

@Test
public void detect_weakCredentialsExists_returnsWeakCredentials() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1)))
.containsExactly(WEAK_CRED_1);
mockWebServer.shutdown();
}

@Test
public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(
tester.testValidCredentials(
targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2)))
.containsExactly(WEAK_CRED_1);
}

@Test
public void detect_canAccept() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(tester.canAccept(targetNetworkService)).isTrue();
}

@Test
public void detect_weakCredentialsExistAndMlflowInForeignLanguage_returnsFirstWeakCredentials()
throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(
tester.testValidCredentials(
targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2)))
.containsExactly(WEAK_CRED_1);
}

@Test
public void detect_noWeakCredentials_returnsNoCredentials() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();
assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WRONG_CRED_1)))
.isEmpty();
}

@Test
public void detect_nonMlflowService_skips() throws Exception {
when(mockConnectionProvider.getConnection(any(), any(), any())).thenReturn(mockConnection);
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(forHostnameAndPort("example.com", 8080))
.setServiceName("http")
.build();

assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1)))
.isEmpty();
verifyNoInteractions(mockConnectionProvider);
}

private void startMockWebServer() throws IOException {
final Dispatcher dispatcher =
new Dispatcher() {
final MockResponse unauthorizedResponse =
new MockResponse()
.setResponseCode(401)
.setBody(
"You are not authenticated. "
+ "Please see https://www.mlflow.org/docs/latest/auth/index.html#authenticating-to-mlflow "
+ "on how to authenticate");

@Override
public MockResponse dispatch(RecordedRequest request) {
String authorizationHeader = request.getHeaders().get("Authorization");
if (authorizationHeader == null) {
return unauthorizedResponse;
}
if (request.getPath().matches("/api/2.0/mlflow/users/get\\?.*")
&& Objects.equals(request.getMethod(), "GET")) {
boolean isDefaultCredentials =
authorizationHeader.equals(WEAK_CRED_AUTH_1)
|| authorizationHeader.equals(WEAK_CRED_AUTH_2);
if (isDefaultCredentials) {
return new MockResponse()
.setResponseCode(200)
.setBody(
"{\"user\":{\"experiment_permissions\":[],\"id\":1,\"is_admin\":true,\"registered_model_permissions\":[],"
+ "\"username\":\"admin\"}}");
} else {
return unauthorizedResponse;
}
}
return new MockResponse().setResponseCode(404);
}
};
mockWebServer.setDispatcher(dispatcher);
mockWebServer.start();
mockWebServer.url("/");
}
@Rule public MockitoRule rule = MockitoJUnit.rule();
@Mock private ConnectionProviderInterface mockConnectionProvider;
@Mock private Connection mockConnection;
@Inject private MlFlowCredentialTester tester;
private MockWebServer mockWebServer;
private static final TestCredential WEAK_CRED_1 =
TestCredential.create("admin", Optional.of("password"));
private static final TestCredential WEAK_CRED_2 =
TestCredential.create("username", Optional.of("password"));
private static final TestCredential WRONG_CRED_1 =
TestCredential.create("wrong", Optional.of("wrong"));

// The base64 encoding of default authentication username:password pairs which the tester will
// send these headers to our mock webserver
private static final String WEAK_CRED_AUTH_1 = "basic dXNlcm5hbWU6cGFzc3dvcmQ=";
private static final String WEAK_CRED_AUTH_2 = "basic YWRtaW46cGFzc3dvcmQ=";

@Before
public void setup() {
mockWebServer = new MockWebServer();
Guice.createInjector(new HttpClientModule.Builder().build()).injectMembers(this);
}

@Test
public void detect_weakCredentialsExists_returnsWeakCredentials() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1)))
.containsExactly(WEAK_CRED_1);
mockWebServer.shutdown();
}

@Test
public void detect_weakCredentialsExist_returnsFirstWeakCredentials() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(
tester.testValidCredentials(
targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2)))
.containsExactly(WEAK_CRED_1);
}

@Test
public void detect_canAccept() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(tester.canAccept(targetNetworkService)).isTrue();
}

@Test
public void detect_weakCredentialsExistAndMlflowInForeignLanguage_returnsFirstWeakCredentials()
throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();

assertThat(
tester.testValidCredentials(
targetNetworkService, ImmutableList.of(WEAK_CRED_1, WEAK_CRED_2)))
.containsExactly(WEAK_CRED_1);
}

@Test
public void detect_noWeakCredentials_returnsNoCredentials() throws Exception {
startMockWebServer();
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(
forHostnameAndPort(mockWebServer.getHostName(), mockWebServer.getPort()))
.setServiceName("http")
.build();
assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WRONG_CRED_1)))
.isEmpty();
}

@Test
public void detect_nonMlflowService_skips() throws Exception {
when(mockConnectionProvider.getConnection(any(), any(), any())).thenReturn(mockConnection);
NetworkService targetNetworkService =
NetworkService.newBuilder()
.setNetworkEndpoint(forHostnameAndPort("example.com", 8080))
.setServiceName("http")
.build();

assertThat(tester.testValidCredentials(targetNetworkService, ImmutableList.of(WEAK_CRED_1)))
.isEmpty();
verifyNoInteractions(mockConnectionProvider);
}

private void startMockWebServer() throws IOException {
final Dispatcher dispatcher =
new Dispatcher() {
final MockResponse unauthorizedResponse =
new MockResponse()
.setResponseCode(401)
.setBody(
"You are not authenticated. "
+ "Please see https://www.mlflow.org/docs/latest/auth/index.html"
+ "#authenticating-to-mlflow "
+ "on how to authenticate");

@Override
public MockResponse dispatch(RecordedRequest request) {
String authorizationHeader = request.getHeaders().get("Authorization");
if (authorizationHeader == null) {
return unauthorizedResponse;
}
if (request.getPath().matches("/api/2.0/mlflow/users/get\\?.*")
&& Objects.equals(request.getMethod(), "GET")) {
boolean isDefaultCredentials =
authorizationHeader.equals(WEAK_CRED_AUTH_1)
|| authorizationHeader.equals(WEAK_CRED_AUTH_2);
if (isDefaultCredentials) {
return new MockResponse()
.setResponseCode(200)
.setBody(
"{\"user\":{\"experiment_permissions\":[],\"id\":1,\"is_admin\":true,"
+ "\"registered_model_permissions\":[],"
+ "\"username\":\"admin\"}}");
} else {
return unauthorizedResponse;
}
}
return new MockResponse().setResponseCode(404);
}
};
mockWebServer.setDispatcher(dispatcher);
mockWebServer.start();
mockWebServer.url("/");
}
}