Skip to content

Commit

Permalink
feat: set and get prompts for metrics (#1259)
Browse files Browse the repository at this point in the history
```python
from ragas.experimental.metrics._faithfulness import FaithfulnessExperimental, LongFormAnswerPrompt

faithfulness = FaithfulnessExperimental() 
faithfulness.get_prompts()

#{'long_form_answer_prompt': <ragas.experimental.metrics._faithfulness.LongFormAnswerPrompt at 0x7fd7baa8efb0>,
#'nli_statement_prompt': <ragas.experimental.metrics._faithfulness.NLIStatementPrompt at 0x7fd7baa8f010>}

long_form_prompt = LongFormAnswerPrompt()
long_form_prompt.instruction = "my new instruction"

prompts = {"long_form_answer_prompt":long_form_prompt}
faithfulness.set_prompts(**prompts)
```

---------

Co-authored-by: Jithin James <[email protected]>
  • Loading branch information
shahules786 and jjmachan authored Sep 10, 2024
1 parent 054c0e9 commit 4e6a96a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ jobs:
pip install .
pip install -r requirements/dev.txt
- name: Format check
run: |
make format
- name: Lint check
run: make lint
- name: Type check
Expand Down
19 changes: 14 additions & 5 deletions docs/howtos/integrations/opik.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
"import getpass\n",
"\n",
"os.environ[\"OPIK_API_KEY\"] = getpass.getpass(\"Opik API Key: \")\n",
"os.environ[\"OPIK_WORKSPACE\"] = input(\"Comet workspace (often the same as your username): \")"
"os.environ[\"OPIK_WORKSPACE\"] = input(\n",
" \"Comet workspace (often the same as your username): \"\n",
")"
]
},
{
Expand Down Expand Up @@ -151,6 +153,7 @@
"import asyncio\n",
"from ragas.integrations.opik import OpikTracer\n",
"\n",
"\n",
"# Define the scoring function\n",
"def compute_metric(opik_tracer, metric, row):\n",
" async def get_score(opik_tracer, metric, row):\n",
Expand All @@ -162,16 +165,17 @@
" result = loop.run_until_complete(get_score(opik_tracer, metric, row))\n",
" return result\n",
"\n",
"\n",
"# Score a simple example\n",
"row = {\n",
" \"question\": \"What is the capital of France?\",\n",
" \"answer\": \"Paris\",\n",
" \"contexts\": [\"Paris is the capital of France.\", \"Paris is in France.\"]\n",
" \"contexts\": [\"Paris is the capital of France.\", \"Paris is in France.\"],\n",
"}\n",
"\n",
"opik_tracer = OpikTracer()\n",
"score = compute_metric(opik_tracer, answer_relevancy_metric, row)\n",
"print(\"Answer Relevancy score:\", score)\n"
"print(\"Answer Relevancy score:\", score)"
]
},
{
Expand Down Expand Up @@ -207,23 +211,27 @@
"from opik import track\n",
"from opik.opik_context import get_current_trace\n",
"\n",
"\n",
"@track\n",
"def retrieve_contexts(question):\n",
" # Define the retrieval function, in this case we will hard code the contexts\n",
" return [\"Paris is the capital of France.\", \"Paris is in France.\"]\n",
"\n",
"\n",
"@track\n",
"def answer_question(question, contexts):\n",
" # Define the answer function, in this case we will hard code the answer\n",
" return \"Paris\"\n",
"\n",
"\n",
"@track(name=\"Compute Ragas metric score\", capture_input=False)\n",
"def compute_rag_score(answer_relevancy_metric, question, answer, contexts):\n",
" # Define the score function\n",
" row = {\"question\": question, \"answer\": answer, \"contexts\": contexts}\n",
" score = compute_metric(answer_relevancy_metric, row)\n",
" return score\n",
"\n",
"\n",
"@track\n",
"def rag_pipeline(question):\n",
" # Define the pipeline\n",
Expand All @@ -233,9 +241,10 @@
" trace = get_current_trace()\n",
" score = compute_rag_score(answer_relevancy_metric, question, answer, contexts)\n",
" trace.log_feedback_score(\"answer_relevancy\", round(score, 4), category_name=\"ragas\")\n",
" \n",
"\n",
" return answer\n",
"\n",
"\n",
"rag_pipeline(\"What is the capital of France?\")"
]
},
Expand Down Expand Up @@ -297,7 +306,7 @@
"result = evaluate(\n",
" fiqa_eval[\"baseline\"].select(range(3)),\n",
" metrics=[context_precision, faithfulness, answer_relevancy],\n",
" callbacks=[opik_tracer_eval]\n",
" callbacks=[opik_tracer_eval],\n",
")\n",
"\n",
"print(result)"
Expand Down
22 changes: 16 additions & 6 deletions src/ragas/integrations/opik.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import typing as t

try:
from opik.integrations.langchain import OpikTracer as LangchainOpikTracer # type: ignore
from opik.integrations.langchain import ( # type: ignore
OpikTracer as LangchainOpikTracer,
)

from ragas.evaluation import RAGAS_EVALUATION_CHAIN_NAME
except ImportError:
raise ImportError("Opik is not installed. Please install it using `pip install opik` to use the Opik tracer.")
raise ImportError(
"Opik is not installed. Please install it using `pip install opik` to use the Opik tracer."
)

if t.TYPE_CHECKING:
from langchain_core.tracers.schemas import Run


class OpikTracer(LangchainOpikTracer):
"""
Callback for Opik that can be used to log traces and evaluation scores to the Opik platform.
Expand All @@ -20,6 +26,7 @@ class OpikTracer(LangchainOpikTracer):
metadata: dict
Additional metadata to log for each trace.
"""

_evaluation_run_id: t.Optional[str] = None

def _persist_run(self, run: "Run"):
Expand Down Expand Up @@ -48,8 +55,11 @@ def _on_chain_end(self, run: "Run"):
span = self._span_map[run.id]
trace_id = span.trace_id
if run.outputs:
self._opik_client.log_traces_feedback_scores([
{"id": trace_id, "name": name, "value": round(value, 4)} for name, value in run.outputs.items()
])

self._opik_client.log_traces_feedback_scores(
[
{"id": trace_id, "name": name, "value": round(value, 4)}
for name, value in run.outputs.items()
]
)

self._persist_run(run)
24 changes: 24 additions & 0 deletions src/ragas/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
from ragas.embeddings import BaseRagasEmbeddings
from ragas.llms import BaseRagasLLM

import inspect

from pysbd import Segmenter
from pysbd.languages import LANGUAGE_CODES

from ragas.experimental.llms.prompt import PydanticPrompt as Prompt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -154,6 +158,26 @@ def init(self, run_config: RunConfig):
)
self.llm.set_run_config(run_config)

def get_prompts(self) -> t.Dict[str, Prompt]:
prompts = {}
for name, value in inspect.getmembers(self):
if isinstance(value, Prompt):
prompts.update({name: value})
return prompts

def set_prompts(self, **prompts):
available_prompts = self.get_prompts()
for key, value in prompts.items():
if key not in available_prompts:
raise ValueError(
f"Prompt with name '{key}' does not exist in the metric {self.name}. Use get_prompts() to see available prompts."
)
if not isinstance(value, Prompt):
raise ValueError(
f"Prompt with name '{key}' must be an instance of 'Prompt'"
)
setattr(self, key, value)


@dataclass
class MetricWithEmbeddings(Metric):
Expand Down

0 comments on commit 4e6a96a

Please sign in to comment.