Skip to content

Commit

Permalink
Add check in create_submission(s) for supported language. Update test…
Browse files Browse the repository at this point in the history
…s to use fixtures.
  • Loading branch information
fkdosilovic committed Oct 8, 2024
1 parent 5efa149 commit 150568a
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 57 deletions.
29 changes: 23 additions & 6 deletions src/judge0/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, endpoint, auth_headers, *, wait=False) -> None:
self.wait = wait

try:
self.languages = self.get_languages()
self.languages = {lang["id"]: lang for lang in self.get_languages()}
except Exception:
raise RuntimeError("Client authentication failed.")

Expand Down Expand Up @@ -51,8 +51,17 @@ def get_statuses(self) -> list[dict]:
r.raise_for_status()
return r.json()

def is_language_supported(self, language_id: int) -> bool:
return language_id in self.languages

def create_submission(self, submission: Submission) -> None:
# TODO: check if client supports specified language_id
# Check if submission contains supported language.
if not self.is_language_supported(language_id=submission.language_id):
raise RuntimeError(
f"Client {type(self).__name__} does not support language with "
f"id {submission.language_id}!"
)

params = {
"base64_encoded": "true",
"wait": str(self.wait).lower(),
Expand Down Expand Up @@ -90,6 +99,14 @@ def get_submission(self, submission: Submission, *, fields=None) -> None:
submission.set_attributes(resp.json())

def create_submissions(self, submissions: list[Submission]) -> None:
# Check if all submissions contain supported language.
for submission in submissions:
if not self.is_language_supported(language_id=submission.language_id):
raise RuntimeError(
f"Client {type(self).__name__} does not support language with "
f"id {submission.language_id}!"
)

params = {
"base64_encoded": "true",
"wait": str(self.wait).lower(),
Expand Down Expand Up @@ -195,11 +212,11 @@ def get_submission(self, submission: Submission, *, fields=None) -> None:
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT)
return super().get_submission(submission, fields=fields)

def create_submissions(self, submissions: Submission) -> None:
def create_submissions(self, submissions: list[Submission]) -> None:
self._update_endpoint_header(self.DEFAULT_CREATE_SUBMISSIONS_ENDPOINT)
return super().create_submissions(submissions)

def get_submissions(self, submissions: Submission, *, fields=None) -> None:
def get_submissions(self, submissions: list[Submission], *, fields=None) -> None:
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT)
return super().get_submissions(submissions, fields=fields)

Expand Down Expand Up @@ -253,11 +270,11 @@ def get_submission(self, submission: Submission, *, fields=None):
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT)
return super().get_submission(submission, fields=fields)

def create_submissions(self, submission: Submission) -> None:
def create_submissions(self, submission: list[Submission]) -> None:
self._update_endpoint_header(self.DEFAULT_CREATE_SUBMISSIONS_ENDPOINT)
return super().create_submissions(submission)

def get_submissions(self, submission: Submission, *, fields=None) -> None:
def get_submissions(self, submission: list[Submission], *, fields=None) -> None:
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT)
return super().get_submissions(submission, fields=fields)

Expand Down
50 changes: 50 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

import pytest
from dotenv import load_dotenv

from judge0 import clients

load_dotenv()


@pytest.fixture(scope="session")
def atd_ce_client():
api_key = os.getenv("ATD_API_KEY")
client = clients.ATDJudge0CE(api_key)
return client


@pytest.fixture(scope="session")
def atd_extra_ce_client():
api_key = os.getenv("ATD_API_KEY")
client = clients.ATDJudge0ExtraCE(api_key)
return client


@pytest.fixture(scope="session")
def rapid_ce_client():
api_key = os.getenv("RAPID_API_KEY")
client = clients.RapidJudge0CE(api_key)
return client


@pytest.fixture(scope="session")
def rapid_extra_ce_client():
api_key = os.getenv("RAPID_API_KEY")
client = clients.RapidJudge0ExtraCE(api_key)
return client


@pytest.fixture(scope="session")
def sulu_ce_client():
api_key = os.getenv("SULU_API_KEY")
client = clients.SuluJudge0CE(api_key)
return client


@pytest.fixture(scope="session")
def sulu_extra_ce_client():
api_key = os.getenv("SULU_API_KEY")
client = clients.SuluJudge0ExtraCE(api_key)
return client
79 changes: 28 additions & 51 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,46 @@
import os

import pytest
from dotenv import load_dotenv

from judge0 import clients

load_dotenv()


def test_atd_ce_client():
api_key = os.getenv("ATD_API_KEY")
DEFAULT_CLIENTS = [
"atd_ce_client",
"atd_extra_ce_client",
"rapid_ce_client",
"rapid_extra_ce_client",
"sulu_ce_client",
"sulu_extra_ce_client",
]

client = clients.ATDJudge0CE(api_key)

@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
def test_get_about(client, request):
client = request.getfixturevalue(client)
client.get_about()
client.get_config_info()
client.get_languages()
client.get_statuses()


def test_atd_extra_ce_client():
api_key = os.getenv("ATD_API_KEY")
client = clients.ATDJudge0ExtraCE(api_key)

client.get_about()
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
def test_get_config_info(client, request):
client = request.getfixturevalue(client)
client.get_config_info()
client.get_languages()
client.get_statuses()


def test_rapid_ce_client():
api_key = os.getenv("RAPID_API_KEY")
client = clients.RapidJudge0CE(api_key)

client.get_about()
client.get_config_info()
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
def test_get_languages(client, request):
client = request.getfixturevalue(client)
client.get_languages()
client.get_statuses()


def test_rapid_extra_ce_client():
api_key = os.getenv("RAPID_API_KEY")
client = clients.RapidJudge0ExtraCE(api_key)

client.get_about()
client.get_config_info()
client.get_languages()
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
def test_get_statuses(client, request):
client = request.getfixturevalue(client)
client.get_statuses()


def test_sulu_ce_client():
api_key = os.getenv("SULU_API_KEY")
client = clients.SuluJudge0CE(api_key)

client.get_about()
client.get_config_info()
client.get_languages()
client.get_statuses()

@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
def test_is_language_supported_multi_file_submission(client, request):
client = request.getfixturevalue(client)
assert client.is_language_supported(89)

def test_sulu_extra_ce_client():
api_key = os.getenv("SULU_API_KEY")
client = clients.SuluJudge0ExtraCE(api_key)

client.get_about()
client.get_config_info()
client.get_languages()
client.get_statuses()
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
def test_is_language_supported_non_valid_lang_id(client, request):
client = request.getfixturevalue(client)
assert not client.is_language_supported(-1)

0 comments on commit 150568a

Please sign in to comment.