Skip to content

Commit

Permalink
Use Session object for requests in client methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
fkdosilovic committed Nov 30, 2024
1 parent 4d5e24b commit a63bae6
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions src/judge0/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
self.endpoint = endpoint
self.auth_headers = auth_headers
self.retry_strategy = retry_strategy
self.session = requests.Session()

# TODO: Should be handled differently.
try:
Expand All @@ -33,46 +34,49 @@ def __init__(
"review your authentication credentials."
) from e

def __del__(self):
self.session.close()

@handle_too_many_requests_error_for_preview_client
def get_about(self) -> dict:
r = requests.get(
response = self.session.get(
f"{self.endpoint}/about",
headers=self.auth_headers,
)
r.raise_for_status()
return r.json()
response.raise_for_status()
return response.json()

@handle_too_many_requests_error_for_preview_client
def get_config_info(self) -> dict:
r = requests.get(
response = self.session.get(
f"{self.endpoint}/config_info",
headers=self.auth_headers,
)
r.raise_for_status()
return r.json()
response.raise_for_status()
return response.json()

@handle_too_many_requests_error_for_preview_client
def get_language(self, language_id) -> dict:
request_url = f"{self.endpoint}/languages/{language_id}"
r = requests.get(request_url, headers=self.auth_headers)
r.raise_for_status()
return r.json()
response = self.session.get(request_url, headers=self.auth_headers)
response.raise_for_status()
return response.json()

@handle_too_many_requests_error_for_preview_client
def get_languages(self) -> list[dict]:
request_url = f"{self.endpoint}/languages"
r = requests.get(request_url, headers=self.auth_headers)
r.raise_for_status()
return r.json()
response = self.session.get(request_url, headers=self.auth_headers)
response.raise_for_status()
return response.json()

@handle_too_many_requests_error_for_preview_client
def get_statuses(self) -> list[dict]:
r = requests.get(
response = self.session.get(
f"{self.endpoint}/statuses",
headers=self.auth_headers,
)
r.raise_for_status()
return r.json()
response.raise_for_status()
return response.json()

@property
def version(self):
Expand Down Expand Up @@ -123,15 +127,15 @@ def create_submission(self, submission: Submission) -> Submission:

body = submission.as_body(self)

resp = requests.post(
response = self.session.post(
f"{self.endpoint}/submissions",
json=body,
params=params,
headers=self.auth_headers,
)
resp.raise_for_status()
response.raise_for_status()

submission.set_attributes(resp.json())
submission.set_attributes(response.json())

return submission

Expand Down Expand Up @@ -169,14 +173,14 @@ def get_submission(
else:
params["fields"] = "*"

resp = requests.get(
response = self.session.get(
f"{self.endpoint}/submissions/{submission.token}",
params=params,
headers=self.auth_headers,
)
resp.raise_for_status()
response.raise_for_status()

submission.set_attributes(resp.json())
submission.set_attributes(response.json())

return submission

Expand Down Expand Up @@ -209,15 +213,15 @@ def create_submissions(self, submissions: Submissions) -> Submissions:

submissions_body = [submission.as_body(self) for submission in submissions]

resp = requests.post(
response = self.session.post(
f"{self.endpoint}/submissions/batch",
headers=self.auth_headers,
params={"base64_encoded": "true"},
json={"submissions": submissions_body},
)
resp.raise_for_status()
response.raise_for_status()

for submission, attrs in zip(submissions, resp.json()):
for submission, attrs in zip(submissions, response.json()):
submission.set_attributes(attrs)

return submissions
Expand Down Expand Up @@ -262,20 +266,22 @@ def get_submissions(
tokens = ",".join(submission.token for submission in submissions)
params["tokens"] = tokens

resp = requests.get(
response = self.session.get(
f"{self.endpoint}/submissions/batch",
params=params,
headers=self.auth_headers,
)
resp.raise_for_status()
response.raise_for_status()

for submission, attrs in zip(submissions, resp.json()["submissions"]):
for submission, attrs in zip(submissions, response.json()["submissions"]):
submission.set_attributes(attrs)

return submissions


class ATD(Client):
"""Base class for all AllThingsDev clients."""

API_KEY_ENV = "JUDGE0_ATD_API_KEY"

def __init__(self, endpoint, host_header_value, api_key, **kwargs):
Expand Down Expand Up @@ -439,6 +445,8 @@ def get_submissions(


class Rapid(Client):
"""Base class for all RapidAPI clients."""

API_KEY_ENV = "JUDGE0_RAPID_API_KEY"

def __init__(self, endpoint, host_header_value, api_key, **kwargs):
Expand Down Expand Up @@ -482,6 +490,8 @@ def __init__(self, api_key, **kwargs):


class Sulu(Client):
"""Base class for all Sulu clients."""

API_KEY_ENV = "JUDGE0_SULU_API_KEY"

def __init__(self, endpoint, api_key=None, **kwargs):
Expand Down

0 comments on commit a63bae6

Please sign in to comment.