Skip to content

Commit

Permalink
Merge pull request #71 from JohnStrunk/estimator
Browse files Browse the repository at this point in the history
Initial draft of token estimator script
  • Loading branch information
mergify[bot] authored May 17, 2024
2 parents 3168a10 + a48a685 commit 1abd68c
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 27 deletions.
4 changes: 2 additions & 2 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def main():
print(f"Summarized {issue_key} ({elapsed}s):\n{summary}\n")
since = start_time # Only update if we succeeded
except requests.exceptions.HTTPError as error:
logging.error("HTTPError exception: %s", error, stack_info=True)
logging.error("HTTPError exception: %s", error.response.reason)
except requests.exceptions.ReadTimeout as error:
logging.error("ReadTimeout exception: %s", error, stack_info=True)
logging.error("ReadTimeout exception: %s", error, exc_info=True)
logging.info(
"Cache stats: %d hits, %d total", issue_cache.hits, issue_cache.tries
)
Expand Down
126 changes: 126 additions & 0 deletions estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#! /usr/bin/env python

"""Estimate the issue change rate and necessary token throughput"""

import argparse
import logging
import os
import time
from dataclasses import dataclass
from datetime import datetime, timedelta

import requests
from atlassian import Jira # type: ignore

from jiraissues import Issue, get_self, issue_cache


@dataclass
class IssueEstimate:
"""Data class to hold the estimate information"""

key: str
issue_type: str
updated: datetime
child_count: int
comment_count: int
tokens: int

def __str__(self) -> str:
return f"{self.key} ({self.issue_type}): {self.tokens} tokens"

@classmethod
def csv_header(cls) -> str:
"""Return the CSV header line"""
return "key,issue_type,updated,child_count,comment_count,tokens"

def as_csv(self) -> str:
"""Return the CSV representation of the data"""
return ",".join(
[
self.key,
self.issue_type,
self.updated.isoformat(),
str(self.child_count),
str(self.comment_count),
str(self.tokens),
]
)


def estimate_issue(issue: Issue) -> IssueEstimate:
"""Estimate the number of tokens needed to summarize the issue"""
return IssueEstimate(
key=issue.key,
issue_type=issue.issue_type,
updated=issue.updated,
child_count=len(issue.children),
comment_count=len(issue.comments),
tokens=0, # Placeholder for now
)


def get_modified_issues(client: Jira, since: datetime) -> list[Issue]:
"""Get issues modified since the given date/time"""
user_zi = get_self(client).tzinfo
since_string = since.astimezone(user_zi).strftime("%Y-%m-%d %H:%M")

issues = client.jql(
f"updated >= '{since_string}' ORDER BY updated DESC",
limit=1000,
fields="key",
)
if not isinstance(issues, dict):
return []
issue_cache.clear()
return [issue_cache.get_issue(client, issue["key"]) for issue in issues["issues"]]


def main() -> None:
"""Main function"""
parser = argparse.ArgumentParser(description="Estimator")
# pylint: disable=duplicate-code
parser.add_argument(
"--log-level",
default="WARNING",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level",
)
parser.add_argument(
"-s",
"--seconds",
type=int,
default=300,
help="Seconds to wait between iterations",
)

args = parser.parse_args()
logging.basicConfig(level=getattr(logging, args.log_level))
delay: int = args.seconds

jira = Jira(url=os.environ["JIRA_URL"], token=os.environ["JIRA_TOKEN"])

print(IssueEstimate.csv_header())
since = datetime.now() + timedelta(seconds=-delay)
while True:
start_time = datetime.now()
logging.info("Starting iteration at %s", start_time.isoformat())
try:
issues = get_modified_issues(jira, since)
for issue in issues:
print(estimate_issue(issue).as_csv())
since = start_time # Only update if we succeeded
except requests.exceptions.HTTPError as error:
logging.error("HTTPError exception: %s", error.response.reason)
except requests.exceptions.ReadTimeout as error:
logging.error("ReadTimeout exception: %s", error, exc_info=True)
logging.info(
"Cache stats: %d hits, %d total", issue_cache.hits, issue_cache.tries
)
print(f"Iteration elapsed time: {datetime.now() - start_time}")
print(f"{'='*20} Sleeping for {delay} seconds {'='*20}")
time.sleep(delay)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion jira-summarizer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ spec:
- "--log-level"
- "INFO"
- "--seconds"
- "300"
- "120"
envFrom:
- secretRef:
name: jira-summarizer-secret
Expand Down
28 changes: 26 additions & 2 deletions jiraissues.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
CF_STATUS_SUMMARY = "customfield_12320841" # string

# How long to delay between API calls
CALL_DELAY_SECONDS: float = 0.2
MIN_CALL_DELAY: float = 0.2


@dataclass
Expand Down Expand Up @@ -382,6 +382,9 @@ def update_status_summary(self, contents: str) -> None:
issue_cache.remove(self.key) # Invalidate any cached copy


_last_call_time = datetime.now()


def _check(response: Any) -> dict:
"""
Check the response from the Jira API and raise an exception if it's an
Expand All @@ -392,7 +395,15 @@ def _check(response: Any) -> dict:
general, when things go well, you get back a dict. Otherwise, you could get
anything.
"""
sleep(CALL_DELAY_SECONDS)
# Here, we throttle the API calls to avoid hitting the rate limit of the Jira server
global _last_call_time # pylint: disable=global-statement
now = datetime.now()
delta = now - _last_call_time
required_delay = MIN_CALL_DELAY - delta.total_seconds()
if required_delay > 0:
sleep(required_delay)
_last_call_time = now

if isinstance(response, dict):
return response
raise ValueError(f"Unexpected response: {response}")
Expand All @@ -413,6 +424,19 @@ def __init__(self, client: Jira) -> None:
self.tzinfo = ZoneInfo(self.timezone)


_self: Optional[Myself] = None


def get_self(client: Jira) -> Myself:
"""
Caching function for the Myself object.
"""
global _self # pylint: disable=global-statement
if _self is None:
_self = Myself(client)
return _self


class IssueCache:
"""
A cache of Jira issues to avoid fetching the same issue multiple times.
Expand Down
14 changes: 10 additions & 4 deletions summarize_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os

import requests
from atlassian import Jira # type: ignore

from jiraissues import Issue
Expand Down Expand Up @@ -51,10 +52,15 @@ def main():
jira = Jira(url=os.environ["JIRA_URL"], token=os.environ["JIRA_TOKEN"])

issue = Issue(jira, args.jira_issue_key)
out = summarize_issue(
issue, regenerate=regenerate, max_depth=max_depth, send_updates=send_updates
)
print(out)
try:
out = summarize_issue(
issue, regenerate=regenerate, max_depth=max_depth, send_updates=send_updates
)
print(out)
except requests.exceptions.HTTPError as error:
logging.error("HTTPError exception: %s", error.response.reason)
except requests.exceptions.ReadTimeout as error:
logging.error("ReadTimeout exception: %s", error, exc_info=True)


if __name__ == "__main__":
Expand Down
24 changes: 6 additions & 18 deletions summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import textwrap
from datetime import UTC, datetime
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union

from atlassian import Jira # type: ignore
from genai import Client, Credentials
Expand All @@ -19,7 +19,7 @@
from langchain_core.language_models import LLM

import text_wrapper
from jiraissues import Issue, Myself, RelatedIssue, issue_cache
from jiraissues import Issue, RelatedIssue, get_self, issue_cache

_logger = logging.getLogger(__name__)

Expand All @@ -44,18 +44,6 @@

_wrapper = text_wrapper.TextWrapper(SUMMARY_START_MARKER, SUMMARY_END_MARKER)

_self: Optional[Myself] = None


def self(client: Jira) -> Myself:
"""
Caching function for the Myself object.
"""
global _self # pylint: disable=global-statement
if _self is None:
_self = Myself(client)
return _self


# pylint: disable=too-many-locals
def summarize_issue(
Expand Down Expand Up @@ -209,9 +197,9 @@ def summary_last_updated(issue: Issue) -> datetime:
return last_update

for change in issue.changelog:
if change.author == self(issue.client).display_name and "Status Summary" in [
chg.field for chg in change.changes
]:
if change.author == get_self(
issue.client
).display_name and "Status Summary" in [chg.field for chg in change.changes]:
last_update = max(last_update, change.created)

return last_update
Expand Down Expand Up @@ -338,7 +326,7 @@ def get_issues_to_summarize(
"""
# The time format for the query needs to be in the local timezone of the
# user, so we need to convert
user_zi = self(client).tzinfo
user_zi = get_self(client).tzinfo
since_string = since.astimezone(user_zi).strftime("%Y-%m-%d %H:%M")
updated_issues = client.jql(
f"labels = '{SUMMARY_ALLOWED_LABEL}' and updated >= '{since_string}' ORDER BY updated DESC",
Expand Down

0 comments on commit 1abd68c

Please sign in to comment.