From 7253a7f34760ffccabfe3a3e08ffb37c6043d418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Mon, 4 Nov 2024 22:51:30 +0100 Subject: [PATCH] Minor refactor of _create_default_ce_client. Use raise from when authentication fails in Client.__init__. --- src/judge0/__init__.py | 38 ++++++++++---------------------------- src/judge0/clients.py | 24 +++++++++++++++++++----- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/src/judge0/__init__.py b/src/judge0/__init__.py index d31de23..a6e6f60 100644 --- a/src/judge0/__init__.py +++ b/src/judge0/__init__.py @@ -1,26 +1,14 @@ import os from .api import async_execute, execute, sync_execute, wait -from .clients import ( - ATDJudge0CE, - ATDJudge0ExtraCE, - Client, - RapidJudge0CE, - RapidJudge0ExtraCE, - SuluJudge0CE, - SuluJudge0ExtraCE, -) +from .clients import CE, Client, EXTRA_CE from .retry import MaxRetries, MaxWaitTime, RegularPeriodRetry from .submission import Submission __all__ = [ - ATDJudge0CE, - ATDJudge0ExtraCE, Client, - RapidJudge0CE, - RapidJudge0ExtraCE, - SuluJudge0CE, - SuluJudge0ExtraCE, + *CE, + *EXTRA_CE, Submission, RegularPeriodRetry, MaxRetries, @@ -40,19 +28,13 @@ def _create_default_ce_client(): except: # noqa: E722 pass - rapid_api_key = os.getenv("JUDGE0_RAPID_API_KEY") - sulu_api_key = os.getenv("JUDGE0_SULU_API_KEY") - atd_api_key = os.getenv("JUDGE0_ATD_API_KEY") - - if rapid_api_key is not None: - client = RapidJudge0CE(api_key=rapid_api_key) - elif sulu_api_key is not None: - client = SuluJudge0CE(api_key=sulu_api_key) - elif atd_api_key is not None: - client = ATDJudge0CE(api_key=atd_api_key) - else: - # TODO: Create SuluJudge0CE with default client. - client = None + client = None + for client_class in CE: + api_key = os.getenv(client_class.API_KEY_ENV) + if api_key is not None: + client = client_class(api_key) + + # TODO: If client is none, instantiate the default client. globals()["judge0_default_client"] = client diff --git a/src/judge0/clients.py b/src/judge0/clients.py index 310b0b3..81a2d06 100644 --- a/src/judge0/clients.py +++ b/src/judge0/clients.py @@ -1,20 +1,21 @@ -from typing import Iterable, Optional, Union +from typing import Iterable, Union import requests -from .retry import RegularPeriodRetry, RetryMechanism - from .submission import Submission class Client: + API_KEY_ENV = "JUDGE0_API_KEY" + def __init__(self, endpoint, auth_headers) -> None: self.endpoint = endpoint self.auth_headers = auth_headers + try: self.languages = {lang["id"]: lang for lang in self.get_languages()} - except Exception: - raise RuntimeError("Client authentication failed.") + except Exception as e: + raise RuntimeError("Client authentication failed.") from e def get_about(self) -> dict: r = requests.get( @@ -172,6 +173,8 @@ def get_submissions( class ATD(Client): + API_KEY_ENV = "JUDGE0_ATD_API_KEY" + def __init__(self, endpoint, host_header_value, api_key): self.api_key = api_key super().__init__( @@ -323,6 +326,8 @@ def get_submissions( class Rapid(Client): + API_KEY_ENV = "JUDGE0_RAPID_API_KEY" + def __init__(self, endpoint, host_header_value, api_key): self.api_key = api_key super().__init__( @@ -359,6 +364,8 @@ def __init__(self, api_key): class Sulu(Client): + API_KEY_ENV = "JUDGE0_SULU_API_KEY" + def __init__(self, endpoint, api_key): self.api_key = api_key super().__init__(endpoint, {"Authorization": f"Bearer {api_key}"}) @@ -376,3 +383,10 @@ class SuluJudge0ExtraCE(Sulu): def __init__(self, api_key): super().__init__(self.DEFAULT_ENDPOINT, api_key=api_key) + + +DEFAULT_CE_CLIENT_CLASS = SuluJudge0CE +DEFAULT_EXTRA_CE_CLASS = SuluJudge0ExtraCE + +CE = [RapidJudge0CE, SuluJudge0CE, ATDJudge0CE] +EXTRA_CE = [RapidJudge0ExtraCE, SuluJudge0ExtraCE, ATDJudge0ExtraCE]