Skip to content

Commit

Permalink
🦄 refactor: Refactor code and add train_test_split function
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPuQing committed Feb 25, 2024
1 parent 4e2969d commit 93387b9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
1 change: 0 additions & 1 deletion backend/app/app/db/init_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def init_db(session: Session) -> None:
is_active=False,
)
user = User.create(session, user_in)
logging.debug(f"User {user.email} created") # type: ignore

with open("/app/app/db/item.csv", "r") as file:
reader = csv.reader(file)
Expand Down
24 changes: 23 additions & 1 deletion backend/app/app/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Union

import gensim
import numpy as np
import pandas as pd
from celery import Task
from celery.utils.log import get_task_logger
Expand All @@ -17,7 +18,6 @@
from rectools.dataset import Dataset
from rectools.metrics import MAP, MeanInvUserFreq, Serendipity, calc_metrics
from rectools.models import ImplicitALSWrapperModel
from sklearn.model_selection import train_test_split
from sqlmodel import Session, col, select

from app import source
Expand Down Expand Up @@ -193,6 +193,28 @@ def byte_to_list_float(byte: bytes):
return list(struct.unpack("f" * (len(byte) // 4), byte))


def train_test_split(
df: DataFrame, test_size: float = 0.2, random_state: int = 32
):
"""
Split the dataset into train and test sets.
Args:
df (DataFrame): The dataset to split.
test_size (float): The size of the test set.
random_state (int): The random state for reproducibility.
Returns:
tuple: A tuple of train and test sets.
"""
np.random.seed(random_state)
shuffled_indices = np.random.permutation(len(df))
test_set_size = int(len(df) * test_size)
test_indices = shuffled_indices[:test_set_size]
train_indices = shuffled_indices[test_set_size:]
return df.iloc[train_indices], df.iloc[test_indices]


@celery_app.task(
acks_late=True,
base=DatabaseTask,
Expand Down
3 changes: 1 addition & 2 deletions backend/backend.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ COPY ./app/pyproject.toml ./app/poetry.lock* /app/

# Allow installing dev dependencies to run tests
ARG INSTALL_DEV=false
RUN if [ $INSTALL_DEV == 'true' ] ; then poetry install --no-root ; else poetry install --no-root --only main ; fi
RUN poetry run python -m pip install --no-use-pep517 rectools[lightfm]
RUN sh -c "if [ '$INSTALL_DEV' == 'true' ] ; then poetry install --no-root ; else poetry install --no-root ; fi"

COPY ./app /app

Expand Down
3 changes: 1 addition & 2 deletions backend/celeryworker.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ COPY ./app/pyproject.toml ./app/poetry.lock* /app/

# Allow installing dev dependencies to run tests
ARG INSTALL_DEV=false
RUN if [ $INSTALL_DEV == 'true' ] ; then poetry install --no-root ; else poetry install --no-root --only main ; fi
RUN poetry run python -m pip install --no-use-pep517 rectools[lightfm]
RUN sh -c "if [ '$INSTALL_DEV' == 'true' ] ; then poetry install --no-root ; else poetry install --no-root ; fi"

ENV C_FORCE_ROOT=1

Expand Down

0 comments on commit 93387b9

Please sign in to comment.