Skip to content

Commit

Permalink
Change response format, switch to hierarchical structure
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 13, 2024
1 parent 72c0ac3 commit 9830ab3
Show file tree
Hide file tree
Showing 21 changed files with 875 additions and 766 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
public static final String EXPLAIN_RESPONSE_KEY = "explain_response";
public static final String EXPLANATION_RESPONSE_KEY = "explanation_response";

@Override
public Collection<Object> createComponents(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import org.apache.lucene.search.Explanation;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
Expand All @@ -23,7 +24,7 @@
import java.util.Map;
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR;

/**
Expand All @@ -45,19 +46,19 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) {
public SearchResponse processResponse(
final SearchRequest request,
final SearchResponse response,
final PipelineProcessingContext requestContext
) {
if (Objects.isNull(requestContext)
|| (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))
|| requestContext.getAttribute(EXPLAIN_RESPONSE_KEY) instanceof ExplanationPayload == false) {
|| (Objects.isNull(requestContext.getAttribute(EXPLANATION_RESPONSE_KEY)))
|| requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) {
return response;
}
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY);
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = explanationPayload.getExplanation();
if (Objects.isNull(processorExplanation)) {
return response;
}
SearchHits searchHits = response.getHits();
SearchHit[] searchHitsArray = searchHits.getHits();
// create a map of searchShard and list of indexes of search hit objects in search hits array
Expand All @@ -73,29 +74,33 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}
if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map<?, ?>) {
@SuppressWarnings("unchecked")
Map<SearchShard, List<CombinedExplainDetails>> combinedExplainDetails = (Map<
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = (Map<
SearchShard,
List<CombinedExplainDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);
List<CombinedExplanationDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);

for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
Explanation normalizedExplanation = Explanation.match(
combinedExplainDetail.getNormalizationExplain().value(),
combinedExplainDetail.getNormalizationExplain().description()
);
Explanation combinedExplanation = Explanation.match(
combinedExplainDetail.getCombinationExplain().value(),
combinedExplainDetail.getCombinationExplain().description()
);

CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
Explanation queryLevelExplanation = searchHit.getExplanation();
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
// normalized score
normalizationExplanation.scoreDetails().get(i).getKey(),
// description of normalized score
normalizationExplanation.scoreDetails().get(i).getValue(),
// shard level details
queryLevelExplanation.getDetails()[i]
);
}
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
processorExplanation.getDescription(),
normalizedExplanation,
combinedExplanation,
searchHit.getExplanation()
// combination level explanation is always a single detail
combinationExplanation.scoreDetails().get(0).getValue(),
normalizedExplanation
);
searchHit.explanation(finalExplanation);
explainsByShardCount.put(searchShard, explanationIndexByShard);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -14,7 +13,6 @@
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Sort;
Expand All @@ -24,7 +22,7 @@
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
Expand All @@ -40,9 +38,8 @@
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.topLevelExpalantionForCombinedScore;
import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria;

/**
Expand Down Expand Up @@ -113,16 +110,9 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
if (!request.isExplain()) {
return;
}
Explanation topLevelExplanationForTechniques = topLevelExpalantionForCombinedScore(
(ExplainableTechnique) request.getNormalizationTechnique(),
(ExplainableTechnique) request.getCombinationTechnique()
);

// build final result object with all explain related information
if (Objects.nonNull(request.getPipelineProcessingContext())) {

Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs);

Map<DocIdAtSearchShard, ExplanationDetails> normalizationExplain = scoreNormalizer.explain(
queryTopDocs,
(ExplainableTechnique) request.getNormalizationTechnique()
Expand All @@ -132,27 +122,22 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
request.getCombinationTechnique(),
sortForQuery
);
Map<SearchShard, List<CombinedExplainDetails>> combinedExplain = new HashMap<>();

combinationExplain.forEach((searchShard, explainDetails) -> {
for (ExplanationDetails explainDetail : explainDetails) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard);
ExplanationDetails normalizedExplanationDetails = normalizationExplain.get(docIdAtSearchShard);
CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder()
.normalizationExplain(normalizedExplanationDetails)
.combinationExplain(explainDetail)
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplanations = combinationExplain.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), entry.getKey());
return CombinedExplanationDetails.builder()
.normalizationExplanations(normalizationExplain.get(docIdAtSearchShard))
.combinationExplanations(explainDetail)
.build();
combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails);
}
});
}).collect(Collectors.toList())));

ExplanationPayload explanationPayload = ExplanationPayload.builder()
.explanation(topLevelExplanationForTechniques)
.explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplain))
.explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations))
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationPayload);
pipelineProcessingContext.setAttribute(EXPLANATION_RESPONSE_KEY, explanationPayload);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on arithmetic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on geometrical mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on harmonic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.Objects;
Expand All @@ -16,6 +17,7 @@

import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -27,10 +29,9 @@

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getScoreCombinationExplainDetailsForDocument;

/**
* Abstracts combination of scores in query search results.
*/
Expand Down Expand Up @@ -360,10 +361,14 @@ private List<ExplanationDetails> explainByShard(

List<ExplanationDetails> listOfExplanations = sortedDocsIds.stream()
.map(
docId -> getScoreCombinationExplainDetailsForDocument(
docId -> new ExplanationDetails(
docId,
combinedNormalizedScoresByDocId,
normalizedScoresPerDoc.get(docId)
List.of(
Pair.of(
combinedNormalizedScoresByDocId.get(docId),
String.format(Locale.ROOT, "%s combination of:", ((ExplainableTechnique) scoreCombinationTechnique).describe())
)
)
)
)
.toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@AllArgsConstructor
@Builder
@Getter
public class CombinedExplainDetails {
private ExplanationDetails normalizationExplain;
private ExplanationDetails combinationExplain;
public class CombinedExplanationDetails {
private ExplanationDetails normalizationExplanations;
private ExplanationDetails combinationExplanations;
}

This file was deleted.

Loading

0 comments on commit 9830ab3

Please sign in to comment.