From 482c0bf315bc0050daad4b2b80a8b627157346bb Mon Sep 17 00:00:00 2001 From: John Strunk Date: Mon, 20 May 2024 15:51:30 +0000 Subject: [PATCH] Script to estimate prompt sizes and change rate Signed-off-by: John Strunk --- estimator.py | 31 +++-- genai.ipynb | 52 +++++++-- jira_howto.ipynb | 290 +++++++---------------------------------------- jiraissues.py | 45 +++++--- summarizer.py | 5 + 5 files changed, 140 insertions(+), 283 deletions(-) diff --git a/estimator.py b/estimator.py index 02dbda4..b416305 100755 --- a/estimator.py +++ b/estimator.py @@ -12,7 +12,8 @@ import requests from atlassian import Jira # type: ignore -from jiraissues import Issue, get_self, issue_cache +from jiraissues import Issue, check_response, get_self, issue_cache +from summarizer import count_tokens, summarize_issue @dataclass @@ -50,13 +51,27 @@ def as_csv(self) -> str: def estimate_issue(issue: Issue) -> IssueEstimate: """Estimate the number of tokens needed to summarize the issue""" + prompt = summarize_issue( + issue, + max_depth=0, + send_updates=False, + regenerate=False, + return_prompt_only=True, + ) + tokens = -1 + try: + tokens = count_tokens(prompt) + except ValueError: + # If the prompt is too large, we can't count the tokens + pass + return IssueEstimate( key=issue.key, issue_type=issue.issue_type, updated=issue.updated, child_count=len(issue.children), comment_count=len(issue.comments), - tokens=0, # Placeholder for now + tokens=tokens, ) @@ -65,13 +80,13 @@ def get_modified_issues(client: Jira, since: datetime) -> list[Issue]: user_zi = get_self(client).tzinfo since_string = since.astimezone(user_zi).strftime("%Y-%m-%d %H:%M") - issues = client.jql( - f"updated >= '{since_string}' ORDER BY updated DESC", - limit=1000, - fields="key", + issues = check_response( + client.jql( + f"updated >= '{since_string}' ORDER BY updated DESC", + limit=1000, + fields="key", + ) ) - if not isinstance(issues, dict): - return [] issue_cache.clear() return [issue_cache.get_issue(client, issue["key"]) for issue in issues["issues"]] diff --git a/genai.ipynb b/genai.ipynb index fdcd58e..16c57c9 100644 --- a/genai.ipynb +++ b/genai.ipynb @@ -9,8 +9,8 @@ "from os import environ\n", "from genai import Client, Credentials\n", "\n", - "genai_key = environ.get('GENAI_KEY', '')\n", - "genai_url = environ.get('GENAI_API', '')\n", + "genai_key = environ.get(\"GENAI_KEY\", \"\")\n", + "genai_url = environ.get(\"GENAI_API\", \"\")\n", "credentials = Credentials(api_key=genai_key, api_endpoint=genai_url)\n", "client = Client(credentials=credentials)" ] @@ -37,7 +37,7 @@ ")\n", "\n", "for response in client.text.generation.create(\n", - " model_id=\"mistralai/mixtral-8x7b-instruct-v0-1\",\n", + " model_id=\"mistralai/mixtral-8x7b-instruct-v01\",\n", " inputs=[\"What is a molecule?\", \"What is NLP?\"],\n", " parameters=TextGenerationParameters(\n", " max_new_tokens=150,\n", @@ -45,9 +45,9 @@ " return_options=TextGenerationReturnOptions(input_text=True),\n", " ),\n", "):\n", - " result = response.results[0]\n", - " print(f\"Input Text: {result.input_text}\")\n", - " print(f\"Generated Text: {result.generated_text}\")\n", + " resp = response.results[0]\n", + " print(f\"Input Text: {resp.input_text}\")\n", + " print(f\"Generated Text: {resp.generated_text}\")\n", " print(\"\")" ] }, @@ -60,7 +60,12 @@ "# https://ibm.github.io/ibm-generative-ai/v2.3.0/rst_source/examples.extensions.langchain.langchain_chat_stream.html\n", "\n", "import pprint\n", - "from langchain_core.messages import HumanMessage, SystemMessage, BaseMessageChunk, BaseMessage\n", + "from langchain_core.messages import (\n", + " HumanMessage,\n", + " SystemMessage,\n", + " BaseMessageChunk,\n", + " BaseMessage,\n", + ")\n", "from genai.extensions.langchain import LangChainChatInterface\n", "from genai.schema import DecodingMethod, TextGenerationParameters\n", "\n", @@ -79,7 +84,7 @@ "\n", "prompt = \"Describe what is Python in one sentence.\"\n", "print(f\"Request: {prompt}\")\n", - "final: BaseMessage = BaseMessageChunk(content=\"\", type=\"\") # dummy chunk\n", + "final: BaseMessage = BaseMessageChunk(content=\"\", type=\"\") # dummy chunk\n", "first = True\n", "for chunk in llm.stream(\n", " input=[\n", @@ -99,7 +104,7 @@ " first = False\n", " pprint.pprint(chunk.response_metadata)\n", " print(chunk.content, end=\"\", flush=True)\n", - " #info = chunk.generation_info\n", + " # info = chunk.generation_info\n", " final = chunk\n", "\n", "print(\"\\n\\n\")\n", @@ -136,6 +141,33 @@ "for chunk in llm.stream(prompt):\n", " print(chunk, end=\"\", flush=True)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from genai.schema import TextTokenizationParameters, TextTokenizationReturnOptions\n", + "\n", + "\n", + "response = client.text.tokenization.create(\n", + " model_id=\"mistralai/mistral-7b-instruct-v0-2\",\n", + " input=\"Tell me a knock knock joke.\", # str | list[str]\n", + " parameters=TextTokenizationParameters(\n", + " return_options=TextTokenizationReturnOptions(\n", + " input_text=False,\n", + " tokens=False,\n", + " ),\n", + " ),\n", + ")\n", + "\n", + "total_tokens = 0\n", + "for resp in response:\n", + " for result in resp.results:\n", + " total_tokens += result.token_count\n", + "print(f\"Total tokens: {total_tokens}\")" + ] } ], "metadata": { @@ -154,7 +186,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/jira_howto.ipynb b/jira_howto.ipynb index 1fc4c27..372e015 100644 --- a/jira_howto.ipynb +++ b/jira_howto.ipynb @@ -57,51 +57,10 @@ "# Create a JIRA client\n", "from os import environ\n", "from atlassian import Jira\n", - "jira_api_token = environ.get('JIRA_TOKEN', '')\n", - "jira_url = environ.get('JIRA_URL', '')\n", - "jira = Jira(url=jira_url, token=jira_api_token)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "issue_key = 'RHSTOR-919'\n", - "issue = jira.issue(issue_key)\n", - "print(f\"{type(issue)}\\n\")\n", "\n", - "from pprint import pprint\n", - "pprint(issue)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "# Let's get the properties as well...\n", - "issue_with_properties = jira.get_issue(issue_key, properties=\"*all\")\n", - "pprint(issue_with_properties)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "rlinks = jira.get_issue_remote_links(issue_key)\n", - "for rlink in rlinks:\n", - " print(f\"{rlink['object']['title']} -> {rlink['object']['url']}\")" + "jira_api_token = environ.get(\"JIRA_TOKEN\", \"\")\n", + "jira_url = environ.get(\"JIRA_URL\", \"\")\n", + "jira = Jira(url=jira_url, token=jira_api_token)" ] }, { @@ -114,8 +73,8 @@ "source": [ "# Lets look at all the possible Jira fields and their types:\n", "fields = jira.get_all_fields()\n", - "for field in sorted(fields, key=lambda x: x['id']):\n", - " ftype = field['schema']['type'] if 'schema' in field else 'unknown'\n", + "for field in sorted(fields, key=lambda x: x[\"id\"]):\n", + " ftype = field[\"schema\"][\"type\"] if \"schema\" in field else \"unknown\"\n", " print(f\"{field['id']} -> {field['name']} -- {ftype}\")" ] }, @@ -127,9 +86,16 @@ }, "outputs": [], "source": [ + "from pprint import pprint\n", + "\n", "for field in fields:\n", - " if field['id'] in ['customfield_12311140', 'customfield_12311141', 'customfield_12313140', 'customfield_12318341']:\n", - " pprint(field)\n" + " if field[\"id\"] in [\n", + " \"customfield_12311140\",\n", + " \"customfield_12311141\",\n", + " \"customfield_12313140\",\n", + " \"customfield_12318341\",\n", + " ]:\n", + " pprint(field)" ] }, { @@ -141,11 +107,11 @@ "outputs": [], "source": [ "# Accessing \"Parent Link\" custom field\n", - "et85 = jira.get_issue('OCTOET-85')\n", - "pprint(et85['fields']['customfield_12313140']) # Has parent\n", + "et85 = jira.get_issue(\"OCTOET-85\")\n", + "pprint(et85[\"fields\"][\"customfield_12313140\"]) # Has parent\n", "\n", - "stor919 = jira.get_issue('RHSTOR-919')\n", - "pprint(stor919['fields']['customfield_12313140']) # No parent" + "stor919 = jira.get_issue(\"RHSTOR-919\")\n", + "pprint(stor919[\"fields\"][\"customfield_12313140\"]) # No parent" ] }, { @@ -156,13 +122,17 @@ }, "outputs": [], "source": [ - "interesting_fields = ['customfield_12311140', 'customfield_12311141', 'customfield_12313140']\n", - "issue = jira.get_issue('OPRUN-3254')\n", + "interesting_fields = [\n", + " \"customfield_12311140\",\n", + " \"customfield_12311141\",\n", + " \"customfield_12313140\",\n", + "]\n", + "issue = jira.get_issue(\"OPRUN-3254\")\n", "for field in interesting_fields:\n", " field_name = \"unknown\"\n", " for idx in fields:\n", - " if idx['id'] == field:\n", - " field_name = idx['name']\n", + " if idx[\"id\"] == field:\n", + " field_name = idx[\"name\"]\n", " break\n", " print(f\"{field_name} -> {issue['fields'].get(field, None)}\")" ] @@ -176,96 +146,18 @@ "outputs": [], "source": [ "# Get a list of the custom fields in this issue by finding all fields that start with \"customfield_\"\n", - "custom_fields = [k for k in issue['fields'].keys() if k.startswith('customfield_')]\n", + "custom_fields = [k for k in issue[\"fields\"].keys() if k.startswith(\"customfield_\")]\n", "for field in custom_fields:\n", - " if issue['fields'].get(field, None) is None:\n", + " if issue[\"fields\"].get(field, None) is None:\n", " continue\n", " field_name = \"unknown\"\n", " id = \"?\"\n", " for idx in fields:\n", - " if idx['id'] == field:\n", - " id = idx['id']\n", - " field_name = idx['name']\n", + " if idx[\"id\"] == field:\n", + " id = idx[\"id\"]\n", + " field_name = idx[\"name\"]\n", " break\n", - " print(f\"{field_name} ({id}) -> {issue['fields'].get(field, None)}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "import sys\n", - "\n", - "\n", - "if 'jhelper' in sys.modules:\n", - " del sys.modules['jhelper']\n", - "import jiraissues\n", - "\n", - "for i in ['OPRUN-1858', 'OPRUN-3254', 'RHSTOR-919']:\n", - " print(f\"Related issues for {i}:\")\n", - " issue = jiraissues.Issue(jira, i)\n", - " related = issue.related\n", - " for r in related:\n", - " print(f\" {r.how} -> {r.key}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "pprint(jira.jql(\"labels = 'AISummary' ORDER BY created DESC\", limit=5, fields=\"key,summary,updated\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "jira.get_issue_changelog('RHSTOR-5635')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "pprint(jira.myself())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "pprint(jira.issue('OCTO-2', fields='key,updated'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "pprint(jira.jql(\"'Parent Link' = 'OCTO-2' ORDER BY created DESC\", limit=5, fields=\"key,summary,updated\"))" + " print(f\"{field_name} ({id}) -> {issue['fields'].get(field, None)}\")" ] }, { @@ -276,12 +168,13 @@ }, "outputs": [], "source": [ - "from datetime import datetime\n", - "dt = datetime.fromisoformat('2024-04-20T13:02:00.000+0000')\n", - "result = jira.jql(f\"labels = 'AISummary' and updated > \\\"{dt.strftime(\"%Y-%m-%d %H:%M\")}\\\" ORDER BY updated DESC\", limit=50, fields=\"key,updated\")\n", - "keys = [(x['key'], datetime.fromisoformat(x['fields']['updated'])) for x in result['issues']]\n", - "pprint(keys)\n", - "print(type(result))" + "pprint(\n", + " jira.jql(\n", + " \"labels = 'AISummary' ORDER BY created DESC\",\n", + " limit=5,\n", + " fields=\"key,summary,updated\",\n", + " )\n", + ")" ] }, { @@ -293,7 +186,7 @@ "outputs": [], "source": [ "# With a private comment\n", - "pprint(jira.issue('OHSS-34055'))" + "pprint(jira.issue(\"OHSS-34055\"))" ] }, { @@ -305,39 +198,6 @@ "Otherwise, they do not have the visibility key." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "jira.get_project('RHSTOR')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "jira.get_project('STOR')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "metadata": {} - }, - "outputs": [], - "source": [ - "jira.get_project('OCTOET', expand='all')" - ] - }, { "cell_type": "code", "execution_count": null, @@ -350,64 +210,10 @@ "import os\n", "\n", "\n", - "i = Issue(jira, 'OCTOET-85')\n", + "i = Issue(jira, \"OCTOET-85\")\n", "print(i.project_key)\n", "\n", - "print(i.project_key in os.environ.get(\n", - " \"ALLOWED_PROJECTS\", \"\"\n", - " ).split(\",\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "if 'jiraissues' in sys.modules:\n", - " del sys.modules['jiraissues']\n", - "import jiraissues\n", - "\n", - "key = 'COS-2692'\n", - "parents = []\n", - "while key:\n", - " issue = jiraissues.Issue(jira, key)\n", - " parents.append((issue.key, issue.level, issue.issue_type))\n", - " key = issue.parent\n", - "pprint(parents)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "from zoneinfo import ZoneInfo\n", - "\n", - "i = jiraissues.Issue(jira, 'OCTOET-236')\n", - "print(i.updated)\n", - "print(i.last_change)\n", - "datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M\")\n", - "\n", - "print(\"changelog\")\n", - "for c in i.changelog:\n", - " print(f\"{c.created}\")\n", - "\n", - "print(\"comments\")\n", - "for co in i.comments:\n", - " print(f\"{co.created}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "jira.jql(\"labels = 'AISummary' and updated >= '2024-05-08 13:19' ORDER BY updated DESC\", fields=\"key,updated\")" + "print(i.project_key in os.environ.get(\"ALLOWED_PROJECTS\", \"\").split(\",\"))" ] }, { @@ -415,17 +221,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from time import timezone\n", - "\n", - "pprint(jira.myself())\n", - "\n", - "tzname = jira.myself()['timeZone']\n", - "timezone = ZoneInfo(tzname)\n", - "now = datetime.datetime.now()\n", - "print(now.astimezone(timezone).strftime(\"%Y-%m-%d %H:%M\"))\n", - "print(now.strftime(\"%Y-%m-%d %H:%M\"))\n" - ] + "source": [] } ], "metadata": { diff --git a/jiraissues.py b/jiraissues.py index fea3c3b..d2f93a2 100644 --- a/jiraissues.py +++ b/jiraissues.py @@ -106,7 +106,7 @@ def __init__(self, client: Jira, issue_key: str) -> None: "updated", CF_STATUS_SUMMARY, ] - data = _check(client.issue(issue_key, fields=",".join(fields))) + data = check_response(client.issue(issue_key, fields=",".join(fields))) # Populate the fields self.summary: str = data["fields"]["summary"] @@ -136,7 +136,9 @@ def __str__(self) -> str: def _fetch_changelog(self) -> List[ChangelogEntry]: """Fetch the changelog from the API.""" _logger.debug("Retrieving changelog for %s", self.key) - log = _check(self.client.get_issue_changelog(self.key, start=0, limit=1000)) + log = check_response( + self.client.get_issue_changelog(self.key, start=0, limit=1000) + ) items: List[ChangelogEntry] = [] for entry in log["histories"]: changes: List[Change] = [] @@ -169,9 +171,9 @@ def changelog(self) -> List[ChangelogEntry]: def _fetch_comments(self) -> List[Comment]: """Fetch the comments from the API.""" _logger.debug("Retrieving comments for %s", self.key) - comments = _check(self.client.issue(self.key, fields="comment"))["fields"][ - "comment" - ]["comments"] + comments = check_response(self.client.issue(self.key, fields="comment"))[ + "fields" + ]["comment"]["comments"] items: List[Comment] = [] for comment in comments: items.append( @@ -201,7 +203,7 @@ def _fetch_related(self) -> List[RelatedIssue]: # pylint: disable=too-many-bran ] found_issues: set[str] = set() _logger.debug("Retrieving related links for %s", self.key) - data = _check(self.client.issue(self.key, fields=",".join(fields))) + data = check_response(self.client.issue(self.key, fields=",".join(fields))) # Get the related issues related: List[RelatedIssue] = [] for link in data["fields"]["issuelinks"]: @@ -258,14 +260,16 @@ def _fetch_related(self) -> List[RelatedIssue]: # pylint: disable=too-many-bran # issue to it's children. epic_issues returns an error if the issue is not # an Epic. These are downward links to children if self.issue_type == "Epic": - issues_in_epic = _check(self.client.epic_issues(self.key, fields="key")) + issues_in_epic = check_response( + self.client.epic_issues(self.key, fields="key") + ) for i in issues_in_epic["issues"]: if i["key"] not in found_issues: related.append(RelatedIssue(key=i["key"], how=_HOW_INEPIC)) found_issues.add(i["key"]) else: # Non-epic issues use the parent link - issues_with_parent = _check( + issues_with_parent = check_response( self.client.jql(f"'Parent Link' = '{self.key}'", limit=50, fields="key") ) for i in issues_with_parent["issues"]: @@ -362,7 +366,7 @@ def last_comment(self) -> Optional[Comment]: @property def is_last_change_mine(self) -> bool: """Check if the last change in the changelog was made by me.""" - me = _check(self.client.myself()) + me = check_response(self.client.myself()) return ( self.last_change is not None and self.last_change.author == me["displayName"] @@ -385,7 +389,18 @@ def update_status_summary(self, contents: str) -> None: _last_call_time = datetime.now() -def _check(response: Any) -> dict: +def _rate_limit() -> None: + """Rate limit the API calls to avoid hitting the rate limit of the Jira server""" + global _last_call_time # pylint: disable=global-statement + now = datetime.now() + delta = now - _last_call_time + required_delay = MIN_CALL_DELAY - delta.total_seconds() + if required_delay > 0: + sleep(required_delay) + _last_call_time = now + + +def check_response(response: Any) -> dict: """ Check the response from the Jira API and raise an exception if it's an error. @@ -396,13 +411,7 @@ def _check(response: Any) -> dict: anything. """ # Here, we throttle the API calls to avoid hitting the rate limit of the Jira server - global _last_call_time # pylint: disable=global-statement - now = datetime.now() - delta = now - _last_call_time - required_delay = MIN_CALL_DELAY - delta.total_seconds() - if required_delay > 0: - sleep(required_delay) - _last_call_time = now + _rate_limit() if isinstance(response, dict): return response @@ -416,7 +425,7 @@ class Myself: # pylint: disable=too-few-public-methods def __init__(self, client: Jira) -> None: self.client = client - self._data = _check(client.myself()) + self._data = check_response(client.myself()) # Break out the fields we care about self.display_name: str = self._data["displayName"] self.key: str = self._data["key"] diff --git a/summarizer.py b/summarizer.py index b0424a9..84dda64 100644 --- a/summarizer.py +++ b/summarizer.py @@ -51,6 +51,7 @@ def summarize_issue( max_depth: int = 0, send_updates: bool = False, regenerate: bool = False, + return_prompt_only: bool = False, ) -> str: """ Summarize a Jira issue. @@ -65,6 +66,8 @@ def summarize_issue( - send_updates: If True, update the issue summaries on the server - regenerate: If True, regenerate the summary even if it is already up-to-date + - return_prompt_only: If True, return the prompt only and don't actually + summarize the issue Returns: A string containing the summary @@ -162,6 +165,8 @@ def summarize_issue( {full_description} ``` """ + if return_prompt_only: + return llm_prompt _logger.info( "Summarizing %s (%d tokens) via LLM", issue.key, count_tokens(llm_prompt)