From d840b164472cc58f0b509d66499aa76b3b176058 Mon Sep 17 00:00:00 2001 From: ikka Date: Thu, 7 Nov 2024 10:39:06 +0530 Subject: [PATCH] feat: improvements in test synthesization (#1621) PR 2 of improvements in test generation --------- Co-authored-by: Jin Lin Tham --- docs/getstarted/rag_testset_generation.md | 6 +- docs/references/testset_schema.md | 4 +- src/ragas/metrics/_string.py | 2 + src/ragas/testset/graph.py | 100 ++++- src/ragas/testset/graph_queries.py | 38 ++ src/ragas/testset/synthesizers/__init__.py | 31 +- .../testset/synthesizers/abstract_query.py | 349 --------------- src/ragas/testset/synthesizers/base.py | 13 +- src/ragas/testset/synthesizers/base_query.py | 95 ---- src/ragas/testset/synthesizers/generate.py | 17 + .../synthesizers/multi_hop/__init__.py | 10 + .../synthesizers/multi_hop/abstract.py | 118 +++++ .../testset/synthesizers/multi_hop/base.py | 181 ++++++++ .../testset/synthesizers/multi_hop/prompts.py | 86 ++++ .../synthesizers/multi_hop/specific.py | 116 +++++ src/ragas/testset/synthesizers/prompts.py | 419 ++---------------- .../synthesizers/single_hop/__init__.py | 3 + .../testset/synthesizers/single_hop/base.py | 145 ++++++ .../synthesizers/single_hop/prompts.py | 35 ++ .../synthesizers/single_hop/specific.py | 90 ++++ .../testset/synthesizers/specific_query.py | 115 ----- src/ragas/testset/transforms/base.py | 28 +- src/ragas/testset/transforms/default.py | 81 ++-- .../transforms/extractors/llm_based.py | 136 ++++-- .../relationship_builders/__init__.py | 3 +- .../relationship_builders/cosine.py | 46 -- .../relationship_builders/traditional.py | 155 +++++++ .../testset/transforms/splitters/headline.py | 26 +- src/ragas/utils.py | 7 + tests/unit/prompt/test_prompt_mixin.py | 6 +- tests/unit/test_analytics.py | 10 +- tests/unit/test_prompt.py | 16 +- 32 files changed, 1336 insertions(+), 1151 deletions(-) create mode 100644 src/ragas/testset/graph_queries.py delete mode 100644 src/ragas/testset/synthesizers/abstract_query.py delete mode 100644 src/ragas/testset/synthesizers/base_query.py create mode 100644 src/ragas/testset/synthesizers/multi_hop/__init__.py create mode 100644 src/ragas/testset/synthesizers/multi_hop/abstract.py create mode 100644 src/ragas/testset/synthesizers/multi_hop/base.py create mode 100644 src/ragas/testset/synthesizers/multi_hop/prompts.py create mode 100644 src/ragas/testset/synthesizers/multi_hop/specific.py create mode 100644 src/ragas/testset/synthesizers/single_hop/__init__.py create mode 100644 src/ragas/testset/synthesizers/single_hop/base.py create mode 100644 src/ragas/testset/synthesizers/single_hop/prompts.py create mode 100644 src/ragas/testset/synthesizers/single_hop/specific.py delete mode 100644 src/ragas/testset/synthesizers/specific_query.py create mode 100644 src/ragas/testset/transforms/relationship_builders/traditional.py diff --git a/docs/getstarted/rag_testset_generation.md b/docs/getstarted/rag_testset_generation.md index 8fefeb382..3b9e0e1a2 100644 --- a/docs/getstarted/rag_testset_generation.md +++ b/docs/getstarted/rag_testset_generation.md @@ -141,9 +141,9 @@ query_distribution = default_query_distribution(generator_llm) ``` ``` [ - (AbstractQuerySynthesizer(llm=generator_llm), 0.25), - (ComparativeAbstractQuerySynthesizer(llm=generator_llm), 0.25), - (SpecificQuerySynthesizer(llm=generator_llm), 0.5), + (SingleHopSpecificQuerySynthesizer(llm=llm), 0.5), + (MultiHopAbstractQuerySynthesizer(llm=llm), 0.25), + (MultiHopSpecificQuerySynthesizer(llm=llm), 0.25), ] ``` diff --git a/docs/references/testset_schema.md b/docs/references/testset_schema.md index fed34111a..740227dd1 100644 --- a/docs/references/testset_schema.md +++ b/docs/references/testset_schema.md @@ -15,12 +15,12 @@ members: - BaseScenario -::: ragas.testset.synthesizers.specific_query.SpecificQueryScenario +::: ragas.testset.synthesizers.single_hop.specific.SingleHopSpecificQuerySynthesizer options: show_root_heading: True show_root_full_path: False -::: ragas.testset.synthesizers.abstract_query.AbstractQueryScenario +::: ragas.testset.synthesizers.multi_hop.specific.MultiHopSpecificQuerySynthesizer options: show_root_heading: True show_root_full_path: False diff --git a/src/ragas/metrics/_string.py b/src/ragas/metrics/_string.py index 9c5fae2d7..fb7de5083 100644 --- a/src/ragas/metrics/_string.py +++ b/src/ragas/metrics/_string.py @@ -13,6 +13,7 @@ class DistanceMeasure(Enum): LEVENSHTEIN = "levenshtein" HAMMING = "hamming" JARO = "jaro" + JARO_WINKLER = "jaro_winkler" @dataclass @@ -77,6 +78,7 @@ def __post_init__(self): DistanceMeasure.LEVENSHTEIN: distance.Levenshtein, DistanceMeasure.HAMMING: distance.Hamming, DistanceMeasure.JARO: distance.Jaro, + DistanceMeasure.JARO_WINKLER: distance.JaroWinkler, } def init(self, run_config: RunConfig): diff --git a/src/ragas/testset/graph.py b/src/ragas/testset/graph.py index bdc4e5d86..94f10aeb2 100644 --- a/src/ragas/testset/graph.py +++ b/src/ragas/testset/graph.py @@ -206,11 +206,15 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def find_clusters( - self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True + def find_indirect_clusters( + self, + relationship_condition: t.Callable[[Relationship], bool] = lambda _: True, + depth_limit: int = 3, ) -> t.List[t.Set[Node]]: """ - Finds clusters of nodes in the knowledge graph based on a relationship condition. + Finds indirect clusters of nodes in the knowledge graph based on a relationship condition. + Here if A -> B -> C -> D, then A, B, C, and D form a cluster. If there's also a path A -> B -> C -> E, + it will form a separate cluster. Parameters ---------- @@ -223,31 +227,95 @@ def find_clusters( A list of sets, where each set contains nodes that form a cluster. """ clusters = [] - visited = set() + visited_paths = set() relationships = [ rel for rel in self.relationships if relationship_condition(rel) ] - def dfs(node: Node, cluster: t.Set[Node]): - visited.add(node) + def dfs(node: Node, cluster: t.Set[Node], depth: int, path: t.Tuple[Node, ...]): + if depth >= depth_limit or path in visited_paths: + return + visited_paths.add(path) cluster.add(node) + for rel in relationships: - if rel.source == node and rel.target not in visited: - dfs(rel.target, cluster) - # if the relationship is bidirectional, we need to check the reverse + neighbor = None + if rel.source == node and rel.target not in cluster: + neighbor = rel.target elif ( rel.bidirectional and rel.target == node - and rel.source not in visited + and rel.source not in cluster ): - dfs(rel.source, cluster) + neighbor = rel.source + + if neighbor is not None: + dfs(neighbor, cluster.copy(), depth + 1, path + (neighbor,)) + + # Add completed path-based cluster + if len(cluster) > 1: + clusters.append(cluster) for node in self.nodes: - if node not in visited: - cluster = set() - dfs(node, cluster) - if len(cluster) > 1: + initial_cluster = set() + dfs(node, initial_cluster, 0, (node,)) + + # Remove duplicates by converting clusters to frozensets + unique_clusters = [ + set(cluster) for cluster in set(frozenset(c) for c in clusters) + ] + + return unique_clusters + + def find_direct_clusters( + self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True + ) -> t.Dict[Node, t.List[t.Set[Node]]]: + """ + Finds direct clusters of nodes in the knowledge graph based on a relationship condition. + Here if A->B, and A->C, then A, B, and C form a cluster. + + Parameters + ---------- + relationship_condition : Callable[[Relationship], bool], optional + A function that takes a Relationship and returns a boolean, by default lambda _: True + + Returns + ------- + List[Set[Node]] + A list of sets, where each set contains nodes that form a cluster. + """ + + clusters = [] + relationships = [ + rel for rel in self.relationships if relationship_condition(rel) + ] + for node in self.nodes: + cluster = set() + cluster.add(node) + for rel in relationships: + if rel.bidirectional: + if rel.source == node: + cluster.add(rel.target) + elif rel.target == node: + cluster.add(rel.source) + else: + if rel.source == node: + cluster.add(rel.target) + + if len(cluster) > 1: + if cluster not in clusters: clusters.append(cluster) - return clusters + # Remove subsets from clusters + unique_clusters = [] + for cluster in clusters: + if not any(cluster < other for other in clusters): + unique_clusters.append(cluster) + clusters = unique_clusters + + cluster_dict = {} + for cluster in clusters: + cluster_dict.update({cluster.pop(): cluster}) + + return cluster_dict diff --git a/src/ragas/testset/graph_queries.py b/src/ragas/testset/graph_queries.py new file mode 100644 index 000000000..23397d803 --- /dev/null +++ b/src/ragas/testset/graph_queries.py @@ -0,0 +1,38 @@ +import typing as t + +from ragas.testset.graph import KnowledgeGraph, Node + + +def get_child_nodes(node: Node, graph: KnowledgeGraph, level: int = 1) -> t.List[Node]: + """ + Get the child nodes of a given node up to a specified level. + + Parameters + ---------- + node : Node + The node to get the children of. + graph : KnowledgeGraph + The knowledge graph containing the node. + level : int + The maximum level to which child nodes are searched. + + Returns + ------- + List[Node] + The list of child nodes up to the specified level. + """ + children = [] + + # Helper function to perform depth-limited search for child nodes + def dfs(current_node: Node, current_level: int): + if current_level > level: + return + for rel in graph.relationships: + if rel.source == current_node and rel.type == "child": + children.append(rel.target) + dfs(rel.target, current_level + 1) + + # Start DFS from the initial node at level 0 + dfs(node, 1) + + return children diff --git a/src/ragas/testset/synthesizers/__init__.py b/src/ragas/testset/synthesizers/__init__.py index c0d23e4d6..48679a179 100644 --- a/src/ragas/testset/synthesizers/__init__.py +++ b/src/ragas/testset/synthesizers/__init__.py @@ -1,38 +1,29 @@ import typing as t from ragas.llms import BaseRagasLLM - -from .abstract_query import ( - AbstractQuerySynthesizer, - ComparativeAbstractQuerySynthesizer, +from ragas.testset.synthesizers.multi_hop import ( + MultiHopAbstractQuerySynthesizer, + MultiHopSpecificQuerySynthesizer, +) +from ragas.testset.synthesizers.single_hop.specific import ( + SingleHopSpecificQuerySynthesizer, ) + from .base import BaseSynthesizer -from .base_query import QuerySynthesizer -from .specific_query import SpecificQuerySynthesizer QueryDistribution = t.List[t.Tuple[BaseSynthesizer, float]] def default_query_distribution(llm: BaseRagasLLM) -> QueryDistribution: - """ - Default query distribution for the test set. - - By default, 25% of the queries are generated using `AbstractQuerySynthesizer`, - 25% are generated using `ComparativeAbstractQuerySynthesizer`, and 50% are - generated using `SpecificQuerySynthesizer`. - """ + """ """ return [ - (AbstractQuerySynthesizer(llm=llm), 0.25), - (ComparativeAbstractQuerySynthesizer(llm=llm), 0.25), - (SpecificQuerySynthesizer(llm=llm), 0.5), + (SingleHopSpecificQuerySynthesizer(llm=llm), 0.5), + (MultiHopAbstractQuerySynthesizer(llm=llm), 0.25), + (MultiHopSpecificQuerySynthesizer(llm=llm), 0.25), ] __all__ = [ "BaseSynthesizer", - "QuerySynthesizer", - "AbstractQuerySynthesizer", - "ComparativeAbstractQuerySynthesizer", - "SpecificQuerySynthesizer", "default_query_distribution", ] diff --git a/src/ragas/testset/synthesizers/abstract_query.py b/src/ragas/testset/synthesizers/abstract_query.py deleted file mode 100644 index 41df9cf32..000000000 --- a/src/ragas/testset/synthesizers/abstract_query.py +++ /dev/null @@ -1,349 +0,0 @@ -from __future__ import annotations - -import logging -import math -import random -import typing as t -from dataclasses import dataclass, field - -from ragas.dataset_schema import SingleTurnSample -from ragas.executor import run_async_batch -from ragas.prompt import PydanticPrompt -from ragas.testset.graph import KnowledgeGraph, NodeType - -from .base import BaseScenario, QueryLength, QueryStyle -from .base_query import QuerySynthesizer -from .prompts import ( - AbstractQueryFromTheme, - CAQInput, - CommonConceptsFromKeyphrases, - CommonThemeFromSummariesPrompt, - ComparativeAbstractQuery, - Concepts, - KeyphrasesAndNumConcepts, - Summaries, - ThemeAndContext, - Themes, -) - -if t.TYPE_CHECKING: - from langchain_core.callbacks import Callbacks - -logger = logging.getLogger(__name__) - - -class AbstractQueryScenario(BaseScenario): - """ - Represents a scenario for generating abstract queries. - Also inherits attributes from [BaseScenario][ragas.testset.synthesizers.base.BaseScenario]. - - Attributes - ---------- - theme : str - The theme of the abstract query scenario. - """ - - theme: str - - -@dataclass -class AbstractQuerySynthesizer(QuerySynthesizer): - """ - Synthesizes abstract queries which generate a theme and a set of summaries from a - cluster of chunks and then generate queries based on that. - - Attributes - ---------- - generate_user_input_prompt : PydanticPrompt - The prompt used for generating the user input. - """ - - generate_user_input_prompt: PydanticPrompt = field( - default_factory=AbstractQueryFromTheme - ) - - def __post_init__(self): - super().__post_init__() - self.common_theme_prompt = CommonThemeFromSummariesPrompt() - - async def _generate_scenarios( - self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks - ) -> t.List[AbstractQueryScenario]: - node_clusters = knowledge_graph.find_clusters( - relationship_condition=lambda rel: ( - True if rel.get_property("cosine_similarity") else False - ) - ) - logger.info("found %d clusters", len(node_clusters)) - - # filter out nodes that are not chunks - node_clusters = [ - cluster - for cluster in node_clusters - if all(node.type == "chunk" for node in cluster) - ] - - # find the number of themes to generation for given n and the num of clusters - # will generate more themes just in case - if len(node_clusters) == 0: - node_clusters_new = [] - # if no clusters, use the nodes directly - for node in knowledge_graph.nodes: - if node.type == NodeType.CHUNK: - node_clusters_new.append([node]) - - if len(node_clusters_new) == 0: - raise ValueError( - "no clusters found. Try running a few transforms to populate the dataset" - ) - node_clusters = node_clusters_new[:n] - - num_clusters = len(node_clusters) - num_themes = math.ceil(n / num_clusters) - logger.info("generating %d themes", num_clusters) - - kw_list = [] - for cluster in node_clusters: - summaries = [] - for node in cluster: - summary = node.get_property("summary") - if summary is not None: - summaries.append(summary) - - summaries = Summaries( - summaries=summaries, - num_themes=num_themes, - ) - kw_list.append({"data": summaries, "llm": self.llm, "callbacks": callbacks}) - - themes: t.List[Themes] = run_async_batch( - desc="Generating common themes", - func=self.common_theme_prompt.generate, - kwargs_list=kw_list, - ) - - # sample clusters and themes to get num_clusters * num_themes - clusters_sampled = [] - themes_sampled = [] - themes_list = [theme.themes for theme in themes] - for cluster, ts in zip(node_clusters, themes_list): - for theme in ts: - themes_sampled.append(theme) - clusters_sampled.append(cluster) - - # sample query styles and query lengths - query_styles = random.choices(list(QueryStyle), k=num_clusters * num_themes) - query_lengths = random.choices(list(QueryLength), k=num_clusters * num_themes) - - # create distributions - distributions = [] - for cluster, theme, style, length in zip( - clusters_sampled, themes_sampled, query_styles, query_lengths - ): - distributions.append( - AbstractQueryScenario( - theme=theme.theme, - nodes=cluster, - style=style, - length=length, - ) - ) - return distributions - - async def _generate_sample( - self, scenario: AbstractQueryScenario, callbacks: Callbacks - ) -> SingleTurnSample: - user_input = await self.generate_query(scenario, callbacks) - if await self.critic_query(user_input): - user_input = await self.modify_query(user_input, scenario, callbacks) - - reference = await self.generate_reference(user_input, scenario) - - reference_contexts = [] - for node in scenario.nodes: - if node.get_property("page_content") is not None: - reference_contexts.append(node.get_property("page_content")) - - return SingleTurnSample( - user_input=user_input, - reference=reference, - reference_contexts=reference_contexts, - ) - - async def generate_query( - self, scenario: AbstractQueryScenario, callbacks: Callbacks - ) -> str: - query = await self.generate_user_input_prompt.generate( - data=ThemeAndContext( - theme=scenario.theme, - context=self.make_reference_contexts(scenario), - ), - llm=self.llm, - callbacks=callbacks, - ) - return query.text - - -class ComparativeAbstractQueryScenario(BaseScenario): - common_concept: str - - -@dataclass -class ComparativeAbstractQuerySynthesizer(QuerySynthesizer): - """ - Synthesizes comparative abstract queries which generate a common concept and - a set of keyphrases and summaries and then generate queries based on that. - - Attributes - ---------- - common_concepts_prompt : PydanticPrompt - The prompt used for generating common concepts. - generate_query_prompt : PydanticPrompt - The prompt used for generating the query. - """ - - common_concepts_prompt: PydanticPrompt = field( - default_factory=CommonConceptsFromKeyphrases - ) - generate_query_prompt: PydanticPrompt = field( - default_factory=ComparativeAbstractQuery - ) - - async def _generate_scenarios( - self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks - ) -> t.List[ComparativeAbstractQueryScenario]: - node_clusters = knowledge_graph.find_clusters( - relationship_condition=lambda rel: ( - True if rel.get_property("summary_cosine_similarity") else False - ) - ) - logger.info("found %d clusters", len(node_clusters)) - - # find the number of themes to generation for given n and the num of clusters - # will generate more themes just in case - if len(node_clusters) == 0: - node_clusters_new = [] - - # if no clusters, use the nodes directly - for node in knowledge_graph.nodes: - if node.type == NodeType.DOCUMENT: - node_clusters_new.append([node]) - - if len(node_clusters_new) == 0: - raise ValueError( - "no clusters found. Try running a few transforms to populate the dataset" - ) - node_clusters = node_clusters_new[:n] - - num_clusters = len(node_clusters) - num_concepts = math.ceil(n / num_clusters) - logger.info("generating %d common_themes", num_concepts) - - # generate common themes - cluster_concepts = [] - kw_list: t.List[t.Dict] = [] - for cluster in node_clusters: - keyphrases = [] - for node in cluster: - keyphrases_node = node.get_property("keyphrases") - if keyphrases_node is not None: - keyphrases.extend(keyphrases_node) - - kw_list.append( - { - "data": KeyphrasesAndNumConcepts( - keyphrases=keyphrases, - num_concepts=num_concepts, - ), - "llm": self.llm, - "callbacks": callbacks, - } - ) - - common_concepts: t.List[Concepts] = run_async_batch( - desc="Generating common_concepts", - func=self.common_concepts_prompt.generate, - kwargs_list=kw_list, - ) - - # sample everything n times - for cluster, common_concept in zip(node_clusters, common_concepts): - for concept in common_concept.concepts: - cluster_concepts.append((cluster, concept)) - - query_lengths_sampled = random.choices( - list(QueryLength), k=num_clusters * num_concepts - ) - query_styles_sampled = random.choices( - list(QueryStyle), k=num_clusters * num_concepts - ) - logger.info( - "len(query_lengths_sampled) = %d, len(query_styles_sampled) = %d, len(cluster_concepts) = %d", - len(query_lengths_sampled), - len(query_styles_sampled), - len(cluster_concepts), - ) - - # make the scenarios - scenarios = [] - for (cluster, concept), length, style in zip( - cluster_concepts, - query_lengths_sampled, - query_styles_sampled, - ): - scenarios.append( - ComparativeAbstractQueryScenario( - common_concept=concept, - nodes=cluster, - length=length, - style=style, - ) - ) - return scenarios - - async def _generate_sample( - self, scenario: ComparativeAbstractQueryScenario, callbacks: Callbacks - ) -> SingleTurnSample: - # generate the user input - keyphrases = [] - summaries = [] - for n in scenario.nodes: - keyphrases_node = n.get_property("keyphrases") - if keyphrases_node is not None: - keyphrases.extend(keyphrases_node) - summary_node = n.get_property("summary") - if summary_node is not None: - summaries.append(summary_node) - - query = await self.generate_query_prompt.generate( - data=CAQInput( - concept=scenario.common_concept, - keyphrases=keyphrases, - summaries=summaries, - ), - llm=self.llm, - callbacks=callbacks, - ) - query = query.text - - # critic the query - if not await self.critic_query(query): - query = await self.modify_query(query, scenario, callbacks) - - # generate the answer - answer = await self.generate_reference( - query, scenario, callbacks, reference_property_name="summary" - ) - - # make the reference contexts - # TODO: make this more efficient. Right now we are taking only the summary - reference_contexts = [] - for node in scenario.nodes: - if node.get_property("summary") is not None: - reference_contexts.append(node.get_property("summary")) - - return SingleTurnSample( - user_input=query, - reference=answer, - reference_contexts=reference_contexts, - ) diff --git a/src/ragas/testset/synthesizers/base.py b/src/ragas/testset/synthesizers/base.py index 4a6fba61e..2c835702f 100644 --- a/src/ragas/testset/synthesizers/base.py +++ b/src/ragas/testset/synthesizers/base.py @@ -11,6 +11,7 @@ from ragas.llms import BaseRagasLLM, llm_factory from ragas.prompt import PromptMixin from ragas.testset.graph import KnowledgeGraph, Node +from ragas.testset.persona import Persona if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks @@ -51,11 +52,14 @@ class BaseScenario(BaseModel): The style of the query. length : QueryLength The length of the query. + persona : Persona + A persona associated with the scenario. """ nodes: t.List[Node] style: QueryStyle length: QueryLength + persona: Persona Scenario = t.TypeVar("Scenario", bound=BaseScenario) @@ -78,6 +82,7 @@ async def generate_scenarios( self, n: int, knowledge_graph: KnowledgeGraph, + persona_list: t.List[Persona], callbacks: t.Optional[Callbacks] = None, ) -> t.List[Scenario]: callbacks = callbacks or [] @@ -87,14 +92,18 @@ async def generate_scenarios( callbacks=callbacks, ) scenarios = await self._generate_scenarios( - n, knowledge_graph, scenario_generation_group + n, knowledge_graph, persona_list, scenario_generation_group ) scenario_generation_rm.on_chain_end(outputs={"scenarios": scenarios}) return scenarios @abstractmethod async def _generate_scenarios( - self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks + self, + n: int, + knowledge_graph: KnowledgeGraph, + persona_list: t.List[Persona], + callbacks: Callbacks, ) -> t.List[Scenario]: pass diff --git a/src/ragas/testset/synthesizers/base_query.py b/src/ragas/testset/synthesizers/base_query.py deleted file mode 100644 index 49e8dd953..000000000 --- a/src/ragas/testset/synthesizers/base_query.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations - -import typing as t -from dataclasses import dataclass, field - -from ragas.prompt import StringIO - -from .base import BaseSynthesizer, Scenario -from .prompts import ( - CriticUserInput, - GenerateReference, - ModifyUserInput, - PydanticPrompt, - QueryAndContext, - QueryWithStyleAndLength, - extend_modify_input_prompt, -) - -if t.TYPE_CHECKING: - from langchain_core.callbacks import Callbacks - - -@dataclass -class QuerySynthesizer(BaseSynthesizer[Scenario]): - """ - Synthesizes Question-Answer pairs. Used as a base class for other query synthesizers. - - Attributes - ---------- - critic_query_prompt : PydanticPrompt - The prompt used for criticizing the query. - query_modification_prompt : PydanticPrompt - The prompt used for modifying the query. - generate_reference_prompt : PydanticPrompt - The prompt used for generating the reference. - """ - - critic_query_prompt: PydanticPrompt = field(default_factory=CriticUserInput) - query_modification_prompt: PydanticPrompt = field(default_factory=ModifyUserInput) - generate_reference_prompt: PydanticPrompt = field(default_factory=GenerateReference) - - async def critic_query( - self, query: str, callbacks: t.Optional[Callbacks] = None - ) -> bool: - callbacks = callbacks or [] - critic = await self.critic_query_prompt.generate( - data=StringIO(text=query), llm=self.llm, callbacks=callbacks - ) - return critic.independence > 1 and critic.clear_intent > 1 - - async def modify_query( - self, query: str, scenario: Scenario, callbacks: Callbacks - ) -> str: - prompt = extend_modify_input_prompt( - query_modification_prompt=self.query_modification_prompt, - style=scenario.style, - length=scenario.length, - ) - modified_query = await prompt.generate( - data=QueryWithStyleAndLength( - query=query, - style=scenario.style, - length=scenario.length, - ), - llm=self.llm, - callbacks=callbacks, - ) - return modified_query.text - - async def generate_reference( - self, - question: str, - scenario: Scenario, - callbacks: t.Optional[Callbacks] = None, - reference_property_name: str = "page_content", - ) -> str: - callbacks = callbacks or [] - reference = await self.generate_reference_prompt.generate( - data=QueryAndContext( - query=question, - context=self.make_reference_contexts(scenario, reference_property_name), - ), - llm=self.llm, - callbacks=callbacks, - ) - return reference.text - - @staticmethod - def make_reference_contexts( - scenario: Scenario, property_name: str = "page_content" - ) -> str: - page_contents = [] - for node in scenario.nodes: - page_contents.append(node.get_property(property_name)) - return "\n\n".join(page_contents) diff --git a/src/ragas/testset/synthesizers/generate.py b/src/ragas/testset/synthesizers/generate.py index 67fae1132..6db8039a1 100644 --- a/src/ragas/testset/synthesizers/generate.py +++ b/src/ragas/testset/synthesizers/generate.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import random import typing as t from dataclasses import dataclass, field @@ -18,6 +19,7 @@ from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LlamaIndexLLMWrapper from ragas.run_config import RunConfig from ragas.testset.graph import KnowledgeGraph, Node, NodeType +from ragas.testset.persona import Persona, generate_personas_from_kg from ragas.testset.synthesizers import default_query_distribution from ragas.testset.synthesizers.testset_schema import Testset, TestsetSample from ragas.testset.synthesizers.utils import calculate_split_values @@ -62,6 +64,7 @@ class TestsetGenerator: llm: BaseRagasLLM embedding_model: BaseRagasEmbeddings knowledge_graph: KnowledgeGraph = field(default_factory=KnowledgeGraph) + persona_list: t.Optional[t.List[Persona]] = None @classmethod def from_langchain( @@ -271,6 +274,7 @@ def generate( self, testset_size: int, query_distribution: t.Optional[QueryDistribution] = None, + num_personas: int = 3, run_config: t.Optional[RunConfig] = None, batch_size: t.Optional[int] = None, callbacks: t.Optional[Callbacks] = None, @@ -288,6 +292,8 @@ def generate( query_distribution : Optional[QueryDistribution], optional A list of tuples containing scenario simulators and their probabilities. If None, default simulators will be used. + num_personas : int, default 3 + The number of personas to generate or use from the persona_list. run_config : Optional[RunConfig], optional Configuration for running the generation process. batch_size: int, optional @@ -356,6 +362,16 @@ def generate( patch_logger("ragas.experimental.testset.graph", logging.DEBUG) patch_logger("ragas.experimental.testset.transforms", logging.DEBUG) + if self.persona_list is None: + self.persona_list = generate_personas_from_kg( + llm=self.llm, + kg=self.knowledge_graph, + num_personas=num_personas, + callbacks=callbacks, + ) + else: + random.shuffle(self.persona_list) + splits, _ = calculate_split_values( [prob for _, prob in query_distribution], testset_size ) @@ -383,6 +399,7 @@ def generate( scenario.generate_scenarios, n=splits[i], knowledge_graph=self.knowledge_graph, + persona_list=self.persona_list[:num_personas], callbacks=scenario_generation_grp, ) diff --git a/src/ragas/testset/synthesizers/multi_hop/__init__.py b/src/ragas/testset/synthesizers/multi_hop/__init__.py new file mode 100644 index 000000000..7dbc2a401 --- /dev/null +++ b/src/ragas/testset/synthesizers/multi_hop/__init__.py @@ -0,0 +1,10 @@ +from .abstract import MultiHopAbstractQuerySynthesizer +from .base import MultiHopQuerySynthesizer, MultiHopScenario +from .specific import MultiHopSpecificQuerySynthesizer + +__all__ = [ + "MultiHopAbstractQuerySynthesizer", + "MultiHopSpecificQuerySynthesizer", + "MultiHopQuerySynthesizer", + "MultiHopScenario", +] diff --git a/src/ragas/testset/synthesizers/multi_hop/abstract.py b/src/ragas/testset/synthesizers/multi_hop/abstract.py new file mode 100644 index 000000000..20162ff4e --- /dev/null +++ b/src/ragas/testset/synthesizers/multi_hop/abstract.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import logging +import typing as t +from dataclasses import dataclass + +import numpy as np + +from ragas.prompt import PydanticPrompt +from ragas.testset.graph import KnowledgeGraph +from ragas.testset.graph_queries import get_child_nodes +from ragas.testset.persona import Persona, PersonaList +from ragas.testset.synthesizers.multi_hop.base import ( + MultiHopQuerySynthesizer, + MultiHopScenario, +) +from ragas.testset.synthesizers.multi_hop.prompts import ( + ConceptCombinationPrompt, + ConceptsList, +) +from ragas.testset.synthesizers.prompts import ( + ThemesPersonasInput, + ThemesPersonasMatchingPrompt, +) + +if t.TYPE_CHECKING: + from langchain_core.callbacks import Callbacks + +logger = logging.getLogger(__name__) + + +@dataclass +class MultiHopAbstractQuerySynthesizer(MultiHopQuerySynthesizer): + """ + Synthesizes abstract multi-hop queries from given knowledge graph. + + Attributes + ---------- + """ + + name: str = "multi_hop_abstract_query_synthesizer" + concept_combination_prompt: PydanticPrompt = ConceptCombinationPrompt() + theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt() + + async def _generate_scenarios( + self, + n: int, + knowledge_graph: KnowledgeGraph, + persona_list: t.List[Persona], + callbacks: Callbacks, + ) -> t.List[MultiHopScenario]: + """ + Generates a list of scenarios on type MultiHopAbstractQuerySynthesizer + Steps to generate scenarios: + 1. Find indirect clusters of nodes based on relationship condition + 2. Calculate the number of samples that should be created per cluster to get n samples in total + 3. For each cluster of nodes + a. Find the child nodes of the cluster nodes + b. Find list of personas that can be associated with the entities to create query + c. Create all possible combinations of (nodes, entities, personas, style, length) as scenarios + 4. Sample diverse combinations of scenarios to get n samples + """ + + node_clusters = knowledge_graph.find_indirect_clusters( + relationship_condition=lambda rel: ( + True if rel.get_property("summary_similarity") else False + ), + depth_limit=3, + ) + logger.info("found %d clusters", len(node_clusters)) + scenarios = [] + + num_sample_per_cluster = int(np.ceil(n / len(node_clusters))) + + for cluster in node_clusters: + if len(scenarios) >= n: + break + nodes = [] + for node in cluster: + child_nodes = get_child_nodes(node, knowledge_graph, level=1) + if child_nodes: + nodes.extend(child_nodes) + else: + nodes.append(node) + + base_scenarios = [] + node_themes = [node.properties.get("themes", []) for node in nodes] + prompt_input = ConceptsList( + lists_of_concepts=node_themes, max_combinations=num_sample_per_cluster + ) + concept_combination = await self.concept_combination_prompt.generate( + data=prompt_input, llm=self.llm, callbacks=callbacks + ) + flattened_themes = [ + theme + for sublist in concept_combination.combinations + for theme in sublist + ] + prompt_input = ThemesPersonasInput( + themes=flattened_themes, personas=persona_list + ) + persona_concepts = await self.theme_persona_matching_prompt.generate( + data=prompt_input, llm=self.llm, callbacks=callbacks + ) + + base_scenarios = self.prepare_combinations( + nodes, + concept_combination.combinations, + PersonaList(personas=persona_list), + persona_concepts, + property_name="themes", + ) + base_scenarios = self.sample_diverse_combinations( + base_scenarios, num_sample_per_cluster + ) + scenarios.extend(base_scenarios) + + return scenarios diff --git a/src/ragas/testset/synthesizers/multi_hop/base.py b/src/ragas/testset/synthesizers/multi_hop/base.py new file mode 100644 index 000000000..3b2e3010c --- /dev/null +++ b/src/ragas/testset/synthesizers/multi_hop/base.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import logging +import random +import typing as t +from collections import defaultdict +from dataclasses import dataclass + +from ragas import SingleTurnSample +from ragas.prompt import PydanticPrompt +from ragas.testset.persona import PersonaList +from ragas.testset.synthesizers.base import ( + BaseScenario, + BaseSynthesizer, + QueryLength, + QueryStyle, + Scenario, +) +from ragas.testset.synthesizers.multi_hop.prompts import ( + QueryAnswerGenerationPrompt, + QueryConditions, +) + +if t.TYPE_CHECKING: + from langchain_core.callbacks import Callbacks + +logger = logging.getLogger(__name__) + + +class MultiHopScenario(BaseScenario): + """ + Scenario for multi-hop queries. + + Attributes + ---------- + combinations: str + The theme of the query. + style: QueryStyle + The style of the query. + length: QueryLength + The length of the query. + """ + + combinations: t.List[str] + + +@dataclass +class MultiHopQuerySynthesizer(BaseSynthesizer[Scenario]): + + generate_query_reference_prompt: PydanticPrompt = QueryAnswerGenerationPrompt() + + def prepare_combinations( + self, + nodes, + combinations: t.List[t.List[str]], + persona_list: PersonaList, + persona_concepts, + property_name: str, + ) -> t.List[t.Dict[str, t.Any]]: + + possible_combinations = [] + for combination in combinations: + dict = {"combination": combination} + valid_personas = [] + for persona, concept_list in persona_concepts.mapping.items(): + concept_list = [c.lower() for c in concept_list] + if ( + any(concept.lower() in concept_list for concept in combination) + and persona_list[persona] + ): + valid_personas.append(persona_list[persona]) + dict["personas"] = valid_personas + valid_nodes = [] + for node in nodes: + node_themes = [ + theme.lower() for theme in node.get_property(property_name) + ] + if node.get_property(property_name) and any( + concept.lower() in node_themes for concept in combination + ): + valid_nodes.append(node) + + dict["nodes"] = valid_nodes + dict["styles"] = list(QueryStyle) + dict["lengths"] = list(QueryLength) + + possible_combinations.append(dict) + return possible_combinations + + def sample_diverse_combinations( + self, data: t.List[t.Dict[str, t.Any]], num_samples: int + ) -> t.List[MultiHopScenario]: + + selected_samples = [] + combination_persona_count = defaultdict(set) + style_count = defaultdict(int) + length_count = defaultdict(int) + + all_possible_samples = [] + + for entry in data: + combination = tuple(entry["combination"]) + nodes = entry["nodes"] + + for persona in entry["personas"]: + for style in entry["styles"]: + for length in entry["lengths"]: + all_possible_samples.append( + { + "combination": combination, + "persona": persona, + "nodes": nodes, + "style": style, + "length": length, + } + ) + + random.shuffle(all_possible_samples) + + for sample in all_possible_samples: + if len(selected_samples) >= num_samples: + break + + combination = sample["combination"] + persona = sample["persona"] + style = sample["style"] + length = sample["length"] + + if persona.name not in combination_persona_count[combination]: + selected_samples.append(sample) + combination_persona_count[combination].add(persona.name) + + elif style_count[style] < max(style_count.values(), default=0) + 1: + selected_samples.append(sample) + style_count[style] += 1 + + elif length_count[length] < max(length_count.values(), default=0) + 1: + selected_samples.append(sample) + length_count[length] += 1 + + return [self.convert_to_scenario(sample) for sample in selected_samples] + + def convert_to_scenario(self, data: t.Dict[str, t.Any]) -> MultiHopScenario: + + return MultiHopScenario( + nodes=data["nodes"], + combinations=data["combination"], + style=data["style"], + length=data["length"], + persona=data["persona"], + ) + + async def _generate_sample( + self, scenario: MultiHopScenario, callbacks: Callbacks + ) -> SingleTurnSample: + + reference_context = self.make_contexts(scenario) + prompt_input = QueryConditions( + persona=scenario.persona, + themes=scenario.combinations, + context=reference_context, + query_length=scenario.length.name, + query_style=scenario.style.name, + ) + response = await self.generate_query_reference_prompt.generate( + data=prompt_input, llm=self.llm, callbacks=callbacks + ) + return SingleTurnSample( + user_input=response.query, + reference=response.answer, + reference_contexts=reference_context, + ) + + def make_contexts(self, scenario: MultiHopScenario) -> t.List[str]: + + contexts = [] + for node in scenario.nodes: + context = f"{node.id}" + "\n\n" + node.properties.get("page_content", "") + contexts.append(context) + + return contexts diff --git a/src/ragas/testset/synthesizers/multi_hop/prompts.py b/src/ragas/testset/synthesizers/multi_hop/prompts.py new file mode 100644 index 000000000..a701eb6c6 --- /dev/null +++ b/src/ragas/testset/synthesizers/multi_hop/prompts.py @@ -0,0 +1,86 @@ +import typing as t + +from pydantic import BaseModel, Field + +from ragas.prompt import PydanticPrompt +from ragas.testset.persona import Persona + + +class ConceptsList(BaseModel): + lists_of_concepts: t.List[t.List[str]] = Field( + description="A list containing lists of concepts from each node" + ) + max_combinations: int = Field( + description="The maximum number of concept combinations to generate", default=5 + ) + + +class ConceptCombinations(BaseModel): + combinations: t.List[t.List[str]] + + +class ConceptCombinationPrompt(PydanticPrompt[ConceptsList, ConceptCombinations]): + instruction: str = ( + "Form combinations by pairing concepts from at least two different lists.\n" + "**Instructions:**\n" + "- Review the concepts from each node.\n" + "- Identify concepts that can logically be connected or contrasted.\n" + "- Form combinations that involve concepts from different nodes.\n" + "- Each combination should include at least one concept from two or more nodes.\n" + "- List the combinations clearly and concisely.\n" + "- Do not repeat the same combination more than once." + ) + input_model: t.Type[ConceptsList] = ( + ConceptsList # Contains lists of concepts from each node + ) + output_model: t.Type[ConceptCombinations] = ( + ConceptCombinations # Contains list of concept combinations + ) + examples: t.List[t.Tuple[ConceptsList, ConceptCombinations]] = [ + ( + ConceptsList( + lists_of_concepts=[ + ["Artificial intelligence", "Automation"], # Concepts from Node 1 + ["Healthcare", "Data privacy"], # Concepts from Node 2 + ], + max_combinations=2, + ), + ConceptCombinations( + combinations=[ + ["Artificial intelligence", "Healthcare"], + ["Automation", "Data privacy"], + ] + ), + ) + ] + + +class QueryConditions(BaseModel): + persona: Persona + themes: t.List[str] + query_style: str + query_length: str + context: t.List[str] + + +class GeneratedQueryAnswer(BaseModel): + query: str + answer: str + + +class QueryAnswerGenerationPrompt( + PydanticPrompt[QueryConditions, GeneratedQueryAnswer] +): + instruction: str = ( + "Generate a query and answer based on the specified conditions (persona, themes, style, length) " + "and the provided context. Ensure the answer is fully faithful to the context, only using information " + "directly from the nodes provided." + "### Instructions:\n" + "1. **Generate a Query**: Based on the context, persona, themes, style, and length, create a question " + "that aligns with the persona’s perspective and reflects the themes.\n" + "2. **Generate an Answer**: Using only the content from the provided context, create a faithful and detailed answer to " + "the query. Do not include any information that not in or cannot be inferred from the given context.\n" + "### Example Outputs:\n\n" + ) + input_model: t.Type[QueryConditions] = QueryConditions + output_model: t.Type[GeneratedQueryAnswer] = GeneratedQueryAnswer diff --git a/src/ragas/testset/synthesizers/multi_hop/specific.py b/src/ragas/testset/synthesizers/multi_hop/specific.py new file mode 100644 index 000000000..b71af16c3 --- /dev/null +++ b/src/ragas/testset/synthesizers/multi_hop/specific.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import logging +import typing as t +from dataclasses import dataclass + +import numpy as np + +from ragas.prompt import PydanticPrompt +from ragas.testset.graph import KnowledgeGraph +from ragas.testset.persona import Persona, PersonaList +from ragas.testset.synthesizers.multi_hop.base import ( + MultiHopQuerySynthesizer, + MultiHopScenario, +) +from ragas.testset.synthesizers.multi_hop.prompts import QueryAnswerGenerationPrompt +from ragas.testset.synthesizers.prompts import ( + ThemesPersonasInput, + ThemesPersonasMatchingPrompt, +) + +if t.TYPE_CHECKING: + from langchain_core.callbacks import Callbacks + +logger = logging.getLogger(__name__) + + +@dataclass +class MultiHopSpecificQuerySynthesizer(MultiHopQuerySynthesizer): + """ + Synthesizes overlap based queries by choosing specific chunks and generating a + keyphrase from them and then generating queries based on that. + + Attributes + ---------- + generate_query_prompt : PydanticPrompt + The prompt used for generating the query. + """ + + name: str = "multi_hop_specific_query_synthesizer" + theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt() + generate_query_reference_prompt: PydanticPrompt = QueryAnswerGenerationPrompt() + + async def _generate_scenarios( + self, + n: int, + knowledge_graph: KnowledgeGraph, + persona_list: t.List[Persona], + callbacks: Callbacks, + ) -> t.List[MultiHopScenario]: + """ + Generates a list of scenarios on type MultiHopSpecificQuerySynthesizer + Steps to generate scenarios: + 1. Filter the knowledge graph to find cluster of nodes or defined relation type. Here entities_overlap + 2. Calculate the number of samples that should be created per cluster to get n samples in total + 3. For each cluster of nodes + a. Find the entities that are common between the nodes + b. Find list of personas that can be associated with the entities to create query + c. Create all possible combinations of (nodes, entities, personas, style, length) as scenarios + 3. Sample num_sample_per_cluster scenarios from the list of scenarios + 4. Return the list of scenarios of length n + """ + + cluster_dict = knowledge_graph.find_direct_clusters( + relationship_condition=lambda rel: ( + True if rel.type == "entities_overlap" else False + ) + ) + + valid_relationships = [ + rel + for rel in knowledge_graph.relationships + if rel.type == "entities_overlap" + ] + + node_clusters = [] + for key_node, list_of_nodes in cluster_dict.items(): + for node in list_of_nodes: + node_clusters.append((key_node, node)) + + logger.info("found %d clusters", len(cluster_dict)) + scenarios = [] + num_sample_per_cluster = int(np.ceil(n / len(node_clusters))) + + for cluster in node_clusters: + if len(scenarios) < n: + key_node, node = cluster + overlapped_items = [] + for rel in valid_relationships: + if rel.source == key_node and rel.target == node: + overlapped_items = rel.get_property("overlapped_items") + break + if overlapped_items: + themes = list(dict(overlapped_items).keys()) + prompt_input = ThemesPersonasInput( + themes=themes, personas=persona_list + ) + persona_concepts = ( + await self.theme_persona_matching_prompt.generate( + data=prompt_input, llm=self.llm, callbacks=callbacks + ) + ) + overlapped_items = [list(item) for item in overlapped_items] + base_scenarios = self.prepare_combinations( + [key_node, node], + overlapped_items, + PersonaList(personas=persona_list), + persona_concepts, + property_name="entities", + ) + base_scenarios = self.sample_diverse_combinations( + base_scenarios, num_sample_per_cluster + ) + scenarios.extend(base_scenarios) + + return scenarios diff --git a/src/ragas/testset/synthesizers/prompts.py b/src/ragas/testset/synthesizers/prompts.py index f8150c007..b00813613 100644 --- a/src/ragas/testset/synthesizers/prompts.py +++ b/src/ragas/testset/synthesizers/prompts.py @@ -2,407 +2,48 @@ from pydantic import BaseModel -from ragas.prompt import PydanticPrompt, StringIO -from ragas.testset.synthesizers.base import QueryLength, QueryStyle +from ragas.prompt import PydanticPrompt +from ragas.testset.persona import Persona -class Summaries(BaseModel): - summaries: t.List[str] - num_themes: int +class ThemesPersonasInput(BaseModel): + themes: t.List[str] + personas: t.List[Persona] -class Theme(BaseModel): - theme: str - description: str +class PersonaThemesMapping(BaseModel): + mapping: t.Dict[str, t.List[str]] -class Themes(BaseModel): - themes: t.List[Theme] - - -class CommonThemeFromSummariesPrompt(PydanticPrompt[Summaries, Themes]): - input_model = Summaries - output_model = Themes - instruction = "Analyze the following summaries and identify given number of common themes. The themes should be concise, descriptive, and highlight a key aspect shared across the summaries." - examples = [ +class ThemesPersonasMatchingPrompt( + PydanticPrompt[ThemesPersonasInput, PersonaThemesMapping] +): + instruction: str = ( + "Given a list of themes and personas with their roles, " + "associate each persona with relevant themes based on their role description." + ) + input_model: t.Type[ThemesPersonasInput] = ThemesPersonasInput + output_model: t.Type[PersonaThemesMapping] = PersonaThemesMapping + examples: t.List[t.Tuple[ThemesPersonasInput, PersonaThemesMapping]] = [ ( - Summaries( - summaries=[ - "Advances in artificial intelligence have revolutionized many industries. From healthcare to finance, AI algorithms are making processes more efficient and accurate. Machine learning models are being used to predict diseases, optimize investment strategies, and even recommend personalized content to users. The integration of AI into daily operations is becoming increasingly indispensable for modern businesses.", - "The healthcare industry is witnessing a significant transformation due to AI advancements. AI-powered diagnostic tools are improving the accuracy of medical diagnoses, reducing human error, and enabling early detection of diseases. Additionally, AI is streamlining administrative tasks, allowing healthcare professionals to focus more on patient care. Personalized treatment plans driven by AI analytics are enhancing patient outcomes.", - "Financial technology, or fintech, has seen a surge in AI applications. Algorithms for fraud detection, risk management, and automated trading are some of the key innovations in this sector. AI-driven analytics are helping companies to understand market trends better and make informed decisions. The use of AI in fintech is not only enhancing security but also increasing efficiency and profitability.", - ], - num_themes=2, - ), - Themes( - themes=[ - Theme( - theme="AI enhances efficiency and accuracy in various industries", - description="AI algorithms are improving processes across healthcare, finance, and more by increasing efficiency and accuracy.", + ThemesPersonasInput( + themes=["Empathy", "Inclusivity", "Remote work"], + personas=[ + Persona( + name="HR Manager", + role_description="Focuses on inclusivity and employee support.", ), - Theme( - theme="AI-powered tools improve decision-making and outcomes", - description="AI applications in diagnostic tools, personalized treatment plans, and fintech analytics are enhancing decision-making and outcomes.", + Persona( + name="Remote Team Lead", + role_description="Manages remote team communication.", ), - ] - ), - ) - ] - - def process_output(self, output: Themes, input: Summaries) -> Themes: - if len(output.themes) < input.num_themes: - # fill the rest with empty strings - output.themes.extend( - [Theme(theme="none", description="")] - * (input.num_themes - len(output.themes)) - ) - return output - - -class ThemeAndContext(BaseModel): - theme: str - context: str - - -class AbstractQueryFromTheme(PydanticPrompt[ThemeAndContext, StringIO]): - input_model = ThemeAndContext - output_model = StringIO - instruction = "Generate an abstract conceptual question using the given theme that can be answered from the information in the provided context." - examples = [ - ( - ThemeAndContext( - theme="AI enhances efficiency and accuracy in various industries", - context="AI is transforming various industries by improving efficiency and accuracy. For instance, in manufacturing, AI-powered robots automate repetitive tasks with high precision, reducing errors and increasing productivity. In healthcare, AI algorithms analyze medical images and patient data to provide accurate diagnoses and personalized treatment plans. Financial services leverage AI for fraud detection and risk management, ensuring quicker and more reliable decision-making. Overall, AI's ability to process vast amounts of data and learn from it enables industries to optimize operations, reduce costs, and deliver better products and services.", - ), - StringIO( - text="How does AI enhance efficiency and accuracy in various industries?" - ), - ) - ] - - -class Feedback(BaseModel): - independence: int - clear_intent: int - - -class CriticUserInput(PydanticPrompt[StringIO, Feedback]): - input_model = StringIO - output_model = Feedback - instruction = "Critique the synthetically generated question based on the following rubrics. Provide a score for each rubric: Independence and Clear Intent. Scores are given as low (0), medium (1), or high (2)." - examples = [ - ( - StringIO( - text="How does AI enhance efficiency and accuracy in various industries?" - ), - Feedback(independence=2, clear_intent=2), - ), - ( - StringIO(text="Explain the benefits of AI."), - Feedback(independence=1, clear_intent=1), - ), - ( - StringIO(text="How does AI?"), - Feedback(independence=0, clear_intent=0), - ), - ] - - -class QueryWithStyleAndLength(BaseModel): - query: str - style: QueryStyle - length: QueryLength - - -EXAMPLES_FOR_USER_INPUT_MODIFICATION = [ - # Short Length Examples - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.MISSPELLED, - length=QueryLength.SHORT, - ), - StringIO(text="How do enrgy storag solutions compare on efficincy?"), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.PERFECT_GRAMMAR, - length=QueryLength.SHORT, - ), - StringIO(text="How do energy storage solutions compare?"), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.POOR_GRAMMAR, - length=QueryLength.SHORT, - ), - StringIO(text="How do storag solutions compare on efficiency?"), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.WEB_SEARCH_LIKE, - length=QueryLength.SHORT, - ), - StringIO( - text="compare energy storage solutions efficiency cost sustainability" - ), - ), - # Medium Length Examples - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.MISSPELLED, - length=QueryLength.MEDIUM, - ), - StringIO(text="How do enrgy storag solutions compare on efficincy n cost?"), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.PERFECT_GRAMMAR, - length=QueryLength.MEDIUM, - ), - StringIO( - text="How do energy storage solutions compare in efficiency and cost?" - ), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.POOR_GRAMMAR, - length=QueryLength.MEDIUM, - ), - StringIO(text="How energy storag solutions compare on efficiency and cost?"), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.WEB_SEARCH_LIKE, - length=QueryLength.MEDIUM, - ), - StringIO( - text="comparison of energy storage solutions efficiency cost sustainability" - ), - ), - # Long Length Examples - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.MISSPELLED, - length=QueryLength.LONG, - ), - StringIO( - text="How do various enrgy storag solutions compare in terms of efficincy, cost, and sustanbility in rnewable energy systems?" - ), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.PERFECT_GRAMMAR, - length=QueryLength.LONG, - ), - StringIO( - text="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?" - ), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.POOR_GRAMMAR, - length=QueryLength.LONG, - ), - StringIO( - text="How various energy storag solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?" - ), - ), - ( - QueryWithStyleAndLength( - query="How do various energy storage solutions compare in terms of efficiency, cost, and sustainability in renewable energy systems?", - style=QueryStyle.WEB_SEARCH_LIKE, - length=QueryLength.LONG, - ), - StringIO( - text="How do various energy storage solutions compare efficiency cost sustainability renewable energy systems?" - ), - ), -] - - -class ModifyUserInput(PydanticPrompt[QueryWithStyleAndLength, StringIO]): - input_model = QueryWithStyleAndLength - output_model = StringIO - instruction = "Modify the given question in order to fit the given style and length" - examples: t.List[t.Tuple[QueryWithStyleAndLength, StringIO]] = [] - - -def extend_modify_input_prompt( - query_modification_prompt: PydanticPrompt, - style: QueryStyle, - length: QueryLength, -) -> PydanticPrompt: - examples = [ - example - for example in EXAMPLES_FOR_USER_INPUT_MODIFICATION - if example[0].style == style and example[0].length == length - ] - if not examples: - raise ValueError(f"No examples found for style {style} and length {length}") - query_modification_prompt.examples = examples - query_modification_prompt.examples = examples - return query_modification_prompt - - -class QueryAndContext(BaseModel): - query: str - context: str - - -class GenerateReference(PydanticPrompt[QueryAndContext, StringIO]): - input_model = QueryAndContext - output_model = StringIO - instruction = "Answer the following question based on the information provided in the given text." - examples = [ - ( - QueryAndContext( - query="How does AI enhance efficiency and accuracy across different industries?", - context="Advances in artificial intelligence have revolutionized many industries. From healthcare to finance, AI algorithms are making processes more efficient and accurate. Machine learning models are being used to predict diseases, optimize investment strategies, and even recommend personalized content to users. The integration of AI into daily operations is becoming increasingly indispensable for modern businesses.", - ), - StringIO( - text="AI improves efficiency and accuracy across different industries by making processes more efficient and accurate." - ), - ) - ] - - -class KeyphrasesAndNumConcepts(BaseModel): - keyphrases: t.List[str] - num_concepts: int - - -class Concepts(BaseModel): - concepts: t.Dict[str, t.List[str]] - - -class CommonConceptsFromKeyphrases(PydanticPrompt[KeyphrasesAndNumConcepts, Concepts]): - input_model = KeyphrasesAndNumConcepts - output_model = Concepts - instruction = "Identify a list of common concepts from the given list of key phrases for comparing the given theme across reports." - examples = [ - ( - KeyphrasesAndNumConcepts( - keyphrases=[ - "fast charging", - "long battery life", - "OLED display", - "waterproof", ], - num_concepts=4, ), - Concepts( - concepts={ - "Charging": [ - "fast charging", - "long battery life", - "OLED display", - "waterproof", - ], - "Battery Life": [ - "long battery life", - "extended battery", - "durable battery", - "prolonged battery", - ], - "Display": [ - "OLED display", - "HD display", - "AMOLED display", - "retina display", - ], - "Water/Dust Resistance": [ - "waterproof", - "dust resistant", - "splash proof", - "water resistant", - ], - } - ), - ) - ] - - def process_output( - self, output: Concepts, input: KeyphrasesAndNumConcepts - ) -> Concepts: - if len(output.concepts) < input.num_concepts: - # fill the rest with empty strings - output.concepts.update( - { - "Concept" + str(i): [] - for i in range(input.num_concepts - len(output.concepts)) + PersonaThemesMapping( + mapping={ + "HR Manager": ["Inclusivity", "Empathy"], + "Remote Team Lead": ["Remote work", "Empathy"], } - ) - return output - - -class CAQInput(BaseModel): - concept: str - keyphrases: t.List[str] - summaries: t.List[str] - - -class ComparativeAbstractQuery(PydanticPrompt[CAQInput, StringIO]): - input_model = CAQInput - output_model = StringIO - instruction = "Generate an abstract comparative question based on the given concept, keyphrases belonging to that concept, and summaries of reports." - examples = [ - ( - CAQInput( - concept="Battery Life", - keyphrases=[ - "long battery life", - "extended battery", - "durable battery", - "prolonged battery", - ], - summaries=[ - "The device offers a long battery life, capable of lasting up to 24 hours on a single charge.", - "Featuring an extended battery, the product can function for 20 hours with heavy usage.", - "With a durable battery, this model ensures 22 hours of operation under normal conditions.", - "The battery life is prolonged, allowing the gadget to be used for up to 18 hours on one charge.", - ], - ), - StringIO( - text="How do the battery life claims and performance metrics compare across different reports for devices featuring long battery life, extended battery, durable battery, and prolonged battery?" - ), - ) - ] - - -class SpecificQuestionInput(BaseModel): - title: str - keyphrase: str - text: str - - -class SpecificQuery(PydanticPrompt[SpecificQuestionInput, StringIO]): - input_model = SpecificQuestionInput - output_model = StringIO - instruction = "Given the title of a text and a text chunk, along with a keyphrase from the chunk, generate a specific question related to the keyphrase.\n\n" - "1. Read the title and the text chunk.\n" - "2. Identify the context of the keyphrase within the text chunk.\n" - "3. Formulate a question that directly relates to the keyphrase and its context within the chunk.\n" - "4. Ensure the question is clear, specific, and relevant to the keyphrase." - examples = [ - ( - SpecificQuestionInput( - title="The Impact of Artificial Intelligence on Modern Healthcare", - keyphrase="personalized treatment plans", - text="Artificial intelligence (AI) is revolutionizing healthcare by improving diagnostic accuracy and enabling personalized treatment plans. AI algorithms analyze vast amounts of medical data to identify patterns and predict patient outcomes, which enhances the decision-making process for healthcare professionals.", - ), - StringIO( - text="How does AI contribute to the development of personalized treatment plans in healthcare?" ), ) ] diff --git a/src/ragas/testset/synthesizers/single_hop/__init__.py b/src/ragas/testset/synthesizers/single_hop/__init__.py new file mode 100644 index 000000000..36803ad06 --- /dev/null +++ b/src/ragas/testset/synthesizers/single_hop/__init__.py @@ -0,0 +1,3 @@ +from .specific import SingleHopQuerySynthesizer, SingleHopScenario + +__all__ = ["SingleHopQuerySynthesizer", "SingleHopScenario"] diff --git a/src/ragas/testset/synthesizers/single_hop/base.py b/src/ragas/testset/synthesizers/single_hop/base.py new file mode 100644 index 000000000..da8ecf368 --- /dev/null +++ b/src/ragas/testset/synthesizers/single_hop/base.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import logging +import random +import typing as t +from dataclasses import dataclass + +from ragas.dataset_schema import SingleTurnSample +from ragas.prompt import PydanticPrompt +from ragas.testset.graph import Node +from ragas.testset.persona import PersonaList +from ragas.testset.synthesizers.base import ( + BaseScenario, + BaseSynthesizer, + QueryLength, + QueryStyle, + Scenario, +) +from ragas.testset.synthesizers.single_hop.prompts import ( + QueryAnswerGenerationPrompt, + QueryCondition, +) + +if t.TYPE_CHECKING: + from langchain_core.callbacks import Callbacks + +logger = logging.getLogger(__name__) + + +class SingleHopScenario(BaseScenario): + """ + Scenario for multi-hop queries. + + Attributes + ---------- + term: str + The theme of the query. + """ + + term: str + + +@dataclass +class SingleHopQuerySynthesizer(BaseSynthesizer[Scenario]): + + generate_query_reference_prompt: PydanticPrompt = QueryAnswerGenerationPrompt() + + def prepare_combinations( + self, + node: Node, + terms: t.List[str], + persona_list: PersonaList, + persona_concepts, + ) -> t.List[t.Dict[str, t.Any]]: + + sample = {"terms": terms, "node": node} + valid_personas = [] + for persona, concepts in persona_concepts.mapping.items(): + concepts = [concept.lower() for concept in concepts] + if any(term.lower() in concepts for term in terms): + if persona_list[persona]: + valid_personas.append(persona_list[persona]) + sample["personas"] = valid_personas + sample["styles"] = list(QueryStyle) + sample["lengths"] = list(QueryLength) + + return [sample] + + def sample_combinations(self, data: t.List[t.Dict[str, t.Any]], num_samples): + + selected_samples = [] + node_term_set = set() + + all_combinations = [] + for entry in data: + node = entry["node"] + for term in entry["terms"]: + for persona in entry["personas"]: + for style in entry["styles"]: + for length in entry["lengths"]: + all_combinations.append( + { + "term": term, + "node": node, + "persona": persona, + "style": style, + "length": length, + } + ) + + random.shuffle(all_combinations) + for sample in all_combinations: + if len(selected_samples) >= num_samples: + break + + term = sample["term"] + node = sample["node"] + + if (node, term) not in node_term_set: + selected_samples.append(sample) + node_term_set.add((node, term)) + elif len(selected_samples) < num_samples: + selected_samples.append(sample) + + return [self.convert_to_scenario(sample) for sample in selected_samples] + + def convert_to_scenario(self, data: t.Dict[str, t.Any]) -> SingleHopScenario: + + return SingleHopScenario( + term=data["term"], + nodes=[data["node"]], + persona=data["persona"], + style=data["style"], + length=data["length"], + ) + + async def _generate_sample( + self, scenario: SingleHopScenario, callbacks: Callbacks + ) -> SingleTurnSample: + + reference_context = self.make_contexts(scenario) + prompt_input = QueryCondition( + persona=scenario.persona, + term=scenario.term, + context=reference_context, + query_length=scenario.length.name, + query_style=scenario.style.name, + ) + response = await self.generate_query_reference_prompt.generate( + data=prompt_input, llm=self.llm, callbacks=callbacks + ) + return SingleTurnSample( + user_input=response.query, + reference=response.answer, + reference_contexts=reference_context, + ) + + def make_contexts(self, scenario: SingleHopScenario) -> t.List[str]: + + contexts = [] + for node in scenario.nodes: + context = f"{node.id}" + "\n\n" + node.properties.get("page_content", "") + contexts.append(context) + + return contexts diff --git a/src/ragas/testset/synthesizers/single_hop/prompts.py b/src/ragas/testset/synthesizers/single_hop/prompts.py new file mode 100644 index 000000000..e01e2b65e --- /dev/null +++ b/src/ragas/testset/synthesizers/single_hop/prompts.py @@ -0,0 +1,35 @@ +import typing as t + +from pydantic import BaseModel + +from ragas.prompt import PydanticPrompt +from ragas.testset.persona import Persona + + +class QueryCondition(BaseModel): + persona: Persona + term: str + query_style: str + query_length: str + context: t.List[str] + + +class GeneratedQueryAnswer(BaseModel): + query: str + answer: str + + +class QueryAnswerGenerationPrompt(PydanticPrompt[QueryCondition, GeneratedQueryAnswer]): + instruction: str = ( + "Generate a query and answer based on the specified conditions (persona, term, style, length) " + "and the provided context. Ensure the answer is entirely faithful to the context, using only the information " + "directly from the provided context." + "### Instructions:\n" + "1. **Generate a Query**: Based on the context, persona, term, style, and length, create a question " + "that aligns with the persona's perspective and incorporates the term.\n" + "2. **Generate an Answer**: Using only the content from the provided context, construct a detailed answer " + "to the query. Do not add any information not included in or inferable from the context.\n" + "### Example Outputs:\n\n" + ) + input_model: t.Type[QueryCondition] = QueryCondition + output_model: t.Type[GeneratedQueryAnswer] = GeneratedQueryAnswer diff --git a/src/ragas/testset/synthesizers/single_hop/specific.py b/src/ragas/testset/synthesizers/single_hop/specific.py new file mode 100644 index 000000000..e3a795501 --- /dev/null +++ b/src/ragas/testset/synthesizers/single_hop/specific.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import logging +import typing as t +from dataclasses import dataclass + +import numpy as np + +from ragas.prompt import PydanticPrompt +from ragas.testset.graph import KnowledgeGraph +from ragas.testset.persona import Persona, PersonaList +from ragas.testset.synthesizers.base import BaseScenario +from ragas.testset.synthesizers.prompts import ( + ThemesPersonasInput, + ThemesPersonasMatchingPrompt, +) + +from .base import SingleHopQuerySynthesizer + +if t.TYPE_CHECKING: + from langchain_core.callbacks import Callbacks + +logger = logging.getLogger(__name__) + + +class SingleHopScenario(BaseScenario): + """ + Scenario for multi-hop queries. + + Attributes + ---------- + term: str + The theme of the query. + """ + + term: str + + +@dataclass +class SingleHopSpecificQuerySynthesizer(SingleHopQuerySynthesizer): + + name: str = "single_hop_specifc_query_synthesizer" + theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt() + + async def _generate_scenarios( + self, + n: int, + knowledge_graph: KnowledgeGraph, + persona_list: t.List[Persona], + callbacks: Callbacks, + ) -> t.List[SingleHopScenario]: + """ + Generates a list of scenarios on type SingleHopSpecificQuerySynthesizer + Steps to generate scenarios: + 1. Find nodes with CHUNK type and entities property + 2. Calculate the number of samples that should be created per node to get n samples in total + 3. For each node + a. Find the entities associated with the node + b. Map personas to the entities to create query + c. Prepare all possible combinations of (node, entities, personas, style, length) as base scenarios + d. Sample num_sample_per_node (step 2) scenarios from base scenarios + 4. Return the list of scenarios + """ + + property_name = "entities" + nodes = [] + for node in knowledge_graph.nodes: + if ( + node.type.name == "CHUNK" + and node.get_property(property_name) is not None + ): + nodes.append(node) + + samples_per_node = int(np.ceil(n / len(nodes))) + + scenarios = [] + for node in nodes: + if len(scenarios) >= n: + break + themes = node.get_property(property_name) + prompt_input = ThemesPersonasInput(themes=themes, personas=persona_list) + persona_concepts = await self.theme_persona_matching_prompt.generate( + data=prompt_input, llm=self.llm, callbacks=callbacks + ) + base_scenarios = self.prepare_combinations( + node, themes, PersonaList(personas=persona_list), persona_concepts + ) + scenarios.extend(self.sample_combinations(base_scenarios, samples_per_node)) + + return scenarios diff --git a/src/ragas/testset/synthesizers/specific_query.py b/src/ragas/testset/synthesizers/specific_query.py deleted file mode 100644 index 940b3ec02..000000000 --- a/src/ragas/testset/synthesizers/specific_query.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -import random -import typing as t -from dataclasses import dataclass, field - -from ragas.dataset_schema import SingleTurnSample -from ragas.prompt import PydanticPrompt -from ragas.testset.graph import KnowledgeGraph, NodeType - -from .base import BaseScenario, QueryLength, QueryStyle -from .base_query import QuerySynthesizer -from .prompts import SpecificQuery, SpecificQuestionInput - -if t.TYPE_CHECKING: - from langchain_core.callbacks import Callbacks - - -class SpecificQueryScenario(BaseScenario): - """ - Represents a scenario for generating specific queries. - Also inherits attributes from [BaseScenario][ragas.testset.synthesizers.base.BaseScenario]. - - Attributes - ---------- - keyphrase : str - The keyphrase of the specific query scenario. - """ - - keyphrase: str - - -@dataclass -class SpecificQuerySynthesizer(QuerySynthesizer): - """ - Synthesizes specific queries by choosing specific chunks and generating a - keyphrase from them and then generating queries based on that. - - Attributes - ---------- - generate_query_prompt : PydanticPrompt - The prompt used for generating the query. - """ - - generate_query_prompt: PydanticPrompt = field(default_factory=SpecificQuery) - - async def _generate_scenarios( - self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks - ) -> t.List[SpecificQueryScenario]: - # filter out nodes that have keyphrases - nodes = [] - for node in knowledge_graph.nodes: - if ( - node.type == NodeType.CHUNK - and node.get_property("keyphrases") is not None - and node.get_property("keyphrases") != [] - ): - nodes.append(node) - - # sample nodes and keyphrases - sampled_nodes = random.choices(nodes, k=n) - sampled_keyphrases = [] - for node in sampled_nodes: - sampled_keyphrases_per_node = set() - keyphrases = node.get_property("keyphrases") - unused_keyphrases = list(set(keyphrases) - sampled_keyphrases_per_node) - if unused_keyphrases: - sampled_keyphrases.append(random.choice(unused_keyphrases)) - else: - sampled_keyphrases.append(random.choice(keyphrases)) - - # sample query styles and lengths - query_styles = random.choices(list(QueryStyle), k=n) - query_lengths = random.choices(list(QueryLength), k=n) - - scenarios = [] - for node, keyphrase, style, length in zip( - sampled_nodes, sampled_keyphrases, query_styles, query_lengths - ): - scenarios.append( - SpecificQueryScenario( - nodes=[node], keyphrase=keyphrase, style=style, length=length - ) - ) - return scenarios - - async def _generate_sample( - self, scenario: SpecificQueryScenario, callbacks: t.Optional[Callbacks] = None - ) -> SingleTurnSample: - query = await self.generate_query_prompt.generate( - data=SpecificQuestionInput( - title=scenario.nodes[0].get_property("title") or "", - keyphrase=scenario.keyphrase, - text=scenario.nodes[0].get_property("page_content") or "", - ), - llm=self.llm, - callbacks=callbacks, - ) - - query_text = query.text - if not await self.critic_query(query_text, callbacks): - query_text = await self.modify_query(query_text, scenario, callbacks) - - reference = await self.generate_reference(query_text, scenario, callbacks) - - reference_contexts = [] - for node in scenario.nodes: - if node.get_property("page_content") is not None: - reference_contexts.append(node.get_property("page_content")) - - return SingleTurnSample( - user_input=query_text, - reference=reference, - reference_contexts=reference_contexts, - ) diff --git a/src/ragas/testset/transforms/base.py b/src/ragas/testset/transforms/base.py index 13ce7249a..1e95fd8e1 100644 --- a/src/ragas/testset/transforms/base.py +++ b/src/ragas/testset/transforms/base.py @@ -22,6 +22,10 @@ class BaseGraphTransformation(ABC): name: str = "" + filter_nodes: t.Callable[[Node], bool] = field( + default_factory=lambda: default_filter + ) + def __post_init__(self): if not self.name: self.name = self.__class__.__name__ @@ -59,7 +63,15 @@ def filter(self, kg: KnowledgeGraph) -> KnowledgeGraph: KnowledgeGraph The filtered knowledge graph. """ - return kg + + return KnowledgeGraph( + nodes=[node for node in kg.nodes if self.filter_nodes(node)], + relationships=[ + rel + for rel in kg.relationships + if rel.source in kg.nodes and rel.target in kg.nodes + ], + ) @abstractmethod def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]: @@ -95,10 +107,6 @@ class Extractor(BaseGraphTransformation): Abstract method to extract a specific property from a node. """ - filter_nodes: t.Callable[[Node], bool] = field( - default_factory=lambda: default_filter - ) - async def transform( self, kg: KnowledgeGraph ) -> t.List[t.Tuple[Node, t.Tuple[str, t.Any]]]: @@ -175,16 +183,6 @@ async def apply_extract(node: Node): filtered = self.filter(kg) return [apply_extract(node) for node in filtered.nodes] - def filter(self, kg: KnowledgeGraph) -> KnowledgeGraph: - return KnowledgeGraph( - nodes=[node for node in kg.nodes if self.filter_nodes(node)], - relationships=[ - rel - for rel in kg.relationships - if rel.source in kg.nodes and rel.target in kg.nodes - ], - ) - @dataclass class LLMBasedExtractor(Extractor, PromptMixin): diff --git a/src/ragas/testset/transforms/default.py b/src/ragas/testset/transforms/default.py index e94790c36..db58045b5 100644 --- a/src/ragas/testset/transforms/default.py +++ b/src/ragas/testset/transforms/default.py @@ -2,19 +2,21 @@ import typing as t -from .engine import Parallel -from .extractors import ( +from ragas.testset.graph import NodeType +from ragas.testset.transforms.extractors import ( EmbeddingExtractor, HeadlinesExtractor, - KeyphrasesExtractor, SummaryExtractor, - TitleExtractor, ) -from .relationship_builders.cosine import ( +from ragas.testset.transforms.extractors.llm_based import NERExtractor, ThemesExtractor +from ragas.testset.transforms.relationship_builders import ( CosineSimilarityBuilder, - SummaryCosineSimilarityBuilder, + OverlapScoreBuilder, ) -from .splitters import HeadlineSplitter +from ragas.testset.transforms.splitters import HeadlineSplitter +from ragas.utils import num_tokens_from_string + +from .engine import Parallel if t.TYPE_CHECKING: from ragas.embeddings.base import BaseRagasEmbeddings @@ -35,13 +37,7 @@ def default_transforms( headlines, and embeddings, as well as building similarity relationships between nodes. - The transforms are applied in the following order: - 1. Parallel extraction of summaries and headlines - 2. Embedding of summaries for document nodes - 3. Splitting of headlines - 4. Parallel extraction of embeddings, keyphrases, and titles - 5. Building cosine similarity relationships between nodes - 6. Building cosine similarity relationships between summaries + Returns ------- @@ -49,32 +45,49 @@ def default_transforms( A list of transformation steps to be applied to the knowledge graph. """ - from ragas.testset.graph import NodeType - # define the transforms - summary_extractor = SummaryExtractor(llm=llm) - keyphrase_extractor = KeyphrasesExtractor(llm=llm) - title_extractor = TitleExtractor(llm=llm) headline_extractor = HeadlinesExtractor(llm=llm) - embedding_extractor = EmbeddingExtractor(embedding_model=embedding_model) - headline_splitter = HeadlineSplitter() - cosine_sim_builder = CosineSimilarityBuilder(threshold=0.8) - summary_embedder = EmbeddingExtractor( - name="summary_embedder", + splitter = HeadlineSplitter(min_tokens=500) + + def summary_filter(node): + return ( + node.type == NodeType.DOCUMENT + and num_tokens_from_string(node.properties["page_content"]) > 500 + ) + + summary_extractor = SummaryExtractor( + llm=llm, filter_nodes=lambda node: summary_filter(node) + ) + + theme_extractor = ThemesExtractor(llm=llm) + ner_extractor = NERExtractor( + llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK + ) + + summary_emb_extractor = EmbeddingExtractor( + embedding_model=embedding_model, property_name="summary_embedding", embed_property_name="summary", - filter_nodes=lambda node: True if node.type == NodeType.DOCUMENT else False, - embedding_model=embedding_model, + filter_nodes=lambda node: summary_filter(node), + ) + + cosine_sim_builder = CosineSimilarityBuilder( + property_name="summary_embedding", + new_property_name="summary_similarity", + threshold=0.7, + filter_nodes=lambda node: summary_filter(node), + ) + + ner_overlap_sim = OverlapScoreBuilder( + threshold=0.01, filter_nodes=lambda node: node.type == NodeType.CHUNK ) - summary_cosine_sim_builder = SummaryCosineSimilarityBuilder(threshold=0.6) - # specify the transforms and their order to be applied transforms = [ - Parallel(summary_extractor, headline_extractor), - summary_embedder, - headline_splitter, - Parallel(embedding_extractor, keyphrase_extractor, title_extractor), - cosine_sim_builder, - summary_cosine_sim_builder, + headline_extractor, + splitter, + Parallel(summary_extractor, theme_extractor, ner_extractor), + summary_emb_extractor, + Parallel(cosine_sim_builder, ner_overlap_sim), ] + return transforms diff --git a/src/ragas/testset/transforms/extractors/llm_based.py b/src/ragas/testset/transforms/extractors/llm_based.py index 613c5b4d8..78b7c5f98 100644 --- a/src/ragas/testset/transforms/extractors/llm_based.py +++ b/src/ragas/testset/transforms/extractors/llm_based.py @@ -66,57 +66,46 @@ class TitleExtractorPrompt(PydanticPrompt[StringIO, StringIO]): class Headlines(BaseModel): - headlines: t.Dict[str, t.List[str]] + headlines: t.List[str] class HeadlinesExtractorPrompt(PydanticPrompt[StringIO, Headlines]): - instruction: str = "Extract the headlines from the given text." + instruction: str = "Extract only level 2 headings from the given text." + input_model: t.Type[StringIO] = StringIO output_model: t.Type[Headlines] = Headlines examples: t.List[t.Tuple[StringIO, Headlines]] = [ ( StringIO( text="""\ -Some Title -1. Introduction and Related Work - -1.1 Conditional Computation -Exploiting scale in both training data and model size has been central to the success of deep learning... -1.2 Our Approach: The Sparsely-Gated Mixture-of-Experts Layer -Our approach to conditional computation is to introduce a new type of general purpose neural network component... -1.3 Related Work on Mixtures of Experts -Since its introduction more than two decades ago (Jacobs et al., 1991; Jordan & Jacobs, 1994), the mixture-of-experts approach.. - -2. The Sparsely-Gated Mixture-of-Experts Layer -2.1 Architecture -The sparsely-gated mixture-of-experts layer is a feedforward neural network layer that consists of a number of expert networks and a single gating network... -""", + Introduction + Overview of the topic... + + Main Concepts + Explanation of core ideas... + + Detailed Analysis + Techniques and methods for analysis... + + Subsection: Specialized Techniques + Further details on specialized techniques... + + Future Directions + Insights into upcoming trends... + + Conclusion + Final remarks and summary. + """, ), Headlines( - headlines={ - "1. Introduction and Related Work": [ - "1.1 Conditional Computation", - "1.2 Our Approach: The Sparsely-Gated Mixture-of-Experts Layer", - "1.3 Related Work on Mixtures of Experts", - ], - "2. The Sparsely-Gated Mixture-of-Experts Layer": [ - "2.1 Architecture" - ], - }, + headlines=["Main Concepts", "Detailed Analysis", "Future Directions"] ), ), ] -class NamedEntities(BaseModel): - ORG: t.List[str] - LOC: t.List[str] - PER: t.List[str] - MISC: t.List[str] - - class NEROutput(BaseModel): - entities: NamedEntities + entities: t.List[str] class NERPrompt(PydanticPrompt[StringIO, NEROutput]): @@ -126,17 +115,21 @@ class NERPrompt(PydanticPrompt[StringIO, NEROutput]): examples: t.List[t.Tuple[StringIO, NEROutput]] = [ ( StringIO( - text="Artificial intelligence\n\nArtificial intelligence is transforming various industries by automating tasks that previously required human intelligence. From healthcare to finance, AI is being used to analyze vast amounts of data quickly and accurately. This technology is also driving innovations in areas like self-driving cars and personalized recommendations." + text="""Elon Musk, the CEO of Tesla and SpaceX, announced plans to expand operations to new locations in Europe and Asia. + This expansion is expected to create thousands of jobs, particularly in cities like Berlin and Shanghai.""" ), NEROutput( - entities=NamedEntities( - ORG=["Artificial intelligence"], - LOC=["healthcare", "finance"], - PER=[], - MISC=["self-driving cars", "personalized recommendations"], - ) + entities=[ + "Elon Musk", + "Tesla", + "SpaceX", + "Europe", + "Asia", + "Berlin", + "Shanghai", + ] ), - ) + ), ] @@ -254,12 +247,12 @@ class NERExtractor(LLMBasedExtractor): property_name: str = "entities" prompt: NERPrompt = NERPrompt() - async def extract(self, node: Node) -> t.Tuple[str, t.Dict[str, t.List[str]]]: + async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]: node_text = node.get_property("page_content") if node_text is None: - return self.property_name, {} + return self.property_name, [] result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) - return self.property_name, result.entities.model_dump() + return self.property_name, result.entities class TopicDescription(BaseModel): @@ -298,7 +291,7 @@ class TopicDescriptionExtractor(LLMBasedExtractor): """ property_name: str = "topic_description" - prompt: TopicDescriptionPrompt = TopicDescriptionPrompt() + prompt: PydanticPrompt = TopicDescriptionPrompt() async def extract(self, node: Node) -> t.Tuple[str, t.Any]: node_text = node.get_property("page_content") @@ -306,3 +299,54 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]: return self.property_name, None result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) return self.property_name, result.description + + +class ThemesAndConcepts(BaseModel): + output: t.List[str] + + +class ThemesAndConceptsExtractorPrompt(PydanticPrompt[StringIO, ThemesAndConcepts]): + instruction: str = "Extract the main themes and concepts from the given text." + input_model: t.Type[StringIO] = StringIO + output_model: t.Type[ThemesAndConcepts] = ThemesAndConcepts + examples: t.List[t.Tuple[StringIO, ThemesAndConcepts]] = [ + ( + StringIO( + text="Artificial intelligence is transforming industries by automating tasks requiring human intelligence. AI analyzes vast data quickly and accurately, driving innovations like self-driving cars and personalized recommendations." + ), + ThemesAndConcepts( + output=[ + "Artificial intelligence", + "Automation", + "Data analysis", + "Innovation", + "Self-driving cars", + "Personalized recommendations", + ] + ), + ) + ] + + +@dataclass +class ThemesExtractor(LLMBasedExtractor): + """ + Extracts themes from the given text. + + Attributes + ---------- + property_name : str + The name of the property to extract. Defaults to "themes". + prompt : ThemesExtractorPrompt + The prompt used for extraction. + """ + + property_name: str = "themes" + prompt: ThemesAndConceptsExtractorPrompt = ThemesAndConceptsExtractorPrompt() + + async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]: + node_text = node.get_property("page_content") + if node_text is None: + return self.property_name, [] + result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) + return self.property_name, result.output diff --git a/src/ragas/testset/transforms/relationship_builders/__init__.py b/src/ragas/testset/transforms/relationship_builders/__init__.py index 40303c2cf..9eff90a84 100644 --- a/src/ragas/testset/transforms/relationship_builders/__init__.py +++ b/src/ragas/testset/transforms/relationship_builders/__init__.py @@ -1,3 +1,4 @@ from .cosine import CosineSimilarityBuilder +from .traditional import JaccardSimilarityBuilder, OverlapScoreBuilder -__all__ = ["CosineSimilarityBuilder"] +__all__ = ["CosineSimilarityBuilder", "OverlapScoreBuilder", "JaccardSimilarityBuilder"] diff --git a/src/ragas/testset/transforms/relationship_builders/cosine.py b/src/ragas/testset/transforms/relationship_builders/cosine.py index 3aade9635..8a37081bb 100644 --- a/src/ragas/testset/transforms/relationship_builders/cosine.py +++ b/src/ragas/testset/transforms/relationship_builders/cosine.py @@ -7,52 +7,6 @@ from ragas.testset.transforms.base import RelationshipBuilder -@dataclass -class JaccardSimilarityBuilder(RelationshipBuilder): - property_name: str = "entities" - key_name: t.Optional[str] = None - new_property_name: str = "jaccard_similarity" - threshold: float = 0.5 - - def _jaccard_similarity(self, set1: t.Set[str], set2: t.Set[str]) -> float: - intersection = len(set1.intersection(set2)) - union = len(set1.union(set2)) - return intersection / union if union > 0 else 0.0 - - async def transform(self, kg: KnowledgeGraph) -> t.List[Relationship]: - if self.property_name is None: - self.property_name - - similar_pairs = [] - for i, node1 in enumerate(kg.nodes): - for j, node2 in enumerate(kg.nodes): - if i >= j: - continue - items1 = node1.get_property(self.property_name) - items2 = node2.get_property(self.property_name) - if items1 is None or items2 is None: - raise ValueError( - f"Node {node1.id} or {node2.id} has no {self.property_name}" - ) - if self.key_name is not None: - items1 = items1.get(self.key_name, []) - items2 = items2.get(self.key_name, []) - similarity = self._jaccard_similarity(set(items1), set(items2)) - if similarity >= self.threshold: - similar_pairs.append((i, j, similarity)) - - return [ - Relationship( - source=kg.nodes[i], - target=kg.nodes[j], - type="jaccard_similarity", - properties={self.new_property_name: similarity_float}, - bidirectional=True, - ) - for i, j, similarity_float in similar_pairs - ] - - @dataclass class CosineSimilarityBuilder(RelationshipBuilder): property_name: str = "embedding" diff --git a/src/ragas/testset/transforms/relationship_builders/traditional.py b/src/ragas/testset/transforms/relationship_builders/traditional.py new file mode 100644 index 000000000..d7ae6ebb0 --- /dev/null +++ b/src/ragas/testset/transforms/relationship_builders/traditional.py @@ -0,0 +1,155 @@ +import typing as t +from collections import Counter +from dataclasses import dataclass + +from ragas.metrics._string import DistanceMeasure +from ragas.testset.graph import KnowledgeGraph, Node, Relationship +from ragas.testset.transforms.base import RelationshipBuilder + + +@dataclass +class JaccardSimilarityBuilder(RelationshipBuilder): + property_name: str = "entities" + key_name: t.Optional[str] = None + new_property_name: str = "jaccard_similarity" + threshold: float = 0.5 + + def _jaccard_similarity(self, set1: t.Set[str], set2: t.Set[str]) -> float: + intersection = len(set1.intersection(set2)) + union = len(set1.union(set2)) + return intersection / union if union > 0 else 0.0 + + async def transform(self, kg: KnowledgeGraph) -> t.List[Relationship]: + if self.property_name is None: + self.property_name + + similar_pairs = [] + for i, node1 in enumerate(kg.nodes): + for j, node2 in enumerate(kg.nodes): + if i >= j: + continue + items1 = node1.get_property(self.property_name) + items2 = node2.get_property(self.property_name) + if items1 is None or items2 is None: + raise ValueError( + f"Node {node1.id} or {node2.id} has no {self.property_name}" + ) + if self.key_name is not None: + items1 = items1.get(self.key_name, []) + items2 = items2.get(self.key_name, []) + similarity = self._jaccard_similarity(set(items1), set(items2)) + if similarity >= self.threshold: + similar_pairs.append((i, j, similarity)) + + return [ + Relationship( + source=kg.nodes[i], + target=kg.nodes[j], + type="jaccard_similarity", + properties={self.new_property_name: similarity_float}, + bidirectional=True, + ) + for i, j, similarity_float in similar_pairs + ] + + +@dataclass +class OverlapScoreBuilder(RelationshipBuilder): + property_name: str = "entities" + key_name: t.Optional[str] = None + new_property_name: str = "overlap_score" + distance_measure: DistanceMeasure = DistanceMeasure.JARO_WINKLER + distance_threshold: float = 0.9 + threshold: float = 0.5 + + def __post_init__(self): + try: + from rapidfuzz import distance + except ImportError: + raise ImportError( + "rapidfuzz is required for string distance. Please install it using `pip install rapidfuzz`" + ) + + self.distance_measure_map = { + DistanceMeasure.LEVENSHTEIN: distance.Levenshtein, + DistanceMeasure.HAMMING: distance.Hamming, + DistanceMeasure.JARO: distance.Jaro, + DistanceMeasure.JARO_WINKLER: distance.JaroWinkler, + } + + def _overlap_score(self, overlaps: t.List[bool]) -> float: + + return sum(overlaps) / len(overlaps) if len(overlaps) > 0 else 0.0 + + def _get_noisy_items( + self, nodes: t.List[Node], property_name: str, percent_cut_off: float = 0.05 + ) -> t.List[str]: + + all_items = [] + for node in nodes: + items = node.get_property(property_name) + if items is not None: + if isinstance(items, str): + all_items.append(items) + elif isinstance(items, list): + all_items.extend(items) + else: + pass + + num_unique_items = len(set(all_items)) + num_noisy_items = max(1, int(num_unique_items * percent_cut_off)) + noisy_list = list(dict(Counter(all_items).most_common()).keys())[ + :num_noisy_items + ] + return noisy_list + + async def transform(self, kg: KnowledgeGraph) -> t.List[Relationship]: + if self.property_name is None: + self.property_name + + distance_measure = self.distance_measure_map[self.distance_measure] + noisy_items = self._get_noisy_items(kg.nodes, self.property_name) + relationships = [] + for i, node_x in enumerate(kg.nodes): + for j, node_y in enumerate(kg.nodes): + if i >= j: + continue + node_x_items = node_x.get_property(self.property_name) + node_y_items = node_y.get_property(self.property_name) + if node_x_items is None or node_y_items is None: + raise ValueError( + f"Node {node_x.id} or {node_y.id} has no {self.property_name}" + ) + if self.key_name is not None: + node_x_items = node_x_items.get(self.key_name, []) + node_y_items = node_y_items.get(self.key_name, []) + + overlaps = [] + overlapped_items = [] + for x in node_x_items: + if x not in noisy_items: + for y in node_y_items: + if y not in noisy_items: + similarity = 1 - distance_measure.distance( + x.lower(), y.lower() + ) + verdict = similarity >= self.distance_threshold + overlaps.append(verdict) + if verdict: + overlapped_items.append((x, y)) + + similarity = self._overlap_score(overlaps) + if similarity >= self.threshold: + relationships.append( + Relationship( + source=node_x, + target=node_y, + type=f"{self.property_name}_overlap", + properties={ + f"{self.property_name}_{self.new_property_name}": similarity, + "overlapped_items": overlapped_items, + }, + ) + ) + + return relationships diff --git a/src/ragas/testset/transforms/splitters/headline.py b/src/ragas/testset/transforms/splitters/headline.py index 4a8552679..3dae5e763 100644 --- a/src/ragas/testset/transforms/splitters/headline.py +++ b/src/ragas/testset/transforms/splitters/headline.py @@ -3,10 +3,13 @@ from ragas.testset.graph import Node, NodeType, Relationship from ragas.testset.transforms.base import Splitter +from ragas.utils import num_tokens_from_string @dataclass class HeadlineSplitter(Splitter): + min_tokens: int = 300 + async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]: text = node.get_property("page_content") if text is None: @@ -17,11 +20,30 @@ async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]] raise ValueError("'headlines' property not found in this node") # create the chunks for the different sections - indices = [] + indices = [0] for headline in headlines: - indices.append(text.find(headline)) + index = text.find(headline) + if index != -1: + indices.append(index) indices.append(len(text)) chunks = [text[indices[i] : indices[i + 1]] for i in range(len(indices) - 1)] + # merge chunks if their length is less than 300 tokens + merged_chunks = [] + current_chunk = chunks[0] + + for next_chunk in chunks[1:]: + if num_tokens_from_string(current_chunk) < self.min_tokens: + current_chunk = "\n\n".join([current_chunk, next_chunk]) + else: + merged_chunks.append(current_chunk) + current_chunk = next_chunk + + merged_chunks.append(current_chunk) + chunks = merged_chunks + + # if there was no headline, return the original node + if len(chunks) == 1: + return [node], [] # create the nodes nodes = [ diff --git a/src/ragas/utils.py b/src/ragas/utils.py index 7f1a42037..33184fb90 100644 --- a/src/ragas/utils.py +++ b/src/ragas/utils.py @@ -9,6 +9,7 @@ from functools import lru_cache import numpy as np +import tiktoken from datasets import Dataset from pysbd.languages import LANGUAGE_CODES @@ -226,6 +227,12 @@ def camel_to_snake(name): return pattern.sub("_", name).lower() +def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: + """Returns the number of tokens in a text string.""" + encoding = tiktoken.get_encoding(encoding_name) + num_tokens = len(encoding.encode(string)) + return num_tokens + def batched(iterable: t.Iterable, n: int) -> t.Iterator[t.Tuple]: """Batch data from the iterable into tuples of length n. The last batch may be shorter than n.""" # batched('ABCDEFG', 3) → ABC DEF G diff --git a/tests/unit/prompt/test_prompt_mixin.py b/tests/unit/prompt/test_prompt_mixin.py index c990e16a4..c4d53d8f5 100644 --- a/tests/unit/prompt/test_prompt_mixin.py +++ b/tests/unit/prompt/test_prompt_mixin.py @@ -1,10 +1,10 @@ import pytest -from ragas.testset.synthesizers import AbstractQuerySynthesizer +from ragas.testset.synthesizers.multi_hop import MultiHopAbstractQuerySynthesizer def test_prompt_save_load(tmp_path, fake_llm): - synth = AbstractQuerySynthesizer(llm=fake_llm) + synth = MultiHopAbstractQuerySynthesizer(llm=fake_llm) synth_prompts = synth.get_prompts() synth.save_prompts(tmp_path) loaded_prompts = synth.load_prompts(tmp_path) @@ -16,7 +16,7 @@ def test_prompt_save_load(tmp_path, fake_llm): @pytest.mark.asyncio async def test_prompt_save_adapt_load(tmp_path, fake_llm): - synth = AbstractQuerySynthesizer(llm=fake_llm) + synth = MultiHopAbstractQuerySynthesizer(llm=fake_llm) # patch adapt_prompts async def adapt_prompts_patched(self, language, llm): diff --git a/tests/unit/test_analytics.py b/tests/unit/test_analytics.py index 98cc44d13..4233ea7d3 100644 --- a/tests/unit/test_analytics.py +++ b/tests/unit/test_analytics.py @@ -127,22 +127,22 @@ def test_testset_generation_tracking(monkeypatch): testset_event_payload = TestsetGenerationEvent( event_type="testset_generation", - evolution_names=[e.__class__.__name__.lower() for e, _ in distributions], + evolution_names=[e.name for e, _ in distributions], evolution_percentages=[p for _, p in distributions], num_rows=10, language="english", ) assert testset_event_payload.model_dump()["evolution_names"] == [ - "abstractquerysynthesizer", - "comparativeabstractquerysynthesizer", - "specificquerysynthesizer", + "single_hop_specifc_query_synthesizer", + "multi_hop_abstract_query_synthesizer", + "multi_hop_specific_query_synthesizer", ] assert testset_event_payload.model_dump()["evolution_percentages"] == [ + 0.5, 0.25, 0.25, - 0.5, ] # just in the case you actually want to check if tracking is working in the diff --git a/tests/unit/test_prompt.py b/tests/unit/test_prompt.py index 8be1dc867..3d550a628 100644 --- a/tests/unit/test_prompt.py +++ b/tests/unit/test_prompt.py @@ -121,9 +121,9 @@ class Prompt(PydanticPrompt[StringIO, StringIO]): def test_prompt_hash_in_ragas(fake_llm): # check with a prompt inside ragas - from ragas.testset.synthesizers import AbstractQuerySynthesizer + from ragas.testset.synthesizers.multi_hop import MultiHopAbstractQuerySynthesizer - synthesizer = AbstractQuerySynthesizer(llm=fake_llm) + synthesizer = MultiHopAbstractQuerySynthesizer(llm=fake_llm) prompts = synthesizer.get_prompts() for prompt in prompts.values(): assert hash(prompt) == hash(prompt) @@ -179,12 +179,12 @@ class Prompt(PydanticPrompt[StringIO, StringIO]): def test_save_existing_prompt(tmp_path): - from ragas.testset.synthesizers.prompts import CommonThemeFromSummariesPrompt + from ragas.testset.synthesizers.prompts import ThemesPersonasMatchingPrompt - p = CommonThemeFromSummariesPrompt() + p = ThemesPersonasMatchingPrompt() file_path = tmp_path / "test_prompt.json" p.save(file_path) - p2 = CommonThemeFromSummariesPrompt.load(file_path) + p2 = ThemesPersonasMatchingPrompt.load(file_path) assert p == p2 @@ -194,10 +194,10 @@ def test_prompt_class_attributes(): We want to make sure there is no relationship between the class attributes and instance. """ - from ragas.testset.synthesizers.prompts import CommonThemeFromSummariesPrompt + from ragas.testset.synthesizers.prompts import ThemesPersonasMatchingPrompt - p = CommonThemeFromSummariesPrompt() - p_another_instance = CommonThemeFromSummariesPrompt() + p = ThemesPersonasMatchingPrompt() + p_another_instance = ThemesPersonasMatchingPrompt() assert p.instruction == p_another_instance.instruction assert p.examples == p_another_instance.examples p.instruction = "You are a helpful assistant."