Skip to content

Commit

Permalink
Merge pull request #6 from Kranium2002/add_improved_context
Browse files Browse the repository at this point in the history
Add improved context
  • Loading branch information
Kranium2002 authored Jul 7, 2024
2 parents b88e081 + 0c8d8f4 commit 8b1086f
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 18 deletions.
4 changes: 2 additions & 2 deletions optimizeai/cot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, signature: dspy.Signature):
super().__init__()
self.prog = dspy.ChainOfThought(signature)

def forward(self, code, perf_metrics):
def forward(self, code, context, perf_metrics):
"""Forward method to pass the code and performance metrics to the Chain of Thought model."""
answer = self.prog(code=code, perf_metrics=perf_metrics)
answer = self.prog(code=code, context=context, perf_metrics=perf_metrics)
return answer
48 changes: 43 additions & 5 deletions optimizeai/decorators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,29 @@
import functools
from io import StringIO
from contextlib import redirect_stdout
import sys
import os
from perfwatch import watch
from optimizeai.llm_wrapper import LLMWrapper
from optimizeai.config import Config
import types

def get_function_code(func):
"""Retrieve the source code of a function."""
try:
source_lines, _ = inspect.getsourcelines(func)
return ''.join(source_lines)
except (IOError, TypeError):
return f"Source code not available for {func.__name__}"

def is_user_defined_function(func, base_folder):
"""Check if a function is user-defined."""
if isinstance(func, types.FunctionType):
func_file = inspect.getfile(func)
# Check if the function file path starts with the base folder path
return func_file.startswith(base_folder)
return False

# Custom optimize decorator
def optimize(profiler_types, config: Config):
"""Decorator to optimize a Python function using LLMs and performance profiling."""
def decorator(func):
Expand All @@ -22,13 +40,33 @@ def wrapper(*args, **kwargs):

# Profile the function and capture the output
with StringIO() as buf, redirect_stdout(buf):
watch(profiler_types)(func)(*args, **kwargs)
captured_output = buf.getvalue()
# Create a dictionary to store called functions' source codes
called_funcs_code = {}

def trace_calls(frame, event, arg):
if event == 'call':
called_func = frame.f_globals.get(frame.f_code.co_name)
base_folder = os.path.dirname(inspect.getfile(func))
if is_user_defined_function(called_func, base_folder):
called_funcs_code[frame.f_code.co_name] = get_function_code(called_func)
return trace_calls

# Set the trace function
sys.settrace(trace_calls)

try:
watch(profiler_types)(func)(*args, **kwargs)
finally:
# Remove the trace function
sys.settrace(None)

captured_output = buf.getvalue()
# print(str(list(called_funcs_code.values())[1:]))
# print(code)
# print(captured_output)
# Initialize the LLMWrapper with the provided config
llm_wrapper = LLMWrapper(config)
response= llm_wrapper.send_request(code = code, perf_metrics=captured_output)

response = llm_wrapper.send_request(code=str(code), context=str(list(called_funcs_code.values())[1:]), perf_metrics=captured_output)
print(response)
return response

Expand Down
5 changes: 3 additions & 2 deletions optimizeai/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ZeroShotQAWithCoT(dspy.Signature):
Provide short detailed technical tips and refactor suggestions, such as algorithm improvements,
data structure optimizations, or parallelization strategies if needed."""
code = dspy.InputField(desc="Code to be optimized")
context = dspy.InputField(desc="Code for all user defined functions in the original code provided")
perf_metrics = dspy.InputField(desc="""Performance metrics of the code along with the
output of the code execution""")
optimization = dspy.OutputField(desc="""Detailed solution to how the code can be optimized with insights on performance metrics
Expand Down Expand Up @@ -43,10 +44,10 @@ def __setup_llm(self):
dspy.settings.configure(lm=self.llm)
self.chain = CoT(ZeroShotQAWithCoT)

def send_request(self, code, perf_metrics):
def send_request(self, code, context, perf_metrics):
"""Send a request to the LLM model with the given prompt and performance metrics.
Args: code (str): The code to send to the LLM model.
perf_metrics (str): The performance metrics to send to the LLM model.
Returns: answer (str): The answer generated by the LLM model."""
answer = self.chain.forward(code=code, perf_metrics=perf_metrics)
answer = self.chain.forward(code=code, context=context, perf_metrics=perf_metrics)
return answer.optimization
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "optimizeai"
version = "0.1.4"
version = "0.2.0"
description = "OptimAI is a powerful Python module designed to optimize your code by analyzing its performance and providing actionable suggestions. It leverages a large language model (LLM) to give you detailed insights and recommendations based on the profiling data collected during the execution of your code."
authors = ["Vidhu Mathur <[email protected]>"]
readme = "README.md"
Expand All @@ -17,7 +17,7 @@ perfwatch = "^1.3.2"
py-context = "^0.3.1"
requests = "^2.32.3"
transformers = "^4.41.2"
google-generativeai = "^0.6.0"
google-generativeai = "^0.7.1"
dspy = "^0.1.5"
anthropic = "^0.28.1"
torch = "^2.3.1"
Expand Down

0 comments on commit 8b1086f

Please sign in to comment.