Skip to content

Commit

Permalink
[FEATURE] Capture metrics (#136)
Browse files Browse the repository at this point in the history
* Added metrics_mixin

Signed-off-by: Deepak <[email protected]>

* Added capture_metrics decorator func

Signed-off-by: Deepak <[email protected]>

* Used capture_metrics decorator wherever needed

Signed-off-by: Deepak <[email protected]>

* Added clear_metrics function

Signed-off-by: Deepak <[email protected]>

* Added redis dependency

Signed-off-by: Deepak <[email protected]>

* Addressed review comments

Signed-off-by: Deepak <[email protected]>

* Addressed review comments

Signed-off-by: Deepak <[email protected]>

* Version bump

Signed-off-by: Deepak <[email protected]>

* Adderessed review comments

Signed-off-by: Deepak <[email protected]>

---------

Signed-off-by: Deepak <[email protected]>
  • Loading branch information
Deepak-Kesavan authored Dec 18, 2024
1 parent 5da46eb commit 8156743
Show file tree
Hide file tree
Showing 7 changed files with 816 additions and 626 deletions.
1,261 changes: 649 additions & 612 deletions pdm.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ dependencies = [
#Unable to import llm adapters : No module named 'mistralai.models.chat_completion'
#Looks like mistralai>0.4.2 is not backward compatible
"mistralai==0.4.2",

"llama-index-llms-anyscale==0.1.4",
"llama-index-llms-anthropic==0.1.16",
"llama-index-llms-azure-openai==0.1.10",
Expand All @@ -58,6 +57,7 @@ dependencies = [
"singleton-decorator~=1.0.0",
"httpx>=0.25.2",
"pdfplumber>=0.11.2",
"redis>=5.2.1",
]
readme = "README.md"
urls = { Homepage = "https://unstract.com", "Release notes" = "https://github.com/Zipstack/unstract-sdk/releases", Source = "https://github.com/Zipstack/unstract-sdk" }
Expand Down Expand Up @@ -120,4 +120,6 @@ path = "src/unstract/sdk/__init__.py"
# Adding the following override to resolve dependency version
# for environs. Otherwise, it stays stuck while resolving pins
[tool.pdm.resolution.overrides]
grpcio = ">=1.62.1"
grpcio = "1.62.3"
grpcio-tools = "1.62.3"
grpcio-health-checking = "1.62.3"
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.54.0rc7"
__version__ = "0.54.0rc8"


def get_sdk_version():
Expand Down
20 changes: 18 additions & 2 deletions src/unstract/sdk/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from unstract.sdk.file_storage import FileStorage, FileStorageProvider
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils import ToolUtils
from unstract.sdk.utils.common_utils import log_elapsed
from unstract.sdk.utils.common_utils import capture_metrics, log_elapsed
from unstract.sdk.vector_db import VectorDB
from unstract.sdk.x2txt import X2Text

Expand All @@ -39,10 +39,19 @@ class Constants:


class Index:
def __init__(self, tool: BaseTool):
def __init__(
self,
tool: BaseTool,
run_id: Optional[str] = None,
capture_metrics: bool = False,
):
# TODO: Inherit from StreamMixin and avoid using BaseTool
self.tool = tool
self._run_id = run_id
self._capture_metrics = capture_metrics
self._metrics = {}

@capture_metrics
def query_index(
self,
embedding_instance_id: str,
Expand Down Expand Up @@ -180,6 +189,7 @@ def extract_text(
return extracted_text

@log_elapsed(operation="CHECK_AND_INDEX(overall)")
@capture_metrics
def index(
self,
tool_id: str,
Expand Down Expand Up @@ -449,6 +459,12 @@ def generate_index_key(
hashed_index_key = ToolUtils.hash_str(json.dumps(index_key, sort_keys=True))
return hashed_index_key

def get_metrics(self):
return self._metrics

def clear_metrics(self):
self._metrics = {}

@deprecated(version="0.45.0", reason="Use generate_index_key() instead")
def generate_file_id(
self,
Expand Down
29 changes: 20 additions & 9 deletions src/unstract/sdk/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unstract.sdk.helper import SdkHelper
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils.callback_manager import CallbackManager
from unstract.sdk.utils.common_utils import capture_metrics

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +37,7 @@ def __init__(
tool: BaseTool,
adapter_instance_id: Optional[str] = None,
usage_kwargs: dict[Any, Any] = {},
capture_metrics: bool = False,
):
"""Creates an instance of this LLM class.
Expand All @@ -50,6 +52,10 @@ def __init__(
self._adapter_instance_id = adapter_instance_id
self._llm_instance: LlamaIndexLLM = None
self._usage_kwargs = usage_kwargs
self._capture_metrics = capture_metrics
self._run_id = usage_kwargs.get("run_id")
self._usage_reason = usage_kwargs.get("llm_usage_reason")
self._metrics = {}
self._initialise()

def _initialise(self):
Expand All @@ -65,14 +71,16 @@ def _initialise(self):
kwargs=self._usage_kwargs,
)

@capture_metrics
def complete(
self,
prompt: str,
extract_json: bool = True,
process_text: Optional[Callable[[str], str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Generates a completion response for the given prompt.
"""Generates a completion response for the given prompt and captures
metrics if run_id is provided.
Args:
prompt (str): The input text prompt for generating the completion.
Expand All @@ -85,12 +93,8 @@ def complete(
**kwargs (Any): Additional arguments passed to the completion function.
Returns:
dict[str, Any]: A dictionary containing the result of the completion
and any processed output.
Raises:
LLMError: If an error occurs during the completion process, it will be
raised after being processed by `parse_llm_err`.
dict[str, Any]: A dictionary containing the result of the completion,
any processed output, and the captured metrics (if applicable).
"""
try:
response: CompletionResponse = self._llm_instance.complete(prompt, **kwargs)
Expand All @@ -105,12 +109,19 @@ def complete(
if not isinstance(process_text_output, dict):
process_text_output = {}
except Exception as e:
logger.error(f"Error occured inside function 'process_text': {e}")
logger.error(f"Error occurred inside function 'process_text': {e}")
process_text_output = {}
return {LLM.RESPONSE: response, **process_text_output}
response_data = {LLM.RESPONSE: response, **process_text_output}
return response_data
except Exception as e:
raise parse_llm_err(e, self._llm_adapter_class) from e

def get_metrics(self):
return self._metrics

def get_usage_reason(self):
return self._usage_reason

def stream_complete(
self,
prompt: str,
Expand Down
73 changes: 73 additions & 0 deletions src/unstract/sdk/metrics_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import os
import time
import uuid
from typing import Any

from redis import StrictRedis

logger = logging.getLogger(__name__)


class MetricsMixin:
TIME_TAKEN_KEY = "time_taken(s)"

def __init__(self, run_id):
"""Initialize the MetricsMixin class.
Args:
run_id (str): Unique identifier for the run.
"""
self.run_id = run_id
self.op_id = str(uuid.uuid4()) # Unique identifier for this instance
self.redis_client = None
try:
# Initialize Redis client
self.redis_client = StrictRedis(
host=os.getenv("REDIS_HOST", "unstract-redis"),
port=int(os.getenv("REDIS_PORT", 6379)),
username=os.getenv("REDIS_USER", "default"),
password=os.getenv("REDIS_PASSWORD", ""),
db=1,
decode_responses=True,
)
except Exception as e:
logger.error(
"Failed to initialize Redis client" f" for run_id={run_id}: {e}"
)

self.redis_key = f"metrics:{self.run_id}:{self.op_id}"

# Set the start time immediately upon initialization
self.set_start_time()

def set_start_time(self, ttl=86400):
"""Store the current timestamp in Redis when the instance is
created."""
if self.redis_client is None:
logger.error("Redis client is not initialized. Cannot set start time.")
return
self.redis_client.set(self.redis_key, time.time(), ex=ttl)

def collect_metrics(self) -> dict[str, Any]:
"""Calculate the time taken since the timestamp was set and delete the
Redis key.
Returns:
dict: The calculated time taken and the associated run_id and op_id.
"""

if self.redis_client is None:
logger.error("Redis client is not initialized. Cannot collect metrics.")
return {self.TIME_TAKEN_KEY: None}

if not self.redis_client.exists(self.redis_key):
return {self.TIME_TAKEN_KEY: None}

start_time = float(self.redis_client.get(self.redis_key))
time_taken = round(time.time() - start_time, 3)

# Delete the Redis key after use
self.redis_client.delete(self.redis_key)

return {self.TIME_TAKEN_KEY: time_taken}
51 changes: 51 additions & 0 deletions src/unstract/sdk/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid

from unstract.sdk.constants import LogLevel
from unstract.sdk.metrics_mixin import MetricsMixin

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,3 +55,53 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def capture_metrics(func):
"""Decorator to capture metrics at the start and end of a function."""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# Ensure the required attributes exist; if not,
# execute the function and return its result
if not all(
hasattr(self, attr) for attr in ["_run_id", "_capture_metrics", "_metrics"]
):
return func(self, *args, **kwargs)

# Check if run_id exists and if metrics should be captured
metrics_mixin = None
time_taken_key = MetricsMixin.TIME_TAKEN_KEY
if self._run_id and self._capture_metrics:
metrics_mixin = MetricsMixin(run_id=self._run_id)

try:
result = func(self, *args, **kwargs)
finally:
# If metrics are being captured, collect and assign them at the end
if metrics_mixin:
new_metrics = metrics_mixin.collect_metrics()

# If time_taken(s) exists in both self._metrics and new_metrics, sum it
if (
self._metrics
and time_taken_key in self._metrics
and time_taken_key in new_metrics
):
previously_measured_time = self._metrics.get(time_taken_key)
newly_measured_time = new_metrics.get(time_taken_key)

# Only sum if both are valid
if previously_measured_time and newly_measured_time:
self._metrics[time_taken_key] = (
previously_measured_time + newly_measured_time
)
else:
self._metrics[time_taken_key] = None
else:
# If the key isn't in self._metrics, set it to new_metrics
self._metrics = new_metrics

return result

return wrapper

0 comments on commit 8156743

Please sign in to comment.