Skip to content

Commit

Permalink
creating unit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelgreca committed Nov 12, 2024
1 parent 82aab67 commit 928b826
Show file tree
Hide file tree
Showing 9 changed files with 611 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
logs/
ipynb_checkpoints/
mlruns
mlartifacts
Expand Down
36 changes: 28 additions & 8 deletions src/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,44 @@ def load_feature(
@logger.catch
def download_dataset(
name: str,
new_name: str,
path: pathlib.Path,
send_to_aws: bool,
) -> None:
"""Dowload the dataset using Kaggle's API.
Args:
name (str): the dataset's name.
new_name (str): the dataset file's new name.
path (pathlib.Path): the path where the dataset will be stored locally.
send_to_aws (bool): whether the dataset will be send to an AWS S3 bucket or not.
"""
kaggle_user = kaggle_credentials.KAGGLE_USERNAME
kaggle_key = kaggle_credentials.KAGGLE_KEY
path = '../data/'
os.environ["KAGGLE_USERNAME"] = kaggle_credentials.KAGGLE_USERNAME
os.environ["KAGGLE_KEY"] = kaggle_credentials.KAGGLE_KEY

logger.info(f"Downloading dataset {name} and saving into the folder {path}.")

# Downloading data using the Kaggle API through the terminal
os.system(f'export KAGGLE_USERNAME={kaggle_user}; export KAGGLE_KEY={kaggle_key};')
os.system(f'kaggle datasets download -d {name} -p {path} --unzip')
# os.system(f'export KAGGLE_USERNAME={kaggle_user}; export KAGGLE_KEY={kaggle_key};')
os.system(f'kaggle datasets download -d {name} --unzip')
os.system(
f'mv ObesityDataSet.csv {pathlib.Path.joinpath(path, new_name)}'
)

# Sending the dataset to the AWS S3 bucket
if aws_credentials.S3 != "YOUR_S3_BUCKET_URL":
send_dataset_to_s3()

if send_to_aws:
if aws_credentials.S3 != "YOUR_S3_BUCKET_URL":
send_dataset_to_s3(
file_path=path,
file_name=new_name,
)
else:
logger.warning(
"The S3 Bucket url was not specified in the 'credentials.yaml' file. " +
"Therefore, the dataset will not be send to S3 and it will be kept saved locally."
)

@logger.catch
def send_dataset_to_s3(
file_path: pathlib.Path,
file_name: str,
Expand All @@ -71,3 +89,5 @@ def send_dataset_to_s3(
aws_credentials.S3,
file_name,
)

os.remove(pathlib.Path.joinpath(file_path, file_name))
9 changes: 7 additions & 2 deletions src/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,21 @@ def load(self) -> None:
logger.critical(f"Couldn't load the model using the flavor {model_settings.MODEL_FLAVOR}.")
raise NotImplementedError()

def predict(self, x: np.ndarray) -> np.ndarray:
def predict(self, x: np.ndarray, transform_to_str: bool = True) -> np.ndarray:
"""Uses the trained model to make a prediction on a given feature array.
Args:
x (np.ndarray): the features array.
transform_to_str (bool): whether to transform the prediction integer to
string or not. Defaults to True.
Returns:
np.ndarray: the predictions array.
"""
prediction = self.model.predict(x)
prediction = label_encoder.inverse_transform(prediction)

if transform_to_str:
prediction = label_encoder.inverse_transform(prediction)

logger.info(f"Prediction: {prediction}.")
return prediction
Empty file added tests/__init__.py
Empty file.
Empty file added tests/unit/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions tests/unit/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import json
from pathlib import Path
from typing import Dict

import requests

from src.config.model import model_settings
from src.config.settings import general_settings

with open(
f"{Path.joinpath(general_settings.RESEARCH_ENVIRONMENT_PATH, 'VERSION')}",
"r",
encoding="utf-8"
) as f:
CODE_VERSION = f.readline().strip()

def test_version_endpoint() -> None:
"""
Unit case to test the API's version endpoint.
"""
desired_keys = ["model_version", "code_version"]

response = requests.get("http://127.0.0.1:8000/version", timeout=100)
content = json.loads(response.text)

assert response.status_code == 200
assert isinstance(content, Dict)
assert all(dk in content.keys() for dk in desired_keys)
assert model_settings.VERSION == content[desired_keys[0]]
assert CODE_VERSION == content[desired_keys[1]]

def test_inference_endpoint() -> None:
"""
Unit case to test the API's inference endpoint.
"""
desired_classes = [["Normal_Weight"]]
desired_keys = ["predictions"]

data = {
"Age": 21,
"CAEC": "Sometimes",
"CALC": "no",
"FAF": 0,
"FCVC": 2,
"Gender": "Female",
"Height": 1.62,
"MTRANS": "Public_Transportation",
"SCC": "no",
"SMOKE": "False",
"TUE": 1,
"Weight": 64
}

response = requests.get("http://127.0.0.1:8000/predict", json=data, timeout=100)
content = json.loads(response.text)

assert response.status_code == 200
assert isinstance(content, Dict)
assert all(dk in content.keys() for dk in desired_keys)
assert content[desired_keys[0]] == desired_classes
Loading

0 comments on commit 928b826

Please sign in to comment.