From 150568af7877e393d292a5272540af9be2bc4fbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Tue, 8 Oct 2024 19:21:49 +0200 Subject: [PATCH] Add check in create_submission(s) for supported language. Update tests to use fixtures. --- src/judge0/clients.py | 29 ++++++++++++---- tests/conftest.py | 50 +++++++++++++++++++++++++++ tests/test_clients.py | 79 +++++++++++++++---------------------------- 3 files changed, 101 insertions(+), 57 deletions(-) create mode 100644 tests/conftest.py diff --git a/src/judge0/clients.py b/src/judge0/clients.py index 024d2ae..20df389 100644 --- a/src/judge0/clients.py +++ b/src/judge0/clients.py @@ -11,7 +11,7 @@ def __init__(self, endpoint, auth_headers, *, wait=False) -> None: self.wait = wait try: - self.languages = self.get_languages() + self.languages = {lang["id"]: lang for lang in self.get_languages()} except Exception: raise RuntimeError("Client authentication failed.") @@ -51,8 +51,17 @@ def get_statuses(self) -> list[dict]: r.raise_for_status() return r.json() + def is_language_supported(self, language_id: int) -> bool: + return language_id in self.languages + def create_submission(self, submission: Submission) -> None: - # TODO: check if client supports specified language_id + # Check if submission contains supported language. + if not self.is_language_supported(language_id=submission.language_id): + raise RuntimeError( + f"Client {type(self).__name__} does not support language with " + f"id {submission.language_id}!" + ) + params = { "base64_encoded": "true", "wait": str(self.wait).lower(), @@ -90,6 +99,14 @@ def get_submission(self, submission: Submission, *, fields=None) -> None: submission.set_attributes(resp.json()) def create_submissions(self, submissions: list[Submission]) -> None: + # Check if all submissions contain supported language. + for submission in submissions: + if not self.is_language_supported(language_id=submission.language_id): + raise RuntimeError( + f"Client {type(self).__name__} does not support language with " + f"id {submission.language_id}!" + ) + params = { "base64_encoded": "true", "wait": str(self.wait).lower(), @@ -195,11 +212,11 @@ def get_submission(self, submission: Submission, *, fields=None) -> None: self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT) return super().get_submission(submission, fields=fields) - def create_submissions(self, submissions: Submission) -> None: + def create_submissions(self, submissions: list[Submission]) -> None: self._update_endpoint_header(self.DEFAULT_CREATE_SUBMISSIONS_ENDPOINT) return super().create_submissions(submissions) - def get_submissions(self, submissions: Submission, *, fields=None) -> None: + def get_submissions(self, submissions: list[Submission], *, fields=None) -> None: self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT) return super().get_submissions(submissions, fields=fields) @@ -253,11 +270,11 @@ def get_submission(self, submission: Submission, *, fields=None): self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT) return super().get_submission(submission, fields=fields) - def create_submissions(self, submission: Submission) -> None: + def create_submissions(self, submission: list[Submission]) -> None: self._update_endpoint_header(self.DEFAULT_CREATE_SUBMISSIONS_ENDPOINT) return super().create_submissions(submission) - def get_submissions(self, submission: Submission, *, fields=None) -> None: + def get_submissions(self, submission: list[Submission], *, fields=None) -> None: self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT) return super().get_submissions(submission, fields=fields) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7d97a3d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,50 @@ +import os + +import pytest +from dotenv import load_dotenv + +from judge0 import clients + +load_dotenv() + + +@pytest.fixture(scope="session") +def atd_ce_client(): + api_key = os.getenv("ATD_API_KEY") + client = clients.ATDJudge0CE(api_key) + return client + + +@pytest.fixture(scope="session") +def atd_extra_ce_client(): + api_key = os.getenv("ATD_API_KEY") + client = clients.ATDJudge0ExtraCE(api_key) + return client + + +@pytest.fixture(scope="session") +def rapid_ce_client(): + api_key = os.getenv("RAPID_API_KEY") + client = clients.RapidJudge0CE(api_key) + return client + + +@pytest.fixture(scope="session") +def rapid_extra_ce_client(): + api_key = os.getenv("RAPID_API_KEY") + client = clients.RapidJudge0ExtraCE(api_key) + return client + + +@pytest.fixture(scope="session") +def sulu_ce_client(): + api_key = os.getenv("SULU_API_KEY") + client = clients.SuluJudge0CE(api_key) + return client + + +@pytest.fixture(scope="session") +def sulu_extra_ce_client(): + api_key = os.getenv("SULU_API_KEY") + client = clients.SuluJudge0ExtraCE(api_key) + return client diff --git a/tests/test_clients.py b/tests/test_clients.py index c592aac..5368df6 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1,69 +1,46 @@ -import os - import pytest -from dotenv import load_dotenv - -from judge0 import clients - -load_dotenv() - -def test_atd_ce_client(): - api_key = os.getenv("ATD_API_KEY") +DEFAULT_CLIENTS = [ + "atd_ce_client", + "atd_extra_ce_client", + "rapid_ce_client", + "rapid_extra_ce_client", + "sulu_ce_client", + "sulu_extra_ce_client", +] - client = clients.ATDJudge0CE(api_key) +@pytest.mark.parametrize("client", DEFAULT_CLIENTS) +def test_get_about(client, request): + client = request.getfixturevalue(client) client.get_about() - client.get_config_info() - client.get_languages() - client.get_statuses() - -def test_atd_extra_ce_client(): - api_key = os.getenv("ATD_API_KEY") - client = clients.ATDJudge0ExtraCE(api_key) - client.get_about() +@pytest.mark.parametrize("client", DEFAULT_CLIENTS) +def test_get_config_info(client, request): + client = request.getfixturevalue(client) client.get_config_info() - client.get_languages() - client.get_statuses() - -def test_rapid_ce_client(): - api_key = os.getenv("RAPID_API_KEY") - client = clients.RapidJudge0CE(api_key) - client.get_about() - client.get_config_info() +@pytest.mark.parametrize("client", DEFAULT_CLIENTS) +def test_get_languages(client, request): + client = request.getfixturevalue(client) client.get_languages() - client.get_statuses() -def test_rapid_extra_ce_client(): - api_key = os.getenv("RAPID_API_KEY") - client = clients.RapidJudge0ExtraCE(api_key) - - client.get_about() - client.get_config_info() - client.get_languages() +@pytest.mark.parametrize("client", DEFAULT_CLIENTS) +def test_get_statuses(client, request): + client = request.getfixturevalue(client) client.get_statuses() -def test_sulu_ce_client(): - api_key = os.getenv("SULU_API_KEY") - client = clients.SuluJudge0CE(api_key) - - client.get_about() - client.get_config_info() - client.get_languages() - client.get_statuses() - +@pytest.mark.parametrize("client", DEFAULT_CLIENTS) +def test_is_language_supported_multi_file_submission(client, request): + client = request.getfixturevalue(client) + assert client.is_language_supported(89) -def test_sulu_extra_ce_client(): - api_key = os.getenv("SULU_API_KEY") - client = clients.SuluJudge0ExtraCE(api_key) - client.get_about() - client.get_config_info() - client.get_languages() - client.get_statuses() +@pytest.mark.parametrize("client", DEFAULT_CLIENTS) +def test_is_language_supported_non_valid_lang_id(client, request): + client = request.getfixturevalue(client) + assert not client.is_language_supported(-1)