-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #50 from cloudfactory/import-image-from-bucket
Introduce method to Import image from bucket
- Loading branch information
Showing
9 changed files
with
200 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from collections import OrderedDict | ||
from dataclasses import dataclass | ||
from typing import Union, Protocol | ||
|
||
from .constants import BucketProviders | ||
from .hasty_object import HastyObject | ||
|
||
@dataclass | ||
class Credentials(Protocol): | ||
def get_credentials(self): | ||
raise NotImplementedError | ||
|
||
def cloud_provider(self): | ||
raise NotImplementedError | ||
|
||
@dataclass | ||
class DummyCreds(Credentials): | ||
secret: str | ||
|
||
def get_credentials(self): | ||
return {"secret": self.secret, "cloud_provider": BucketProviders.DUMMY} | ||
|
||
def cloud_provider(self): | ||
return BucketProviders.DUMMY | ||
|
||
@dataclass | ||
class GCSCreds(Credentials): | ||
bucket: str | ||
key_json: str | ||
|
||
def get_credentials(self): | ||
return {"bucket_gcs": self.bucket, "key_json": self.key_json, "cloud_provider": BucketProviders.GCS} | ||
|
||
def cloud_provider(self): | ||
return BucketProviders.GCS | ||
|
||
@dataclass | ||
class S3Creds(Credentials): | ||
bucket: str | ||
role: str | ||
|
||
def get_credentials(self): | ||
return {"bucket_s3": self.bucket, "role": self.role, "cloud_provider": BucketProviders.S3} | ||
|
||
def cloud_provider(self): | ||
return BucketProviders.S3 | ||
|
||
@dataclass | ||
class AZCreds(Credentials): | ||
account_name: str | ||
secret_access_key: str | ||
container: str | ||
|
||
def get_credentials(self): | ||
return {"account_name": self.account_name, "secret_access_key": self.secret_access_key, | ||
"container": self.container, "cloud_provider": BucketProviders.AZ} | ||
|
||
def cloud_provider(self): | ||
return BucketProviders.AZ | ||
|
||
class Bucket(HastyObject): | ||
"""Class that contains some basic requests and features for bucket management""" | ||
endpoint = '/v1/buckets/{workspace_id}/credentials' | ||
|
||
def __repr__(self): | ||
return self.get__repr__(OrderedDict({"id": self._id, "name": self._name, "cloud_provider": self._cloud_provider})) | ||
|
||
@property | ||
def id(self): | ||
""" | ||
:type: string | ||
""" | ||
return self._id | ||
|
||
@property | ||
def name(self): | ||
""" | ||
:type: string | ||
""" | ||
return self._name | ||
|
||
@property | ||
def cloud_provider(self): | ||
""" | ||
:type: string | ||
""" | ||
return self._cloud_provider | ||
|
||
def _init_properties(self): | ||
self._id = None | ||
self._name = None | ||
self._cloud_provider = None | ||
|
||
def _set_prop_values(self, data): | ||
if "credential_id" in data: | ||
self._id = data["credential_id"] | ||
if "description" in data: | ||
self._name = data["description"] | ||
if "cloud_provider" in data: | ||
self._cloud_provider = data["cloud_provider"] | ||
|
||
@staticmethod | ||
def _create_bucket(requester, workspace_id, name, credentials: Union[DummyCreds, GCSCreds, S3Creds, AZCreds]): | ||
json = {"description": name, "cloud_provider": credentials.cloud_provider(), **credentials.get_credentials()} | ||
data = requester.post(Bucket.endpoint.format(workspace_id=workspace_id), json_data=json) | ||
return Bucket(requester, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import unittest | ||
|
||
from tests.utils import get_client | ||
|
||
from hasty.bucket import S3Creds | ||
|
||
|
||
class TestBucketManagement(unittest.TestCase): | ||
def setUp(self): | ||
self.h = get_client() | ||
self.workspace = self.h.get_workspaces()[0] | ||
self.project = self.h.create_project(self.workspace, "Test Project 1") | ||
|
||
def tearDown(self): | ||
self.project.delete() | ||
|
||
def test_bucket_creation(self): | ||
ws = self.h.get_workspaces()[0] | ||
res = ws.create_bucket("test_bucket", S3Creds(bucket="hasty-public-bucket-mounter", role="arn:aws:iam::045521589961:role/hasty-public-bucket-mounter")) | ||
self.assertIsNotNone(res.id) | ||
self.assertEqual("test_bucket", res.name) | ||
self.assertEqual("s3", res.cloud_provider) | ||
|
||
def test_import_image(self): | ||
# create a bucket | ||
bucket = self.workspace.create_bucket("test_bucket", S3Creds(bucket="hasty-public-bucket-mounter", role="arn:aws:iam::045521589961:role/hasty-public-bucket-mounter")) | ||
|
||
# Import an image from the bucket | ||
dataset = self.project.create_dataset("ds2") | ||
img = self.project.upload_from_bucket(dataset, "1645001880-075718046bb2fbf9b8c35d6e88571cd7f91ca1a1.png", | ||
"dummy/1645001880-075718046bb2fbf9b8c35d6e88571cd7f91ca1a1.png", bucket.id) | ||
self.assertEqual("1645001880-075718046bb2fbf9b8c35d6e88571cd7f91ca1a1.png", img.name) | ||
self.assertEqual("ds2", img.dataset_name) | ||
self.assertIsNotNone(img.id) | ||
self.assertEqual(1280, img.width) | ||
self.assertEqual(720, img.height) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |