diff --git a/python/sdk/merlin/model.py b/python/sdk/merlin/model.py index 175cc4e70..2150305a0 100644 --- a/python/sdk/merlin/model.py +++ b/python/sdk/merlin/model.py @@ -58,7 +58,7 @@ from merlin.transformer import Transformer from merlin.util import ( autostr, - download_files_from_gcs, + download_files_from_blob_storage, extract_optional_value_with_default, guess_mlp_ui_url, valid_name_check, @@ -956,7 +956,7 @@ def download_artifact(self, destination_path): if artifact_uri is None or artifact_uri == "": raise Exception("There is no artifact uri for this model version") - download_files_from_gcs(artifact_uri, destination_path) + download_files_from_blob_storage(artifact_uri, destination_path) def log_artifacts(self, local_dir, artifact_path=None): """ diff --git a/python/sdk/merlin/util.py b/python/sdk/merlin/util.py index 2b01e36be..399570f16 100644 --- a/python/sdk/merlin/util.py +++ b/python/sdk/merlin/util.py @@ -14,6 +14,7 @@ import re import os +import boto3 from urllib.parse import urlparse from google.cloud import storage from os.path import dirname @@ -66,6 +67,11 @@ def valid_name_check(input_name: str) -> bool: return matching_group == input_name +def get_blob_storage_schema(artifact_uri: str) -> str: + parsed_result = urlparse(artifact_uri) + return parsed_result.scheme + + def get_bucket_name(gcs_uri: str) -> str: parsed_result = urlparse(gcs_uri) return parsed_result.netloc @@ -76,24 +82,42 @@ def get_gcs_path(gcs_uri: str) -> str: return parsed_result.path.strip("/") -def download_files_from_gcs(gcs_uri: str, destination_path: str): +def download_files_from_blob_storage(artifact_uri: str, destination_path: str): makedirs(destination_path, exist_ok=True) - client = storage.Client() - bucket_name = get_bucket_name(gcs_uri) - path = get_gcs_path(gcs_uri) - - bucket = client.get_bucket(bucket_name) - blobs = bucket.list_blobs(prefix=path) - for blob in blobs: - # Get only the path after .../artifacts/model - # E.g. - # Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb - # we only want to extract 1/saved_model.pb - artifact_path = os.path.join(*blob.name.split("/")[5:]) - dir = os.path.join(destination_path, dirname(artifact_path)) - makedirs(dir, exist_ok=True) - blob.download_to_filename(os.path.join(destination_path, artifact_path)) + storage_schema = get_blob_storage_schema(artifact_uri) + bucket_name = get_bucket_name(artifact_uri) + path = get_gcs_path(artifact_uri) + + if storage_schema == "gs": + client = storage.Client() + bucket = client.get_bucket(bucket_name) + blobs = bucket.list_blobs(prefix=path) + for blob in blobs: + # Get only the path after .../artifacts/model + # E.g. + # Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb + # we only want to extract 1/saved_model.pb + artifact_path = os.path.join(*blob.name.split("/")[5:]) + dir = os.path.join(destination_path, dirname(artifact_path)) + makedirs(dir, exist_ok=True) + blob.download_to_filename(os.path.join(destination_path, artifact_path)) + elif storage_schema == "s3": + client = boto3.client("s3") + bucket = client.list_objects_v2(Prefix=path, Bucket=bucket_name)["Contents"] + for s3_object in bucket: + # we do this because the list_objects_v2 method lists all subdirectories in addition to files + if not s3_object['Key'].endswith('/'): + # Get only the path after .../artifacts/model + # E.g. + # Some blob looks like this mlflow/3/ad8f15a4023f461796955f71e1152bac/artifacts/model/1/saved_model.pb + # we only want to extract 1/saved_model.pb + object_paths = s3_object['Key'].split("/")[5:] + if len(object_paths) != 0: + artifact_path = os.path.join(*object_paths) + os.makedirs(os.path.join(destination_path, dirname(artifact_path)), exist_ok=True) + client.download_file(bucket_name, s3_object['Key'], os.path.join(destination_path, artifact_path)) + def extract_optional_value_with_default(opt: Optional[Any], default: Any) -> Any: if opt is not None: