diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java index 54a72cf38..9fbc0f6d5 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java @@ -31,7 +31,6 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; /** * Abstract class to handle search request. @@ -85,7 +84,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout()); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index 08dcf28f7..d713e8f48 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -17,12 +17,14 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -80,7 +82,8 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java index 3ce1ceffa..afe3b85d4 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java @@ -17,10 +17,15 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; + /** * Transport Action to search workflow states */ @@ -45,8 +50,10 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { // AccessController should take care of letting the user with right permission to view the workflow + User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Searching workflow states in global context"); + SearchSourceBuilder searchSourceBuilder = request.source(); + searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); client.search(request, ActionListener.runBefore(actionListener, context::restore)); } catch (Exception e) { logger.error("Failed to search workflow states in global context", e); diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java index 28dad8c78..41a8b23f9 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java @@ -17,10 +17,15 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; + /** * Transport Action to search workflows created */ @@ -45,8 +50,11 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { // AccessController should take care of letting the user with right permission to view the workflow + User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { logger.info("Searching workflows in global context"); + SearchSourceBuilder searchSourceBuilder = request.source(); + searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); client.search(request, ActionListener.runBefore(actionListener, context::restore)); } catch (Exception e) { logger.error("Failed to search workflows in global context", e); diff --git a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java index 8be8a5efc..946c68908 100644 --- a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java @@ -17,6 +17,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -201,10 +202,11 @@ String decrypt(final String encryptedCredential) { // TODO : Improve redactTemplateCredentials to redact different fields /** * Removes the credential fields from a template + * @param user User * @param template the template * @return the redacted template */ - public Template redactTemplateSecuredFields(Template template) { + public Template redactTemplateSecuredFields(User user, Template template) { Map processedWorkflows = new HashMap<>(); for (Map.Entry entry : template.workflows().entrySet()) { @@ -228,6 +230,10 @@ public Template redactTemplateSecuredFields(Template template) { processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges())); } + if (ParseUtils.isAdmin(user)) { + return new Template.Builder(template).workflows(processedWorkflows).build(); + } + return new Template.Builder(template).user(null).workflows(processedWorkflows).build(); } diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index f7a1da0d4..ccf9ab686 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -25,7 +25,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.ml.common.agent.LLMSpec; import java.io.FileNotFoundException; import java.io.IOException; @@ -47,8 +46,6 @@ import jakarta.json.bind.JsonbBuilder; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; -import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; /** * Utility methods for Template parsing @@ -113,6 +110,18 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map parameters = llm.getParameters(); - xContentBuilder.field(MODEL_ID, modelId); - xContentBuilder.field(PARAMETERS_FIELD); - buildStringToStringMap(xContentBuilder, parameters); - } - /** * Parses an XContent object representing a map of String keys to String values. * diff --git a/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java b/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java index 6fb685bc0..e9dfa38c1 100644 --- a/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java @@ -9,9 +9,9 @@ package org.opensearch.flowframework.util; import org.apache.commons.lang3.ArrayUtils; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.Strings; import org.opensearch.flowframework.common.CommonValue; -import org.opensearch.rest.RestRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -36,11 +36,11 @@ private RestHandlerUtils() {} /** * Creates a source context and include/exclude information to be shared based on the user * - * @param request the REST request + * @param user User * @param searchSourceBuilder the search request source builder * @return modified sources */ - public static FetchSourceContext getSourceContext(RestRequest request, SearchSourceBuilder searchSourceBuilder) { + public static FetchSourceContext getSourceContext(User user, SearchSourceBuilder searchSourceBuilder) { // TODO // 1. check if the request came from dashboard and exclude UI_METADATA if (searchSourceBuilder.fetchSource() != null) { @@ -48,6 +48,9 @@ public static FetchSourceContext getSourceContext(RestRequest request, SearchSou return new FetchSourceContext(true, searchSourceBuilder.fetchSource().includes(), newArray); } else { // When user does not set the _source field in search api request, searchSourceBuilder.fetchSource becomes null + if (ParseUtils.isAdmin(user)) { + return new FetchSourceContext(true, Strings.EMPTY_ARRAY, new String[] { PATH_TO_CREDENTIAL_FIELD }); + } return new FetchSourceContext(true, Strings.EMPTY_ARRAY, EXCLUDES); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java index 69d0e6bc3..f3f55c052 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -71,6 +72,8 @@ public void testSearchWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener); verify(client, times(1)).search(any(SearchRequest.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java index d60316085..763ae73b5 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -34,12 +35,17 @@ public class SearchWorkflowTransportActionTests extends OpenSearchTestCase { private SearchWorkflowTransportAction searchWorkflowTransportAction; private Client client; private ThreadPool threadPool; + ThreadContext threadContext; @Override public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); - + this.threadPool = mock(ThreadPool.class); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); this.searchWorkflowTransportAction = new SearchWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), @@ -73,6 +79,8 @@ public void testSearchWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener); verify(client, times(1)).search(any(SearchRequest.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java index c6ec15a92..167dd634f 100644 --- a/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java @@ -17,6 +17,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -26,6 +27,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -199,8 +201,10 @@ public void testRedactTemplateCredential() { WorkflowNode node = testTemplate.workflows().get("provision").nodes().get(0); assertNotNull(node.userInputs().get(CREDENTIAL_FIELD)); + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + // Redact template with credential field - Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(testTemplate); + Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate); // Validate the credential field has been removed WorkflowNode redactedNode = redactedTemplate.workflows().get("provision").nodes().get(0); @@ -211,10 +215,25 @@ public void testRedactTemplateUserField() { // Confirm user is present in the non-redacted template assertNotNull(testTemplate.getUser()); + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); // Redact template with user field - Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(testTemplate); + Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate); // Validate the user field has been removed assertNull(redactedTemplate.getUser()); } + + public void testAdminUserTemplate() { + // Confirm user is present in the non-redacted template + assertNotNull(testTemplate.getUser()); + + List roles = new ArrayList<>(); + roles.add("all_access"); + + User user = new User("admin", roles, roles, Collections.emptyList()); + + // Redact template with user field + Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate); + assertNotNull(redactedTemplate.getUser()); + } } diff --git a/src/test/java/org/opensearch/flowframework/util/RestHandlerUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/RestHandlerUtilsTests.java new file mode 100644 index 000000000..76d80feea --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/RestHandlerUtilsTests.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.commons.authuser.User; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class RestHandlerUtilsTests extends OpenSearchTestCase { + + public void testGetSourceContextFromClientWithDashboardExcludes() { + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + testSearchSourceBuilder.fetchSource(new String[] { "a" }, new String[] { "b" }); + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 4); + } + + public void testGetSourceContextFromClientWithExcludes() { + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 2); + } + + public void testGetSourceContextAdminUser() { + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + List roles = new ArrayList<>(); + roles.add("all_access"); + + User user = new User("admin", roles, roles, Collections.emptyList()); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 1); + } + +}