From 8fa07f88e1dc8c8f072e92d28a69d7294cf14ee5 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Wed, 4 Oct 2023 17:52:22 -0700 Subject: [PATCH] handle case where user accidentally sets doesVersionCreateModelGroup to true Signed-off-by: Bhavana Ramaram --- .../TransportRegisterModelAction.java | 1 + .../TransportRegisterModelActionTests.java | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index a6e1cebe85..c739a29faa 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -262,6 +262,7 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis listener.onFailure(e); })); } else { + registerModelInput.setDoesVersionCreateModelGroup(false); registerModel(registerModelInput, listener); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 66c78d2431..5153b2845b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -495,6 +495,42 @@ public void test_ModelNameAlreadyExists() throws IOException { verify(actionListener).onResponse(argumentCaptor.capture()); } + public void test_DoesVersionCreateModelGroupFieldSetToTrueByUserByMistake() throws IOException { + when(node1.getId()).thenReturn("NodeId1"); + when(node2.getId()).thenReturn("NodeId2"); + MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); + doAnswer(invocation -> { + ActionListenerResponseHandler handler = invocation.getArgument(3); + handler.handleResponse(forwardResponse); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.BATCH_RCF) + .modelGroupId("model_group_ID") + .modelName("Test Model") + .modelConfig( + new TextEmbeddingModelConfig( + "CUSTOM", + 123, + TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, + "all config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ) + ) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .url("http://test_url") + .doesVersionCreateModelGroup(true) + .build(); + + transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> {