diff --git a/README.md b/README.md index 126c265..bec23ad 100644 --- a/README.md +++ b/README.md @@ -140,13 +140,21 @@ optional arguments: - GCS: Use `gcloud` CLI. - Using end-user credentials: You will be asked to enter credentials of your Google account. ``` - $ gcloud init + $ gcloud auth application-default login --no-launch-browser ``` - Using service account credentials: If you use a service account and a JSON key file associated with it. ``` $ gcloud auth activate-service-account --key-file=[YOUR_JSON_KEY.json] $ GOOGLE_APPLICATION_CREDENTIALS="PATH/FOR/YOUR_JSON_KEY.json" ``` + + Or import and call `add_google_app_creds_to_env()`. + ```python + import autouri + from autouri.gcsuri import add_google_app_creds_to_env + + add_google_app_creds_to_env('YOUR_JSON_KEY.json') + ``` Then set your default project. ``` $ gcloud config set project [YOUR_GCP_PROJECT_ID] diff --git a/autouri/__init__.py b/autouri/__init__.py index 167cb40..88ea2b7 100644 --- a/autouri/__init__.py +++ b/autouri/__init__.py @@ -5,4 +5,4 @@ from .s3uri import S3URI __all__ = ["AbsPath", "AutoURI", "URIBase", "GCSURI", "HTTPURL", "S3URI"] -__version__ = "0.2.0" +__version__ = "0.2.1" diff --git a/autouri/abspath.py b/autouri/abspath.py index 7f1c091..e7eb068 100644 --- a/autouri/abspath.py +++ b/autouri/abspath.py @@ -75,9 +75,9 @@ def _get_lock(self, timeout=None, poll_interval=None): def get_metadata(self, skip_md5=False, make_md5_file=False): """If md5 file doesn't exist then use hashlib.md5() to calculate md5 hash """ - ex = os.path.exists(self._uri) + exists = os.path.exists(self._uri) mt, sz, md5 = None, None, None - if ex: + if exists: mt = os.path.getmtime(self._uri) sz = os.path.getsize(self._uri) if not skip_md5: @@ -87,7 +87,7 @@ def get_metadata(self, skip_md5=False, make_md5_file=False): if make_md5_file: self.md5_file_uri.write(md5) - return URIMetadata(exists=ex, mtime=mt, size=sz, md5=md5) + return URIMetadata(exists=exists, mtime=mt, size=sz, md5=md5) def read(self, byte=False): if byte: diff --git a/autouri/gcsuri.py b/autouri/gcsuri.py index 8848d68..9c3c3cf 100644 --- a/autouri/gcsuri.py +++ b/autouri/gcsuri.py @@ -20,6 +20,7 @@ PermissionDenied, ServiceUnavailable, ) +from google.auth.exceptions import DefaultCredentialsError from google.cloud import storage from google.cloud.storage import Blob from google.oauth2.service_account import Credentials @@ -30,16 +31,67 @@ logger = logging.getLogger(__name__) +ENV_VAR_GOOGLE_APPLICATION_CREDENTIALS = "GOOGLE_APPLICATION_CREDENTIALS" + + +def add_google_app_creds_to_env(service_account_key_file): + """Google auth with a service account. + To globally use the key file for all GCS Client() with different thread IDs, + update environment variable `GOOGLE_APPLICATION_CREDENTIALS` with a given + service account key JSON file. + """ + service_account_key_file = os.path.abspath( + os.path.expanduser(service_account_key_file) + ) + if ENV_VAR_GOOGLE_APPLICATION_CREDENTIALS in os.environ: + existing_key_file = os.environ[ENV_VAR_GOOGLE_APPLICATION_CREDENTIALS] + if not os.path.samefile(existing_key_file, service_account_key_file): + logger.warning( + "Environment variable {env_var} does not match with " + "given service_account_key_file. " + "Using application default credentials? ".format( + env_var=ENV_VAR_GOOGLE_APPLICATION_CREDENTIALS + ) + ) + logger.debug( + "Adding service account key JSON {key_file} to " + "environment variable {env_var}".format( + key_file=service_account_key_file, + env_var=ENV_VAR_GOOGLE_APPLICATION_CREDENTIALS, + ) + ) + os.environ[ENV_VAR_GOOGLE_APPLICATION_CREDENTIALS] = service_account_key_file + + class GCSURILock(BaseFileLock): """Slow but stable locking with using GCS temporary_hold + Hold the lock file instead of the target file that lock file protects. + + Class constants: + DEFAULT_RETRY_RELEASE: + Retry if release (deletion) of a lock file fails. + DEFAULT_RETRY_RELEASE_INTERVAL: + Interval for retrial in seconds. """ + DEFAULT_RETRY_RELEASE = 3 + DEFAULT_RETRY_RELEASE_INTERVAL = 3 + def __init__( - self, lock_file, thread_id=-1, timeout=900, poll_interval=10.0, no_lock=False + self, + lock_file, + thread_id=-1, + timeout=900, + poll_interval=10.0, + retry_release=DEFAULT_RETRY_RELEASE, + retry_release_interval=DEFAULT_RETRY_RELEASE_INTERVAL, + no_lock=False, ): super().__init__(lock_file, timeout=timeout) self._poll_interval = poll_interval self._thread_id = thread_id + self._retry_release = retry_release + self._retry_release_interval = retry_release_interval def acquire(self, timeout=None, poll_intervall=5.0): """Use self._poll_interval instead of poll_intervall in args @@ -48,28 +100,42 @@ def acquire(self, timeout=None, poll_intervall=5.0): def _acquire(self): u = GCSURI(self._lock_file, thread_id=self._thread_id) - blob, bucket_obj = u.get_blob(new=True) - if blob is not None: - try: - blob.upload_from_string("") - blob.temporary_hold = True - blob.patch() - self._lock_file_fd = id(self) - except (Forbidden, GatewayTimeout, NotFound, ServiceUnavailable): - pass + try: + blob, bucket_obj = u.get_blob(new=True) + blob.upload_from_string("") + blob.temporary_hold = True + blob.patch() + self._lock_file_fd = id(self) + except (Forbidden, GatewayTimeout, NotFound, ServiceUnavailable): + pass + return None def _release(self): u = GCSURI(self._lock_file, thread_id=self._thread_id) - blob, _ = u.get_blob() - if blob is not None: - blob.temporary_hold = False + for retry in range(self._retry_release): try: + blob, _ = u.get_blob() + blob.temporary_hold = False blob.patch() blob.delete() self._lock_file_fd = None - except (NotFound,): - pass + break + except Exception as e: + error_msg = "{err}. Failed to delete a lock file: file={file}. " + if retry == self._retry_release - 1: + error_msg += ( + "You may need to manually delete a lock file. " + 'Use "gsutil retention temp release {file}" to unlock it first. ' + 'Then use "gsutil rm -f {file}" to delete it. ' + "Deleting a lock file itself does not affect " + "the file protected by it." + ) + error_msg = error_msg.format(err=e, file=self._lock_file) + + logger.error(error_msg) + + time.sleep(self._retry_release_interval) return None @@ -95,7 +161,7 @@ class GCSURI(URIBase): run "gsutil config" to generate corrensponding ~/.boto file. Protected class constants: - _CACHED_GCS_CLIENT_PER_THREAD: + _CACHED_GCS_CLIENTS: Per-thread GCS client object is required since GCS client is not thread-safe. _CACHED_PRESIGNED_URLS: @@ -111,7 +177,8 @@ class GCSURI(URIBase): RETRY_BUCKET_DELAY: int = 1 USE_GSUTIL_FOR_S3: bool = False - _CACHED_GCS_CLIENT_PER_THREAD = {} + _CACHED_GCS_CLIENTS = {} + _CACHED_GCS_ANONYMOUS_CLIENTS = {} _CACHED_PRESIGNED_URLS = {} _GCS_PUBLIC_URL_FORMAT = "http://storage.googleapis.com/{bucket}/{path}" @@ -134,36 +201,35 @@ def _get_lock(self, timeout=None, poll_interval=None): ) def get_metadata(self, skip_md5=False, make_md5_file=False): - ex, mt, sz, md5 = False, None, None, None + exists, mt, sz, md5 = False, None, None, None try: b, _ = self.get_blob() - if b is not None: - # make keys lower-case - h = {k.lower(): v for k, v in b._properties.items()} - ex = True - - if not skip_md5: - if "md5hash" in h: - md5 = parse_md5_str(h["md5hash"]) - elif "etag" in h: - md5 = parse_md5_str(h["etag"]) - if md5 is None: - # make_md5_file is ignored for GCSURI - md5 = self.md5_from_file - - if "size" in h: - sz = int(h["size"]) - - if "updated" in h: - mt = get_seconds_from_epoch(h["updated"]) - elif "timecreated" in h: - mt = get_seconds_from_epoch(h["timecreated"]) + # make keys lower-case + headers = {k.lower(): v for k, v in b._properties.items()} + exists = True + + if not skip_md5: + if "md5hash" in headers: + md5 = parse_md5_str(headers["md5hash"]) + elif "etag" in headers: + md5 = parse_md5_str(headers["etag"]) + if md5 is None: + # make_md5_file is ignored for GCSURI + md5 = self.md5_from_file + + if "size" in headers: + sz = int(headers["size"]) + + if "updated" in headers: + mt = get_seconds_from_epoch(headers["updated"]) + elif "timecreated" in headers: + mt = get_seconds_from_epoch(headers["timecreated"]) except Exception: - pass + logger.debug("Failed to get metadata from {uri}".format(uri=self._uri)) - return URIMetadata(exists=ex, mtime=mt, size=sz, md5=md5) + return URIMetadata(exists=exists, mtime=mt, size=sz, md5=md5) def read(self, byte=False): blob, _ = self.get_blob() @@ -218,7 +284,7 @@ def _cp(self, dest_uri): if isinstance(dest_uri, GCSURI): _, dest_path = dest_uri.get_bucket_path() - _, dest_bucket = dest_uri.get_blob() + _, dest_bucket = dest_uri.get_blob(new=True) src_bucket.copy_blob(src_blob, dest_bucket, dest_path) return True @@ -291,17 +357,25 @@ def _cp_from(self, src_uri): return False def get_blob(self, new=False) -> Blob: - """GCS client() has a bug that shows an outdated version of a file + """GCS Client() has a bug that shows an outdated version of a file when using Blob() without update(). For read-only functions (e.g. read()), need to directly call cl.get_bucket(bucket).get_blob(path) instead of using Blob() class. - Also, GCS client() is not thread-safe and it fails for a variety of reasons. + Also, GCS Client() is not thread-safe and it fails for a variety of reasons. Retry several times for whatever reasons. + Client.get_bucket() and Client.get_bucket().get_blob() can fail + even if the bucket is public (Storage Reader permission for allUsers + or allAuthenticatedUsers). + Needs an anonymous client (Client.create_anonymous_client()) for public buckets. + If this error occurs, then retry with an anonymous client. + Returns: - blob: Blob object or None - bucket_obj: Bucket object or None + blob: + Blob object + bucket_obj: + Bucket object """ bucket, path = self.get_bucket_path() cl = GCSURI.get_gcs_client(self._thread_id) @@ -315,12 +389,24 @@ def get_blob(self, new=False) -> Blob: if new and blob is None: blob = Blob(name=path, bucket=bucket_obj) break + except Forbidden: + logger.debug( + "Bucket/blob is forbidden. Trying again with anonymous client." + ) + cl = GCSURI.get_gcs_anonymous_client(self._thread_id) except NotFound: raise except PermissionDenied: raise except Exception: time.sleep(GCSURI.RETRY_BUCKET_DELAY) + if blob is None: + raise ValueError( + "GCS blob does not exist. lack of {access_type} permission? {uri}".format( + access_type="write" if new else "read", uri=self._uri + ) + ) + return blob, bucket_obj def get_bucket_path(self) -> Tuple[str, str]: @@ -373,12 +459,47 @@ def get_public_url(self) -> str: @staticmethod def get_gcs_client(thread_id) -> storage.Client: - if thread_id in GCSURI._CACHED_GCS_CLIENT_PER_THREAD: - return GCSURI._CACHED_GCS_CLIENT_PER_THREAD[thread_id] - else: - cl = storage.Client() - GCSURI._CACHED_GCS_CLIENT_PER_THREAD[thread_id] = cl - return cl + """Get GCS client per thread_id. + + Get default credentials (internally calling google.auth.default()) from: + - Environment variable GOOGLE_APPLICATION_CREDENTIALS + - Set a service account key JSON file path as this environment variable. + - JSON file ~/.config/gcloud/application_default_credentials.json + - To use end-user's credentials. + - This file is created by `gcloud auth application-default login`. + + If default credentials are not found, then + - Make/return an anonymous client instead + - For this thread_id, it will cache anonymous client instead of failed + client with credentials. + """ + cl = GCSURI._CACHED_GCS_CLIENTS.get(thread_id) + + if cl is None: + try: + logger.debug("New GCS client for thread {id}.".format(id=thread_id)) + cl = storage.Client() + except DefaultCredentialsError: + cl = GCSURI.get_gcs_anonymous_client(thread_id) + # anonymous client can also be cached here + GCSURI._CACHED_GCS_CLIENTS[thread_id] = cl + + return cl + + @staticmethod + def get_gcs_anonymous_client(thread_id) -> storage.Client: + """Get GCS anonymous client per thread_id. + """ + cl = GCSURI._CACHED_GCS_ANONYMOUS_CLIENTS.get(thread_id) + + if cl is None: + logger.debug( + "New GCS anonymous client for thread {id}.".format(id=thread_id) + ) + cl = storage.Client.create_anonymous_client() + GCSURI._CACHED_GCS_ANONYMOUS_CLIENTS[thread_id] = cl + + return cl @staticmethod def init_gcsuri( diff --git a/autouri/httpurl.py b/autouri/httpurl.py index 8166cde..c238669 100644 --- a/autouri/httpurl.py +++ b/autouri/httpurl.py @@ -58,7 +58,7 @@ def get_metadata(self, skip_md5=False, make_md5_file=False): but corresponding URL on a public bucket will still have "Last-modified" property which is pointing to creation time. """ - ex, mt, sz, md5 = False, None, None, None + exists, mt, sz, md5 = False, None, None, None try: # get header only r = requests.get( @@ -69,30 +69,30 @@ def get_metadata(self, skip_md5=False, make_md5_file=False): ) r.raise_for_status() # make keys lower-case - h = {k.lower(): v for k, v in r.headers.items()} - ex = True + headers = {k.lower(): v for k, v in r.headers.items()} + exists = True if not skip_md5: - if "content-md5" in h: - md5 = parse_md5_str(h["content-md5"]) - elif "x-goog-hash" in h: - hashes = h["x-goog-hash"].strip().split(",") + if "content-md5" in headers: + md5 = parse_md5_str(headers["content-md5"]) + elif "x-goog-hash" in headers: + hashes = headers["x-goog-hash"].strip().split(",") for hs in hashes: if hs.strip().startswith("md5="): raw = hs.strip().replace("md5=", "", 1) md5 = parse_md5_str(raw) - if md5 is None and "etag" in h: - md5 = parse_md5_str(h["etag"]) + if md5 is None and "etag" in headers: + md5 = parse_md5_str(headers["etag"]) if md5 is None: md5 = self.md5_from_file - if "content-length" in h: - sz = int(h["content-length"]) - elif "x-goog-stored-content-length" in h: - sz = int(h["x-goog-stored-content-length"]) + if "content-length" in headers: + sz = int(headers["content-length"]) + elif "x-goog-stored-content-length" in headers: + sz = int(headers["x-goog-stored-content-length"]) - if "last-modified" in h: - mt = get_seconds_from_epoch(h["last-modified"]) + if "last-modified" in headers: + mt = get_seconds_from_epoch(headers["last-modified"]) except requests.exceptions.ConnectionError: pass @@ -101,7 +101,7 @@ def get_metadata(self, skip_md5=False, make_md5_file=False): if status_code == 403: raise - return URIMetadata(exists=ex, mtime=mt, size=sz, md5=md5) + return URIMetadata(exists=exists, mtime=mt, size=sz, md5=md5) def read(self, byte=False): r = requests.get( diff --git a/autouri/s3uri.py b/autouri/s3uri.py index 7fb0488..0e48c3c 100644 --- a/autouri/s3uri.py +++ b/autouri/s3uri.py @@ -77,7 +77,7 @@ class S3URI(URIBase): Duration for presigned URLs in seconds. Protected class constants: - _CACHED_BOTO3_CLIENT_PER_THREAD: + _CACHED_BOTO3_CLIENTS: _CACHED_PRESIGNED_URLS: _S3_PUBLIC_URL_FORMAT: End point for a bucket with public access + key path @@ -85,7 +85,7 @@ class S3URI(URIBase): DURATION_PRESIGNED_URL: int = 4233600 - _CACHED_BOTO3_CLIENT_PER_THREAD = {} + _CACHED_BOTO3_CLIENTS = {} _CACHED_PRESIGNED_URLS = {} _S3_PUBLIC_URL_FORMAT = "http://{bucket}.s3.amazonaws.com/{path}" @@ -107,7 +107,7 @@ def _get_lock(self, timeout=None, poll_interval=None): ) def get_metadata(self, skip_md5=False, make_md5_file=False): - ex, mt, sz, md5 = False, None, None, None + exists, mt, sz, md5 = False, None, None, None cl = S3URI.get_boto3_client(self._thread_id) bucket, path = self.get_bucket_path() @@ -117,28 +117,28 @@ def get_metadata(self, skip_md5=False, make_md5_file=False): "HTTPHeaders" ] # make keys lower-case - h = {k.lower(): v for k, v in m.items()} - ex = True + headers = {k.lower(): v for k, v in m.items()} + exists = True if not skip_md5: - if "content-md5" in h: - md5 = parse_md5_str(h["content-md5"]) - elif "etag" in h: - md5 = parse_md5_str(h["etag"]) + if "content-md5" in headers: + md5 = parse_md5_str(headers["content-md5"]) + elif "etag" in headers: + md5 = parse_md5_str(headers["etag"]) if md5 is None: # make_md5_file is ignored for S3URI md5 = self.md5_from_file - if "content-length" in h: - sz = int(h["content-length"]) + if "content-length" in headers: + sz = int(headers["content-length"]) - if "last-modified" in h: - mt = get_seconds_from_epoch(h["last-modified"]) + if "last-modified" in headers: + mt = get_seconds_from_epoch(headers["last-modified"]) except Exception: pass - return URIMetadata(exists=ex, mtime=mt, size=sz, md5=md5) + return URIMetadata(exists=exists, mtime=mt, size=sz, md5=md5) def read(self, byte=False): cl = S3URI.get_boto3_client(self._thread_id) @@ -278,11 +278,11 @@ def get_public_url(self) -> str: @staticmethod def get_boto3_client(thread_id=-1) -> client: - if thread_id in S3URI._CACHED_BOTO3_CLIENT_PER_THREAD: - return S3URI._CACHED_BOTO3_CLIENT_PER_THREAD[thread_id] + if thread_id in S3URI._CACHED_BOTO3_CLIENTS: + return S3URI._CACHED_BOTO3_CLIENTS[thread_id] else: cl = client("s3") - S3URI._CACHED_BOTO3_CLIENT_PER_THREAD[thread_id] = cl + S3URI._CACHED_BOTO3_CLIENTS[thread_id] = cl return cl @staticmethod diff --git a/setup.py b/setup.py index a9867fc..53fb634 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="autouri", - version="0.2.0", + version="0.2.1", python_requires=">=3.6", scripts=["bin/autouri"], author="Jin wook Lee", diff --git a/tests/test_gcsuri.py b/tests/test_gcsuri.py index 89d0d31..62d3df5 100644 --- a/tests/test_gcsuri.py +++ b/tests/test_gcsuri.py @@ -315,8 +315,9 @@ def test_gcsuri_get_blob(gcs_v6_txt): b_new, _ = u_non_existing.get_blob(new=True) assert b_new is not None - b, _ = u_non_existing.get_blob(new=False) - assert b is None + + with pytest.raises(ValueError): + u_non_existing.get_blob(new=False) def test_gcsuri_get_bucket_path():