Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create datasetBlockages collection + block datasets #2933

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/libcommon/src/libcommon/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QUEUE_COLLECTION_JOBS = "jobsBlue"
QUEUE_COLLECTION_PAST_JOBS = "pastJobs"
QUEUE_COLLECTION_LOCKS = "locks"
QUEUE_COLLECTION_DATASET_BLOCKAGES = "datasetBlockages"
QUEUE_MONGOENGINE_ALIAS = "queue"
QUEUE_TTL_SECONDS = 600 # 10 minutes
LOCK_TTL_SECONDS_NO_OWNER = 600 # 10 minutes
Expand Down
83 changes: 83 additions & 0 deletions libs/libcommon/src/libcommon/queue/dataset_blockages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The HuggingFace Authors.

import types
from typing import Generic, TypeVar

from mongoengine import Document
from mongoengine.fields import DateTimeField, StringField
from mongoengine.queryset.queryset import QuerySet

from libcommon.constants import (
QUEUE_COLLECTION_DATASET_BLOCKAGES,
QUEUE_MONGOENGINE_ALIAS,
)
from libcommon.utils import get_datetime

# START monkey patching ### hack ###
# see https://github.com/sbdchd/mongo-types#install
U = TypeVar("U", bound=Document)


def no_op(self, _): # type: ignore
return self


QuerySet.__class_getitem__ = types.MethodType(no_op, QuerySet)


class QuerySetManager(Generic[U]):
def __get__(self, instance: object, cls: type[U]) -> QuerySet[U]:
return QuerySet(cls, cls._get_collection())


# END monkey patching ### hack ###

# delete the dataset blockage (ie. potentially unblock it) after 1 hour
DATASET_BLOCKAGE_EXPIRE_AFTER_SECONDS = 1 * 60 * 60


class DatasetBlockageDocument(Document):
"""A decision to block (rate-limit) a dataset. The same dataset can be blocked several times.
It is released automatically when the blockage expires (DATASET_BLOCKAGE_EXPIRE_AFTER_SECONDS).

Args:
dataset (`str`): The dataset on which to apply the job.
blocked_at (`datetime`): The date the dataset has been blocked.
"""

meta = {
"collection": QUEUE_COLLECTION_DATASET_BLOCKAGES,
"db_alias": QUEUE_MONGOENGINE_ALIAS,
"indexes": [
("dataset"),
{
"name": "DATASET_BLOCKAGE_EXPIRE_AFTER_SECONDS",
"fields": ["blocked_at"],
"expireAfterSeconds": DATASET_BLOCKAGE_EXPIRE_AFTER_SECONDS,
},
],
}
dataset = StringField(required=True)
blocked_at = DateTimeField(required=True)

objects = QuerySetManager["DatasetBlockageDocument"]()


def block_dataset(dataset: str) -> None:
"""Create a dataset blockage in the mongoDB database, at the current time.

Args:
dataset (`str`): The dataset to block.
"""
DatasetBlockageDocument(dataset=dataset, blocked_at=get_datetime()).save()


def get_blocked_datasets() -> list[str]:
"""Return the list of blocked datasets."""
return DatasetBlockageDocument.objects().distinct("dataset")


def is_blocked(dataset: str) -> bool:
"""Return True if the dataset is blocked."""
return DatasetBlockageDocument.objects(dataset=dataset).count() > 0
34 changes: 21 additions & 13 deletions libs/libcommon/src/libcommon/queue/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
QUEUE_MONGOENGINE_ALIAS,
)
from libcommon.dtos import FlatJobInfo, JobInfo, Priority, Status, WorkerSize
from libcommon.queue.dataset_blockages import get_blocked_datasets
from libcommon.queue.lock import lock, release_lock, release_locks
from libcommon.queue.metrics import (
decrease_metric,
increase_metric,
update_metrics_for_type,
)
from libcommon.queue.past_jobs import NegativeDurationError, create_past_job
from libcommon.queue.past_jobs import create_past_job
from libcommon.utils import get_datetime, inputs_to_string

# START monkey patching ### hack ###
Expand Down Expand Up @@ -113,6 +114,7 @@ class JobQueryFilters(TypedDict, total=False):
type__in: list[str]
difficulty__gt: int
difficulty__lte: int
dataset__nin: list[str]


PA_SCHEMA = Schema(
Expand Down Expand Up @@ -383,7 +385,8 @@ def _get_next_waiting_job_for_priority(
"""Get the next job in the queue for a given priority.

For a given priority, get the waiting job with the oldest creation date:
- among the datasets that still have no started job.
- among the datasets that are not rate-limited
- among the datasets that still have no started job
- if none, among the datasets that have the least started jobs:
- ensuring that the unicity_id field is unique among the started jobs.

Expand All @@ -404,6 +407,9 @@ def _get_next_waiting_job_for_priority(
f"Getting next waiting job for priority {priority}, blocked types: {job_types_blocked}, only types:"
f" {job_types_only}"
)
blocked_datasets = get_blocked_datasets()
logging.debug(f"Blocked datasets: {blocked_datasets}")

filters: JobQueryFilters = {}
if job_types_blocked:
filters["type__nin"] = job_types_blocked
Expand All @@ -413,14 +419,19 @@ def _get_next_waiting_job_for_priority(
filters["difficulty__gt"] = difficulty_min
if difficulty_max is not None and difficulty_max < DEFAULT_DIFFICULTY_MAX:
filters["difficulty__lte"] = difficulty_max
if blocked_datasets:
filters["dataset__nin"] = blocked_datasets
started_jobs = JobDocument.objects(status=Status.STARTED, **filters)
logging.debug(f"Number of started jobs: {started_jobs.count()}")
started_job_namespaces = [job.namespace for job in started_jobs.only("namespace")]
logging.debug(f"Started job namespaces: {started_job_namespaces}")

next_waiting_job = (
JobDocument.objects(
status=Status.WAITING, namespace__nin=set(started_job_namespaces), priority=priority, **filters
status=Status.WAITING,
namespace__nin=set(started_job_namespaces),
priority=priority,
**filters,
)
.order_by("+created_at")
.only("type", "dataset", "revision", "config", "split", "priority", "unicity_id")
Expand All @@ -436,6 +447,7 @@ def _get_next_waiting_job_for_priority(
# all the waiting jobs, if any, are for namespaces that already have started jobs.
#
# Let's:
# - exclude the blocked datasets
# - exclude the waiting jobs which unicity_id is already in a started job
# and, among the remaining waiting jobs, let's:
# - select the oldest waiting job for the namespace with the least number of started jobs
Expand Down Expand Up @@ -478,6 +490,7 @@ def get_next_waiting_job(
"""Get the next job in the queue.

Get the waiting job with the oldest creation date with the following criteria:
- among the datasets that are not rate-limited,
- among the highest priority jobs,
- among the datasets that still have no started job.
- if none, among the datasets that have the least started jobs:
Expand Down Expand Up @@ -704,16 +717,11 @@ def finish_job(self, job_id: str) -> Optional[Priority]:
return None
decrease_metric(job_type=job.type, status=job.status, difficulty=job.difficulty)
if job.started_at is not None:
try:
create_past_job(
dataset=job.dataset,
started_at=pytz.UTC.localize(job.started_at),
finished_at=get_datetime(),
)
except NegativeDurationError:
logging.warning(
f"job {job_id} has a negative duration. The duration is not saved in the past jobs collection."
)
create_past_job(
dataset=job.dataset,
started_at=pytz.UTC.localize(job.started_at),
finished_at=get_datetime(),
)
job_priority = job.priority
job.delete()
release_locks(owner=job_id)
Expand Down
2 changes: 1 addition & 1 deletion libs/libcommon/src/libcommon/queue/lock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2022 The HuggingFace Authors.
# Copyright 2024 The HuggingFace Authors.

import contextlib
import json
Expand Down
2 changes: 1 addition & 1 deletion libs/libcommon/src/libcommon/queue/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2022 The HuggingFace Authors.
# Copyright 2024 The HuggingFace Authors.

import types
from typing import Generic, TypeVar
Expand Down
50 changes: 28 additions & 22 deletions libs/libcommon/src/libcommon/queue/past_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from typing import Generic, TypeVar

from mongoengine import Document
from mongoengine.errors import ValidationError
from mongoengine.fields import DateTimeField, FloatField, StringField
from mongoengine.fields import DateTimeField, IntField, StringField
from mongoengine.queryset.queryset import QuerySet

from libcommon.constants import (
QUEUE_COLLECTION_PAST_JOBS,
QUEUE_MONGOENGINE_ALIAS,
)
from libcommon.queue.dataset_blockages import block_dataset, is_blocked

# START monkey patching ### hack ###
# see https://github.com/sbdchd/mongo-types#install
Expand All @@ -34,55 +34,61 @@ def __get__(self, instance: object, cls: type[U]) -> QuerySet[U]:

# END monkey patching ### hack ###

# delete the past jobs after 6 hours
PAST_JOB_EXPIRE_AFTER_SECONDS = 6 * 60 * 60
# we allow 10 hours of compute (parallel jobs) every hour, i.e. 10 dedicated machines
MAX_MACHINES = 10
# we look at the last 6 hours to decide to rate-limit a dataset
RATE_LIMIT_WINDOW_SECONDS = 6 * 60 * 60
# total jobs duration that triggers rate-limiting a dataset
DATASET_BLOCKAGE_THRESHOLD_SECONDS = MAX_MACHINES * RATE_LIMIT_WINDOW_SECONDS
# don't check for rate-limiting if the duration is super short
JOB_DURATION_CHECK_MIN_SECONDS = 5 * 60
# don't record short durations, because they will not have impact, but can clutter the collection
JOB_DURATION_MIN_SECONDS = 30


class PastJobDocument(Document):
"""A job in the mongoDB database
"""The duration of a job that has been completed.

Args:
dataset (`str`): The dataset on which to apply the job.
duration (`float`): The duration of the job, in seconds.
duration (`int`): The duration of the job, in seconds.
finished_at (`datetime`): The date the job has finished.
"""

meta = {
"collection": QUEUE_COLLECTION_PAST_JOBS,
"db_alias": QUEUE_MONGOENGINE_ALIAS,
"indexes": [
("dataset", "duration"),
{
"name": "PAST_JOB_EXPIRE_AFTER_SECONDS",
"fields": ["finished_at"],
"expireAfterSeconds": PAST_JOB_EXPIRE_AFTER_SECONDS,
"expireAfterSeconds": RATE_LIMIT_WINDOW_SECONDS,
},
],
}
dataset = StringField(required=True)
duration = FloatField(required=True, min_value=0.0)
duration = IntField(required=True, min_value=0)
finished_at = DateTimeField(required=True)

objects = QuerySetManager["PastJobDocument"]()


class NegativeDurationError(ValidationError):
pass


def create_past_job(dataset: str, started_at: datetime, finished_at: datetime) -> None:
"""Create a past job in the mongoDB database
"""Create a past job in the mongoDB database.

After creating the entry, we check if it should be rate-limited (if it isn't yet), and if so, we block
the dataset.

Args:
dataset (`str`): The dataset on which to apply the job.
started_at (`datetime`): The date the job has started.
finished_at (`datetime`): The date the job has finished.

Raises:
ValidationError: If the duration is negative.
"""
duration = (finished_at - started_at).total_seconds()
try:
PastJobDocument(dataset=dataset, duration=duration, finished_at=finished_at).save()
except ValidationError as e:
raise NegativeDurationError("The duration of the job cannot be negative.") from e
duration = int((finished_at - started_at).total_seconds())
if duration < JOB_DURATION_MIN_SECONDS:
return
PastJobDocument(dataset=dataset, duration=duration, finished_at=finished_at).save()

if not is_blocked(dataset) and duration > JOB_DURATION_CHECK_MIN_SECONDS:
if PastJobDocument.objects(dataset=dataset).sum("duration") > DATASET_BLOCKAGE_THRESHOLD_SECONDS:
block_dataset(dataset)
2 changes: 2 additions & 0 deletions libs/libcommon/src/libcommon/queue/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The HuggingFace Authors.

from .dataset_blockages import DatasetBlockageDocument
from .jobs import JobDocument
from .lock import Lock
from .metrics import JobTotalMetricDocument, WorkerSizeJobsCountDocument
Expand All @@ -15,3 +16,4 @@ def _clean_queue_database() -> None:
WorkerSizeJobsCountDocument.drop_collection() # type: ignore
Lock.drop_collection() # type: ignore
PastJobDocument.drop_collection() # type: ignore
DatasetBlockageDocument.drop_collection() # type: ignore
31 changes: 31 additions & 0 deletions libs/libcommon/tests/queue/test_dataset_blockages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2022 The HuggingFace Authors.

import pytest

from libcommon.queue.dataset_blockages import block_dataset, get_blocked_datasets, is_blocked
from libcommon.resources import QueueMongoResource


@pytest.fixture(autouse=True)
def queue_mongo_resource_autouse(queue_mongo_resource: QueueMongoResource) -> QueueMongoResource:
return queue_mongo_resource


@pytest.mark.parametrize(
"datasets,expected_datasets",
[
([], []),
(["dataset"], ["dataset"]),
(["dataset", "dataset"], ["dataset"]),
(["dataset1", "dataset2"], ["dataset1", "dataset2"]),
],
)
def test_dataset_blockage(datasets: list[str], expected_datasets: set[str]) -> None:
for dataset in datasets:
block_dataset(dataset=dataset)

assert sorted(get_blocked_datasets()) == sorted(expected_datasets)
for dataset in expected_datasets:
assert is_blocked(dataset=dataset)
assert not is_blocked(dataset="not_blocked_dataset")
Loading
Loading