From 14857c010f9f25a986f9dc9b7e0a12f62b2792f6 Mon Sep 17 00:00:00 2001 From: nazarfil Date: Tue, 9 Jul 2024 08:50:23 +0200 Subject: [PATCH] feat(Dataset): add dataset file snapshot model --- hexa/datasets/admin.py | 14 ++++++- hexa/datasets/graphql/schema.graphql | 25 +++++++++++- hexa/datasets/models.py | 60 ++++++++++++++++++++++++++++ hexa/datasets/permissions.py | 15 ++++++- hexa/datasets/queue.py | 49 +++++++++++++++++++---- hexa/datasets/schema/mutations.py | 4 +- hexa/datasets/schema/queries.py | 23 ++++++++++- hexa/files/basefs.py | 4 ++ hexa/files/gcp.py | 6 +++ hexa/files/s3.py | 7 +++- 10 files changed, 192 insertions(+), 15 deletions(-) diff --git a/hexa/datasets/admin.py b/hexa/datasets/admin.py index 7eb704a1e..43343064b 100644 --- a/hexa/datasets/admin.py +++ b/hexa/datasets/admin.py @@ -1,6 +1,12 @@ from django.contrib import admin -from .models import Dataset, DatasetLink, DatasetVersion, DatasetVersionFile +from .models import ( + Dataset, + DatasetFileSnapshot, + DatasetLink, + DatasetVersion, + DatasetVersionFile, +) @admin.register(Dataset) @@ -26,6 +32,12 @@ class DatasetVersionObjectAdmin(admin.ModelAdmin): list_filter = ("dataset_version__dataset", "created_by") +@admin.register(DatasetFileSnapshot) +class DatasetFileSnapshotAdmin(admin.ModelAdmin): + list_display = ("filename", "dataset_version_file") + list_filter = ("dataset_version_file__dataset_version__dataset", "created_by") + + @admin.register(DatasetLink) class DatasetLinkAdmin(admin.ModelAdmin): list_display = ("dataset", "workspace", "created_at", "created_by") diff --git a/hexa/datasets/graphql/schema.graphql b/hexa/datasets/graphql/schema.graphql index 6679ec85b..f4941fb29 100644 --- a/hexa/datasets/graphql/schema.graphql +++ b/hexa/datasets/graphql/schema.graphql @@ -405,11 +405,30 @@ type PinDatasetResult { errors: [PinDatasetError!]! } +input CreateDatasetFileSnapshotInput { + fileId: String! +} + +type DatasetFileSnapshot { + uri: String! + created_by: String! + dataset_version_file: DatasetVersionFile +} + +type CreateDatasetFileSnapshotResult { + dataset_file_snapshot : DatasetFileSnapshot + success: Boolean! + errors: [PrepareVersionFileDownloadError!]! +} + + extend type Query { "Get a dataset by its ID." dataset(id: ID!): Dataset "Get a dataset by its slug." datasetVersion(id: ID!): DatasetVersion + "Get a dataset file snapshot by fileId" + datasetFileSnapshot(id: ID, fileId: ID): DatasetFileSnapshot "Get a dataset link by its id." datasetLink(id: ID!): DatasetLink "Get a dataset link by its slug." @@ -418,6 +437,7 @@ extend type Query { datasets(query: String, page: Int = 1, perPage: Int = 15): DatasetPage! } + extend type Mutation { "Create a new dataset." createDataset(input: CreateDatasetInput!): CreateDatasetResult! @loginRequired @@ -431,6 +451,8 @@ extend type Mutation { deleteDatasetVersion(input: DeleteDatasetVersionInput!): DeleteDatasetVersionResult! @loginRequired "Create a new file in a dataset version." createDatasetVersionFile(input: CreateDatasetVersionFileInput!): CreateDatasetVersionFileResult! @loginRequired + "Create dataset version snapshot." + createDatasetVersionFileSnapshot(input: CreateDatasetFileSnapshotInput!): CreateDatasetFileSnapshotResult! @loginRequired "Prepare to download a file in a dataset version." prepareVersionFileDownload(input: PrepareVersionFileDownloadInput!): PrepareVersionFileDownloadResult! @loginRequired "Link a dataset with a workspace." @@ -439,4 +461,5 @@ extend type Mutation { deleteDatasetLink(input: DeleteDatasetLinkInput!): DeleteDatasetLinkResult! @loginRequired "Pin or unpin a dataset for a workspace." pinDataset(input: PinDatasetInput!): PinDatasetResult! @loginRequired -} \ No newline at end of file +} + diff --git a/hexa/datasets/models.py b/hexa/datasets/models.py index 7dd91b2cb..5742d9efd 100644 --- a/hexa/datasets/models.py +++ b/hexa/datasets/models.py @@ -255,6 +255,66 @@ class Meta: ordering = ["uri"] +class DatasetFileSnapshotQuerySet(BaseQuerySet): + def filter_for_user(self, user: AnonymousUser | User): + return self._filter_for_user_and_query_object( + user, + models.Q( + dataset_version_file__dataset_version__dataset__in=Dataset.objects.filter_for_user( + user + ), + return_all_if_superuser=False, + ), + ) + + +class DatasetFileSnapshotManager(models.Manager): + def create_if_has_perm( + self, + principal: User, + dataset_version_file: DatasetVersionFile, + *, + uri: str, + ): + from hexa.pipelines.authentication import PipelineRunUser + + if isinstance(principal, PipelineRunUser): + if ( + principal.pipeline_run.pipeline.workspace + != dataset_version_file.dataset_version.dataset.workspace + ): + raise PermissionDenied + elif not principal.has_perm( + "datasets.create_dataset_version_file_snapshot", dataset_version_file + ): + raise PermissionDenied + + created_by = principal if not isinstance(principal, PipelineRunUser) else None + return self.create( + dataset_version_file=dataset_version_file, + uri=uri, + created_by=created_by, + ) + + +class DatasetFileSnapshot(Base): + uri = models.TextField(null=False, blank=False, unique=True) + created_by = models.ForeignKey(User, null=True, on_delete=models.SET_NULL) + dataset_version_file = models.ForeignKey( + DatasetVersionFile, + null=False, + blank=False, + on_delete=models.CASCADE, + related_name="snapshots", + ) + + objects = DatasetFileSnapshotManager.from_queryset(DatasetFileSnapshotQuerySet)() + + @property + def filename(self): + return self.uri.split("/")[-1] + + class DatasetLinkQuerySet(BaseQuerySet): def filter_for_user(self, user: AnonymousUser | User): # FIXME: Use a generic permission system instead of differencing between User and PipelineRunUser diff --git a/hexa/datasets/permissions.py b/hexa/datasets/permissions.py index 0a6628f42..dbdc60880 100644 --- a/hexa/datasets/permissions.py +++ b/hexa/datasets/permissions.py @@ -1,4 +1,9 @@ -from hexa.datasets.models import Dataset, DatasetLink, DatasetVersion +from hexa.datasets.models import ( + Dataset, + DatasetLink, + DatasetVersion, + DatasetVersionFile, +) from hexa.user_management.models import User from hexa.workspaces.models import ( Workspace, @@ -105,3 +110,11 @@ def create_dataset_version_file(principal: User, dataset_version: DatasetVersion return False return create_dataset_version(principal, dataset_version.dataset) + + +def create_dataset_version_file_snapshot( + principal: User, dataset_version_file: DatasetVersionFile +): + if dataset_version_file != dataset_version_file.latest_version: + return False + return create_dataset_version_file(principal, dataset_version_file.dataset_version) diff --git a/hexa/datasets/queue.py b/hexa/datasets/queue.py index cb1b0fe8e..1eecde278 100644 --- a/hexa/datasets/queue.py +++ b/hexa/datasets/queue.py @@ -1,16 +1,51 @@ +import os.path from logging import getLogger from dpq.queue import AtLeastOnceQueue -from hexa.datasets.models import DatasetSnapshotJob +from hexa.datasets.models import ( + DatasetFileSnapshot, + DatasetSnapshotJob, + DatasetVersionFile, +) +from hexa.files.api import get_storage +from hexa.user_management.models import User logger = getLogger(__name__) - -def create_dataset_snnapshot_task(queue: AtLeastOnceQueue, job: DatasetSnapshotJob): - # TODO: imlpement ticket PATHWAYS-98 - extract data in background task - dataset_version_file_id = job.args["fileId"] - logger.info(f"Creating dataset version file {dataset_version_file_id}") +DEFAULT_SNAPSHOT_LINES = 500 + + +def create_dataset_snapshot_task(queue: AtLeastOnceQueue, job: DatasetSnapshotJob): + try: + dataset_version_file_id = job.args["file_id"] + user_id = job.args["user_id"] + logger.info( + f"Creating dataset snapshot for version file {dataset_version_file_id}" + ) + dataset_version_file = DatasetVersionFile.objects.get( + id=dataset_version_file_id + ) + user = User.objects.get(id=user_id) + + storage = get_storage() + dataset_snapshot = storage.read_object_lines( + dataset_version_file, DEFAULT_SNAPSHOT_LINES + ) + bucket_name = dataset_version_file.uri.split("/")[0] + filename, extension = os.path.splitext(dataset_version_file.uri) + upload_uri = f"{filename}-snapshot{extension}" + storage.upload_object_from_string(bucket_name, upload_uri, dataset_snapshot) + + logger.info( + f"Uploaded dataset snapshot to {upload_uri} for file {dataset_version_file_id}" + ) + DatasetFileSnapshot.objects.create_if_has_perm( + principal=user, dataset_version_file=dataset_version_file, uri=upload_uri + ) + logger.info("Dataset snapshot created for file {dataset_version_file_id}") + except Exception as e: + logger.exception(f"Failed to create dataset snapshot: \n {e}") class DatasetSnapshotQueue(AtLeastOnceQueue): @@ -19,7 +54,7 @@ class DatasetSnapshotQueue(AtLeastOnceQueue): dataset_snapshot_queue = DatasetSnapshotQueue( tasks={ - "create_snapshot": create_dataset_snnapshot_task, + "create_snapshot": create_dataset_snapshot_task, }, notify_channel="dataset_snapshot_queue", ) diff --git a/hexa/datasets/schema/mutations.py b/hexa/datasets/schema/mutations.py index 843d39394..9831b4bd7 100644 --- a/hexa/datasets/schema/mutations.py +++ b/hexa/datasets/schema/mutations.py @@ -209,9 +209,7 @@ def resolve_create_version_file(_, info, **kwargs): dataset_snapshot_queue.enqueue( { "create_snapshot", - { - "file_id": str(file.id), - }, + {"file_id": str(file.id), "user_id": str(request.user.id)}, } ) return { diff --git a/hexa/datasets/schema/queries.py b/hexa/datasets/schema/queries.py index 427c3e1fc..06b9657ce 100644 --- a/hexa/datasets/schema/queries.py +++ b/hexa/datasets/schema/queries.py @@ -2,7 +2,12 @@ from hexa.core.graphql import result_page -from ..models import Dataset, DatasetLink, DatasetVersion +from ..models import ( + Dataset, + DatasetFileSnapshot, + DatasetLink, + DatasetVersion, +) datasets_queries = QueryType() @@ -37,6 +42,22 @@ def resolve_dataset_version(_, info, **kwargs): return None +@datasets_queries.field("datasetFileSnapshot") +def resolve_dataset_file_snapshot(_, info, **kwargs): + request = info.context["request"] + try: + if kwargs.get("file_id"): + return DatasetFileSnapshot.objects.filter_for_user(request.user).get( + dataset_version_file=kwargs["file_id"] + ) + else: + return DatasetFileSnapshot.objects.filter_for_user(request.user).get( + id=kwargs["id"] + ) + except DatasetFileSnapshot.DoesNotExist: + return None + + @datasets_queries.field("datasetLink") def resolve_dataset_link(_, info, **kwargs): request = info.context["request"] diff --git a/hexa/files/basefs.py b/hexa/files/basefs.py index 890878ac8..b64aac483 100644 --- a/hexa/files/basefs.py +++ b/hexa/files/basefs.py @@ -46,6 +46,10 @@ def delete_bucket(self, bucket_name: str, fully: bool = False): def upload_object(self, bucket_name: str, file_name: str, source: str): pass + @abstractmethod + def upload_object_from_string(self, bucket_name: str, file_name: str, content: str): + pass + @abstractmethod def create_bucket_folder(self, bucket_name: str, folder_key: str): pass diff --git a/hexa/files/gcp.py b/hexa/files/gcp.py index 1206f0550..de096845d 100644 --- a/hexa/files/gcp.py +++ b/hexa/files/gcp.py @@ -159,6 +159,12 @@ def upload_object(self, bucket_name: str, file_name: str, source: str): blob = bucket.blob(file_name) blob.upload_from_filename(source) + def upload_object_from_string(self, bucket_name: str, file_name: str, content: str): + client = get_storage_client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(file_name) + blob.upload_from_string(content) + def create_bucket_folder(self, bucket_name: str, folder_key: str): client = get_storage_client() bucket = client.get_bucket(bucket_name) diff --git a/hexa/files/s3.py b/hexa/files/s3.py index 0d68eb73d..66dc77a8c 100644 --- a/hexa/files/s3.py +++ b/hexa/files/s3.py @@ -451,6 +451,10 @@ def get_token_as_env_variables(self, token): ).decode(), } + def upload_object_from_string(self, bucket_name: str, file_name: str, content: str): + s3 = get_storage_client() + s3.put_object(bucket_name, file_name, content) + def read_object_lines(self, bucket_name: str, filename: str, lines_number: int): s3 = get_storage_client() object = s3.get_object(Bucket=bucket_name, Key=filename) @@ -458,5 +462,6 @@ def read_object_lines(self, bucket_name: str, filename: str, lines_number: int): file_stream.seek(0) lines = file_stream.readlines() - specific_lines = [lines[i].decode("utf-8").strip() for i in range(lines_number)] + max_lines = min(lines_number, len(lines)) + specific_lines = [lines[i].decode("utf-8").strip() for i in range(max_lines)] return specific_lines