Skip to content

Commit

Permalink
Doing some refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 1, 2024
1 parent 9340557 commit e8a374c
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
Expand Down Expand Up @@ -185,7 +185,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRespon
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()),
ExplainResponseProcessor.TYPE,
ExplanationResponseProcessor.TYPE,
new ProcessorExplainPublisherFactory()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSo
public CompoundTopDocs(final QuerySearchResult querySearchResult) {
final TopDocs topDocs = querySearchResult.topDocs().topDocs;
final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
SearchShard searchShard = new SearchShard(
searchShardTarget.getIndex(),
searchShardTarget.getShardId().id(),
searchShardTarget.getNodeId()
);
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
boolean isSortEnabled = false;
if (topDocs instanceof TopFieldDocs) {
isSortEnabled = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
Expand All @@ -24,13 +24,13 @@
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR;
import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR;

@Getter
@AllArgsConstructor
public class ExplainResponseProcessor implements SearchResponseProcessor {
public class ExplanationResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "explain_response_processor";
public static final String TYPE = "explanation_response_processor";

private final String description;
private final String tag;
Expand All @@ -46,10 +46,10 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) {
return response;
}
ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ProcessorExplainDto.ExplanationType, Object> explainPayload = processorExplainDto.getExplainPayload();
ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationResponse.ExplanationType, Object> explainPayload = explanationResponse.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = processorExplainDto.getExplanation();
Explanation processorExplanation = explanationResponse.getExplanation();
if (Objects.isNull(processorExplanation)) {
return response;
}
Expand All @@ -62,7 +62,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
for (int i = 0; i < searchHitsArray.length; i++) {
SearchHit searchHit = searchHitsArray[i];
SearchShardTarget searchShardTarget = searchHit.getShard();
SearchShard searchShard = SearchShard.create(searchShardTarget);
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i);
explainsByShardCount.putIfAbsent(searchShard, -1);
}
Expand All @@ -73,7 +73,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
List<CombinedExplainDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);

for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.create(searchHit.getShard());
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
Explanation normalizedExplanation = Explanation.match(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -146,13 +146,13 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
}
});

ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder()
ExplanationResponse explanationResponse = ExplanationResponse.builder()
.explanation(topLevelExplanationForTechniques)
.explainPayload(Map.of(ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto);
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

public record SearchShard(String index, int shardId, String nodeId) {

public static SearchShard create(SearchShardTarget searchShardTarget) {
public static SearchShard createSearchShard(SearchShardTarget searchShardTarget) {
return new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
@AllArgsConstructor
@Builder
@Getter
/**
* DTO class to hold explain details for normalization and combination
*/
public class CombinedExplainDetails {
private ExplainDetails normalizationExplain;
private ExplainDetails combinationExplain;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.opensearch.neuralsearch.processor.SearchShard;

/**
* Data class to store docId and search shard for a query.
* DTO class to store docId and search shard for a query.
* Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards.
* @param docId
* @param searchShard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
* @param value
* @param description
*/
public record ExplainDetails(float value, String description, int docId) {

public record ExplainDetails(int docId, float value, String description) {
public ExplainDetails(float value, String description) {
this(value, description, -1);
this(-1, value, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ public static ExplainDetails getScoreCombinationExplainDetailsForDocument(
) {
float combinedScore = combinedNormalizedScoresByDocId.get(docId);
return new ExplainDetails(
docId,
combinedScore,
String.format(
Locale.ROOT,
"normalized scores: %s combined to a final score: %s",
Arrays.toString(normalizedScoresPerDoc),
combinedScore
),
docId
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@AllArgsConstructor
@Builder
@Getter
public class ProcessorExplainDto {
public class ExplanationResponse {
Explanation explanation;
Map<ExplanationType, Object> explainPayload;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

Expand All @@ -21,6 +21,6 @@ public SearchResponseProcessor create(
Map<String, Object> config,
Processor.PipelineContext pipelineContext
) throws Exception {
return new ExplainResponseProcessor(description, tag, ignoreFailure);
return new ExplanationResponseProcessor(description, tag, ignoreFailure);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@

import static org.mockito.Mockito.mock;

public class ExplainResponseProcessorTests extends OpenSearchTestCase {
public class ExplanationResponseProcessorTests extends OpenSearchTestCase {
private static final String PROCESSOR_TAG = "mockTag";
private static final String DESCRIPTION = "mockDescription";

public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() {
ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);

assertEquals(DESCRIPTION, explainResponseProcessor.getDescription());
assertEquals(PROCESSOR_TAG, explainResponseProcessor.getTag());
assertFalse(explainResponseProcessor.isIgnoreFailure());
assertEquals(DESCRIPTION, explanationResponseProcessor.getDescription());
assertEquals(PROCESSOR_TAG, explanationResponseProcessor.getTag());
assertFalse(explanationResponseProcessor.isIgnoreFailure());
}

@SneakyThrows
public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() {
ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
SearchRequest searchRequest = mock(SearchRequest.class);
SearchResponse searchResponse = new SearchResponse(
null,
Expand All @@ -40,14 +40,14 @@ public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProc
SearchResponse.Clusters.EMPTY
);

SearchResponse processedResponse = explainResponseProcessor.processResponse(searchRequest, searchResponse);
SearchResponse processedResponse = explanationResponseProcessor.processResponse(searchRequest, searchResponse);
assertEquals(searchResponse, processedResponse);

SearchResponse processedResponse2 = explainResponseProcessor.processResponse(searchRequest, searchResponse, null);
SearchResponse processedResponse2 = explanationResponseProcessor.processResponse(searchRequest, searchResponse, null);
assertEquals(searchResponse, processedResponse2);

PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext();
SearchResponse processedResponse3 = explainResponseProcessor.processResponse(
SearchResponse processedResponse3 = explanationResponseProcessor.processResponse(
searchRequest,
searchResponse,
pipelineProcessingContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() {
assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size());

Map<String, Object> hit1DetailsForHit2 = hit1Details.get(1);
assertEquals(0.6666667, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.666, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION);
assertEquals("normalized scores: [1.0, 0.0, 1.0] combined to a final score: 0.6666667", hit1DetailsForHit2.get("description"));
assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size());

Expand Down Expand Up @@ -627,7 +627,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() {
assertEquals(0, ((List) hit1DetailsForHit4.get("details")).size());

Map<String, Object> hit2DetailsForHit4 = hit4Details.get(1);
assertEquals(0.6666667, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.666, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION);
assertEquals("normalized scores: [0.0, 1.0, 1.0] combined to a final score: 0.6666667", hit2DetailsForHit4.get("description"));
assertEquals(0, ((List) hit2DetailsForHit4.get("details")).size());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.TokenWeightUtil;
import org.opensearch.search.sort.SortBuilder;
Expand Down Expand Up @@ -1203,7 +1203,7 @@ protected void createSearchPipeline(
if (addExplainResponseProcessor) {
stringBuilderForContentBody.append(", \"response_processors\": [ ")
.append("{\"")
.append(ExplainResponseProcessor.TYPE)
.append(ExplanationResponseProcessor.TYPE)
.append("\": {}}")
.append("]");
}
Expand Down

0 comments on commit e8a374c

Please sign in to comment.