Skip to content

Commit

Permalink
Script to estimate prompt sizes and change rate
Browse files Browse the repository at this point in the history
Signed-off-by: John Strunk <[email protected]>
  • Loading branch information
JohnStrunk committed May 20, 2024
1 parent 46bb2e3 commit 482c0bf
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 283 deletions.
31 changes: 23 additions & 8 deletions estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import requests
from atlassian import Jira # type: ignore

from jiraissues import Issue, get_self, issue_cache
from jiraissues import Issue, check_response, get_self, issue_cache
from summarizer import count_tokens, summarize_issue


@dataclass
Expand Down Expand Up @@ -50,13 +51,27 @@ def as_csv(self) -> str:

def estimate_issue(issue: Issue) -> IssueEstimate:
"""Estimate the number of tokens needed to summarize the issue"""
prompt = summarize_issue(
issue,
max_depth=0,
send_updates=False,
regenerate=False,
return_prompt_only=True,
)
tokens = -1
try:
tokens = count_tokens(prompt)
except ValueError:
# If the prompt is too large, we can't count the tokens
pass

return IssueEstimate(
key=issue.key,
issue_type=issue.issue_type,
updated=issue.updated,
child_count=len(issue.children),
comment_count=len(issue.comments),
tokens=0, # Placeholder for now
tokens=tokens,
)


Expand All @@ -65,13 +80,13 @@ def get_modified_issues(client: Jira, since: datetime) -> list[Issue]:
user_zi = get_self(client).tzinfo
since_string = since.astimezone(user_zi).strftime("%Y-%m-%d %H:%M")

issues = client.jql(
f"updated >= '{since_string}' ORDER BY updated DESC",
limit=1000,
fields="key",
issues = check_response(
client.jql(
f"updated >= '{since_string}' ORDER BY updated DESC",
limit=1000,
fields="key",
)
)
if not isinstance(issues, dict):
return []
issue_cache.clear()
return [issue_cache.get_issue(client, issue["key"]) for issue in issues["issues"]]

Expand Down
52 changes: 42 additions & 10 deletions genai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
"from os import environ\n",
"from genai import Client, Credentials\n",
"\n",
"genai_key = environ.get('GENAI_KEY', '')\n",
"genai_url = environ.get('GENAI_API', '')\n",
"genai_key = environ.get(\"GENAI_KEY\", \"\")\n",
"genai_url = environ.get(\"GENAI_API\", \"\")\n",
"credentials = Credentials(api_key=genai_key, api_endpoint=genai_url)\n",
"client = Client(credentials=credentials)"
]
Expand All @@ -37,17 +37,17 @@
")\n",
"\n",
"for response in client.text.generation.create(\n",
" model_id=\"mistralai/mixtral-8x7b-instruct-v0-1\",\n",
" model_id=\"mistralai/mixtral-8x7b-instruct-v01\",\n",
" inputs=[\"What is a molecule?\", \"What is NLP?\"],\n",
" parameters=TextGenerationParameters(\n",
" max_new_tokens=150,\n",
" min_new_tokens=20,\n",
" return_options=TextGenerationReturnOptions(input_text=True),\n",
" ),\n",
"):\n",
" result = response.results[0]\n",
" print(f\"Input Text: {result.input_text}\")\n",
" print(f\"Generated Text: {result.generated_text}\")\n",
" resp = response.results[0]\n",
" print(f\"Input Text: {resp.input_text}\")\n",
" print(f\"Generated Text: {resp.generated_text}\")\n",
" print(\"\")"
]
},
Expand All @@ -60,7 +60,12 @@
"# https://ibm.github.io/ibm-generative-ai/v2.3.0/rst_source/examples.extensions.langchain.langchain_chat_stream.html\n",
"\n",
"import pprint\n",
"from langchain_core.messages import HumanMessage, SystemMessage, BaseMessageChunk, BaseMessage\n",
"from langchain_core.messages import (\n",
" HumanMessage,\n",
" SystemMessage,\n",
" BaseMessageChunk,\n",
" BaseMessage,\n",
")\n",
"from genai.extensions.langchain import LangChainChatInterface\n",
"from genai.schema import DecodingMethod, TextGenerationParameters\n",
"\n",
Expand All @@ -79,7 +84,7 @@
"\n",
"prompt = \"Describe what is Python in one sentence.\"\n",
"print(f\"Request: {prompt}\")\n",
"final: BaseMessage = BaseMessageChunk(content=\"\", type=\"\") # dummy chunk\n",
"final: BaseMessage = BaseMessageChunk(content=\"\", type=\"\") # dummy chunk\n",
"first = True\n",
"for chunk in llm.stream(\n",
" input=[\n",
Expand All @@ -99,7 +104,7 @@
" first = False\n",
" pprint.pprint(chunk.response_metadata)\n",
" print(chunk.content, end=\"\", flush=True)\n",
" #info = chunk.generation_info\n",
" # info = chunk.generation_info\n",
" final = chunk\n",
"\n",
"print(\"\\n\\n\")\n",
Expand Down Expand Up @@ -136,6 +141,33 @@
"for chunk in llm.stream(prompt):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from genai.schema import TextTokenizationParameters, TextTokenizationReturnOptions\n",
"\n",
"\n",
"response = client.text.tokenization.create(\n",
" model_id=\"mistralai/mistral-7b-instruct-v0-2\",\n",
" input=\"Tell me a knock knock joke.\", # str | list[str]\n",
" parameters=TextTokenizationParameters(\n",
" return_options=TextTokenizationReturnOptions(\n",
" input_text=False,\n",
" tokens=False,\n",
" ),\n",
" ),\n",
")\n",
"\n",
"total_tokens = 0\n",
"for resp in response:\n",
" for result in resp.results:\n",
" total_tokens += result.token_count\n",
"print(f\"Total tokens: {total_tokens}\")"
]
}
],
"metadata": {
Expand All @@ -154,7 +186,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 482c0bf

Please sign in to comment.