From f80b6b243775ebad1a596a06c76d11024395546e Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Thu, 6 Jun 2024 14:42:15 -0700 Subject: [PATCH] adding pretrained model templates --- build.gradle | 2 + .../flowframework/common/DefaultUseCases.java | 15 +++ .../defaults/hybrid-search-defaults.json | 3 +- ...brid-search-with-local-model-defaults.json | 23 ++++ .../defaults/multi-modal-search-defaults.json | 4 +- ...timodal-search-bedrock-titan-defaults.json | 4 +- ...ntic-search-with-local-model-defaults.json | 20 ++++ .../hybrid-search-template.json | 8 +- ...brid-search-with-local-model-template.json | 109 ++++++++++++++++++ .../multi-modal-search-template.json | 7 +- ...al-search-with-bedrock-titan-template.json | 7 +- ...eural-sparse-local-biencoder-template.json | 3 - .../semantic-search-template.json | 3 - ...ntic-search-with-local-model-template.json | 86 ++++++++++++++ ...ith-model-and-query-enricher-template.json | 3 - .../semantic-search-with-model-template.json | 3 - ...c-search-with-query-enricher-template.json | 3 - .../FlowFrameworkRestTestCase.java | 71 +++++++++++- .../rest/FlowFrameworkRestApiIT.java | 80 ++++++++++++- 19 files changed, 413 insertions(+), 41 deletions(-) create mode 100644 src/main/resources/defaults/hybrid-search-with-local-model-defaults.json create mode 100644 src/main/resources/defaults/semantic-search-with-local-model-defaults.json create mode 100644 src/main/resources/substitutionTemplates/hybrid-search-with-local-model-template.json create mode 100644 src/main/resources/substitutionTemplates/semantic-search-with-local-model-template.json diff --git a/build.gradle b/build.gradle index 4f5f30a59..f19f52318 100644 --- a/build.gradle +++ b/build.gradle @@ -181,6 +181,8 @@ dependencies { // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${opensearch_build}" secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" configurations.all { diff --git a/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java index bc88f2b4d..c2c3abdb7 100644 --- a/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java +++ b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java @@ -132,6 +132,21 @@ public enum DefaultUseCases { "defaults/conversational-search-defaults.json", "substitutionTemplates/conversational-search-with-cohere-model-template.json", List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) + ), + /** defaults file and substitution ready template for semantic search with a local pretrained model*/ + SEMANTIC_SEARCH_WITH_LOCAL_MODEL( + "semantic_search_with_local_model", + "defaults/semantic-search-with-local-model-defaults.json", + "substitutionTemplates/semantic-search-with-local-model-template.json", + Collections.emptyList() + + ), + /** defaults file and substitution ready template for hybrid search with a local pretrained model*/ + HYBRID_SEARCH_WITH_LOCAL_MODEL( + "hybrid_search_with_local_model", + "defaults/hybrid-search-with-local-model-defaults.json", + "substitutionTemplates/hybrid-search-with-local-model-template.json", + Collections.emptyList() ); private final String useCaseName; diff --git a/src/main/resources/defaults/hybrid-search-defaults.json b/src/main/resources/defaults/hybrid-search-defaults.json index cf9fb584b..b64bce6ae 100644 --- a/src/main/resources/defaults/hybrid-search-defaults.json +++ b/src/main/resources/defaults/hybrid-search-defaults.json @@ -14,6 +14,5 @@ "text_embedding.field_map.output.dimension": "1024", "create_search_pipeline.pipeline_id": "nlp-search-pipeline", "normalization-processor.normalization.technique": "min_max", - "normalization-processor.combination.technique": "arithmetic_mean", - "normalization-processor.combination.parameters.weights": "[0.3, 0.7]" + "normalization-processor.combination.technique": "arithmetic_mean" } diff --git a/src/main/resources/defaults/hybrid-search-with-local-model-defaults.json b/src/main/resources/defaults/hybrid-search-with-local-model-defaults.json new file mode 100644 index 000000000..26b389a29 --- /dev/null +++ b/src/main/resources/defaults/hybrid-search-with-local-model-defaults.json @@ -0,0 +1,23 @@ +{ + "template.name": "hybrid-search", + "template.description": "Setting up hybrid search, ingest pipeline and index", + "register_local_pretrained_model.name": "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b", + "register_local_pretrained_model.description": "This is a sentence transformer model", + "register_local_pretrained_model.model_format": "TORCH_SCRIPT", + "register_local_pretrained_model.deploy": "true", + "register_local_pretrained_model.version": "1.0.2", + "create_ingest_pipeline.pipeline_id": "nlp-ingest-pipeline", + "create_ingest_pipeline.description": "A text embedding pipeline", + "create_ingest_pipeline.model_id": "123", + "text_embedding.field_map.input": "passage_text", + "text_embedding.field_map.output": "passage_embedding", + "create_index.name": "my-nlp-index", + "create_index.settings.number_of_shards": "2", + "create_index.mappings.method.engine": "lucene", + "create_index.mappings.method.space_type": "l2", + "create_index.mappings.method.name": "hnsw", + "text_embedding.field_map.output.dimension": "768", + "create_search_pipeline.pipeline_id": "nlp-search-pipeline", + "normalization-processor.normalization.technique": "min_max", + "normalization-processor.combination.technique": "arithmetic_mean" +} diff --git a/src/main/resources/defaults/multi-modal-search-defaults.json b/src/main/resources/defaults/multi-modal-search-defaults.json index 0588e7182..4e0f86449 100644 --- a/src/main/resources/defaults/multi-modal-search-defaults.json +++ b/src/main/resources/defaults/multi-modal-search-defaults.json @@ -11,5 +11,7 @@ "create_index.settings.number_of_shards": "2", "text_image_embedding.field_map.output.dimension": "1024", "create_index.mappings.method.engine": "lucene", - "create_index.mappings.method.name": "hnsw" + "create_index.mappings.method.name": "hnsw", + "text_image_embedding.field_map.image.type": "text", + "text_image_embedding.field_map.text.type": "text" } diff --git a/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json b/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json index b6d6a0ff9..3a6a09b21 100644 --- a/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json +++ b/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json @@ -24,5 +24,7 @@ "create_index.settings.number_of_shards": "2", "text_image_embedding.field_map.output.dimension": "1024", "create_index.mappings.method.engine": "lucene", - "create_index.mappings.method.name": "hnsw" + "create_index.mappings.method.name": "hnsw", + "text_image_embedding.field_map.image.type": "text", + "text_image_embedding.field_map.text.type": "text" } diff --git a/src/main/resources/defaults/semantic-search-with-local-model-defaults.json b/src/main/resources/defaults/semantic-search-with-local-model-defaults.json new file mode 100644 index 000000000..5330d04a5 --- /dev/null +++ b/src/main/resources/defaults/semantic-search-with-local-model-defaults.json @@ -0,0 +1,20 @@ +{ + "template.name": "semantic search with local pretrained model", + "template.description": "Setting up semantic search, with a local pretrained embedding model", + "register_local_pretrained_model.name": "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b", + "register_local_pretrained_model.description": "This is a sentence transformer model", + "register_local_pretrained_model.model_format": "TORCH_SCRIPT", + "register_local_pretrained_model.deploy": "true", + "register_local_pretrained_model.version": "1.0.2", + "create_ingest_pipeline.pipeline_id": "nlp-ingest-pipeline", + "create_ingest_pipeline.description": "A text embedding pipeline", + "text_embedding.field_map.input": "passage_text", + "text_embedding.field_map.output": "passage_embedding", + "create_index.name": "my-nlp-index", + "create_index.settings.number_of_shards": "2", + "create_index.mappings.method.engine": "lucene", + "create_index.mappings.method.space_type": "l2", + "create_index.mappings.method.name": "hnsw", + "text_embedding.field_map.output.dimension": "768", + "create_search_pipeline.pipeline_id": "default_model_pipeline" +} diff --git a/src/main/resources/substitutionTemplates/hybrid-search-template.json b/src/main/resources/substitutionTemplates/hybrid-search-template.json index 9e16f1d09..1669ba7a7 100644 --- a/src/main/resources/substitutionTemplates/hybrid-search-template.json +++ b/src/main/resources/substitutionTemplates/hybrid-search-template.json @@ -50,9 +50,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{text_embedding.field_map.output}}": { "type": "knn_vector", "dimension": "${{text_embedding.field_map.output.dimension}}", @@ -86,10 +83,7 @@ "technique": "${{normalization-processor.normalization.technique}}" }, "combination": { - "technique": "${{normalization-processor.combination.technique}}", - "parameters": { - "weights": "${{normalization-processor.combination.parameters.weights}}" - } + "technique": "${{normalization-processor.combination.technique}}" } } } diff --git a/src/main/resources/substitutionTemplates/hybrid-search-with-local-model-template.json b/src/main/resources/substitutionTemplates/hybrid-search-with-local-model-template.json new file mode 100644 index 000000000..457746ab4 --- /dev/null +++ b/src/main/resources/substitutionTemplates/hybrid-search-with-local-model-template.json @@ -0,0 +1,109 @@ +{ + "name": "${{template.name}}", + "description": "${{template.description}}", + "use_case": "HYBRID_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "register_local_pretrained_model", + "type": "register_local_pretrained_model", + "user_inputs": { + "name": "${{register_local_pretrained_model.name}}", + "version": "${{register_local_pretrained_model.version}}", + "description": "${{register_local_pretrained_model.description}}", + "model_format": "${{register_local_pretrained_model.model_format}}", + "deploy": true + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "register_local_pretrained_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "${{create_ingest_pipeline.pipeline_id}}", + "configurations": { + "description": "${{create_ingest_pipeline.description}}", + "processors": [ + { + "text_embedding": { + "model_id": "${{register_local_pretrained_model.model_id}}", + "field_map": { + "${{text_embedding.field_map.input}}": "${{text_embedding.field_map.output}}" + } + } + } + ] + } + } + }, + { + "id": "create_index", + "type": "create_index", + "previous_node_inputs": { + "create_ingest_pipeline": "pipeline_id" + }, + "user_inputs": { + "index_name": "${{create_index.name}}", + "configurations": { + "settings": { + "index.knn": true, + "default_pipeline": "${{create_ingest_pipeline.pipeline_id}}", + "number_of_shards": "${{create_index.settings.number_of_shards}}", + "index.search.default_pipeline": "${{create_search_pipeline.pipeline_id}}" + }, + "mappings": { + "properties": { + "${{text_embedding.field_map.output}}": { + "type": "knn_vector", + "dimension": "${{text_embedding.field_map.output.dimension}}", + "method": { + "engine": "${{create_index.mappings.method.engine}}", + "space_type": "${{create_index.mappings.method.space_type}}", + "name": "${{create_index.mappings.method.name}}", + "parameters": {} + } + }, + "${{text_embedding.field_map.input}}": { + "type": "text" + } + } + } + } + } + }, + { + "id": "create_search_pipeline", + "type": "create_search_pipeline", + "user_inputs": { + "pipeline_id": "${{create_search_pipeline.pipeline_id}}", + "configurations": { + "description": "Post processor for hybrid search", + "phase_results_processors": [ + { + "normalization-processor": { + "normalization": { + "technique": "${{normalization-processor.normalization.technique}}" + }, + "combination": { + "technique": "${{normalization-processor.combination.technique}}" + } + } + } + ] + } + } + } + ] + } + } +} diff --git a/src/main/resources/substitutionTemplates/multi-modal-search-template.json b/src/main/resources/substitutionTemplates/multi-modal-search-template.json index f6a14dc75..bad7f4a52 100644 --- a/src/main/resources/substitutionTemplates/multi-modal-search-template.json +++ b/src/main/resources/substitutionTemplates/multi-modal-search-template.json @@ -51,9 +51,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{text_image_embedding.embedding}}": { "type": "knn_vector", "dimension": "${{text_image_embedding.field_map.output.dimension}}", @@ -64,10 +61,10 @@ } }, "${{text_image_embedding.field_map.text}}": { - "type": "text" + "type": "${{text_image_embedding.field_map.text.type}}" }, "${{text_image_embedding.field_map.image}}": { - "type": "binary" + "type": "${{text_image_embedding.field_map.image.type}}" } } } diff --git a/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json b/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json index da85a9387..e36370a73 100644 --- a/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json +++ b/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json @@ -101,9 +101,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{text_image_embedding.embedding}}": { "type": "knn_vector", "dimension": "${{text_image_embedding.field_map.output.dimension}}", @@ -114,10 +111,10 @@ } }, "${{text_image_embedding.field_map.text}}": { - "type": "text" + "type": "${{text_image_embedding.field_map.text.type}}" }, "${{text_image_embedding.field_map.image}}": { - "type": "binary" + "type": "${{text_image_embedding.field_map.image.type}}" } } } diff --git a/src/main/resources/substitutionTemplates/neural-sparse-local-biencoder-template.json b/src/main/resources/substitutionTemplates/neural-sparse-local-biencoder-template.json index 603e462ee..737d2f438 100644 --- a/src/main/resources/substitutionTemplates/neural-sparse-local-biencoder-template.json +++ b/src/main/resources/substitutionTemplates/neural-sparse-local-biencoder-template.json @@ -61,9 +61,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{create_ingest_pipeline.text_embedding.field_map.output}}": { "type": "rank_features" }, diff --git a/src/main/resources/substitutionTemplates/semantic-search-template.json b/src/main/resources/substitutionTemplates/semantic-search-template.json index 3aa7095e1..d592f1ec1 100644 --- a/src/main/resources/substitutionTemplates/semantic-search-template.json +++ b/src/main/resources/substitutionTemplates/semantic-search-template.json @@ -49,9 +49,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{text_embedding.field_map.output}}": { "type": "knn_vector", "dimension": "${{text_embedding.field_map.output.dimension}}", diff --git a/src/main/resources/substitutionTemplates/semantic-search-with-local-model-template.json b/src/main/resources/substitutionTemplates/semantic-search-with-local-model-template.json new file mode 100644 index 000000000..125554b78 --- /dev/null +++ b/src/main/resources/substitutionTemplates/semantic-search-with-local-model-template.json @@ -0,0 +1,86 @@ +{ + "name": "${{template.name}}", + "description": "${{template.description}}", + "use_case": "SEMANTIC_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "register_local_pretrained_model", + "type": "register_local_pretrained_model", + "user_inputs": { + "name": "${{register_local_pretrained_model.name}}", + "version": "${{register_local_pretrained_model.version}}", + "description": "${{register_local_pretrained_model.description}}", + "model_format": "${{register_local_pretrained_model.model_format}}", + "deploy": true + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "previous_node_inputs": { + "register_local_pretrained_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "${{create_ingest_pipeline.pipeline_id}}", + "configurations": { + "description": "${{create_ingest_pipeline.description}}", + "processors": [ + { + "text_embedding": { + "model_id": "${{register_local_pretrained_model.model_id}}", + "field_map": { + "${{text_embedding.field_map.input}}": "${{text_embedding.field_map.output}}" + } + } + } + ] + } + } + }, + { + "id": "create_index", + "type": "create_index", + "previous_node_inputs": { + "create_ingest_pipeline": "pipeline_id" + }, + "user_inputs": { + "index_name": "${{create_index.name}}", + "configurations": { + "settings": { + "index.knn": true, + "default_pipeline": "${{create_ingest_pipeline.pipeline_id}}", + "number_of_shards": "${{create_index.settings.number_of_shards}}" + }, + "mappings": { + "properties": { + "${{text_embedding.field_map.output}}": { + "type": "knn_vector", + "dimension": "${{text_embedding.field_map.output.dimension}}", + "method": { + "engine": "${{create_index.mappings.method.engine}}", + "space_type": "${{create_index.mappings.method.space_type}}", + "name": "${{create_index.mappings.method.name}}", + "parameters": {} + } + }, + "${{text_embedding.field_map.input}}": { + "type": "text" + } + } + } + } + } + } + ] + } + } +} diff --git a/src/main/resources/substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json b/src/main/resources/substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json index f75b58e06..71f8286cd 100644 --- a/src/main/resources/substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json +++ b/src/main/resources/substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json @@ -99,9 +99,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{text_embedding.field_map.output}}": { "type": "knn_vector", "dimension": "${{text_embedding.field_map.output.dimension}}", diff --git a/src/main/resources/substitutionTemplates/semantic-search-with-model-template.json b/src/main/resources/substitutionTemplates/semantic-search-with-model-template.json index f98c68659..c2261c475 100644 --- a/src/main/resources/substitutionTemplates/semantic-search-with-model-template.json +++ b/src/main/resources/substitutionTemplates/semantic-search-with-model-template.json @@ -98,9 +98,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{text_embedding.field_map.output}}": { "type": "knn_vector", "dimension": "${{text_embedding.field_map.output.dimension}}", diff --git a/src/main/resources/substitutionTemplates/semantic-search-with-query-enricher-template.json b/src/main/resources/substitutionTemplates/semantic-search-with-query-enricher-template.json index 4244cd791..6e33d04c5 100644 --- a/src/main/resources/substitutionTemplates/semantic-search-with-query-enricher-template.json +++ b/src/main/resources/substitutionTemplates/semantic-search-with-query-enricher-template.json @@ -67,9 +67,6 @@ "mappings": { "_doc": { "properties": { - "id": { - "type": "text" - }, "${{text_embedding.field_map.output}}": { "type": "knn_vector", "dimension": "${{text_embedding.field_map.output.dimension}}", diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 9a1d89c2e..922c26b0f 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -48,6 +48,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.junit.After; @@ -350,7 +351,7 @@ protected Response createWorkflow(RestClient client, Template template) throws E * @throws Exception if the request fails * @return a rest response */ - protected Response createWorkflowWithUseCase(RestClient client, String useCase, List params) throws Exception { + protected Response createWorkflowWithUseCaseWithNoValidation(RestClient client, String useCase, List params) throws Exception { StringBuilder sb = new StringBuilder(); for (String param : params) { @@ -370,6 +371,28 @@ protected Response createWorkflowWithUseCase(RestClient client, String useCase, ); } + /** + * Helper method to invoke the create workflow API with a use case and also the provision param as true + * @param client the rest client + * @param useCase the usecase to create + * @param defaults the defaults to override given through the request payload + * @throws Exception if the request fails + * @return a rest response + */ + protected Response createAndProvisionWorkflowWithUseCaseWithContent(RestClient client, String useCase, Map defaults) + throws Exception { + String payload = ParseUtils.parseArbitraryStringToObjectMapToString(defaults); + + return TestHelpers.makeRequest( + client, + "POST", + WORKFLOW_URI + "?provision=true&use_case=" + useCase, + Collections.emptyMap(), + payload, + null + ); + } + /** * Helper method to invoke the Create Workflow Rest Action with provision * @param client the rest client @@ -742,6 +765,52 @@ protected GetPipelineResponse getPipelines(String pipelineId) throws IOException } } + protected void ingestSingleDoc(String payload, String indexName) throws IOException { + try { + TestHelpers.makeRequest( + client(), + "PUT", + indexName + "/_doc/1", + null, + payload, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected SearchResponse neuralSearchRequest(String indexName, String modelId) throws IOException { + String searchRequest = + "{\"_source\":{\"excludes\":[\"passage_embedding\"]},\"query\":{\"neural\":{\"passage_embedding\":{\"query_text\":\"world\",\"k\":5,\"model_id\":\"" + + modelId + + "\"}}}}"; + try { + Response restSearchResponse = TestHelpers.makeRequest( + client(), + "POST", + indexName + "/_search", + null, + searchRequest, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + // Parse entity content into SearchResponse + MediaType mediaType = MediaType.fromMediaType(restSearchResponse.getEntity().getContentType()); + try ( + XContentParser parser = mediaType.xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + restSearchResponse.getEntity().getContent() + ) + ) { + return SearchResponse.fromXContent(parser); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("unchecked") protected List catPlugins() throws IOException { Response response = TestHelpers.makeRequest( diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 7f51cd276..1ddf67c2a 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -34,6 +34,7 @@ import java.time.Instant; import java.util.Collections; import java.util.EnumSet; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -429,7 +430,11 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception { public void testDefaultCohereUseCase() throws Exception { // Hit Create Workflow API with original template - Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy", List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)); + Response response = createWorkflowWithUseCaseWithNoValidation( + client(), + "cohere_embedding_model_deploy", + List.of(CREATE_CONNECTOR_CREDENTIAL_KEY) + ); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); @@ -468,7 +473,7 @@ public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Excepti // Hit Create Workflow API with original template without required params ResponseException exception = expectThrows( ResponseException.class, - () -> createWorkflowWithUseCase(client(), "semantic_search", Collections.emptyList()) + () -> createWorkflowWithUseCaseWithNoValidation(client(), "semantic_search", Collections.emptyList()) ); assertTrue( exception.getMessage() @@ -476,7 +481,11 @@ public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Excepti ); // Pass in required params - Response response = createWorkflowWithUseCase(client(), "semantic_search", List.of(CREATE_INGEST_PIPELINE_MODEL_ID)); + Response response = createWorkflowWithUseCaseWithNoValidation( + client(), + "semantic_search", + List.of(CREATE_INGEST_PIPELINE_MODEL_ID) + ); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); @@ -502,7 +511,7 @@ public void testAllDefaultUseCasesCreation() throws Exception { .collect(Collectors.toSet()); for (String useCaseName : allUseCaseNames) { - Response response = createWorkflowWithUseCase( + Response response = createWorkflowWithUseCaseWithNoValidation( client(), useCaseName, DefaultUseCases.getRequiredParamsByUseCaseName(useCaseName) @@ -514,4 +523,67 @@ public void testAllDefaultUseCasesCreation() throws Exception { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); } } + + public void testSemanticSearchWithLocalModelEndToEnd() throws Exception { + + Map defaults = new HashMap<>(); + defaults.put("register_local_pretrained_model.name", "huggingface/sentence-transformers/all-MiniLM-L6-v2"); + defaults.put("register_local_pretrained_model.version", "1.0.1"); + defaults.put("text_embedding.field_map.output.dimension", 384); + + Response response = createAndProvisionWorkflowWithUseCaseWithContent(client(), "semantic_search_with_local_model", defaults); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); + + // Wait until provisioning has completed successfully before attempting to retrieve created resources + List resourcesCreated = getResourcesCreated(client(), workflowId, 45); + + // This template should create 4 resources, registered model_id, deployed model_id, ingest pipeline, and index name + assertEquals(4, resourcesCreated.size()); + String modelId = resourcesCreated.get(1).resourceId(); + String indexName = resourcesCreated.get(3).resourceId(); + + // Short wait before ingesting data + Thread.sleep(30000); + + String docContent = "{\"passage_text\": \"Hello planet\"\n}"; + ingestSingleDoc(docContent, indexName); + // Short wait before neural search + Thread.sleep(500); + SearchResponse neuralSearchResponse = neuralSearchRequest(indexName, modelId); + assertEquals(neuralSearchResponse.getHits().getHits().length, 1); + Thread.sleep(500); + deleteIndex(indexName); + + // Hit Deprovision API + // By design, this may not completely deprovision the first time if it takes >2s to process removals + Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); + try { + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 30, + TimeUnit.SECONDS + ); + } catch (ComparisonFailure e) { + // 202 return if still processing + assertEquals(RestStatus.ACCEPTED, TestHelpers.restStatus(deprovisionResponse)); + } + if (TestHelpers.restStatus(deprovisionResponse) == RestStatus.ACCEPTED) { + // Short wait before we try again + Thread.sleep(10000); + deprovisionResponse = deprovisionWorkflow(client(), workflowId); + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 30, + TimeUnit.SECONDS + ); + } + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + // Hit Delete API + Response deleteResponse = deleteWorkflow(client(), workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + } }