Skip to content

Commit

Permalink
feat(Dataset): add dataset file snapshot model
Browse files Browse the repository at this point in the history
  • Loading branch information
nazarfil committed Jul 10, 2024
1 parent 06bc234 commit 14857c0
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 15 deletions.
14 changes: 13 additions & 1 deletion hexa/datasets/admin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from django.contrib import admin

from .models import Dataset, DatasetLink, DatasetVersion, DatasetVersionFile
from .models import (
Dataset,
DatasetFileSnapshot,
DatasetLink,
DatasetVersion,
DatasetVersionFile,
)


@admin.register(Dataset)
Expand All @@ -26,6 +32,12 @@ class DatasetVersionObjectAdmin(admin.ModelAdmin):
list_filter = ("dataset_version__dataset", "created_by")


@admin.register(DatasetFileSnapshot)
class DatasetFileSnapshotAdmin(admin.ModelAdmin):
list_display = ("filename", "dataset_version_file")
list_filter = ("dataset_version_file__dataset_version__dataset", "created_by")


@admin.register(DatasetLink)
class DatasetLinkAdmin(admin.ModelAdmin):
list_display = ("dataset", "workspace", "created_at", "created_by")
Expand Down
25 changes: 24 additions & 1 deletion hexa/datasets/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,30 @@ type PinDatasetResult {
errors: [PinDatasetError!]!
}

input CreateDatasetFileSnapshotInput {
fileId: String!
}

type DatasetFileSnapshot {
uri: String!
created_by: String!
dataset_version_file: DatasetVersionFile
}

type CreateDatasetFileSnapshotResult {
dataset_file_snapshot : DatasetFileSnapshot
success: Boolean!
errors: [PrepareVersionFileDownloadError!]!
}


extend type Query {
"Get a dataset by its ID."
dataset(id: ID!): Dataset
"Get a dataset by its slug."
datasetVersion(id: ID!): DatasetVersion
"Get a dataset file snapshot by fileId"
datasetFileSnapshot(id: ID, fileId: ID): DatasetFileSnapshot
"Get a dataset link by its id."
datasetLink(id: ID!): DatasetLink
"Get a dataset link by its slug."
Expand All @@ -418,6 +437,7 @@ extend type Query {
datasets(query: String, page: Int = 1, perPage: Int = 15): DatasetPage!
}


extend type Mutation {
"Create a new dataset."
createDataset(input: CreateDatasetInput!): CreateDatasetResult! @loginRequired
Expand All @@ -431,6 +451,8 @@ extend type Mutation {
deleteDatasetVersion(input: DeleteDatasetVersionInput!): DeleteDatasetVersionResult! @loginRequired
"Create a new file in a dataset version."
createDatasetVersionFile(input: CreateDatasetVersionFileInput!): CreateDatasetVersionFileResult! @loginRequired
"Create dataset version snapshot."
createDatasetVersionFileSnapshot(input: CreateDatasetFileSnapshotInput!): CreateDatasetFileSnapshotResult! @loginRequired
"Prepare to download a file in a dataset version."
prepareVersionFileDownload(input: PrepareVersionFileDownloadInput!): PrepareVersionFileDownloadResult! @loginRequired
"Link a dataset with a workspace."
Expand All @@ -439,4 +461,5 @@ extend type Mutation {
deleteDatasetLink(input: DeleteDatasetLinkInput!): DeleteDatasetLinkResult! @loginRequired
"Pin or unpin a dataset for a workspace."
pinDataset(input: PinDatasetInput!): PinDatasetResult! @loginRequired
}
}

60 changes: 60 additions & 0 deletions hexa/datasets/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,66 @@ class Meta:
ordering = ["uri"]


class DatasetFileSnapshotQuerySet(BaseQuerySet):
def filter_for_user(self, user: AnonymousUser | User):
return self._filter_for_user_and_query_object(
user,
models.Q(
dataset_version_file__dataset_version__dataset__in=Dataset.objects.filter_for_user(
user
),
return_all_if_superuser=False,
),
)


class DatasetFileSnapshotManager(models.Manager):
def create_if_has_perm(
self,
principal: User,
dataset_version_file: DatasetVersionFile,
*,
uri: str,
):
from hexa.pipelines.authentication import PipelineRunUser

if isinstance(principal, PipelineRunUser):
if (
principal.pipeline_run.pipeline.workspace
!= dataset_version_file.dataset_version.dataset.workspace
):
raise PermissionDenied
elif not principal.has_perm(
"datasets.create_dataset_version_file_snapshot", dataset_version_file
):
raise PermissionDenied

created_by = principal if not isinstance(principal, PipelineRunUser) else None
return self.create(
dataset_version_file=dataset_version_file,
uri=uri,
created_by=created_by,
)


class DatasetFileSnapshot(Base):
uri = models.TextField(null=False, blank=False, unique=True)
created_by = models.ForeignKey(User, null=True, on_delete=models.SET_NULL)
dataset_version_file = models.ForeignKey(
DatasetVersionFile,
null=False,
blank=False,
on_delete=models.CASCADE,
related_name="snapshots",
)

objects = DatasetFileSnapshotManager.from_queryset(DatasetFileSnapshotQuerySet)()

@property
def filename(self):
return self.uri.split("/")[-1]


class DatasetLinkQuerySet(BaseQuerySet):
def filter_for_user(self, user: AnonymousUser | User):
# FIXME: Use a generic permission system instead of differencing between User and PipelineRunUser
Expand Down
15 changes: 14 additions & 1 deletion hexa/datasets/permissions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from hexa.datasets.models import Dataset, DatasetLink, DatasetVersion
from hexa.datasets.models import (
Dataset,
DatasetLink,
DatasetVersion,
DatasetVersionFile,
)
from hexa.user_management.models import User
from hexa.workspaces.models import (
Workspace,
Expand Down Expand Up @@ -105,3 +110,11 @@ def create_dataset_version_file(principal: User, dataset_version: DatasetVersion
return False

return create_dataset_version(principal, dataset_version.dataset)


def create_dataset_version_file_snapshot(
principal: User, dataset_version_file: DatasetVersionFile
):
if dataset_version_file != dataset_version_file.latest_version:
return False
return create_dataset_version_file(principal, dataset_version_file.dataset_version)
49 changes: 42 additions & 7 deletions hexa/datasets/queue.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,51 @@
import os.path
from logging import getLogger

from dpq.queue import AtLeastOnceQueue

from hexa.datasets.models import DatasetSnapshotJob
from hexa.datasets.models import (
DatasetFileSnapshot,
DatasetSnapshotJob,
DatasetVersionFile,
)
from hexa.files.api import get_storage
from hexa.user_management.models import User

logger = getLogger(__name__)


def create_dataset_snnapshot_task(queue: AtLeastOnceQueue, job: DatasetSnapshotJob):
# TODO: imlpement ticket PATHWAYS-98 - extract data in background task
dataset_version_file_id = job.args["fileId"]
logger.info(f"Creating dataset version file {dataset_version_file_id}")
DEFAULT_SNAPSHOT_LINES = 500


def create_dataset_snapshot_task(queue: AtLeastOnceQueue, job: DatasetSnapshotJob):
try:
dataset_version_file_id = job.args["file_id"]
user_id = job.args["user_id"]
logger.info(
f"Creating dataset snapshot for version file {dataset_version_file_id}"
)
dataset_version_file = DatasetVersionFile.objects.get(
id=dataset_version_file_id
)
user = User.objects.get(id=user_id)

storage = get_storage()
dataset_snapshot = storage.read_object_lines(
dataset_version_file, DEFAULT_SNAPSHOT_LINES
)
bucket_name = dataset_version_file.uri.split("/")[0]
filename, extension = os.path.splitext(dataset_version_file.uri)
upload_uri = f"{filename}-snapshot{extension}"
storage.upload_object_from_string(bucket_name, upload_uri, dataset_snapshot)

logger.info(
f"Uploaded dataset snapshot to {upload_uri} for file {dataset_version_file_id}"
)
DatasetFileSnapshot.objects.create_if_has_perm(
principal=user, dataset_version_file=dataset_version_file, uri=upload_uri
)
logger.info("Dataset snapshot created for file {dataset_version_file_id}")
except Exception as e:
logger.exception(f"Failed to create dataset snapshot: \n {e}")


class DatasetSnapshotQueue(AtLeastOnceQueue):
Expand All @@ -19,7 +54,7 @@ class DatasetSnapshotQueue(AtLeastOnceQueue):

dataset_snapshot_queue = DatasetSnapshotQueue(
tasks={
"create_snapshot": create_dataset_snnapshot_task,
"create_snapshot": create_dataset_snapshot_task,
},
notify_channel="dataset_snapshot_queue",
)
4 changes: 1 addition & 3 deletions hexa/datasets/schema/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,7 @@ def resolve_create_version_file(_, info, **kwargs):
dataset_snapshot_queue.enqueue(
{
"create_snapshot",
{
"file_id": str(file.id),
},
{"file_id": str(file.id), "user_id": str(request.user.id)},
}
)
return {
Expand Down
23 changes: 22 additions & 1 deletion hexa/datasets/schema/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from hexa.core.graphql import result_page

from ..models import Dataset, DatasetLink, DatasetVersion
from ..models import (
Dataset,
DatasetFileSnapshot,
DatasetLink,
DatasetVersion,
)

datasets_queries = QueryType()

Expand Down Expand Up @@ -37,6 +42,22 @@ def resolve_dataset_version(_, info, **kwargs):
return None


@datasets_queries.field("datasetFileSnapshot")
def resolve_dataset_file_snapshot(_, info, **kwargs):
request = info.context["request"]
try:
if kwargs.get("file_id"):
return DatasetFileSnapshot.objects.filter_for_user(request.user).get(
dataset_version_file=kwargs["file_id"]
)
else:
return DatasetFileSnapshot.objects.filter_for_user(request.user).get(
id=kwargs["id"]
)
except DatasetFileSnapshot.DoesNotExist:
return None


@datasets_queries.field("datasetLink")
def resolve_dataset_link(_, info, **kwargs):
request = info.context["request"]
Expand Down
4 changes: 4 additions & 0 deletions hexa/files/basefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def delete_bucket(self, bucket_name: str, fully: bool = False):
def upload_object(self, bucket_name: str, file_name: str, source: str):
pass

@abstractmethod
def upload_object_from_string(self, bucket_name: str, file_name: str, content: str):
pass

@abstractmethod
def create_bucket_folder(self, bucket_name: str, folder_key: str):
pass
Expand Down
6 changes: 6 additions & 0 deletions hexa/files/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def upload_object(self, bucket_name: str, file_name: str, source: str):
blob = bucket.blob(file_name)
blob.upload_from_filename(source)

def upload_object_from_string(self, bucket_name: str, file_name: str, content: str):
client = get_storage_client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(file_name)
blob.upload_from_string(content)

def create_bucket_folder(self, bucket_name: str, folder_key: str):
client = get_storage_client()
bucket = client.get_bucket(bucket_name)
Expand Down
7 changes: 6 additions & 1 deletion hexa/files/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,17 @@ def get_token_as_env_variables(self, token):
).decode(),
}

def upload_object_from_string(self, bucket_name: str, file_name: str, content: str):
s3 = get_storage_client()
s3.put_object(bucket_name, file_name, content)

def read_object_lines(self, bucket_name: str, filename: str, lines_number: int):
s3 = get_storage_client()
object = s3.get_object(Bucket=bucket_name, Key=filename)
file_stream = io.BytesIO(object["Body"].read())
file_stream.seek(0)
lines = file_stream.readlines()

specific_lines = [lines[i].decode("utf-8").strip() for i in range(lines_number)]
max_lines = min(lines_number, len(lines))
specific_lines = [lines[i].decode("utf-8").strip() for i in range(max_lines)]
return specific_lines

0 comments on commit 14857c0

Please sign in to comment.