Skip to content

Commit

Permalink
feat/refactor: Allow pipelines without generators to be used with the…
Browse files Browse the repository at this point in the history
… RAG eval harness (#31)
  • Loading branch information
shadeMe authored Jul 4, 2024
1 parent 9973f3b commit 8401384
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 98 deletions.
3 changes: 2 additions & 1 deletion haystack_experimental/evaluation/harness/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from .harness import RAGEvaluationHarness
from .harness import DefaultRAGArchitecture, RAGEvaluationHarness
from .parameters import (
RAGEvaluationInput,
RAGEvaluationMetric,
Expand All @@ -13,6 +13,7 @@
)

_all_ = [
"DefaultRAGArchitecture",
"RAGEvaluationHarness",
"RAGExpectedComponent",
"RAGExpectedComponentMetadata",
Expand Down
237 changes: 154 additions & 83 deletions haystack_experimental/evaluation/harness/rag/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Any, Dict, List, Optional, Set
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union

from haystack import Pipeline
from haystack.evaluation.eval_run_result import EvaluationRunResult
Expand All @@ -25,6 +26,83 @@
)


class DefaultRAGArchitecture(Enum):
"""
Represents default RAG pipeline architectures that can be used with the evaluation harness.
"""

#: A RAG pipeline with:
#: - A query embedder component named 'query_embedder' with a 'text' input.
#: - A document retriever component named 'retriever' with a 'documents' output.
EMBEDDING_RETRIEVAL = "embedding_retrieval"

#: A RAG pipeline with:
#: - A document retriever component named 'retriever' with a 'query' input and a 'documents' output.
KEYWORD_RETRIEVAL = "keyword_retrieval"

#: A RAG pipeline with:
#: - A query embedder component named 'query_embedder' with a 'text' input.
#: - A document retriever component named 'retriever' with a 'documents' output.
#: - A response generator component named 'generator' with a 'replies' output.
GENERATION_WITH_EMBEDDING_RETRIEVAL = "generation_with_embedding_retrieval"

#: A RAG pipeline with:
#: - A document retriever component named 'retriever' with a 'query' input and a 'documents' output.
#: - A response generator component named 'generator' with a 'replies' output.
GENERATION_WITH_KEYWORD_RETRIEVAL = "generation_with_keyword_retrieval"

@property
def expected_components(
self,
) -> Dict[RAGExpectedComponent, RAGExpectedComponentMetadata]:
"""
Returns the expected components for the architecture.
:returns:
The expected components.
"""
if self in (
DefaultRAGArchitecture.EMBEDDING_RETRIEVAL,
DefaultRAGArchitecture.GENERATION_WITH_EMBEDDING_RETRIEVAL,
):
expected = {
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
name="query_embedder", input_mapping={"query": "text"}
),
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
name="retriever",
output_mapping={"retrieved_documents": "documents"},
),
}
elif self in (
DefaultRAGArchitecture.KEYWORD_RETRIEVAL,
DefaultRAGArchitecture.GENERATION_WITH_KEYWORD_RETRIEVAL,
):
expected = {
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
name="retriever", input_mapping={"query": "query"}
),
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
name="retriever",
output_mapping={"retrieved_documents": "documents"},
),
}
else:
raise NotImplementedError(f"Unexpected default RAG architecture: {self}")

if self in (
DefaultRAGArchitecture.GENERATION_WITH_EMBEDDING_RETRIEVAL,
DefaultRAGArchitecture.GENERATION_WITH_KEYWORD_RETRIEVAL,
):
expected[RAGExpectedComponent.RESPONSE_GENERATOR] = (
RAGExpectedComponentMetadata(
name="generator", output_mapping={"replies": "replies"}
)
)

return expected


class RAGEvaluationHarness(
EvaluationHarness[RAGEvaluationInput, RAGEvaluationOverrides, RAGEvaluationOutput]
):
Expand All @@ -35,7 +113,10 @@ class RAGEvaluationHarness(
def __init__(
self,
rag_pipeline: Pipeline,
rag_components: Dict[RAGExpectedComponent, RAGExpectedComponentMetadata],
rag_components: Union[
DefaultRAGArchitecture,
Dict[RAGExpectedComponent, RAGExpectedComponentMetadata],
],
metrics: Set[RAGEvaluationMetric],
):
"""
Expand All @@ -44,76 +125,23 @@ def __init__(
:param rag_pipeline:
The RAG pipeline to evaluate.
:param rag_components:
A mapping of expected components to their metadata.
Either a default RAG architecture or a mapping
of expected components to their metadata.
:param metrics:
The metrics to use during evaluation.
"""
super().__init__()

self._validate_rag_components(rag_pipeline, rag_components)
if isinstance(rag_components, DefaultRAGArchitecture):
rag_components = rag_components.expected_components

self._validate_rag_components(rag_pipeline, rag_components, metrics)

self.rag_pipeline = rag_pipeline
self.rag_components = rag_components
self.metrics = metrics
self.rag_components = deepcopy(rag_components)
self.metrics = deepcopy(metrics)
self.evaluation_pipeline = default_rag_evaluation_pipeline(metrics)

@classmethod
def default_with_embedding_retriever(
cls, rag_pipeline: Pipeline, metrics: Set[RAGEvaluationMetric]
) -> "RAGEvaluationHarness":
"""
Create a default evaluation harness for evaluating RAG pipelines with a query embedder.
:param rag_pipeline:
The RAG pipeline to evaluate. The following assumptions are made:
- The query embedder component is named 'query_embedder' and has a 'text' input.
- The document retriever component is named 'retriever' and has a 'documents' output.
- The response generator component is named 'generator' and has a 'replies' output.
:param metrics:
The metrics to use during evaluation.
"""
rag_components = {
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
name="query_embedder", input_mapping={"query": "text"}
),
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
name="retriever", output_mapping={"retrieved_documents": "documents"}
),
RAGExpectedComponent.RESPONSE_GENERATOR: RAGExpectedComponentMetadata(
name="generator", output_mapping={"replies": "replies"}
),
}

return cls(rag_pipeline, rag_components, deepcopy(metrics))

@classmethod
def default_with_keyword_retriever(
cls, rag_pipeline: Pipeline, metrics: Set[RAGEvaluationMetric]
) -> "RAGEvaluationHarness":
"""
Create a default evaluation harness for evaluating RAG pipelines with a keyword retriever.
:param rag_pipeline:
The RAG pipeline to evaluate. The following assumptions are made:
- The document retriever component is named 'retriever' and has a 'query' input and a 'documents' output.
- The response generator component is named 'generator' and has a 'replies' output.
:param metrics:
The metrics to use during evaluation.
"""
rag_components = {
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
name="retriever", input_mapping={"query": "query"}
),
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
name="retriever", output_mapping={"retrieved_documents": "documents"}
),
RAGExpectedComponent.RESPONSE_GENERATOR: RAGExpectedComponentMetadata(
name="generator", output_mapping={"replies": "replies"}
),
}

return cls(rag_pipeline, rag_components, deepcopy(metrics))

def run( # noqa: D102
self,
inputs: RAGEvaluationInput,
Expand Down Expand Up @@ -141,10 +169,12 @@ def run( # noqa: D102
"retrieved_documents",
)
],
"responses": self._lookup_component_output(
RAGExpectedComponent.RESPONSE_GENERATOR, rag_outputs, "replies"
),
}
if RAGExpectedComponent.RESPONSE_GENERATOR in self.rag_components:
result_inputs["responses"] = self._lookup_component_output(
RAGExpectedComponent.RESPONSE_GENERATOR, rag_outputs, "replies"
)

if inputs.ground_truth_answers is not None:
result_inputs["ground_truth_answers"] = inputs.ground_truth_answers
if inputs.ground_truth_documents is not None:
Expand Down Expand Up @@ -199,34 +229,40 @@ def _generate_eval_run_pipelines(
rag_pipeline = self._override_pipeline(self.rag_pipeline, rag_overrides)
eval_pipeline = self._override_pipeline(self.evaluation_pipeline, eval_overrides) # type: ignore

included_first_outputs = {
self.rag_components[RAGExpectedComponent.DOCUMENT_RETRIEVER].name
}
if RAGExpectedComponent.RESPONSE_GENERATOR in self.rag_components:
included_first_outputs.add(
self.rag_components[RAGExpectedComponent.RESPONSE_GENERATOR].name
)

return PipelinePair(
first=rag_pipeline,
second=eval_pipeline,
outputs_to_inputs=self._map_rag_eval_pipeline_io(),
map_first_outputs=lambda x: self._aggregate_rag_outputs( # pylint: disable=unnecessary-lambda
x
),
included_first_outputs={
self.rag_components[RAGExpectedComponent.DOCUMENT_RETRIEVER].name,
self.rag_components[RAGExpectedComponent.RESPONSE_GENERATOR].name,
},
included_first_outputs=included_first_outputs,
)

def _aggregate_rag_outputs(
self, outputs: List[Dict[str, Dict[str, Any]]]
) -> Dict[str, Dict[str, Any]]:
aggregate = aggregate_batched_pipeline_outputs(outputs)

# We only care about the first response from the generator.
generator_name = self.rag_components[
RAGExpectedComponent.RESPONSE_GENERATOR
].name
replies_output_name = self.rag_components[
RAGExpectedComponent.RESPONSE_GENERATOR
].output_mapping["replies"]
aggregate[generator_name][replies_output_name] = [
r[0] for r in aggregate[generator_name][replies_output_name]
]
if RAGExpectedComponent.RESPONSE_GENERATOR in self.rag_components:
# We only care about the first response from the generator.
generator_name = self.rag_components[
RAGExpectedComponent.RESPONSE_GENERATOR
].name
replies_output_name = self.rag_components[
RAGExpectedComponent.RESPONSE_GENERATOR
].output_mapping["replies"]
aggregate[generator_name][replies_output_name] = [
r[0] for r in aggregate[generator_name][replies_output_name]
]

return aggregate

Expand Down Expand Up @@ -383,11 +419,46 @@ def _prepare_eval_pipeline_additional_inputs(
def _validate_rag_components(
pipeline: Pipeline,
components: Dict[RAGExpectedComponent, RAGExpectedComponentMetadata],
metrics: Set[RAGEvaluationMetric],
):
for e in RAGExpectedComponent:
if e not in components:
metric_specific_required_components = {
RAGEvaluationMetric.DOCUMENT_MAP: [
RAGExpectedComponent.QUERY_PROCESSOR,
RAGExpectedComponent.DOCUMENT_RETRIEVER,
],
RAGEvaluationMetric.DOCUMENT_MRR: [
RAGExpectedComponent.QUERY_PROCESSOR,
RAGExpectedComponent.DOCUMENT_RETRIEVER,
],
RAGEvaluationMetric.DOCUMENT_RECALL_SINGLE_HIT: [
RAGExpectedComponent.QUERY_PROCESSOR,
RAGExpectedComponent.DOCUMENT_RETRIEVER,
],
RAGEvaluationMetric.DOCUMENT_RECALL_MULTI_HIT: [
RAGExpectedComponent.QUERY_PROCESSOR,
RAGExpectedComponent.DOCUMENT_RETRIEVER,
],
RAGEvaluationMetric.SEMANTIC_ANSWER_SIMILARITY: [
RAGExpectedComponent.QUERY_PROCESSOR,
RAGExpectedComponent.RESPONSE_GENERATOR,
],
RAGEvaluationMetric.FAITHFULNESS: [
RAGExpectedComponent.QUERY_PROCESSOR,
RAGExpectedComponent.DOCUMENT_RETRIEVER,
RAGExpectedComponent.RESPONSE_GENERATOR,
],
RAGEvaluationMetric.CONTEXT_RELEVANCE: [
RAGExpectedComponent.QUERY_PROCESSOR,
RAGExpectedComponent.DOCUMENT_RETRIEVER,
],
}

for m in metrics:
required_components = metric_specific_required_components[m]
if not all(c in components for c in required_components):
raise ValueError(
f"RAG evaluation harness requires metadata for the '{e.value}' component."
f"In order to use the metric '{m}', the RAG evaluation harness requires metadata "
f"for the following components: {required_components}"
)

pipeline_outputs = pipeline.outputs(
Expand Down
10 changes: 9 additions & 1 deletion haystack_experimental/evaluation/harness/rag/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class RAGExpectedComponent(Enum):
"""
Represents the basic components in a RAG pipeline that needs to be present for evaluation.
Represents the basic components in a RAG pipeline that are, by default, required to be present for evaluation.
Each of these can be separate components in the pipeline or a single component that performs
multiple tasks.
Expand All @@ -27,6 +27,7 @@ class RAGExpectedComponent(Enum):
DOCUMENT_RETRIEVER = "document_retriever"

#: The component in a RAG pipeline that generates responses based on the query and the retrieved documents.
#: Can be optional if the harness is only evaluating retrieval.
#: Expected outputs: `replies` - Name of out containing the LLM responses. Only the first response is used.
RESPONSE_GENERATOR = "response_generator"

Expand Down Expand Up @@ -57,24 +58,31 @@ class RAGEvaluationMetric(Enum):
"""

#: Document Mean Average Precision.
#: Required RAG components: Query Processor, Document Retriever.
DOCUMENT_MAP = "metric_doc_map"

#: Document Mean Reciprocal Rank.
#: Required RAG components: Query Processor, Document Retriever.
DOCUMENT_MRR = "metric_doc_mrr"

#: Document Recall with a single hit.
#: Required RAG components: Query Processor, Document Retriever.
DOCUMENT_RECALL_SINGLE_HIT = "metric_doc_recall_single"

#: Document Recall with multiple hits.
#: Required RAG components: Query Processor, Document Retriever.
DOCUMENT_RECALL_MULTI_HIT = "metric_doc_recall_multi"

#: Semantic Answer Similarity.
#: Required RAG components: Query Processor, Response Generator.
SEMANTIC_ANSWER_SIMILARITY = "metric_sas"

#: Faithfulness.
#: Required RAG components: Query Processor, Document Retriever, Response Generator.
FAITHFULNESS = "metric_faithfulness"

#: Context Relevance.
#: Required RAG components: Query Processor, Document Retriever.
CONTEXT_RELEVANCE = "metric_context_relevance"


Expand Down
Loading

0 comments on commit 8401384

Please sign in to comment.