diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 76c5f87cd8..4449df2451 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -318,7 +318,7 @@ wdl: - pwd - apt update && apt install -y default-jre - ${MAIN_PYTHON_PKG} -m virtualenv venv && . venv/bin/activate && pip install -U pip wheel && make prepare && make develop extras=[all] - - make test threads="${TEST_THREADS}" marker="${MARKER}" tests=src/toil/test/wdl/wdltoil_test.py + - make test threads="${TEST_THREADS}" marker="${MARKER}" tests=src/toil/test/wdl/ jobstore: rules: diff --git a/contrib/admin/mypy-with-ignore.py b/contrib/admin/mypy-with-ignore.py index bf19e459e7..a8bde51a5b 100755 --- a/contrib/admin/mypy-with-ignore.py +++ b/contrib/admin/mypy-with-ignore.py @@ -38,7 +38,6 @@ def main(): 'src/toil/provisioners/__init__.py', 'src/toil/provisioners/node.py', 'src/toil/provisioners/aws/boto2Context.py', - 'src/toil/provisioners/aws/awsProvisioner.py', 'src/toil/provisioners/aws/__init__.py', 'src/toil/batchSystems/slurm.py', 'src/toil/batchSystems/gridengine.py', diff --git a/docs/Makefile b/docs/Makefile index 5117fbf5b3..ae5fb19b64 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -12,9 +12,13 @@ BUILDDIR = _build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +.PHONY: help Makefile clean # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + rm -rf autoapi diff --git a/requirements-aws.txt b/requirements-aws.txt index 1ae51c64f6..b5a76b21ee 100644 --- a/requirements-aws.txt +++ b/requirements-aws.txt @@ -1,4 +1,6 @@ boto>=2.48.0, <3 -boto3-stubs[s3,sdb,iam,sts,boto3]>=1.28.3.post2, <2 +boto3-stubs[s3,sdb,iam,sts,boto3,ec2,autoscaling]>=1.28.3.post2, <2 mypy-boto3-iam>=1.28.3.post2, <2 # Need to force .post1 to be replaced -moto>=4.1.11, <5 +mypy-boto3-s3>=1.28.3.post2, <2 +moto>=5.0.3, <6 +ec2_metadata<3 \ No newline at end of file diff --git a/src/toil/common.py b/src/toil/common.py index 30899d9776..8a21fb5ca1 100644 --- a/src/toil/common.py +++ b/src/toil/common.py @@ -64,8 +64,7 @@ from typing_extensions import Literal from toil import logProcessContext, lookupEnvVar -from toil.batchSystems.options import (add_all_batchsystem_options, - set_batchsystem_options) +from toil.batchSystems.options import set_batchsystem_options from toil.bus import (ClusterDesiredSizeMessage, ClusterSizeMessage, JobCompletedMessage, @@ -75,17 +74,15 @@ MessageBus, QueueSizeMessage) from toil.fileStores import FileID -from toil.lib.aws import zone_to_region, build_tag_dict_from_env from toil.lib.compatibility import deprecated from toil.lib.io import try_path, AtomicFileCreate from toil.lib.retry import retry from toil.provisioners import (add_provisioner_options, - cluster_factory, - parse_node_types) + cluster_factory) from toil.realtimeLogger import RealtimeLogger from toil.statsAndLogging import (add_logging_options, set_logging_from_options) -from toil.version import dockerRegistry, dockerTag, version, baseVersion +from toil.version import dockerRegistry, dockerTag, version if TYPE_CHECKING: from toil.batchSystems.abstractBatchSystem import AbstractBatchSystem @@ -1440,6 +1437,8 @@ def __init__(self, bus: MessageBus, provisioner: Optional["AbstractProvisioner"] clusterName = str(provisioner.clusterName) if provisioner._zone is not None: if provisioner.cloud == 'aws': + # lazy import to avoid AWS dependency if the aws extra is not installed + from toil.lib.aws import zone_to_region # Remove AZ name region = zone_to_region(provisioner._zone) else: diff --git a/src/toil/jobStores/aws/utils.py b/src/toil/jobStores/aws/utils.py index 9f4e41e264..48ef581ff7 100644 --- a/src/toil/jobStores/aws/utils.py +++ b/src/toil/jobStores/aws/utils.py @@ -24,19 +24,18 @@ from botocore.client import Config from botocore.exceptions import ClientError -from toil.lib.aws import session +from toil.lib.aws import session, AWSServerErrors from toil.lib.aws.utils import connection_reset, get_bucket_region from toil.lib.compatibility import compat_bytes from toil.lib.retry import (DEFAULT_DELAYS, DEFAULT_TIMEOUT, - ErrorCondition, get_error_code, get_error_message, get_error_status, old_retry, retry) if TYPE_CHECKING: - from mypy_boto3_s3 import S3Client, S3ServiceResource + from mypy_boto3_s3 import S3ServiceResource logger = logging.getLogger(__name__) @@ -193,10 +192,7 @@ def fileSizeAndTime(localFilePath): return file_stat.st_size, file_stat.st_mtime -@retry(errors=[ErrorCondition( - error=ClientError, - error_codes=[404, 500, 502, 503, 504] -)]) +@retry(errors=[AWSServerErrors]) def uploadFromPath(localFilePath: str, resource, bucketName: str, @@ -232,10 +228,7 @@ def uploadFromPath(localFilePath: str, return version -@retry(errors=[ErrorCondition( - error=ClientError, - error_codes=[404, 500, 502, 503, 504] -)]) +@retry(errors=[AWSServerErrors]) def uploadFile(readable, resource, bucketName: str, @@ -287,10 +280,7 @@ class ServerSideCopyProhibitedError(RuntimeError): insists that you pay to download and upload the data yourself instead. """ -@retry(errors=[ErrorCondition( - error=ClientError, - error_codes=[404, 500, 502, 503, 504] -)]) +@retry(errors=[AWSServerErrors]) def copyKeyMultipart(resource: "S3ServiceResource", srcBucketName: str, srcKeyName: str, diff --git a/src/toil/lib/aws/__init__.py b/src/toil/lib/aws/__init__.py index 0968b566fe..93a221b93b 100644 --- a/src/toil/lib/aws/__init__.py +++ b/src/toil/lib/aws/__init__.py @@ -16,11 +16,25 @@ import os import re import socket +import toil.lib.retry from http.client import HTTPException -from typing import Dict, MutableMapping, Optional +from typing import Dict, MutableMapping, Optional, Union, Literal from urllib.error import URLError from urllib.request import urlopen +from botocore.exceptions import ClientError + +from mypy_boto3_s3.literals import BucketLocationConstraintType + +AWSRegionName = Union[BucketLocationConstraintType, Literal["us-east-1"]] + +# These are errors where we think something randomly +# went wrong on the AWS side and we ought to retry. +AWSServerErrors = toil.lib.retry.ErrorCondition( + error=ClientError, + error_codes=[404, 500, 502, 503, 504] +) + logger = logging.getLogger(__name__) # This file isn't allowed to import anything that depends on Boto or Boto3, @@ -67,11 +81,10 @@ def get_aws_zone_from_metadata() -> Optional[str]: # metadata. try: # Use the EC2 metadata service - import boto - str(boto) # to prevent removal of the import - from boto.utils import get_instance_metadata + from ec2_metadata import ec2_metadata + logger.debug("Fetch AZ from EC2 metadata") - return get_instance_metadata()['placement']['availability-zone'] + return ec2_metadata.availability_zone except ImportError: # This is expected to happen a lot logger.debug("No boto to fetch ECS metadata") @@ -128,7 +141,7 @@ def get_current_aws_zone() -> Optional[str]: get_aws_zone_from_environment_region() or \ get_aws_zone_from_boto() -def zone_to_region(zone: str) -> str: +def zone_to_region(zone: str) -> AWSRegionName: """Get a region (e.g. us-west-2) from a zone (e.g. us-west-1c).""" # re.compile() caches the regex internally so we don't have to availability_zone = re.compile(r'^([a-z]{2}-[a-z]+-[1-9][0-9]*)([a-z])$') diff --git a/src/toil/lib/aws/iam.py b/src/toil/lib/aws/iam.py index d1740eac8a..1351ff8cf2 100644 --- a/src/toil/lib/aws/iam.py +++ b/src/toil/lib/aws/iam.py @@ -257,8 +257,8 @@ def get_policy_permissions(region: str) -> AllowedActionCollection: :param zone: AWS zone to connect to """ - iam: IAMClient = cast(IAMClient, get_client('iam', region)) - sts: STSClient = cast(STSClient, get_client('sts', region)) + iam: IAMClient = get_client('iam', region) + sts: STSClient = get_client('sts', region) #TODO Condider effect: deny at some point allowed_actions: AllowedActionCollection = defaultdict(lambda: {'Action': [], 'NotAction': []}) try: diff --git a/src/toil/lib/aws/session.py b/src/toil/lib/aws/session.py index c3c679c02f..96cacae0bb 100644 --- a/src/toil/lib/aws/session.py +++ b/src/toil/lib/aws/session.py @@ -15,16 +15,22 @@ import logging import os import threading -from typing import Dict, Optional, Tuple, cast +from typing import Dict, Optional, Tuple, cast, Union, Literal, overload, TypeVar +import boto import boto3 import boto3.resources.base -import boto.connection import botocore from boto3 import Session from botocore.client import Config from botocore.session import get_session from botocore.utils import JSONFileCache +from mypy_boto3_autoscaling import AutoScalingClient +from mypy_boto3_ec2 import EC2Client, EC2ServiceResource +from mypy_boto3_iam import IAMClient, IAMServiceResource +from mypy_boto3_s3 import S3Client, S3ServiceResource +from mypy_boto3_sdb import SimpleDBClient +from mypy_boto3_sts import STSClient logger = logging.getLogger(__name__) @@ -120,6 +126,13 @@ def session(self, region: Optional[str]) -> boto3.session.Session: storage.item = _new_boto3_session(region_name=region) return cast(boto3.session.Session, storage.item) + @overload + def resource(self, region: Optional[str], service_name: Literal["s3"], endpoint_url: Optional[str] = None) -> S3ServiceResource: ... + @overload + def resource(self, region: Optional[str], service_name: Literal["iam"], endpoint_url: Optional[str] = None) -> IAMServiceResource: ... + @overload + def resource(self, region: Optional[str], service_name: Literal["ec2"], endpoint_url: Optional[str] = None) -> EC2ServiceResource: ... + def resource(self, region: Optional[str], service_name: str, endpoint_url: Optional[str] = None) -> boto3.resources.base.ServiceResource: """ Get the Boto3 Resource to use with the given service (like 'ec2') in the given region. @@ -146,7 +159,28 @@ def resource(self, region: Optional[str], service_name: str, endpoint_url: Optio return cast(boto3.resources.base.ServiceResource, storage.item) - def client(self, region: Optional[str], service_name: str, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> botocore.client.BaseClient: + @overload + def client(self, region: Optional[str], service_name: Literal["ec2"], endpoint_url: Optional[str] = None, + config: Optional[Config] = None) -> EC2Client: ... + @overload + def client(self, region: Optional[str], service_name: Literal["iam"], endpoint_url: Optional[str] = None, + config: Optional[Config] = None) -> IAMClient: ... + @overload + def client(self, region: Optional[str], service_name: Literal["s3"], endpoint_url: Optional[str] = None, + config: Optional[Config] = None) -> S3Client: ... + @overload + def client(self, region: Optional[str], service_name: Literal["sts"], endpoint_url: Optional[str] = None, + config: Optional[Config] = None) -> STSClient: ... + @overload + def client(self, region: Optional[str], service_name: Literal["sdb"], endpoint_url: Optional[str] = None, + config: Optional[Config] = None) -> SimpleDBClient: ... + @overload + def client(self, region: Optional[str], service_name: Literal["autoscaling"], endpoint_url: Optional[str] = None, + config: Optional[Config] = None) -> AutoScalingClient: ... + + + def client(self, region: Optional[str], service_name: Literal["ec2", "iam", "s3", "sts", "sdb", "autoscaling"], endpoint_url: Optional[str] = None, + config: Optional[Config] = None) -> botocore.client.BaseClient: """ Get the Boto3 Client to use with the given service (like 'ec2') in the given region. @@ -159,9 +193,9 @@ def client(self, region: Optional[str], service_name: str, endpoint_url: Optiona # Don't try and memoize if a custom config is used with _init_lock: if endpoint_url is not None: - return self.session(region).client(service_name, endpoint_url=endpoint_url, config=config) # type: ignore + return self.session(region).client(service_name, endpoint_url=endpoint_url, config=config) else: - return self.session(region).client(service_name, config=config) # type: ignore + return self.session(region).client(service_name, config=config) key = (region, service_name, endpoint_url) storage = self.client_cache[key] @@ -172,25 +206,12 @@ def client(self, region: Optional[str], service_name: str, endpoint_url: Optiona if endpoint_url is not None: # The Boto3 stubs are probably missing an overload here too. See: # - storage.item = self.session(region).client(service_name, endpoint_url=endpoint_url) # type: ignore + storage.item = self.session(region).client(service_name, endpoint_url=endpoint_url) else: # We might not be able to pass None to Boto3 and have it be the same as no argument. - storage.item = self.session(region).client(service_name) # type: ignore + storage.item = self.session(region).client(service_name) return cast(botocore.client.BaseClient , storage.item) - def boto2(self, region: Optional[str], service_name: str) -> boto.connection.AWSAuthConnection: - """ - Get the connected boto2 connection for the given region and service. - """ - if service_name == 'iam': - # IAM connections are regionless - region = 'universal' - key = (region, service_name) - storage = self.boto2_cache[key] - if not hasattr(storage, 'item'): - with _init_lock: - storage.item = getattr(boto, service_name).connect_to_region(region, profile_name=os.environ.get("TOIL_AWS_PROFILE", None)) - return cast(boto.connection.AWSAuthConnection, storage.item) # If you don't want your own AWSConnectionManager, we have a global one and some global functions _global_manager = AWSConnectionManager() @@ -205,7 +226,20 @@ def establish_boto3_session(region_name: Optional[str] = None) -> Session: # Just use a global version of the manager. Note that we change the argument order! return _global_manager.session(region_name) -def client(service_name: str, region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> botocore.client.BaseClient: +@overload +def client(service_name: Literal["ec2"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> EC2Client: ... +@overload +def client(service_name: Literal["iam"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> IAMClient: ... +@overload +def client(service_name: Literal["s3"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> S3Client: ... +@overload +def client(service_name: Literal["sts"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> STSClient: ... +@overload +def client(service_name: Literal["sdb"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> SimpleDBClient: ... +@overload +def client(service_name: Literal["autoscaling"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> AutoScalingClient: ... + +def client(service_name: Literal["ec2", "iam", "s3", "sts", "sdb", "autoscaling"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> botocore.client.BaseClient: """ Get a Boto 3 client for a particular AWS service, usable by the current thread. @@ -215,7 +249,14 @@ def client(service_name: str, region_name: Optional[str] = None, endpoint_url: O # Just use a global version of the manager. Note that we change the argument order! return _global_manager.client(region_name, service_name, endpoint_url=endpoint_url, config=config) -def resource(service_name: str, region_name: Optional[str] = None, endpoint_url: Optional[str] = None) -> boto3.resources.base.ServiceResource: +@overload +def resource(service_name: Literal["s3"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None) -> S3ServiceResource: ... +@overload +def resource(service_name: Literal["iam"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None) -> IAMServiceResource: ... +@overload +def resource(service_name: Literal["ec2"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None) -> EC2ServiceResource: ... + +def resource(service_name: Literal["s3", "iam", "ec2"], region_name: Optional[str] = None, endpoint_url: Optional[str] = None) -> boto3.resources.base.ServiceResource: """ Get a Boto 3 resource for a particular AWS service, usable by the current thread. diff --git a/src/toil/lib/aws/utils.py b/src/toil/lib/aws/utils.py index 3e26691358..bee4973dfd 100644 --- a/src/toil/lib/aws/utils.py +++ b/src/toil/lib/aws/utils.py @@ -15,7 +15,6 @@ import logging import os import socket -import sys from typing import (Any, Callable, ContextManager, @@ -25,11 +24,10 @@ List, Optional, Set, - Union, cast) from urllib.parse import ParseResult -from toil.lib.aws import session +from toil.lib.aws import session, AWSRegionName, AWSServerErrors from toil.lib.misc import printq from toil.lib.retry import (DEFAULT_DELAYS, DEFAULT_TIMEOUT, @@ -38,11 +36,6 @@ old_retry, retry) -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - try: from boto.exception import BotoServerError, S3ResponseError from botocore.exceptions import ClientError @@ -77,11 +70,10 @@ 'EC2ThrottledException', ] -@retry(errors=[BotoServerError]) +@retry(errors=[AWSServerErrors]) def delete_iam_role( role_name: str, region: Optional[str] = None, quiet: bool = True ) -> None: - from boto.iam.connection import IAMConnection # TODO: the Boto3 type hints are a bit oversealous here; they want hundreds # of overloads of the client-getting methods to exist based on the literal @@ -92,9 +84,8 @@ def delete_iam_role( # we wanted MyPy to be able to understand us. So at some point we should # consider revising our API here to be less annoying to explain to the type # checker. - iam_client = cast(IAMClient, session.client('iam', region_name=region)) - iam_resource = cast(IAMServiceResource, session.resource('iam', region_name=region)) - boto_iam_connection = IAMConnection() + iam_client = session.client('iam', region_name=region) + iam_resource = session.resource('iam', region_name=region) role = iam_resource.Role(role_name) # normal policies for attached_policy in role.attached_policies.all(): @@ -103,17 +94,16 @@ def delete_iam_role( # inline policies for inline_policy in role.policies.all(): printq(f'Deleting inline policy: {inline_policy.policy_name} from role {role.name}', quiet) - # couldn't find an easy way to remove inline policies with boto3; use boto - boto_iam_connection.delete_role_policy(role.name, inline_policy.policy_name) + iam_client.delete_role_policy(RoleName=role.name, PolicyName=inline_policy.policy_name) iam_client.delete_role(RoleName=role_name) printq(f'Role {role_name} successfully deleted.', quiet) -@retry(errors=[BotoServerError]) +@retry(errors=[AWSServerErrors]) def delete_iam_instance_profile( instance_profile_name: str, region: Optional[str] = None, quiet: bool = True ) -> None: - iam_resource = cast(IAMServiceResource, session.resource("iam", region_name=region)) + iam_resource = session.resource("iam", region_name=region) instance_profile = iam_resource.InstanceProfile(instance_profile_name) if instance_profile.roles is not None: for role in instance_profile.roles: @@ -123,11 +113,11 @@ def delete_iam_instance_profile( printq(f'Instance profile "{instance_profile_name}" successfully deleted.', quiet) -@retry(errors=[BotoServerError]) +@retry(errors=[AWSServerErrors]) def delete_sdb_domain( sdb_domain_name: str, region: Optional[str] = None, quiet: bool = True ) -> None: - sdb_client = cast(SimpleDBClient, session.client("sdb", region_name=region)) + sdb_client = session.client("sdb", region_name=region) sdb_client.delete_domain(DomainName=sdb_domain_name) printq(f'SBD Domain: "{sdb_domain_name}" successfully deleted.', quiet) @@ -162,7 +152,7 @@ def retry_s3(delays: Iterable[float] = DEFAULT_DELAYS, timeout: float = DEFAULT_ """ return old_retry(delays=delays, timeout=timeout, predicate=predicate) -@retry(errors=[BotoServerError]) +@retry(errors=[AWSServerErrors]) def delete_s3_bucket( s3_resource: "S3ServiceResource", bucket: str, @@ -195,7 +185,7 @@ def delete_s3_bucket( def create_s3_bucket( s3_resource: "S3ServiceResource", bucket_name: str, - region: Union["BucketLocationConstraintType", Literal["us-east-1"]], + region: AWSRegionName, ) -> "Bucket": """ Create an AWS S3 bucket, using the given Boto3 S3 session, with the @@ -238,7 +228,7 @@ def enable_public_objects(bucket_name: str) -> None: would be a very awkward way to do it. So we restore the old behavior. """ - s3_client = cast(S3Client, session.client('s3')) + s3_client = session.client('s3') # Even though the new default is for public access to be prohibited, this # is implemented by adding new things attached to the bucket. If we remove @@ -261,7 +251,7 @@ def get_bucket_region(bucket_name: str, endpoint_url: Optional[str] = None, only :param only_strategies: For testing, use only strategies with 1-based numbers in this set. """ - s3_client = cast(S3Client, session.client('s3', endpoint_url=endpoint_url)) + s3_client = session.client('s3', endpoint_url=endpoint_url) def attempt_get_bucket_location() -> Optional[str]: """ @@ -283,7 +273,7 @@ def attempt_get_bucket_location_from_us_east_1() -> Optional[str]: # It could also be because AWS open data buckets (which we tend to # encounter this problem for) tend to actually themselves be in # us-east-1. - backup_s3_client = cast(S3Client, session.client('s3', region_name='us-east-1')) + backup_s3_client = session.client('s3', region_name='us-east-1') return backup_s3_client.get_bucket_location(Bucket=bucket_name).get('LocationConstraint', None) def attempt_head_bucket() -> Optional[str]: @@ -368,11 +358,11 @@ def get_object_for_url(url: ParseResult, existing: Optional[bool] = None) -> "Ob try: # Get the bucket's region to avoid a redirect per request region = get_bucket_region(bucket_name, endpoint_url=endpoint_url) - s3 = cast(S3ServiceResource, session.resource('s3', region_name=region, endpoint_url=endpoint_url)) + s3 = session.resource('s3', region_name=region, endpoint_url=endpoint_url) except ClientError: # Probably don't have permission. # TODO: check if it is that - s3 = cast(S3ServiceResource, session.resource('s3', endpoint_url=endpoint_url)) + s3 = session.resource('s3', endpoint_url=endpoint_url) obj = s3.Object(bucket_name, key_name) objExists = True @@ -394,7 +384,7 @@ def get_object_for_url(url: ParseResult, existing: Optional[bool] = None) -> "Ob return obj -@retry(errors=[BotoServerError]) +@retry(errors=[AWSServerErrors]) def list_objects_for_url(url: ParseResult) -> List[str]: """ Extracts a key (object) from a given parsed s3:// URL. The URL will be @@ -419,7 +409,7 @@ def list_objects_for_url(url: ParseResult) -> List[str]: if host: endpoint_url = f'{protocol}://{host}' + f':{port}' if port else '' - client = cast(S3Client, session.client('s3', endpoint_url=endpoint_url)) + client = session.client('s3', endpoint_url=endpoint_url) listing = [] diff --git a/src/toil/lib/ec2.py b/src/toil/lib/ec2.py index cf32925b99..8b3baa94b3 100644 --- a/src/toil/lib/ec2.py +++ b/src/toil/lib/ec2.py @@ -1,13 +1,13 @@ import logging import time from base64 import b64encode -from operator import attrgetter -from typing import Dict, Iterable, List, Optional, Union +from operator import itemgetter +from typing import Dict, Iterable, List, Optional, Union, TYPE_CHECKING, Generator, Callable, Mapping, Any +import botocore.client from boto3.resources.base import ServiceResource from boto.ec2.instance import Instance as Boto2Instance from boto.ec2.spotinstancerequest import SpotInstanceRequest -from botocore.client import BaseClient from toil.lib.aws.session import establish_boto3_session from toil.lib.aws.utils import flatten_tags @@ -18,6 +18,11 @@ old_retry, retry) +from mypy_boto3_ec2.client import EC2Client +from mypy_boto3_autoscaling.client import AutoScalingClient +from mypy_boto3_ec2.type_defs import SpotInstanceRequestTypeDef, DescribeInstancesResultTypeDef, InstanceTypeDef +from mypy_boto3_ec2.service_resource import EC2ServiceResource, Instance + a_short_time = 5 a_long_time = 60 * 60 logger = logging.getLogger(__name__) @@ -38,6 +43,7 @@ def not_found(e): # Not the right kind of error return False + def inconsistencies_detected(e): if get_error_code(e) == 'InvalidGroup.NotFound': return True @@ -45,6 +51,7 @@ def inconsistencies_detected(e): matches = ('invalid iam instance profile' in m) or ('no associated iam roles' in m) return matches + # We also define these error categories for the new retry decorator INCONSISTENCY_ERRORS = [ErrorCondition(boto_error_codes=['InvalidGroup.NotFound']), ErrorCondition(error_message_must_include='Invalid IAM Instance Profile'), @@ -62,9 +69,10 @@ def __init__(self, resource, to_state, state): super().__init__( "Expected state of %s to be '%s' but got '%s'" % (resource, to_state, state)) - -def wait_transition(resource, from_states, to_state, - state_getter=attrgetter('state')): + + +def wait_transition(boto3_ec2: EC2Client, resource: InstanceTypeDef, from_states: Iterable[str], to_state: str, + state_getter: Callable[[InstanceTypeDef], str]=lambda x: x.get('State').get('Name')): """ Wait until the specified EC2 resource (instance, image, volume, ...) transitions from any of the given 'from' states to the specified 'to' state. If the instance is found in a state @@ -76,22 +84,24 @@ def wait_transition(resource, from_states, to_state, :param to_state: the state of the resource when this method returns """ state = state_getter(resource) + instance_id = resource["InstanceId"] while state in from_states: time.sleep(a_short_time) for attempt in retry_ec2(): with attempt: - resource.update(validate=True) + described = boto3_ec2.describe_instances(InstanceIds=[instance_id]) + resource = described["Reservations"][0]["Instances"][0] # there should only be one requested state = state_getter(resource) if state != to_state: raise UnexpectedResourceState(resource, to_state, state) -def wait_instances_running(ec2, instances: Iterable[Boto2Instance]) -> Iterable[Boto2Instance]: +def wait_instances_running(boto3_ec2: EC2Client, instances: Iterable[InstanceTypeDef]) -> Generator[InstanceTypeDef, None, None]: """ Wait until no instance in the given iterable is 'pending'. Yield every instance that entered the running state as soon as it does. - :param boto.ec2.connection.EC2Connection ec2: the EC2 connection to use for making requests + :param EC2Client boto3_ec2: the EC2 connection to use for making requests :param Iterable[Boto2Instance] instances: the instances to wait on :rtype: Iterable[Boto2Instance] """ @@ -100,17 +110,18 @@ def wait_instances_running(ec2, instances: Iterable[Boto2Instance]) -> Iterable[ while True: pending_ids = set() for i in instances: - if i.state == 'pending': - pending_ids.add(i.id) - elif i.state == 'running': - if i.id in running_ids: + i: InstanceTypeDef + if i['State']['Name'] == 'pending': + pending_ids.add(i['InstanceId']) + elif i['State']['Name'] == 'running': + if i['InstanceId'] in running_ids: raise RuntimeError("An instance was already added to the list of running instance IDs. Maybe there is a duplicate.") - running_ids.add(i.id) + running_ids.add(i['InstanceId']) yield i else: - if i.id in other_ids: + if i['InstanceId'] in other_ids: raise RuntimeError("An instance was already added to the list of other instances. Maybe there is a duplicate.") - other_ids.add(i.id) + other_ids.add(i['InstanceId']) yield i logger.info('%i instance(s) pending, %i running, %i other.', *list(map(len, (pending_ids, running_ids, other_ids)))) @@ -121,14 +132,16 @@ def wait_instances_running(ec2, instances: Iterable[Boto2Instance]) -> Iterable[ time.sleep(seconds) for attempt in retry_ec2(): with attempt: - instances = ec2.get_only_instances(list(pending_ids)) + described_instances = boto3_ec2.describe_instances(InstanceIds=list(pending_ids)) + instances = [instance for reservation in described_instances["Reservations"] for instance in reservation["Instances"]] -def wait_spot_requests_active(ec2, requests: Iterable[SpotInstanceRequest], timeout: float = None, tentative: bool = False) -> Iterable[List[SpotInstanceRequest]]: +def wait_spot_requests_active(boto3_ec2: EC2Client, requests: Iterable[SpotInstanceRequestTypeDef], timeout: float = None, tentative: bool = False) -> Iterable[List[SpotInstanceRequest]]: """ Wait until no spot request in the given iterator is in the 'open' state or, optionally, a timeout occurs. Yield spot requests as soon as they leave the 'open' state. + :param boto3_ec2: ec2 client :param requests: The requests to wait on. :param timeout: Maximum time in seconds to spend waiting or None to wait forever. If a @@ -145,11 +158,11 @@ def wait_spot_requests_active(ec2, requests: Iterable[SpotInstanceRequest], time other_ids = set() open_ids = None - def cancel(): + def cancel() -> None: logger.warning('Cancelling remaining %i spot requests.', len(open_ids)) - ec2.cancel_spot_instance_requests(list(open_ids)) + boto3_ec2.cancel_spot_instance_requests(SpotInstanceRequestIds=list(open_ids)) - def spot_request_not_found(e): + def spot_request_not_found(e: Exception) -> bool: return get_error_code(e) == 'InvalidSpotInstanceRequestID.NotFound' try: @@ -157,30 +170,31 @@ def spot_request_not_found(e): open_ids, eval_ids, fulfill_ids = set(), set(), set() batch = [] for r in requests: - if r.state == 'open': - open_ids.add(r.id) - if r.status.code == 'pending-evaluation': - eval_ids.add(r.id) - elif r.status.code == 'pending-fulfillment': - fulfill_ids.add(r.id) + r: SpotInstanceRequestTypeDef # pycharm thinks it is a string + if r['State']['Name'] == 'open': + open_ids.add(r['InstanceId']) + if r['Status']['Code'] == 'pending-evaluation': + eval_ids.add(r['InstanceId']) + elif r['Status']['Code'] == 'pending-fulfillment': + fulfill_ids.add(r['InstanceId']) else: logger.info( 'Request %s entered status %s indicating that it will not be ' - 'fulfilled anytime soon.', r.id, r.status.code) - elif r.state == 'active': - if r.id in active_ids: + 'fulfilled anytime soon.', r['InstanceId'], r['Status']['Code']) + elif r['State']['Name'] == 'active': + if r['InstanceId'] in active_ids: raise RuntimeError("A request was already added to the list of active requests. Maybe there are duplicate requests.") - active_ids.add(r.id) + active_ids.add(r['InstanceId']) batch.append(r) else: - if r.id in other_ids: + if r['InstanceId'] in other_ids: raise RuntimeError("A request was already added to the list of other IDs. Maybe there are duplicate requests.") - other_ids.add(r.id) + other_ids.add(r['InstanceId']) batch.append(r) if batch: yield batch logger.info('%i spot requests(s) are open (%i of which are pending evaluation and %i ' - 'are pending fulfillment), %i are active and %i are in another state.', + 'are pending fulfillment), %i are active and %i are in another state.', *list(map(len, (open_ids, eval_ids, fulfill_ids, active_ids, other_ids)))) if not open_ids or tentative and not eval_ids and not fulfill_ids: break @@ -192,8 +206,7 @@ def spot_request_not_found(e): time.sleep(sleep_time) for attempt in retry_ec2(retry_while=spot_request_not_found): with attempt: - requests = ec2.get_all_spot_instance_requests( - list(open_ids)) + requests = boto3_ec2.describe_spot_instance_requests(SpotInstanceRequestIds=list(open_ids)) except BaseException: if open_ids: with panic(logger): @@ -204,29 +217,32 @@ def spot_request_not_found(e): cancel() -def create_spot_instances(ec2, price, image_id, spec, num_instances=1, timeout=None, tentative=False, tags=None) -> Iterable[List[Boto2Instance]]: +def create_spot_instances(boto3_ec2: EC2Client, price, image_id, spec, num_instances=1, timeout=None, tentative=False, tags=None) -> Generator[DescribeInstancesResultTypeDef, None, None]: """ Create instances on the spot market. """ + def spotRequestNotFound(e): return getattr(e, 'error_code', None) == "InvalidSpotInstanceRequestID.NotFound" + spec['LaunchSpecification'].update({'ImageId': image_id}) # boto3 image id is in the launch specification for attempt in retry_ec2(retry_for=a_long_time, retry_while=inconsistencies_detected): with attempt: - requests = ec2.request_spot_instances( - price, image_id, count=num_instances, **spec) + requests_dict = boto3_ec2.request_spot_instances( + SpotPrice=price, InstanceCount=num_instances, **spec) + requests = requests_dict['SpotInstanceRequests'] if tags is not None: - for requestID in (request.id for request in requests): + for requestID in (request['SpotInstanceRequestId'] for request in requests): for attempt in retry_ec2(retry_while=spotRequestNotFound): with attempt: - ec2.create_tags([requestID], tags) + boto3_ec2.create_tags(Resources=[requestID], Tags=tags) num_active, num_other = 0, 0 # noinspection PyUnboundLocalVariable,PyTypeChecker # request_spot_instances's type annotation is wrong - for batch in wait_spot_requests_active(ec2, + for batch in wait_spot_requests_active(boto3_ec2, requests, timeout=timeout, tentative=tentative): @@ -244,7 +260,12 @@ def spotRequestNotFound(e): if instance_ids: # This next line is the reason we batch. It's so we can get multiple instances in # a single request. - yield ec2.get_only_instances(instance_ids) + for instance_id in instance_ids: + for attempt in retry_ec2(): + with attempt: + # Increase hop limit from 1 to use Instance Metadata V2 + boto3_ec2.modify_instance_metadata_options(InstanceId=instance_id, HttpPutResponseHopLimit=3) + yield boto3_ec2.describe_instances(InstanceIds=instance_ids) if not num_active: message = 'None of the spot requests entered the active state' if tentative: @@ -255,22 +276,43 @@ def spotRequestNotFound(e): logger.warning('%i request(s) entered a state other than active.', num_other) -def create_ondemand_instances(ec2, image_id, spec, num_instances=1) -> List[Boto2Instance]: +def create_ondemand_instances(boto3_ec2: EC2Client, image_id: str, spec: Mapping[str, Any], num_instances: int=1) -> List[InstanceTypeDef]: """ Requests the RunInstances EC2 API call but accounts for the race between recently created instance profiles, IAM roles and an instance creation that refers to them. :rtype: List[Boto2Instance] """ - instance_type = spec['instance_type'] + instance_type = spec['InstanceType'] logger.info('Creating %s instance(s) ... ', instance_type) + boto_instance_list = [] for attempt in retry_ec2(retry_for=a_long_time, retry_while=inconsistencies_detected): with attempt: - return ec2.run_instances(image_id, - min_count=num_instances, - max_count=num_instances, - **spec).instances + boto_instance_list: List[InstanceTypeDef] = boto3_ec2.run_instances(ImageId=image_id, + MinCount=num_instances, + MaxCount=num_instances, + **spec)['Instances'] + + return boto_instance_list + + +def increase_instance_hop_limit(boto3_ec2: EC2Client, boto_instance_list: List[InstanceTypeDef]) -> None: + """ + Increase the default HTTP hop limit, as we are running Toil and Kubernetes inside a Docker container, so the default + hop limit of 1 will not be enough when grabbing metadata information with ec2_metadata + + Must be called after the instances are guaranteed to be running. + + :param boto_instance_list: List of boto instances to modify + :return: + """ + for boto_instance in boto_instance_list: + instance_id = boto_instance['InstanceId'] + for attempt in retry_ec2(): + with attempt: + # Increase hop limit from 1 to use Instance Metadata V2 + boto3_ec2.modify_instance_metadata_options(InstanceId=instance_id, HttpPutResponseHopLimit=3) def prune(bushy: dict) -> dict: @@ -289,6 +331,7 @@ def prune(bushy: dict) -> dict: # catch, and to wait on IAM items. iam_client = establish_boto3_session().client('iam') + # exception is generated by a factory so we weirdly need a client instance to reference it @retry(errors=[iam_client.exceptions.NoSuchEntityException], intervals=[1, 1, 2, 4, 8, 16, 32, 64]) @@ -301,7 +344,7 @@ def wait_until_instance_profile_arn_exists(instance_profile_arn: str): @retry(intervals=[5, 5, 10, 20, 20, 20, 20], errors=INCONSISTENCY_ERRORS) -def create_instances(ec2_resource: ServiceResource, +def create_instances(ec2_resource: EC2ServiceResource, image_id: str, key_name: str, instance_type: str, @@ -312,7 +355,7 @@ def create_instances(ec2_resource: ServiceResource, instance_profile_arn: Optional[str] = None, placement_az: Optional[str] = None, subnet_id: str = None, - tags: Optional[Dict[str, str]] = None) -> List[dict]: + tags: Optional[Dict[str, str]] = None) -> List[Instance]: """ Replaces create_ondemand_instances. Uses boto3 and returns a list of Boto3 instance dicts. @@ -336,7 +379,10 @@ def create_instances(ec2_resource: ServiceResource, 'InstanceType': instance_type, 'UserData': user_data, 'BlockDeviceMappings': block_device_map, - 'SubnetId': subnet_id} + 'SubnetId': subnet_id, + # Metadata V2 defaults hops to 1, which is an issue when running inside a docker container + # https://github.com/adamchainz/ec2-metadata?tab=readme-ov-file#instance-metadata-service-version-2 + 'MetadataOptions': {'HttpPutResponseHopLimit': 3}} if instance_profile_arn: # We could just retry when we get an error because the ARN doesn't @@ -357,8 +403,9 @@ def create_instances(ec2_resource: ServiceResource, return ec2_resource.create_instances(**prune(request)) + @retry(intervals=[5, 5, 10, 20, 20, 20, 20], errors=INCONSISTENCY_ERRORS) -def create_launch_template(ec2_client: BaseClient, +def create_launch_template(ec2_client: EC2Client, template_name: str, image_id: str, key_name: str, @@ -400,7 +447,10 @@ def create_launch_template(ec2_client: BaseClient, 'InstanceType': instance_type, 'UserData': user_data, 'BlockDeviceMappings': block_device_map, - 'SubnetId': subnet_id} + 'SubnetId': subnet_id, + # Increase hop limit from 1 to use Instance Metadata V2 + 'MetadataOptions': {'HttpPutResponseHopLimit': 3} + } if instance_profile_arn: # We could just retry when we get an error because the ARN doesn't @@ -413,6 +463,7 @@ def create_launch_template(ec2_client: BaseClient, if placement_az: template['Placement'] = {'AvailabilityZone': placement_az} + flat_tags = [] if tags: # Tag everything when we make it. flat_tags = flatten_tags(tags) @@ -429,17 +480,16 @@ def create_launch_template(ec2_client: BaseClient, @retry(intervals=[5, 5, 10, 20, 20, 20, 20], errors=INCONSISTENCY_ERRORS) -def create_auto_scaling_group(autoscaling_client: BaseClient, +def create_auto_scaling_group(autoscaling_client: AutoScalingClient, asg_name: str, launch_template_ids: Dict[str, str], vpc_subnets: List[str], min_size: int, max_size: int, - instance_types: Optional[List[str]] = None, + instance_types: Optional[Iterable[str]] = None, spot_bid: Optional[float] = None, spot_cheapest: bool = False, tags: Optional[Dict[str, str]] = None) -> None: - """ Create a new Auto Scaling Group with the given name (which is also its unique identifier). @@ -472,7 +522,7 @@ def create_auto_scaling_group(autoscaling_client: BaseClient, """ if instance_types is None: - instance_types = [] + instance_types: List[str] = [] if instance_types is not None and len(instance_types) > 20: raise RuntimeError(f"Too many instance types ({len(instance_types)}) in group; AWS supports only 20.") @@ -493,8 +543,8 @@ def get_launch_template_spec(instance_type): # We need to use a launch template per instance type so that different # instance types with specified EBS storage size overrides will get their # storage. - mip = {'LaunchTemplate': {'LaunchTemplateSpecification': get_launch_template_spec(next(iter(instance_types))), - 'Overrides': [{'InstanceType': t, 'LaunchTemplateSpecification': get_launch_template_spec(t)} for t in instance_types]}} + mip = {'LaunchTemplate': {'LaunchTemplateSpecification': get_launch_template_spec(next(iter(instance_types))), # noqa + 'Overrides': [{'InstanceType': t, 'LaunchTemplateSpecification': get_launch_template_spec(t)} for t in instance_types]}} # noqa if spot_bid is not None: # Ask for spot instances by saying everything above base capacity of 0 should be spot. diff --git a/src/toil/lib/retry.py b/src/toil/lib/retry.py index 49ee4033c1..b06481bab7 100644 --- a/src/toil/lib/retry.py +++ b/src/toil/lib/retry.py @@ -142,7 +142,7 @@ def boto_bucket(bucket_name): Sequence, Tuple, Type, - Union) + Union, TypeVar) import requests.exceptions import urllib3.exceptions @@ -224,13 +224,16 @@ def __init__(self, ) +# There is a better way to type hint this with python 3.10 +# https://stackoverflow.com/a/68290080 +RT = TypeVar("RT") def retry( intervals: Optional[List] = None, infinite_retries: bool = False, errors: Optional[Sequence[Union[ErrorCondition, Type[Exception]]]] = None, log_message: Optional[Tuple[Callable, str]] = None, prepare: Optional[List[Callable]] = None, -) -> Callable[[Any], Any]: +) -> Callable[[Callable[..., RT]], Callable[..., RT]]: """ Retry a function if it fails with any Exception defined in "errors". @@ -281,9 +284,9 @@ def retry( if error_condition.retry_on_this_condition: retriable_errors.add(error_condition.error) - def decorate(func): + def decorate(func: Callable[..., RT]) -> Callable[..., RT]: @functools.wraps(func) - def call(*args, **kwargs): + def call(*args, **kwargs) -> RT: intervals_remaining = copy.deepcopy(intervals) while True: try: @@ -488,13 +491,15 @@ def error_meets_conditions(e, error_conditions): DEFAULT_DELAYS = (0, 1, 1, 4, 16, 64) DEFAULT_TIMEOUT = 300 +E = TypeVar("E", bound=Exception) # so mypy understands passed through types + # TODO: Replace the use of this with retry() # The aws provisioner and jobstore need a large refactoring to be boto3 compliant, so this is # still used there to avoid the duplication of future work def old_retry( delays: Iterable[float] = DEFAULT_DELAYS, timeout: float = DEFAULT_TIMEOUT, - predicate: Callable[[Exception], bool] = lambda e: False, + predicate: Callable[[E], bool] = lambda e: False, ) -> Generator[ContextManager, None, None]: """ Deprecated. diff --git a/src/toil/provisioners/__init__.py b/src/toil/provisioners/__init__.py index 7cf935995c..3a7e8c571f 100644 --- a/src/toil/provisioners/__init__.py +++ b/src/toil/provisioners/__init__.py @@ -174,9 +174,14 @@ def check_valid_node_types(provisioner, node_types: List[Tuple[Set[str], Optiona class NoSuchClusterException(Exception): """Indicates that the specified cluster does not exist.""" - def __init__(self, cluster_name): + def __init__(self, cluster_name: str) -> None: super().__init__(f"The cluster '{cluster_name}' could not be found") +class NoSuchZoneException(Exception): + """Indicates that a valid zone could not be found.""" + def __init__(self) -> None: + super().__init__(f"No valid zone could be found!") + class ClusterTypeNotSupportedException(Exception): """Indicates that a provisioner does not support a given cluster type.""" diff --git a/src/toil/provisioners/abstractProvisioner.py b/src/toil/provisioners/abstractProvisioner.py index 576cf06f77..615989cf2a 100644 --- a/src/toil/provisioners/abstractProvisioner.py +++ b/src/toil/provisioners/abstractProvisioner.py @@ -162,7 +162,7 @@ def __init__( for override in nodeStorageOverrides or []: nodeShape, storageOverride = override.split(':') self._nodeStorageOverrides[nodeShape] = int(storageOverride) - self._leaderPrivateIP = None + self._leaderPrivateIP: Optional[str] = None # This will hold an SSH public key for Mesos clusters, or the # Kubernetes joining information as a dict for Kubernetes clusters. self._leaderWorkerAuthentication = None @@ -1236,7 +1236,7 @@ def addKubernetesWorker(self, config: InstanceConfiguration, authVars: Dict[str, WantedBy=multi-user.target ''').format(**values)) - def _getIgnitionUserData(self, role, keyPath=None, preemptible=False, architecture='amd64'): + def _getIgnitionUserData(self, role: str, keyPath: Optional[str] = None, preemptible: bool = False, architecture: str = 'amd64') -> str: """ Return the text (not bytes) user data to pass to a provisioned node. diff --git a/src/toil/provisioners/aws/__init__.py b/src/toil/provisioners/aws/__init__.py index 1055962825..694780591d 100644 --- a/src/toil/provisioners/aws/__init__.py +++ b/src/toil/provisioners/aws/__init__.py @@ -18,6 +18,8 @@ from statistics import mean, stdev from typing import List, Optional +from botocore.client import BaseClient + from toil.lib.aws import (get_aws_zone_from_boto, get_aws_zone_from_environment, get_aws_zone_from_environment_region, @@ -27,8 +29,10 @@ ZoneTuple = namedtuple('ZoneTuple', ['name', 'price_deviation']) + def get_aws_zone_from_spot_market(spotBid: Optional[float], nodeType: Optional[str], - boto2_ec2: Optional["boto.connection.AWSAuthConnection"], zone_options: Optional[List[str]]) -> Optional[str]: + boto3_ec2: Optional[BaseClient], zone_options: Optional[List[str]]) -> \ +Optional[str]: """ If a spot bid, node type, and Boto2 EC2 connection are specified, picks a zone where instances are easy to buy from the zones in the region of the @@ -40,21 +44,22 @@ def get_aws_zone_from_spot_market(spotBid: Optional[float], nodeType: Optional[s """ if spotBid: # if spot bid is present, all the other parameters must be as well - assert bool(spotBid) == bool(nodeType) == bool(boto2_ec2) + assert bool(spotBid) == bool(nodeType) == bool(boto3_ec2) # if the zone is unset and we are using the spot market, optimize our # choice based on the spot history if zone_options is None: # We can use all the zones in the region - zone_options = [z.name for z in boto2_ec2.get_all_zones()] + zone_options = [z.name for z in boto3_ec2.describe_availability_zones()] - return optimize_spot_bid(boto2_ec2, instance_type=nodeType, spot_bid=float(spotBid), zone_options=zone_options) + return optimize_spot_bid(boto3_ec2, instance_type=nodeType, spot_bid=float(spotBid), zone_options=zone_options) else: return None def get_best_aws_zone(spotBid: Optional[float] = None, nodeType: Optional[str] = None, - boto2_ec2: Optional["boto.connection.AWSAuthConnection"] = None, zone_options: Optional[List[str]] = None) -> Optional[str]: + boto3_ec2: Optional[BaseClient] = None, + zone_options: Optional[List[str]] = None) -> Optional[str]: """ Get the right AWS zone to use. @@ -81,12 +86,13 @@ def get_best_aws_zone(spotBid: Optional[float] = None, nodeType: Optional[str] = """ return get_aws_zone_from_environment() or \ get_aws_zone_from_metadata() or \ - get_aws_zone_from_spot_market(spotBid, nodeType, boto2_ec2, zone_options) or \ + get_aws_zone_from_spot_market(spotBid, nodeType, boto3_ec2, zone_options) or \ get_aws_zone_from_environment_region() or \ get_aws_zone_from_boto() -def choose_spot_zone(zones: List[str], bid: float, spot_history: List['boto.ec2.spotpricehistory.SpotPriceHistory']) -> str: +def choose_spot_zone(zones: List[str], bid: float, + spot_history: List['boto.ec2.spotpricehistory.SpotPriceHistory']) -> str: """ Returns the zone to put the spot request based on, in order of priority: @@ -137,7 +143,7 @@ def choose_spot_zone(zones: List[str], bid: float, spot_history: List['boto.ec2. return min(markets_under_bid or markets_over_bid, key=attrgetter('price_deviation')).name -def optimize_spot_bid(boto2_ec2, instance_type, spot_bid, zone_options: List[str]): +def optimize_spot_bid(boto3_ec2: BaseClient, instance_type: str, spot_bid: float, zone_options: List[str]): """ Check whether the bid is in line with history and makes an effort to place the instance in a sensible zone. @@ -145,7 +151,7 @@ def optimize_spot_bid(boto2_ec2, instance_type, spot_bid, zone_options: List[str :param zone_options: The collection of allowed zones to consider, within the region associated with the Boto2 connection. """ - spot_history = _get_spot_history(boto2_ec2, instance_type) + spot_history = _get_spot_history(boto3_ec2, instance_type) if spot_history: _check_spot_bid(spot_bid, spot_history) most_stable_zone = choose_spot_zone(zone_options, spot_bid, spot_history) @@ -183,20 +189,19 @@ def _check_spot_bid(spot_bid, spot_history): average = mean([datum.price for datum in spot_history]) if spot_bid > average * 2: logger.warning("Your bid $ %f is more than double this instance type's average " - "spot price ($ %f) over the last week", spot_bid, average) + "spot price ($ %f) over the last week", spot_bid, average) -def _get_spot_history(boto2_ec2, instance_type): +def _get_spot_history(boto3_ec2: BaseClient, instance_type: str): """ Returns list of 1,000 most recent spot market data points represented as SpotPriceHistory objects. Note: The most recent object/data point will be first in the list. :rtype: list[SpotPriceHistory] """ - one_week_ago = datetime.datetime.now() - datetime.timedelta(days=7) - spot_data = boto2_ec2.get_spot_price_history(start_time=one_week_ago.isoformat(), - instance_type=instance_type, - product_description="Linux/UNIX") + spot_data = boto3_ec2.describe_spot_price_history(StartTime=one_week_ago.isoformat(), + InstanceTypes=[instance_type], + ProductDescriptions=["Linux/UNIX"]) spot_data.sort(key=attrgetter("timestamp"), reverse=True) return spot_data diff --git a/src/toil/provisioners/aws/awsProvisioner.py b/src/toil/provisioners/aws/awsProvisioner.py index 43c0e9f8b8..48a8edc470 100644 --- a/src/toil/provisioners/aws/awsProvisioner.py +++ b/src/toil/provisioners/aws/awsProvisioner.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import logging import os @@ -29,34 +31,34 @@ Iterable, List, Optional, - Set) + Set, + Union, + Literal, + cast, + TypeVar) from urllib.parse import unquote # We need these to exist as attributes we can get off of the boto object -import boto.ec2 -import boto.iam -import boto.vpc -from boto.ec2.blockdevicemapping import \ - BlockDeviceMapping as Boto2BlockDeviceMapping -from boto.ec2.blockdevicemapping import BlockDeviceType as Boto2BlockDeviceType -from boto.ec2.instance import Instance as Boto2Instance -from boto.exception import BotoServerError, EC2ResponseError -from boto.utils import get_instance_metadata from botocore.exceptions import ClientError +from mypy_boto3_autoscaling.client import AutoScalingClient +from mypy_boto3_ec2.service_resource import Instance +from mypy_boto3_iam.type_defs import InstanceProfileTypeDef, RoleTypeDef, ListRolePoliciesResponseTypeDef +from mypy_extensions import VarArg, KwArg -from toil.lib.aws import zone_to_region +from toil.lib.aws import zone_to_region, AWSRegionName, AWSServerErrors from toil.lib.aws.ami import get_flatcar_ami from toil.lib.aws.iam import (CLUSTER_LAUNCHING_PERMISSIONS, get_policy_permissions, policy_permissions_allow) from toil.lib.aws.session import AWSConnectionManager -from toil.lib.aws.utils import create_s3_bucket +from toil.lib.aws.utils import create_s3_bucket, flatten_tags from toil.lib.conversions import human2bytes from toil.lib.ec2 import (a_short_time, create_auto_scaling_group, create_instances, create_launch_template, create_ondemand_instances, + increase_instance_hop_limit, create_spot_instances, wait_instances_running, wait_transition, @@ -73,12 +75,19 @@ old_retry, retry) from toil.provisioners import (ClusterCombinationNotSupportedException, - NoSuchClusterException) + NoSuchClusterException, + NoSuchZoneException) from toil.provisioners.abstractProvisioner import (AbstractProvisioner, ManagedNodesNotSupportedException, Shape) from toil.provisioners.aws import get_best_aws_zone from toil.provisioners.node import Node +from toil.lib.aws.session import client as get_client + +from mypy_boto3_ec2.client import EC2Client +from mypy_boto3_iam.client import IAMClient +from mypy_boto3_ec2.type_defs import DescribeInstancesResultTypeDef, InstanceTypeDef, TagTypeDef, BlockDeviceMappingTypeDef, EbsBlockDeviceTypeDef, FilterTypeDef, SpotInstanceRequestTypeDef, TagDescriptionTypeDef, SecurityGroupTypeDef, \ + CreateSecurityGroupResultTypeDef, IpPermissionTypeDef, ReservationTypeDef logger = logging.getLogger(__name__) logging.getLogger("boto").setLevel(logging.CRITICAL) @@ -98,14 +107,8 @@ # The suffix of the S3 bucket associated with the cluster _S3_BUCKET_INTERNAL_SUFFIX = '--internal' -# prevent removal of these imports -str(boto.ec2) -str(boto.iam) -str(boto.vpc) - - -def awsRetryPredicate(e): +def awsRetryPredicate(e: Exception) -> bool: if isinstance(e, socket.gaierror): # Could be a DNS outage: # socket.gaierror: [Errno -2] Name or service not known @@ -135,33 +138,38 @@ def expectedShutdownErrors(e: Exception) -> bool: return get_error_status(e) == 400 and 'dependent object' in get_error_body(e) -def awsRetry(f): +F = TypeVar("F") # so mypy understands passed through types + + +def awsRetry(f: Callable[..., F]) -> Callable[..., F]: """ - This decorator retries the wrapped function if aws throws unexpected errors - errors. + This decorator retries the wrapped function if aws throws unexpected errors. + It should wrap any function that makes use of boto """ + @wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: for attempt in old_retry(delays=truncExpBackoff(), timeout=300, predicate=awsRetryPredicate): with attempt: return f(*args, **kwargs) + return wrapper -def awsFilterImpairedNodes(nodes, ec2): +def awsFilterImpairedNodes(nodes: List[InstanceTypeDef], boto3_ec2: EC2Client) -> List[InstanceTypeDef]: # if TOIL_AWS_NODE_DEBUG is set don't terminate nodes with # failing status checks so they can be debugged nodeDebug = os.environ.get('TOIL_AWS_NODE_DEBUG') in ('True', 'TRUE', 'true', True) if not nodeDebug: return nodes - nodeIDs = [node.id for node in nodes] - statuses = ec2.get_all_instance_status(instance_ids=nodeIDs) - statusMap = {status.id: status.instance_status for status in statuses} - healthyNodes = [node for node in nodes if statusMap.get(node.id, None) != 'impaired'] - impairedNodes = [node.id for node in nodes if statusMap.get(node.id, None) == 'impaired'] + nodeIDs = [node["InstanceId"] for node in nodes] + statuses = boto3_ec2.describe_instance_status(InstanceIds=nodeIDs) + statusMap = {status["InstanceId"]: status["InstanceStatus"]["Status"] for status in statuses["InstanceStatuses"]} + healthyNodes = [node for node in nodes if statusMap.get(node["InstanceId"], None) != 'impaired'] + impairedNodes = [node["InstanceId"] for node in nodes if statusMap.get(node["InstanceId"], None) == 'impaired'] logger.warning('TOIL_AWS_NODE_DEBUG is set and nodes %s have failed EC2 status checks so ' 'will not be terminated.', ' '.join(impairedNodes)) return healthyNodes @@ -170,8 +178,23 @@ def awsFilterImpairedNodes(nodes, ec2): class InvalidClusterStateException(Exception): pass + +def collapse_tags(instance_tags: List[TagTypeDef]) -> Dict[str, str]: + """ + Collapse tags from boto3 format to node format + :param instance_tags: tags as list of TagTypeDef + :return: Dict of tags + """ + collapsed_tags: Dict[str, str] = dict() + for tag in instance_tags: + if tag.get("Key") is not None: + collapsed_tags[tag["Key"]] = tag["Value"] + return collapsed_tags + + class AWSProvisioner(AbstractProvisioner): - def __init__(self, clusterName, clusterType, zone, nodeStorage, nodeStorageOverrides, sseKey): + def __init__(self, clusterName: Optional[str], clusterType: Optional[str], zone: Optional[str], + nodeStorage: int, nodeStorageOverrides: Optional[List[str]], sseKey: Optional[str]): self.cloud = 'aws' self._sseKey = sseKey # self._zone will be filled in by base class constructor @@ -186,7 +209,7 @@ def __init__(self, clusterName, clusterType, zone, nodeStorage, nodeStorageOverr # Determine our region to work in, before readClusterSettings() which # might need it. TODO: support multiple regions in one cluster - self._region = zone_to_region(zone) + self._region: AWSRegionName = zone_to_region(zone) # Set up our connections to AWS self.aws = AWSConnectionManager() @@ -198,15 +221,17 @@ def __init__(self, clusterName, clusterType, zone, nodeStorage, nodeStorageOverr # Call base class constructor, which will call createClusterSettings() # or readClusterSettings() super().__init__(clusterName, clusterType, zone, nodeStorage, nodeStorageOverrides) + self._leader_subnet: str = self._get_default_subnet(self._zone) + self._tags: Dict[str, Any] = {} # After self.clusterName is set, generate a valid name for the S3 bucket associated with this cluster suffix = _S3_BUCKET_INTERNAL_SUFFIX self.s3_bucket_name = self.clusterName[:_S3_BUCKET_MAX_NAME_LEN - len(suffix)] + suffix - def supportedClusterTypes(self): + def supportedClusterTypes(self) -> Set[str]: return {'mesos', 'kubernetes'} - def createClusterSettings(self): + def createClusterSettings(self) -> None: """ Create a new set of cluster settings for a cluster to be deployed into AWS. @@ -216,41 +241,51 @@ def createClusterSettings(self): # constructor. assert self._zone is not None - def readClusterSettings(self): + def readClusterSettings(self) -> None: """ Reads the cluster settings from the instance metadata, which assumes the instance is the leader. """ - instanceMetaData = get_instance_metadata() - ec2 = self.aws.boto2(self._region, 'ec2') - instance = ec2.get_all_instances(instance_ids=[instanceMetaData["instance-id"]])[0].instances[0] + from ec2_metadata import ec2_metadata + boto3_ec2 = self.aws.client(self._region, 'ec2') + instance: InstanceTypeDef = boto3_ec2.describe_instances(InstanceIds=[ec2_metadata.instance_id])["Reservations"][0]["Instances"][0] # The cluster name is the same as the name of the leader. - self.clusterName = str(instance.tags["Name"]) + self.clusterName: str = "default-toil-cluster-name" + for tag in instance["Tags"]: + if tag.get("Key") == "Name": + self.clusterName = tag["Value"] # Determine what subnet we, the leader, are in - self._leader_subnet = instance.subnet_id + self._leader_subnet = instance["SubnetId"] # Determine where to deploy workers. self._worker_subnets_by_zone = self._get_good_subnets_like(self._leader_subnet) - self._leaderPrivateIP = instanceMetaData['local-ipv4'] # this is PRIVATE IP - self._keyName = list(instanceMetaData['public-keys'].keys())[0] - self._tags = {k: v for k, v in self.getLeader().tags.items() if k != _TAG_KEY_TOIL_NODE_TYPE} + self._leaderPrivateIP = ec2_metadata.private_ipv4 # this is PRIVATE IP + self._tags = {k: v for k, v in (self.getLeader().tags or {}).items() if k != _TAG_KEY_TOIL_NODE_TYPE} # Grab the ARN name of the instance profile (a str) to apply to workers - self._leaderProfileArn = instanceMetaData['iam']['info']['InstanceProfileArn'] + leader_info = None + for attempt in old_retry(timeout=300, predicate=lambda e: True): + with attempt: + leader_info = ec2_metadata.iam_info + if leader_info is None: + raise RuntimeError("Could not get EC2 metadata IAM info") + if leader_info is None: + # This is more for mypy as it is unable to see that the retry will guarantee this is not None + # and that this is not reachable + raise RuntimeError(f"Leader IAM metadata is unreachable.") + self._leaderProfileArn = leader_info["InstanceProfileArn"] + # The existing metadata API returns a single string if there is one security group, but # a list when there are multiple: change the format to always be a list. - rawSecurityGroups = instanceMetaData['security-groups'] - self._leaderSecurityGroupNames = {rawSecurityGroups} if not isinstance(rawSecurityGroups, list) else set(rawSecurityGroups) + rawSecurityGroups = ec2_metadata.security_groups + self._leaderSecurityGroupNames: Set[str] = set(rawSecurityGroups) # Since we have access to the names, we don't also need to use any IDs - self._leaderSecurityGroupIDs = set() + self._leaderSecurityGroupIDs: Set[str] = set() # Let the base provisioner work out how to deploy duly authorized # workers for this leader. self._setLeaderWorkerAuthentication() - @retry(errors=[ErrorCondition( - error=ClientError, - error_codes=[404, 500, 502, 503, 504] - )]) + @retry(errors=[AWSServerErrors]) def _write_file_to_cloud(self, key: str, contents: bytes) -> str: bucket_name = self.s3_bucket_name @@ -289,7 +324,7 @@ def _read_file_from_cloud(self, key: str) -> bytes: obj = self.aws.resource(self._region, 's3').Object(bucket_name, key) try: - return obj.get().get('Body').read() + return obj.get()['Body'].read() except ClientError as e: if get_error_status(e) == 404: logger.warning(f'Trying to read non-existent file "{key}" from {bucket_name}.') @@ -305,11 +340,11 @@ def launchCluster(self, owner: str, keyName: str, botoPath: str, - userTags: Optional[dict], + userTags: Optional[Dict[str, str]], vpcSubnet: Optional[str], awsEc2ProfileArn: Optional[str], - awsEc2ExtraSecurityGroupIds: Optional[list], - **kwargs): + awsEc2ExtraSecurityGroupIds: Optional[List[str]], + **kwargs: Dict[str, Any]) -> None: """ Starts a single leader node and populates this class with the leader's metadata. @@ -352,9 +387,6 @@ def launchCluster(self, if vpcSubnet: # This is where we put the leader self._leader_subnet = vpcSubnet - else: - # Find the default subnet for the zone - self._leader_subnet = self._get_default_subnet(self._zone) profileArn = awsEc2ProfileArn or self._createProfileArn() @@ -370,7 +402,7 @@ def launchCluster(self, if userTags is not None: self._tags.update(userTags) - #All user specified tags have been set + # All user specified tags have been set userData = self._getIgnitionUserData('leader', architecture=self._architecture) if self.clusterType == 'kubernetes': @@ -383,18 +415,18 @@ def launchCluster(self, leader_tags[_TAG_KEY_TOIL_NODE_TYPE] = 'leader' logger.debug('Launching leader with tags: %s', leader_tags) - instances = create_instances(self.aws.resource(self._region, 'ec2'), - image_id=self._discoverAMI(), - num_instances=1, - key_name=self._keyName, - security_group_ids=createdSGs + (awsEc2ExtraSecurityGroupIds or []), - instance_type=leader_type.name, - user_data=userData, - block_device_map=bdms, - instance_profile_arn=profileArn, - placement_az=self._zone, - subnet_id=self._leader_subnet, - tags=leader_tags) + instances: List[Instance] = create_instances(self.aws.resource(self._region, 'ec2'), + image_id=self._discoverAMI(), + num_instances=1, + key_name=self._keyName, + security_group_ids=createdSGs + (awsEc2ExtraSecurityGroupIds or []), + instance_type=leader_type.name, + user_data=userData, + block_device_map=bdms, + instance_profile_arn=profileArn, + placement_az=self._zone, + subnet_id=self._leader_subnet, + tags=leader_tags) # wait for the leader to exist at all leader = instances[0] @@ -425,7 +457,7 @@ def launchCluster(self, leaderNode = Node(publicIP=leader.public_ip_address, privateIP=leader.private_ip_address, name=leader.id, launchTime=leader.launch_time, nodeType=leader_type.name, preemptible=False, - tags=leader.tags) + tags=collapse_tags(leader.tags)) leaderNode.waitForNode('toil_leader') # Download credentials @@ -483,7 +515,7 @@ def _get_good_subnets_like(self, base_subnet_id: str) -> Dict[str, List[str]]: acls = set(self._get_subnet_acls(base_subnet_id)) # Compose a filter that selects the subnets we might want - filters = [{ + filters: List[FilterTypeDef] = [{ 'Name': 'vpc-id', 'Values': [vpc_id] }, { @@ -495,7 +527,7 @@ def _get_good_subnets_like(self, base_subnet_id: str) -> Dict[str, List[str]]: }] # Fill in this collection - by_az = {} + by_az: Dict[str, List[str]] = {} # Go get all the subnets. There's no way to page manually here so it # must page automatically. @@ -546,7 +578,7 @@ def _get_default_subnet(self, zone: str) -> str: """ # Compose a filter that selects the default subnet in the AZ - filters = [{ + filters: List[FilterTypeDef] = [{ 'Name': 'default-for-az', 'Values': ['true'] }, { @@ -586,7 +618,7 @@ def getKubernetesCloudProvider(self) -> Optional[str]: return 'aws' - def getNodeShape(self, instance_type: str, preemptible=False) -> Shape: + def getNodeShape(self, instance_type: str, preemptible: bool = False) -> Shape: """ Get the Shape for the given instance type (e.g. 't2.medium'). """ @@ -603,13 +635,13 @@ def getNodeShape(self, instance_type: str, preemptible=False) -> Shape: # mesos about whether a job can run on a particular node type memory = (type_info.memory - 0.1) * 2 ** 30 return Shape(wallTime=60 * 60, - memory=memory, + memory=int(memory), cores=type_info.cores, - disk=disk, + disk=int(disk), preemptible=preemptible) @staticmethod - def retryPredicate(e): + def retryPredicate(e: Exception) -> bool: return awsRetryPredicate(e) def destroyCluster(self) -> None: @@ -619,8 +651,8 @@ def destroyCluster(self) -> None: # The leader may create more instances while we're terminating the workers. vpcId = None try: - leader = self._getLeaderInstance() - vpcId = leader.vpc_id + leader = self._getLeaderInstanceBoto3() + vpcId = leader.get("VpcId") logger.info('Terminating the leader first ...') self._terminateInstances([leader]) except (NoSuchClusterException, InvalidClusterStateException): @@ -651,14 +683,16 @@ def destroyCluster(self) -> None: # Do the workers after the ASGs because some may belong to ASGs logger.info('Terminating any remaining workers ...') removed = False - instances = self._get_nodes_in_cluster(include_stopped_nodes=True) + instances = self._get_nodes_in_cluster_boto3(include_stopped_nodes=True) spotIDs = self._getSpotRequestIDs() + boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name="ec2") if spotIDs: - self.aws.boto2(self._region, 'ec2').cancel_spot_instance_requests(request_ids=spotIDs) + boto3_ec2.cancel_spot_instance_requests(SpotInstanceRequestIds=spotIDs) + # self.aws.boto2(self._region, 'ec2').cancel_spot_instance_requests(request_ids=spotIDs) removed = True - instancesToTerminate = awsFilterImpairedNodes(instances, self.aws.boto2(self._region, 'ec2')) + instancesToTerminate = awsFilterImpairedNodes(instances, self.aws.client(self._region, 'ec2')) if instancesToTerminate: - vpcId = vpcId or instancesToTerminate[0].vpc_id + vpcId = vpcId or instancesToTerminate[0].get("VpcId") self._terminateInstances(instancesToTerminate) removed = True if removed: @@ -672,7 +706,7 @@ def destroyCluster(self) -> None: # for some LuanchTemplate. mistake = False for ltID in self._get_launch_template_ids(): - response = self.aws.client(self._region, 'ec2').delete_launch_template(LaunchTemplateId=ltID) + response = boto3_ec2.delete_launch_template(LaunchTemplateId=ltID) if 'LaunchTemplate' not in response: mistake = True else: @@ -694,16 +728,17 @@ def destroyCluster(self) -> None: removed = False for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors): with attempt: - for sg in self.aws.boto2(self._region, 'ec2').get_all_security_groups(): + security_groups: List[SecurityGroupTypeDef] = boto3_ec2.describe_security_groups()["SecurityGroups"] + for security_group in security_groups: # TODO: If we terminate the leader and the workers but # miss the security group, we won't find it now because # we won't have vpcId set. - if sg.name == self.clusterName and vpcId and sg.vpc_id == vpcId: + if security_group.get("GroupName") == self.clusterName and vpcId and security_group.get("VpcId") == vpcId: try: - self.aws.boto2(self._region, 'ec2').delete_security_group(group_id=sg.id) + boto3_ec2.delete_security_group(GroupId=security_group["GroupId"]) removed = True - except BotoServerError as e: - if e.error_code == 'InvalidGroup.NotFound': + except ClientError as e: + if get_error_code(e) == 'InvalidGroup.NotFound': pass else: raise @@ -777,10 +812,9 @@ def _recover_node_type_bid(self, node_type: Set[str], spot_bid: Optional[float]) return spot_bid - def addNodes(self, nodeTypes: Set[str], numNodes, preemptible, spotBid=None) -> int: + def addNodes(self, nodeTypes: Set[str], numNodes: int, preemptible: bool, spotBid: Optional[float] = None) -> int: # Grab the AWS connection we need - ec2 = self.aws.boto2(self._region, 'ec2') - + boto3_ec2 = get_client(service_name='ec2', region_name=self._region) assert self._leaderPrivateIP if preemptible: @@ -792,7 +826,7 @@ def addNodes(self, nodeTypes: Set[str], numNodes, preemptible, spotBid=None) -> node_type = next(iter(nodeTypes)) type_info = E2Instances[node_type] root_vol_size = self._nodeStorageOverrides.get(node_type, self._nodeStorage) - bdm = self._getBoto2BlockDeviceMapping(type_info, + bdm = self._getBoto3BlockDeviceMapping(type_info, rootVolSize=root_vol_size) # Pick a zone and subnet_id to launch into @@ -803,7 +837,7 @@ def addNodes(self, nodeTypes: Set[str], numNodes, preemptible, spotBid=None) -> # We're allowed to pick from any of these zones. zone_options = list(self._worker_subnets_by_zone.keys()) - zone = get_best_aws_zone(spotBid, type_info.name, ec2, zone_options) + zone = get_best_aws_zone(spotBid, type_info.name, boto3_ec2, zone_options) else: # We don't need to ever do any balancing across zones for on-demand # instances. Just pick a zone. @@ -814,6 +848,9 @@ def addNodes(self, nodeTypes: Set[str], numNodes, preemptible, spotBid=None) -> # The workers aren't allowed in the leader's zone. # Pick an arbitrary zone we can use. zone = next(iter(self._worker_subnets_by_zone.keys())) + if zone is None: + logger.exception("Could not find a valid zone. Make sure TOIL_AWS_ZONE is set or spot bids are not too low.") + raise NoSuchZoneException() if self._leader_subnet in self._worker_subnets_by_zone.get(zone, []): # The leader's subnet is an option for this zone, so use it. subnet_id = self._leader_subnet @@ -822,21 +859,40 @@ def addNodes(self, nodeTypes: Set[str], numNodes, preemptible, spotBid=None) -> subnet_id = next(iter(self._worker_subnets_by_zone[zone])) keyPath = self._sseKey if self._sseKey else None - userData = self._getIgnitionUserData('worker', keyPath, preemptible, self._architecture) + userData: str = self._getIgnitionUserData('worker', keyPath, preemptible, self._architecture) + userDataBytes: bytes = b"" if isinstance(userData, str): # Spot-market provisioning requires bytes for user data. - userData = userData.encode('utf-8') - - kwargs = {'key_name': self._keyName, - 'security_group_ids': self._getSecurityGroupIDs(), - 'instance_type': type_info.name, - 'user_data': userData, - 'block_device_map': bdm, - 'instance_profile_arn': self._leaderProfileArn, - 'placement': zone, - 'subnet_id': subnet_id} - - instancesLaunched = [] + userDataBytes = userData.encode('utf-8') + + spot_kwargs = {'KeyName': self._keyName, + 'LaunchSpecification': { + 'SecurityGroupIds': self._getSecurityGroupIDs(), + 'InstanceType': type_info.name, + 'UserData': userDataBytes, + 'BlockDeviceMappings': bdm, + 'IamInstanceProfile': { + 'Arn': self._leaderProfileArn + }, + 'Placement': { + 'AvailabilityZone': zone + }, + 'SubnetId': subnet_id} + } + on_demand_kwargs = {'KeyName': self._keyName, + 'SecurityGroupIds': self._getSecurityGroupIDs(), + 'InstanceType': type_info.name, + 'UserData': userDataBytes, + 'BlockDeviceMappings': bdm, + 'IamInstanceProfile': { + 'Arn': self._leaderProfileArn + }, + 'Placement': { + 'AvailabilityZone': zone + }, + 'SubnetId': subnet_id} + + instancesLaunched: List[InstanceTypeDef] = [] for attempt in old_retry(predicate=awsRetryPredicate): with attempt: @@ -845,41 +901,45 @@ def addNodes(self, nodeTypes: Set[str], numNodes, preemptible, spotBid=None) -> # every request in this method if not preemptible: logger.debug('Launching %s non-preemptible nodes', numNodes) - instancesLaunched = create_ondemand_instances(ec2, + instancesLaunched = create_ondemand_instances(boto3_ec2=boto3_ec2, image_id=self._discoverAMI(), - spec=kwargs, num_instances=numNodes) + spec=on_demand_kwargs, num_instances=numNodes) else: logger.debug('Launching %s preemptible nodes', numNodes) # force generator to evaluate - instancesLaunched = list(create_spot_instances(ec2=ec2, - price=spotBid, - image_id=self._discoverAMI(), - tags={_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName}, - spec=kwargs, - num_instances=numNodes, - tentative=True) - ) + generatedInstancesLaunched: List[DescribeInstancesResultTypeDef] = list(create_spot_instances(boto3_ec2=boto3_ec2, + price=spotBid, + image_id=self._discoverAMI(), + tags={_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName}, + spec=spot_kwargs, + num_instances=numNodes, + tentative=True) + ) # flatten the list - instancesLaunched = [item for sublist in instancesLaunched for item in sublist] + flatten_reservations: List[ReservationTypeDef] = [reservation for subdict in generatedInstancesLaunched for reservation in subdict["Reservations"] for key, value in subdict.items()] + # get a flattened list of all requested instances, as before instancesLaunched is a dict of reservations which is a dict of instance requests + instancesLaunched = [instance for instances in flatten_reservations for instance in instances['Instances']] for attempt in old_retry(predicate=awsRetryPredicate): with attempt: - wait_instances_running(ec2, instancesLaunched) + list(wait_instances_running(boto3_ec2, instancesLaunched)) # ensure all instances are running + + increase_instance_hop_limit(boto3_ec2, instancesLaunched) self._tags[_TAG_KEY_TOIL_NODE_TYPE] = 'worker' - AWSProvisioner._addTags(instancesLaunched, self._tags) + AWSProvisioner._addTags(boto3_ec2, instancesLaunched, self._tags) if self._sseKey: for i in instancesLaunched: self._waitForIP(i) - node = Node(publicIP=i.ip_address, privateIP=i.private_ip_address, name=i.id, - launchTime=i.launch_time, nodeType=i.instance_type, preemptible=preemptible, - tags=i.tags) + node = Node(publicIP=i['PublicIpAddress'], privateIP=i['PrivateIpAddress'], name=i['InstanceId'], + launchTime=i['LaunchTime'], nodeType=i['InstanceType'], preemptible=preemptible, + tags=collapse_tags(i['Tags'])) node.waitForNode('toil_worker') node.coreRsync([self._sseKey, ':' + self._sseKey], applianceName='toil_worker') logger.debug('Launched %s new instance(s)', numNodes) return len(instancesLaunched) - def addManagedNodes(self, nodeTypes: Set[str], minNodes, maxNodes, preemptible, spotBid=None) -> None: + def addManagedNodes(self, nodeTypes: Set[str], minNodes: int, maxNodes: int, preemptible: bool, spotBid: Optional[float] = None) -> None: if self.clusterType != 'kubernetes': raise ManagedNodesNotSupportedException("Managed nodes only supported for Kubernetes clusters") @@ -901,17 +961,17 @@ def addManagedNodes(self, nodeTypes: Set[str], minNodes, maxNodes, preemptible, def getProvisionedWorkers(self, instance_type: Optional[str] = None, preemptible: Optional[bool] = None) -> List[Node]: assert self._leaderPrivateIP - entireCluster = self._get_nodes_in_cluster(instance_type=instance_type) + entireCluster = self._get_nodes_in_cluster_boto3(instance_type=instance_type) logger.debug('All nodes in cluster: %s', entireCluster) - workerInstances = [i for i in entireCluster if i.private_ip_address != self._leaderPrivateIP] + workerInstances: List[InstanceTypeDef] = [i for i in entireCluster if i["PrivateIpAddress"] != self._leaderPrivateIP] logger.debug('All workers found in cluster: %s', workerInstances) if preemptible is not None: - workerInstances = [i for i in workerInstances if preemptible == (i.spot_instance_request_id is not None)] + workerInstances = [i for i in workerInstances if preemptible == (i["SpotInstanceRequestId"] is not None)] logger.debug('%spreemptible workers found in cluster: %s', 'non-' if not preemptible else '', workerInstances) - workerInstances = awsFilterImpairedNodes(workerInstances, self.aws.boto2(self._region, 'ec2')) - return [Node(publicIP=i.ip_address, privateIP=i.private_ip_address, - name=i.id, launchTime=i.launch_time, nodeType=i.instance_type, - preemptible=i.spot_instance_request_id is not None, tags=i.tags) + workerInstances = awsFilterImpairedNodes(workerInstances, self.aws.client(self._region, 'ec2')) + return [Node(publicIP=i["PublicIpAddress"], privateIP=i["PrivateIpAddress"], + name=i["InstanceId"], launchTime=i["LaunchTime"], nodeType=i["InstanceType"], + preemptible=i["SpotInstanceRequestId"] is not None, tags=collapse_tags(i["Tags"])) for i in workerInstances] @memoize @@ -952,37 +1012,65 @@ def _is_our_namespaced_name(self, namespaced_name: str) -> bool: denamespaced = '/' + '_'.join(s.replace('_', '/') for s in namespaced_name.split('__')) return denamespaced.startswith(self._toNameSpace()) + def _getLeaderInstanceBoto3(self) -> InstanceTypeDef: + """ + Get the Boto 3 instance for the cluster's leader. + :return: InstanceTypeDef + """ + # Tags are stored differently in Boto 3 + instances: List[InstanceTypeDef] = self._get_nodes_in_cluster_boto3(include_stopped_nodes=True) + instances.sort(key=lambda x: x["LaunchTime"]) + try: + leader = instances[0] # assume leader was launched first + except IndexError: + raise NoSuchClusterException(self.clusterName) + if leader.get("Tags") is not None: + tag_value = next(item["Value"] for item in leader["Tags"] if item["Key"] == _TAG_KEY_TOIL_NODE_TYPE) + else: + tag_value = None + if (tag_value or 'leader') != 'leader': + raise InvalidClusterStateException( + 'Invalid cluster state! The first launched instance appears not to be the leader ' + 'as it is missing the "leader" tag. The safest recovery is to destroy the cluster ' + 'and restart the job. Incorrect Leader ID: %s' % leader["InstanceId"] + ) + return leader - def _getLeaderInstance(self) -> Boto2Instance: + def _getLeaderInstance(self) -> InstanceTypeDef: """ Get the Boto 2 instance for the cluster's leader. """ - instances = self._get_nodes_in_cluster(include_stopped_nodes=True) - instances.sort(key=lambda x: x.launch_time) + instances = self._get_nodes_in_cluster_boto3(include_stopped_nodes=True) + instances.sort(key=lambda x: x["LaunchTime"]) try: - leader = instances[0] # assume leader was launched first + leader: InstanceTypeDef = instances[0] # assume leader was launched first except IndexError: raise NoSuchClusterException(self.clusterName) - if (leader.tags.get(_TAG_KEY_TOIL_NODE_TYPE) or 'leader') != 'leader': + tagged_node_type: str = 'leader' + for tag in leader["Tags"]: + # If a tag specifying node type exists, + if tag.get("Key") is not None and tag["Key"] == _TAG_KEY_TOIL_NODE_TYPE: + tagged_node_type = tag["Value"] + if tagged_node_type != 'leader': raise InvalidClusterStateException( 'Invalid cluster state! The first launched instance appears not to be the leader ' 'as it is missing the "leader" tag. The safest recovery is to destroy the cluster ' - 'and restart the job. Incorrect Leader ID: %s' % leader.id + 'and restart the job. Incorrect Leader ID: %s' % leader["InstanceId"] ) return leader - def getLeader(self, wait=False) -> Node: + def getLeader(self, wait: bool = False) -> Node: """ Get the leader for the cluster as a Toil Node object. """ - leader = self._getLeaderInstance() + leader: InstanceTypeDef = self._getLeaderInstanceBoto3() - leaderNode = Node(publicIP=leader.ip_address, privateIP=leader.private_ip_address, - name=leader.id, launchTime=leader.launch_time, nodeType=None, - preemptible=False, tags=leader.tags) + leaderNode = Node(publicIP=leader["PublicIpAddress"], privateIP=leader["PrivateIpAddress"], + name=leader["InstanceId"], launchTime=leader["LaunchTime"], nodeType=None, + preemptible=False, tags=collapse_tags(leader["Tags"])) if wait: logger.debug("Waiting for toil_leader to enter 'running' state...") - wait_instances_running(self.aws.boto2(self._region, 'ec2'), [leader]) + wait_instances_running(self.aws.client(self._region, 'ec2'), [leader]) logger.debug('... toil_leader is running') self._waitForIP(leader) leaderNode.waitForNode('toil_leader') @@ -991,17 +1079,20 @@ def getLeader(self, wait=False) -> Node: @classmethod @awsRetry - def _addTag(cls, instance: Boto2Instance, key: str, value: str): - instance.add_tag(key, value) + def _addTag(cls, boto3_ec2: EC2Client, instance: InstanceTypeDef, key: str, value: str) -> None: + if instance.get('Tags') is None: + instance['Tags'] = [] + new_tag: TagTypeDef = {"Key": key, "Value": value} + boto3_ec2.create_tags(Resources=[instance["InstanceId"]], Tags=[new_tag]) @classmethod - def _addTags(cls, instances: List[Boto2Instance], tags: Dict[str, str]): + def _addTags(cls, boto3_ec2: EC2Client, instances: List[InstanceTypeDef], tags: Dict[str, str]) -> None: for instance in instances: for key, value in tags.items(): - cls._addTag(instance, key, value) + cls._addTag(boto3_ec2, instance, key, value) @classmethod - def _waitForIP(cls, instance: Boto2Instance): + def _waitForIP(cls, instance: InstanceTypeDef) -> None: """ Wait until the instances has a public IP address assigned to it. @@ -1010,32 +1101,32 @@ def _waitForIP(cls, instance: Boto2Instance): logger.debug('Waiting for ip...') while True: time.sleep(a_short_time) - instance.update() - if instance.ip_address or instance.public_dns_name or instance.private_ip_address: + if instance.get("PublicIpAddress") or instance.get("PublicDnsName") or instance.get("PrivateIpAddress"): logger.debug('...got ip') break - def _terminateInstances(self, instances: List[Boto2Instance]): - instanceIDs = [x.id for x in instances] + def _terminateInstances(self, instances: List[InstanceTypeDef]) -> None: + instanceIDs = [x["InstanceId"] for x in instances] self._terminateIDs(instanceIDs) logger.info('... Waiting for instance(s) to shut down...') for instance in instances: - wait_transition(instance, {'pending', 'running', 'shutting-down', 'stopping', 'stopped'}, 'terminated') + wait_transition(self.aws.client(region=self._region, service_name="ec2"), instance, {'pending', 'running', 'shutting-down', 'stopping', 'stopped'}, 'terminated') logger.info('Instance(s) terminated.') @awsRetry - def _terminateIDs(self, instanceIDs: List[str]): + def _terminateIDs(self, instanceIDs: List[str]) -> None: logger.info('Terminating instance(s): %s', instanceIDs) - self.aws.boto2(self._region, 'ec2').terminate_instances(instance_ids=instanceIDs) + boto3_ec2 = self.aws.client(region=self._region, service_name="ec2") + boto3_ec2.terminate_instances(InstanceIds=instanceIDs) logger.info('Instance(s) terminated.') @awsRetry - def _deleteRoles(self, names: List[str]): + def _deleteRoles(self, names: List[str]) -> None: """ Delete all the given named IAM roles. Detatches but does not delete associated instance profiles. """ - + boto3_iam = self.aws.client(region=self._region, service_name="iam") for role_name in names: for profile_name in self._getRoleInstanceProfileNames(role_name): # We can't delete either the role or the profile while they @@ -1043,60 +1134,64 @@ def _deleteRoles(self, names: List[str]): for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors): with attempt: - self.aws.client(self._region, 'iam').remove_role_from_instance_profile(InstanceProfileName=profile_name, - RoleName=role_name) + boto3_iam.remove_role_from_instance_profile(InstanceProfileName=profile_name, + RoleName=role_name) # We also need to drop all inline policies for policy_name in self._getRoleInlinePolicyNames(role_name): for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors): with attempt: - self.aws.client(self._region, 'iam').delete_role_policy(PolicyName=policy_name, - RoleName=role_name) + boto3_iam.delete_role_policy(PolicyName=policy_name, + RoleName=role_name) for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors): with attempt: - self.aws.client(self._region, 'iam').delete_role(RoleName=role_name) + boto3_iam.delete_role(RoleName=role_name) logger.debug('... Successfully deleted IAM role %s', role_name) - @awsRetry - def _deleteInstanceProfiles(self, names: List[str]): + def _deleteInstanceProfiles(self, names: List[str]) -> None: """ Delete all the given named IAM instance profiles. All roles must already be detached. """ - + boto3_iam = self.aws.client(region=self._region, service_name="iam") for profile_name in names: for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors): with attempt: - self.aws.client(self._region, 'iam').delete_instance_profile(InstanceProfileName=profile_name) + boto3_iam.delete_instance_profile(InstanceProfileName=profile_name) logger.debug('... Succesfully deleted instance profile %s', profile_name) @classmethod - def _getBoto2BlockDeviceMapping(cls, type_info: InstanceType, rootVolSize: int = 50) -> Boto2BlockDeviceMapping: + def _getBoto3BlockDeviceMapping(cls, type_info: InstanceType, rootVolSize: int = 50) -> List[BlockDeviceMappingTypeDef]: # determine number of ephemeral drives via cgcloud-lib (actually this is moved into toil's lib bdtKeys = [''] + [f'/dev/xvd{c}' for c in string.ascii_lowercase[1:]] - bdm = Boto2BlockDeviceMapping() + bdm_list: List[BlockDeviceMappingTypeDef] = [] # Change root volume size to allow for bigger Docker instances - root_vol = Boto2BlockDeviceType(delete_on_termination=True) - root_vol.size = rootVolSize - bdm["/dev/xvda"] = root_vol + root_vol: EbsBlockDeviceTypeDef = {"DeleteOnTermination": True, + "VolumeSize": rootVolSize} + bdm: BlockDeviceMappingTypeDef = {"DeviceName": "/dev/xvda", "Ebs": root_vol} + bdm_list.append(bdm) # The first disk is already attached for us so start with 2nd. # Disk count is weirdly a float in our instance database, so make it an int here. for disk in range(1, int(type_info.disks) + 1): - bdm[bdtKeys[disk]] = Boto2BlockDeviceType( - ephemeral_name=f'ephemeral{disk - 1}') # ephemeral counts start at 0 + bdm = {} + bdm["DeviceName"] = bdtKeys[disk] + bdm["VirtualName"] = f"ephemeral{disk - 1}" # ephemeral counts start at 0 + bdm["Ebs"] = root_vol # default + # bdm["Ebs"] = root_vol.update({"VirtualName": f"ephemeral{disk - 1}"}) + bdm_list.append(bdm) - logger.debug('Device mapping: %s', bdm) - return bdm + logger.debug('Device mapping: %s', bdm_list) + return bdm_list @classmethod - def _getBoto3BlockDeviceMappings(cls, type_info: InstanceType, rootVolSize: int = 50) -> List[dict]: + def _getBoto3BlockDeviceMappings(cls, type_info: InstanceType, rootVolSize: int = 50) -> List[BlockDeviceMappingTypeDef]: """ Get block device mappings for the root volume for a worker. """ # Start with the root - bdms = [{ + bdms: List[BlockDeviceMappingTypeDef] = [{ 'DeviceName': '/dev/xvda', 'Ebs': { 'DeleteOnTermination': True, @@ -1121,96 +1216,99 @@ def _getBoto3BlockDeviceMappings(cls, type_info: InstanceType, rootVolSize: int return bdms @awsRetry - def _get_nodes_in_cluster(self, instance_type: Optional[str] = None, include_stopped_nodes=False) -> List[Boto2Instance]: + def _get_nodes_in_cluster_boto3(self, instance_type: Optional[str] = None, include_stopped_nodes: bool = False) -> List[InstanceTypeDef]: """ - Get Boto2 instance objects for all nodes in the cluster. + Get Boto3 instance objects for all nodes in the cluster. """ + boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name='ec2') + instance_filter: FilterTypeDef = {'Name': 'instance.group-name', 'Values': [self.clusterName]} + describe_response: DescribeInstancesResultTypeDef = boto3_ec2.describe_instances(Filters=[instance_filter]) + all_instances: List[InstanceTypeDef] = [] + for reservation in describe_response['Reservations']: + instances = reservation['Instances'] + all_instances.extend(instances) - all_instances = self.aws.boto2(self._region, 'ec2').get_only_instances(filters={'instance.group-name': self.clusterName}) + # all_instances = self.aws.boto2(self._region, 'ec2').get_only_instances(filters={'instance.group-name': self.clusterName}) - def instanceFilter(i): + def instanceFilter(i: InstanceTypeDef) -> bool: # filter by type only if nodeType is true - rightType = not instance_type or i.instance_type == instance_type - rightState = i.state == 'running' or i.state == 'pending' + rightType = not instance_type or i['InstanceType'] == instance_type + rightState = i['State']['Name'] == 'running' or i['State']['Name'] == 'pending' if include_stopped_nodes: - rightState = rightState or i.state == 'stopping' or i.state == 'stopped' + rightState = rightState or i['State']['Name'] == 'stopping' or i['State']['Name'] == 'stopped' return rightType and rightState return [i for i in all_instances if instanceFilter(i)] - def _filter_nodes_in_cluster(self, instance_type: Optional[str] = None, preemptible: bool = False) -> List[Boto2Instance]: - """ - Get Boto2 instance objects for the nodes in the cluster filtered by preemptability. - """ - - instances = self._get_nodes_in_cluster(instance_type, include_stopped_nodes=False) - - if preemptible: - return [i for i in instances if i.spot_instance_request_id is not None] - - return [i for i in instances if i.spot_instance_request_id is None] - def _getSpotRequestIDs(self) -> List[str]: """ Get the IDs of all spot requests associated with the cluster. """ # Grab the connection we need to use for this operation. - ec2 = self.aws.boto2(self._region, 'ec2') + ec2: EC2Client = self.aws.client(self._region, 'ec2') - requests = ec2.get_all_spot_instance_requests() - tags = ec2.get_all_tags({'tag:': {_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName}}) - idsToCancel = [tag.id for tag in tags] - return [request for request in requests if request.id in idsToCancel] + requests: List[SpotInstanceRequestTypeDef] = ec2.describe_spot_instance_requests()["SpotInstanceRequests"] + tag_filter: FilterTypeDef = {"Name": "tag:" + _TAG_KEY_TOIL_CLUSTER_NAME, "Values": [self.clusterName]} + tags: List[TagDescriptionTypeDef] = ec2.describe_tags(Filters=[tag_filter])["Tags"] + idsToCancel = [tag["ResourceId"] for tag in tags] + return [request["SpotInstanceRequestId"] for request in requests if request["InstanceId"] in idsToCancel] def _createSecurityGroups(self) -> List[str]: """ Create security groups for the cluster. Returns a list of their IDs. """ + def group_not_found(e: ClientError) -> bool: + retry = (get_error_status(e) == 400 and 'does not exist in default VPC' in get_error_body(e)) + return retry + # Grab the connection we need to use for this operation. # The VPC connection can do anything the EC2 one can do, but also look at subnets. - vpc = self.aws.boto2(self._region, 'vpc') + boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name="ec2") - def groupNotFound(e): - retry = (e.status == 400 and 'does not exist in default VPC' in e.body) - return retry - # Security groups need to belong to the same VPC as the leader. If we - # put the leader in a particular non-default subnet, it may be in a - # particular non-default VPC, which we need to know about. - vpcId = None + vpc_id = None if self._leader_subnet: - subnets = vpc.get_all_subnets(subnet_ids=[self._leader_subnet]) + subnets = boto3_ec2.describe_subnets(SubnetIds=[self._leader_subnet])["Subnets"] if len(subnets) > 0: - vpcId = subnets[0].vpc_id - # security group create/get. ssh + all ports open within the group + vpc_id = subnets[0]["VpcId"] try: - web = vpc.create_security_group(self.clusterName, - 'Toil appliance security group', vpc_id=vpcId) - except EC2ResponseError as e: - if e.status == 400 and 'already exists' in e.body: - pass # group exists- nothing to do + # Security groups need to belong to the same VPC as the leader. If we + # put the leader in a particular non-default subnet, it may be in a + # particular non-default VPC, which we need to know about. + other = {"GroupName": self.clusterName, "Description": "Toil appliance security group"} + if vpc_id is not None: + other["VpcId"] = vpc_id + # mypy stubs don't explicitly state kwargs even though documentation allows it, and mypy gets confused + web_response: CreateSecurityGroupResultTypeDef = boto3_ec2.create_security_group(**other) # type: ignore[arg-type] + except ClientError as e: + if get_error_status(e) == 400 and 'already exists' in get_error_body(e): + pass else: raise else: - for attempt in old_retry(predicate=groupNotFound, timeout=300): - with attempt: - # open port 22 for ssh-ing - web.authorize(ip_protocol='tcp', from_port=22, to_port=22, cidr_ip='0.0.0.0/0') - # TODO: boto2 doesn't support IPv6 here but we need to. - for attempt in old_retry(predicate=groupNotFound, timeout=300): + for attempt in old_retry(predicate=group_not_found, timeout=300): with attempt: - # the following authorizes all TCP access within the web security group - web.authorize(ip_protocol='tcp', from_port=0, to_port=65535, src_group=web) - for attempt in old_retry(predicate=groupNotFound, timeout=300): - with attempt: - # We also want to open up UDP, both for user code and for the RealtimeLogger - web.authorize(ip_protocol='udp', from_port=0, to_port=65535, src_group=web) + ip_permissions: List[IpPermissionTypeDef] = [{"IpProtocol": "tcp", + "FromPort": 22, + "ToPort": 22, + "IpRanges": [ + {"CidrIp": "0.0.0.0/0"} + ], + "Ipv6Ranges": [{"CidrIpv6": "::/0"}]}] + for protocol in ("tcp", "udp"): + ip_permissions.append({"IpProtocol": protocol, + "FromPort": 0, + "ToPort": 65535, + "UserIdGroupPairs": + [{"GroupId": web_response["GroupId"], + "GroupName": self.clusterName}]}) + boto3_ec2.authorize_security_group_ingress(IpPermissions=ip_permissions, GroupName=self.clusterName, GroupId=web_response["GroupId"]) out = [] - for sg in vpc.get_all_security_groups(): - if sg.name == self.clusterName and (vpcId is None or sg.vpc_id == vpcId): - out.append(sg) - return [sg.id for sg in out] + for sg in boto3_ec2.describe_security_groups()["SecurityGroups"]: + if sg["GroupName"] == self.clusterName and (vpc_id is None or sg["VpcId"] == vpc_id): + out.append(sg["GroupId"]) + return out @awsRetry def _getSecurityGroupIDs(self) -> List[str]: @@ -1222,13 +1320,13 @@ def _getSecurityGroupIDs(self) -> List[str]: # Depending on if we enumerated them on the leader or locally, we might # know the required security groups by name, ID, or both. - sgs = [sg for sg in self.aws.boto2(self._region, 'ec2').get_all_security_groups() - if (sg.name in self._leaderSecurityGroupNames or - sg.id in self._leaderSecurityGroupIDs)] - return [sg.id for sg in sgs] + boto3_ec2 = self.aws.client(region=self._region, service_name='ec2') + return [sg["GroupId"] for sg in boto3_ec2.describe_security_groups()["SecurityGroups"] + if (sg["GroupName"] in self._leaderSecurityGroupNames or + sg["GroupId"] in self._leaderSecurityGroupIDs)] @awsRetry - def _get_launch_template_ids(self, filters: Optional[List[Dict[str, List[str]]]] = None) -> List[str]: + def _get_launch_template_ids(self, filters: Optional[List[FilterTypeDef]] = None) -> List[str]: """ Find all launch templates associated with the cluster. @@ -1236,10 +1334,10 @@ def _get_launch_template_ids(self, filters: Optional[List[Dict[str, List[str]]]] """ # Grab the connection we need to use for this operation. - ec2 = self.aws.client(self._region, 'ec2') + ec2: EC2Client = self.aws.client(self._region, 'ec2') # How do we match the right templates? - combined_filters = [{'Name': 'tag:' + _TAG_KEY_TOIL_CLUSTER_NAME, 'Values': [self.clusterName]}] + combined_filters: List[FilterTypeDef] = [{'Name': 'tag:' + _TAG_KEY_TOIL_CLUSTER_NAME, 'Values': [self.clusterName]}] if filters: # Add any user-specified filters @@ -1254,7 +1352,7 @@ def _get_launch_template_ids(self, filters: Optional[List[Dict[str, List[str]]]] allTemplateIDs += [item['LaunchTemplateId'] for item in response.get('LaunchTemplates', [])] if 'NextToken' in response: # There are more pages. Get the next one, supplying the token. - response = ec2.describe_launch_templates(Filters=filters, + response = ec2.describe_launch_templates(Filters=filters or [], NextToken=response['NextToken'], MaxResults=200) else: @@ -1286,10 +1384,10 @@ def _get_worker_launch_template(self, instance_type: str, preemptible: bool = Fa lt_name = self._name_worker_launch_template(instance_type, preemptible=preemptible) # How do we match the right templates? - filters = [{'Name': 'launch-template-name', 'Values': [lt_name]}] + filters: List[FilterTypeDef] = [{'Name': 'launch-template-name', 'Values': [lt_name]}] # Get the templates - templates = self._get_launch_template_ids(filters=filters) + templates: List[str] = self._get_launch_template_ids(filters=filters) if len(templates) > 1: # There shouldn't ever be multiple templates with our reserved name @@ -1305,7 +1403,7 @@ def _get_worker_launch_template(self, instance_type: str, preemptible: bool = Fa # writes). Recurse to try again, because now it exists. logger.info('Waiting %f seconds for template %s to be available', backoff, lt_name) time.sleep(backoff) - return self._get_worker_launch_template(instance_type, preemptible=preemptible, backoff=backoff*2) + return self._get_worker_launch_template(instance_type, preemptible=preemptible, backoff=backoff * 2) else: raise else: @@ -1345,7 +1443,7 @@ def _create_worker_launch_template(self, instance_type: str, preemptible: bool = assert self._leaderPrivateIP type_info = E2Instances[instance_type] - rootVolSize=self._nodeStorageOverrides.get(instance_type, self._nodeStorage) + rootVolSize = self._nodeStorageOverrides.get(instance_type, self._nodeStorage) bdms = self._getBoto3BlockDeviceMappings(type_info, rootVolSize=rootVolSize) keyPath = self._sseKey if self._sseKey else None @@ -1377,16 +1475,16 @@ def _getAutoScalingGroupNames(self) -> List[str]: """ # Grab the connection we need to use for this operation. - autoscaling = self.aws.client(self._region, 'autoscaling') + autoscaling: AutoScalingClient = self.aws.client(self._region, 'autoscaling') # AWS won't filter ASGs server-side for us in describe_auto_scaling_groups. # So we search instances of applied tags for the ASGs they are on. # The ASGs tagged with our cluster are our ASGs. # The filtering is on different fields of the tag object itself. - filters = [{'Name': 'key', - 'Values': [_TAG_KEY_TOIL_CLUSTER_NAME]}, - {'Name': 'value', - 'Values': [self.clusterName]}] + filters: List[FilterTypeDef] = [{'Name': 'key', + 'Values': [_TAG_KEY_TOIL_CLUSTER_NAME]}, + {'Name': 'value', + 'Values': [self.clusterName]}] matchedASGs = [] # Get the first page with no NextToken @@ -1461,7 +1559,7 @@ def _createWorkerAutoScalingGroup(self, if self.clusterType == 'kubernetes': # We also need to tag it with Kubernetes autoscaler info (empty tags) tags['k8s.io/cluster-autoscaler/' + self.clusterName] = '' - assert(self.clusterName != 'enabled') + assert (self.clusterName != 'enabled') tags['k8s.io/cluster-autoscaler/enabled'] = '' tags['k8s.io/cluster-autoscaler/node-template/resources/ephemeral-storage'] = f'{min_gigs}G' @@ -1481,7 +1579,7 @@ def _createWorkerAutoScalingGroup(self, return asg_name - def _boto2_pager(self, requestor_callable: Callable, result_attribute_name: str) -> Iterable[Dict[str, Any]]: + def _boto2_pager(self, requestor_callable: Callable[[...], Any], result_attribute_name: str) -> Iterable[Dict[str, Any]]: # type: ignore[misc] """ Yield all the results from calling the given Boto 2 method and paging through all the results using the "marker" field. Results are to be @@ -1489,14 +1587,15 @@ def _boto2_pager(self, requestor_callable: Callable, result_attribute_name: str) """ marker = None while True: - result = requestor_callable(marker=marker) + result = requestor_callable(marker=marker) # type: ignore[call-arg] yield from getattr(result, result_attribute_name) if result.is_truncated == 'true': marker = result.marker else: break - def _pager(self, requestor_callable: Callable, result_attribute_name: str, **kwargs) -> Iterable[Dict[str, Any]]: + def _pager(self, requestor_callable: Callable[..., Any], result_attribute_name: str, + **kwargs: Any) -> Iterable[Any]: """ Yield all the results from calling the given Boto 3 method with the given keyword arguments, paging through the results using the Marker or @@ -1505,7 +1604,7 @@ def _pager(self, requestor_callable: Callable, result_attribute_name: str, **kwa """ # Recover the Boto3 client, and the name of the operation - client = requestor_callable.__self__ + client = requestor_callable.__self__ # type: ignore[attr-defined] op_name = requestor_callable.__name__ # grab a Boto 3 built-in paginator. See @@ -1523,10 +1622,12 @@ def _getRoleNames(self) -> List[str]: """ results = [] - for result in self._boto2_pager(self.aws.boto2(self._region, 'iam').list_roles, 'roles'): + boto3_iam = self.aws.client(self._region, 'iam') + for result in self._pager(boto3_iam.list_roles, 'Roles'): # For each Boto2 role object # Grab out the name - name = result['role_name'] + cast(RoleTypeDef, result) + name = result['RoleName'] if self._is_our_namespaced_name(name): # If it looks like ours, it is ours. results.append(name) @@ -1539,11 +1640,12 @@ def _getInstanceProfileNames(self) -> List[str]: """ results = [] - for result in self._boto2_pager(self.aws.boto2(self._region, 'iam').list_instance_profiles, - 'instance_profiles'): - # For each Boto2 role object + boto3_iam = self.aws.client(self._region, 'iam') + for result in self._pager(boto3_iam.list_instance_profiles, + 'InstanceProfiles'): + # For each Boto role object # Grab out the name - name = result['instance_profile_name'] + name = result['InstanceProfileName'] if self._is_our_namespaced_name(name): # If it looks like ours, it is ours. results.append(name) @@ -1558,9 +1660,9 @@ def _getRoleInstanceProfileNames(self, role_name: str) -> List[str]: """ # Grab the connection we need to use for this operation. - iam = self.aws.client(self._region, 'iam') + boto3_iam: IAMClient = self.aws.client(self._region, 'iam') - return [item['InstanceProfileName'] for item in self._pager(iam.list_instance_profiles_for_role, + return [item['InstanceProfileName'] for item in self._pager(boto3_iam.list_instance_profiles_for_role, 'InstanceProfiles', RoleName=role_name)] @@ -1575,11 +1677,11 @@ def _getRolePolicyArns(self, role_name: str) -> List[str]: """ # Grab the connection we need to use for this operation. - iam = self.aws.client(self._region, 'iam') + boto3_iam: IAMClient = self.aws.client(self._region, 'iam') # TODO: we don't currently use attached policies. - return [item['PolicyArn'] for item in self._pager(iam.list_attached_role_policies, + return [item['PolicyArn'] for item in self._pager(boto3_iam.list_attached_role_policies, 'AttachedPolicies', RoleName=role_name)] @@ -1591,20 +1693,18 @@ def _getRoleInlinePolicyNames(self, role_name: str) -> List[str]: """ # Grab the connection we need to use for this operation. - iam = self.aws.client(self._region, 'iam') + boto3_iam: IAMClient = self.aws.client(self._region, 'iam') - return list(self._pager(iam.list_role_policies, - 'PolicyNames', - RoleName=role_name)) + return list(self._pager(boto3_iam.list_role_policies, 'PolicyNames', RoleName=role_name)) - def full_policy(self, resource: str) -> dict: + def full_policy(self, resource: str) -> Dict[str, Any]: """ Produce a dict describing the JSON form of a full-access-granting AWS IAM policy for the service with the given name (e.g. 's3'). """ return dict(Version="2012-10-17", Statement=[dict(Effect="Allow", Resource="*", Action=f"{resource}:*")]) - def kubernetes_policy(self) -> dict: + def kubernetes_policy(self) -> Dict[str, Any]: """ Get the Kubernetes policy grants not provided by the full grants on EC2 and IAM. See @@ -1671,45 +1771,42 @@ def _setup_iam_ec2_role(self, local_role_name: str, policies: Dict[str, Any]) -> """ # Grab the connection we need to use for this operation. - iam = self.aws.boto2(self._region, 'iam') + boto3_iam: IAMClient = self.aws.client(self._region, 'iam') # Make sure we can tell our roles apart from roles for other clusters aws_role_name = self._namespace_name(local_role_name) try: # Make the role logger.debug('Creating IAM role %s...', aws_role_name) - iam.create_role(aws_role_name, assume_role_policy_document=json.dumps({ + assume_role_policy_document = json.dumps({ "Version": "2012-10-17", "Statement": [{ "Effect": "Allow", "Principal": {"Service": ["ec2.amazonaws.com"]}, "Action": ["sts:AssumeRole"]} - ]})) + ]}) + boto3_iam.create_role(RoleName=aws_role_name, AssumeRolePolicyDocument=assume_role_policy_document) logger.debug('Created new IAM role') - except BotoServerError as e: - if e.status == 409 and e.error_code == 'EntityAlreadyExists': + except ClientError as e: + if get_error_status(e) == 409 and get_error_code(e) == 'EntityAlreadyExists': logger.debug('IAM role already exists. Reusing.') else: raise # Delete superfluous policies - policy_names = set(iam.list_role_policies(aws_role_name).policy_names) + policy_names = set(boto3_iam.list_role_policies(RoleName=aws_role_name)["PolicyNames"]) for policy_name in policy_names.difference(set(list(policies.keys()))): - iam.delete_role_policy(aws_role_name, policy_name) + boto3_iam.delete_role_policy(RoleName=aws_role_name, PolicyName=policy_name) # Create expected policies for policy_name, policy in policies.items(): current_policy = None try: - current_policy = json.loads(unquote( - iam.get_role_policy(aws_role_name, policy_name).policy_document)) - except BotoServerError as e: - if e.status == 404 and e.error_code == 'NoSuchEntity': - pass - else: - raise + current_policy = boto3_iam.get_role_policy(RoleName=aws_role_name, PolicyName=policy_name)["PolicyDocument"] + except boto3_iam.exceptions.NoSuchEntityException: + pass if current_policy != policy: - iam.put_role_policy(aws_role_name, policy_name, json.dumps(policy)) + boto3_iam.put_role_policy(RoleName=aws_role_name, PolicyName=policy_name, PolicyDocument=json.dumps(policy)) # Now the role has the right policies so it is ready. return aws_role_name @@ -1724,7 +1821,7 @@ def _createProfileArn(self) -> str: """ # Grab the connection we need to use for this operation. - iam = self.aws.boto2(self._region, 'iam') + boto3_iam: IAMClient = self.aws.client(self._region, 'iam') policy = dict(iam_full=self.full_policy('iam'), ec2_full=self.full_policy('ec2'), s3_full=self.full_policy('s3'), sbd_full=self.full_policy('sdb')) @@ -1735,45 +1832,41 @@ def _createProfileArn(self) -> str: iamRoleName = self._setup_iam_ec2_role(_INSTANCE_PROFILE_ROLE_NAME, policy) try: - profile = iam.get_instance_profile(iamRoleName) - logger.debug("Have preexisting instance profile: %s", profile.get_instance_profile_response.get_instance_profile_result.instance_profile) - except BotoServerError as e: - if e.status == 404: - profile = iam.create_instance_profile(iamRoleName) - profile = profile.create_instance_profile_response.create_instance_profile_result - logger.debug("Created new instance profile: %s", profile.instance_profile) - else: - raise + profile_result = boto3_iam.get_instance_profile(InstanceProfileName=iamRoleName) + profile: InstanceProfileTypeDef = profile_result["InstanceProfile"] + logger.debug("Have preexisting instance profile: %s", profile) + except boto3_iam.exceptions.NoSuchEntityException: + profile_result = boto3_iam.create_instance_profile(InstanceProfileName=iamRoleName) + profile = profile_result["InstanceProfile"] + logger.debug("Created new instance profile: %s", profile) else: - profile = profile.get_instance_profile_response.get_instance_profile_result - profile = profile.instance_profile + profile = profile_result["InstanceProfile"] - profile_arn = profile.arn + profile_arn: str = profile["Arn"] # Now we have the profile ARN, but we want to make sure it really is # visible by name in a different session. wait_until_instance_profile_arn_exists(profile_arn) - if len(profile.roles) > 1: + if len(profile["Roles"]) > 1: # This is too many roles. We probably grabbed something we should # not have by mistake, and this is some important profile for # something else. raise RuntimeError(f'Did not expect instance profile {profile_arn} to contain ' f'more than one role; is it really a Toil-managed profile?') - elif len(profile.roles) == 1: - # this should be profile.roles[0].role_name - if profile.roles.member.role_name == iamRoleName: + elif len(profile["Roles"]) == 1: + if profile["Roles"][0]["RoleName"] == iamRoleName: return profile_arn else: # Drop this wrong role and use the fallback code for 0 roles - iam.remove_role_from_instance_profile(iamRoleName, - profile.roles.member.role_name) + boto3_iam.remove_role_from_instance_profile(InstanceProfileName=iamRoleName, + RoleName=profile["Roles"][0]["RoleName"]) # If we get here, we had 0 roles on the profile, or we had 1 but we removed it. - for attempt in old_retry(predicate=lambda err: err.status == 404): + for attempt in old_retry(predicate=lambda err: get_error_status(err) == 404): with attempt: # Put the IAM role on the profile - iam.add_role_to_instance_profile(profile.instance_profile_name, iamRoleName) + boto3_iam.add_role_to_instance_profile(InstanceProfileName=profile["InstanceProfileName"], RoleName=iamRoleName) logger.debug("Associated role %s with profile", iamRoleName) return profile_arn diff --git a/src/toil/provisioners/node.py b/src/toil/provisioners/node.py index 7ee220c2d8..9b4f304430 100644 --- a/src/toil/provisioners/node.py +++ b/src/toil/provisioners/node.py @@ -18,6 +18,7 @@ import subprocess import time from itertools import count +from typing import Union, Dict, Optional, List, Any from toil.lib.memoize import parse_iso_utc @@ -29,7 +30,8 @@ class Node: maxWaitTime = 7 * 60 - def __init__(self, publicIP, privateIP, name, launchTime, nodeType, preemptible, tags=None, use_private_ip=None): + def __init__(self, publicIP: str, privateIP: str, name: str, launchTime: Union[datetime.datetime, str], + nodeType: Optional[str], preemptible: bool, tags: Optional[Dict[str, str]] = None, use_private_ip: Optional[bool] = None) -> None: self.publicIP = publicIP self.privateIP = privateIP if use_private_ip: @@ -37,7 +39,13 @@ def __init__(self, publicIP, privateIP, name, launchTime, nodeType, preemptible, else: self.effectiveIP = self.publicIP or self.privateIP self.name = name - self.launchTime = launchTime + if isinstance(launchTime, datetime.datetime): + self.launchTime = launchTime + else: + try: + self.launchTime = parse_iso_utc(launchTime) + except ValueError: + self.launchTime = datetime.datetime.fromisoformat(launchTime) self.nodeType = nodeType self.preemptible = preemptible self.tags = tags @@ -65,12 +73,12 @@ def remainingBillingInterval(self) -> float: """ if self.launchTime: now = datetime.datetime.utcnow() - delta = now - parse_iso_utc(self.launchTime) + delta = now - self.launchTime return 1 - delta.total_seconds() / 3600.0 % 1.0 else: return 1 - def waitForNode(self, role, keyName='core'): + def waitForNode(self, role: str, keyName: str='core') -> None: self._waitForSSHPort() # wait here so docker commands can be used reliably afterwards self._waitForSSHKeys(keyName=keyName) @@ -288,7 +296,7 @@ def coreSSH(self, *args, **kwargs): % (' '.join(args), exit_code, stdout, stderr)) return stdout - def coreRsync(self, args, applianceName='toil_leader', **kwargs): + def coreRsync(self, args: List[str], applianceName: str = 'toil_leader', **kwargs: Any) -> int: remoteRsync = "docker exec -i %s rsync -v" % applianceName # Access rsync inside appliance parsedArgs = [] sshCommand = "ssh" diff --git a/src/toil/test/__init__.py b/src/toil/test/__init__.py index 6642d8de25..fb0770735a 100644 --- a/src/toil/test/__init__.py +++ b/src/toil/test/__init__.py @@ -56,7 +56,6 @@ from toil import ApplianceImageNotFound, applianceSelf, toilPackageDirPath from toil.lib.accelerators import (have_working_nvidia_docker_runtime, have_working_nvidia_smi) -from toil.lib.aws import running_on_ec2 from toil.lib.io import mkdtemp from toil.lib.iterables import concat from toil.lib.memoize import memoize @@ -127,6 +126,7 @@ def awsRegion(cls) -> str: Use us-west-2 unless running on EC2, in which case use the region in which the instance is located """ + from toil.lib.aws import running_on_ec2 return cls._region() if running_on_ec2() else 'us-west-2' @classmethod @@ -378,7 +378,7 @@ def needs_aws_s3(test_item: MT) -> MT: return unittest.skip("Install Toil with the 'aws' extra to include this test.")( test_item ) - + from toil.lib.aws import running_on_ec2 if not (boto_credentials or os.path.exists(os.path.expanduser('~/.aws/credentials')) or running_on_ec2()): return unittest.skip("Configure AWS credentials to include this test.")(test_item) return test_item diff --git a/src/toil/test/lib/aws/test_iam.py b/src/toil/test/lib/aws/test_iam.py index 7df3b75475..c1b5310468 100644 --- a/src/toil/test/lib/aws/test_iam.py +++ b/src/toil/test/lib/aws/test_iam.py @@ -15,7 +15,7 @@ import logging import boto3 -from moto import mock_iam +from moto import mock_aws from toil.lib.aws import iam from toil.test import ToilTest @@ -46,7 +46,7 @@ def test_wildcard_handling(self): assert iam.permission_matches_any("iam:*", ["*"]) is True assert iam.permission_matches_any("ec2:*", ['iam:*']) is False - @mock_iam + @mock_aws def test_get_policy_permissions(self): mock_iam = boto3.client("iam") diff --git a/src/toil/test/provisioners/aws/awsProvisionerTest.py b/src/toil/test/provisioners/aws/awsProvisionerTest.py index 387ad89148..0e9b6be74b 100644 --- a/src/toil/test/provisioners/aws/awsProvisionerTest.py +++ b/src/toil/test/provisioners/aws/awsProvisionerTest.py @@ -19,9 +19,13 @@ from abc import abstractmethod from inspect import getsource from textwrap import dedent +from typing import Optional, List from uuid import uuid4 +import botocore.exceptions import pytest +from mypy_boto3_ec2 import EC2Client +from mypy_boto3_ec2.type_defs import EbsInstanceBlockDeviceTypeDef, InstanceTypeDef, InstanceBlockDeviceMappingTypeDef, FilterTypeDef, DescribeVolumesResultTypeDef, VolumeTypeDef from toil.provisioners import cluster_factory from toil.provisioners.aws.awsProvisioner import AWSProvisioner @@ -113,15 +117,20 @@ def data(self, filename): def rsyncUtil(self, src, dest): subprocess.check_call(['toil', 'rsync-cluster', '--insecure', '-p=aws', '-z', self.zone, self.clusterName] + [src, dest]) - def getRootVolID(self): - instances = self.cluster._get_nodes_in_cluster() - instances.sort(key=lambda x: x.launch_time) - leader = instances[0] # assume leader was launched first - - from boto.ec2.blockdevicemapping import BlockDeviceType - rootBlockDevice = leader.block_device_mapping["/dev/xvda"] - assert isinstance(rootBlockDevice, BlockDeviceType) - return rootBlockDevice.volume_id + def getRootVolID(self) -> str: + instances: List[InstanceTypeDef] = self.cluster._get_nodes_in_cluster_boto3() + instances.sort(key=lambda x: x.get("LaunchTime")) + leader: InstanceTypeDef = instances[0] # assume leader was launched first + + bdm: Optional[List[InstanceBlockDeviceMappingTypeDef]] = leader.get("BlockDeviceMappings") + assert bdm is not None + root_block_device: Optional[EbsInstanceBlockDeviceTypeDef] = None + for device in bdm: + if device["DeviceName"] == "/dev/xvda": + root_block_device = device["Ebs"] + assert root_block_device is not None # There should be a device named "/dev/xvda" + assert root_block_device.get("VolumeId") is not None + return root_block_device["VolumeId"] @abstractmethod def _getScript(self): @@ -191,21 +200,20 @@ def _test(self, preemptibleJobs=False): assert len(self.cluster._getRoleNames()) == 1 - from boto.exception import EC2ResponseError volumeID = self.getRootVolID() self.cluster.destroyCluster() + boto3_ec2: EC2Client = self.aws.client(region=self.region, service_name="ec2") + volume_filter: FilterTypeDef = {"Name": "volume-id", "Values": [volumeID]} + volumes: Optional[List[VolumeTypeDef]] = None for attempt in range(6): # https://github.com/BD2KGenomics/toil/issues/1567 # retry this for up to 1 minute until the volume disappears - try: - self.boto2_ec2.get_all_volumes(volume_ids=[volumeID]) - time.sleep(10) - except EC2ResponseError as e: - if e.status == 400 and 'InvalidVolume.NotFound' in e.code: - break - else: - raise - else: + volumes = boto3_ec2.describe_volumes(Filters=[volume_filter])["Volumes"] + if len(volumes) == 0: + # None are left, so they have been properly deleted + break + time.sleep(10) + if volumes is None or len(volumes) > 0: self.fail('Volume with ID %s was not cleaned up properly' % volumeID) assert len(self.cluster._getRoleNames()) == 0 @@ -246,16 +254,19 @@ def launchCluster(self): # add arguments to test that we can specify leader storage self.createClusterUtil(args=['--leaderStorage', str(self.requestedLeaderStorage)]) - def getRootVolID(self): + def getRootVolID(self) -> str: """ Adds in test to check that EBS volume is build with adequate size. Otherwise is functionally equivalent to parent. :return: volumeID """ volumeID = super().getRootVolID() - rootVolume = self.boto2_ec2.get_all_volumes(volume_ids=[volumeID])[0] + boto3_ec2: EC2Client = self.aws.client(region=self.region, service_name="ec2") + volume_filter: FilterTypeDef = {"Name": "volume-id", "Values": [volumeID]} + volumes: DescribeVolumesResultTypeDef = boto3_ec2.describe_volumes(Filters=[volume_filter]) + root_volume: VolumeTypeDef = volumes["Volumes"][0] # should be first # test that the leader is given adequate storage - self.assertGreaterEqual(rootVolume.size, self.requestedLeaderStorage) + self.assertGreaterEqual(root_volume["Size"], self.requestedLeaderStorage) return volumeID @integrative @@ -290,8 +301,6 @@ def __init__(self, name): self.requestedNodeStorage = 20 def launchCluster(self): - from boto.ec2.blockdevicemapping import BlockDeviceType - from toil.lib.ec2 import wait_instances_running self.createClusterUtil(args=['--leaderStorage', str(self.requestedLeaderStorage), '--nodeTypes', ",".join(self.instanceTypes), @@ -303,8 +312,8 @@ def launchCluster(self): # visible to EC2 read requests immediately after the create returns, # which is the last thing that starting the cluster does. time.sleep(10) - nodes = self.cluster._get_nodes_in_cluster() - nodes.sort(key=lambda x: x.launch_time) + nodes: List[InstanceTypeDef] = self.cluster._get_nodes_in_cluster_boto3() + nodes.sort(key=lambda x: x.get("LaunchTime")) # assuming that leader is first workers = nodes[1:] # test that two worker nodes were created @@ -312,11 +321,22 @@ def launchCluster(self): # test that workers have expected storage size # just use the first worker worker = workers[0] - worker = next(wait_instances_running(self.boto2_ec2, [worker])) - rootBlockDevice = worker.block_device_mapping["/dev/xvda"] - self.assertTrue(isinstance(rootBlockDevice, BlockDeviceType)) - rootVolume = self.boto2_ec2.get_all_volumes(volume_ids=[rootBlockDevice.volume_id])[0] - self.assertGreaterEqual(rootVolume.size, self.requestedNodeStorage) + boto3_ec2: EC2Client = self.aws.client(region=self.region, service_name="ec2") + + worker: InstanceTypeDef = next(wait_instances_running(boto3_ec2, [worker])) + + bdm: Optional[List[InstanceBlockDeviceMappingTypeDef]] = worker.get("BlockDeviceMappings") + assert bdm is not None + root_block_device: Optional[EbsInstanceBlockDeviceTypeDef] = None + for device in bdm: + if device["DeviceName"] == "/dev/xvda": + root_block_device = device["Ebs"] + assert root_block_device is not None + assert root_block_device.get("VolumeId") is not None # TypedDicts cannot have runtime type checks + + volume_filter: FilterTypeDef = {"Name": "volume-id", "Values": [root_block_device["VolumeId"]]} + root_volume: VolumeTypeDef = boto3_ec2.describe_volumes(Filters=[volume_filter])["Volumes"][0] # should be first + self.assertGreaterEqual(root_volume.get("Size"), self.requestedNodeStorage) def _runScript(self, toilOptions): # Autoscale even though we have static nodes @@ -337,9 +357,6 @@ def __init__(self, name): self.requestedNodeStorage = 20 def launchCluster(self): - from boto.ec2.blockdevicemapping import BlockDeviceType # noqa - - from toil.lib.ec2 import wait_instances_running # noqa self.createClusterUtil(args=['--leaderStorage', str(self.requestedLeaderStorage), '--nodeTypes', ",".join(self.instanceTypes), '--workers', ",".join([f'0-{c}' for c in self.numWorkers]), diff --git a/src/toil/test/provisioners/clusterTest.py b/src/toil/test/provisioners/clusterTest.py index c1d4f2fcaa..6a129eff5f 100644 --- a/src/toil/test/provisioners/clusterTest.py +++ b/src/toil/test/provisioners/clusterTest.py @@ -20,6 +20,7 @@ from typing import Optional, List from toil.lib.aws import zone_to_region +from toil.lib.aws.session import AWSConnectionManager from toil.lib.retry import retry from toil.provisioners.aws import get_best_aws_zone from toil.test import ToilTest, needs_aws_ec2, needs_fetchable_appliance @@ -40,7 +41,12 @@ def __init__(self, methodName: str) -> None: # We need a boto2 connection to EC2 to check on the cluster. # Since we are protected by needs_aws_ec2 we can import from boto. import boto.ec2 - self.boto2_ec2 = boto.ec2.connect_to_region(zone_to_region(self.zone)) + self.region = zone_to_region(self.zone) + self.boto2_ec2 = boto.ec2.connect_to_region(self.region) + + # Get connection to AWS with boto3/boto2 + self.aws = AWSConnectionManager() + # Where should we put our virtualenv? self.venvDir = '/tmp/venv' diff --git a/src/toil/test/wdl/wdltoil_test.py b/src/toil/test/wdl/wdltoil_test.py index 855d190e3d..955852a199 100644 --- a/src/toil/test/wdl/wdltoil_test.py +++ b/src/toil/test/wdl/wdltoil_test.py @@ -17,7 +17,6 @@ needs_google_storage, needs_singularity_or_docker, slow, integrative) -from toil.test.provisioners.clusterTest import AbstractClusterTest from toil.version import exactPython from toil.wdl.wdltoil import WDLSectionJob, WDLWorkflowGraph @@ -357,70 +356,5 @@ def mock_get_transitive_dependencies(self: Any, node_id: str) -> Set[str]: assert "successor" in result[1] -@integrative -@slow -@pytest.mark.timeout(600) -class WDLKubernetesClusterTest(AbstractClusterTest): - """ - Ensure WDL works on the Kubernetes batchsystem. - """ - - def __init__(self, name): - super().__init__(name) - self.clusterName = 'wdl-integration-test-' + str(uuid4()) - # t2.medium is the minimum t2 instance that permits Kubernetes - self.leaderNodeType = "t2.medium" - self.instanceTypes = ["t2.medium"] - self.clusterType = "kubernetes" - - def setUp(self) -> None: - super().setUp() - self.jobStore = f'aws:{self.awsRegion()}:wdl-test-{uuid4()}' - - def launchCluster(self) -> None: - self.createClusterUtil(args=['--leaderStorage', str(self.requestedLeaderStorage), - '--nodeTypes', ",".join(self.instanceTypes), - '-w', ",".join(self.numWorkers), - '--nodeStorage', str(self.requestedLeaderStorage)]) - - def test_wdl_kubernetes_cluster(self): - """ - Test that a wdl workflow works on a kubernetes cluster. Launches a cluster with 1 worker. This runs a wdl - workflow that performs an image pull on the worker. - :return: - """ - self.numWorkers = "1" - self.requestedLeaderStorage = 30 - # create the cluster - self.launchCluster() - # get leader - self.cluster = cluster_factory( - provisioner="aws", zone=self.zone, clusterName=self.clusterName - ) - self.leader = self.cluster.getLeader() - - url = "https://github.com/DataBiosphere/wdl-conformance-tests.git" - commit = "09b9659cd01473e836738a2e0dd205df0adb49c5" - wdl_dir = "wdl_conformance_tests" - - # get the wdl-conformance-tests repo to get WDL tasks to run - self.sshUtil([ - "bash", - "-c", - f"git clone {url} {wdl_dir} && cd {wdl_dir} && git checkout {commit}" - ]) - - # run on kubernetes batchsystem - toil_options = ['--batchSystem=kubernetes', - f"--jobstore={self.jobStore}"] - - # run WDL workflow that will run singularity - test_options = [f"tests/md5sum/md5sum.wdl", f"tests/md5sum/md5sum.json"] - self.sshUtil([ - "bash", - "-c", - f"cd {wdl_dir} && toil-wdl-runner {' '.join(test_options)} {' '.join(toil_options)}"]) - - if __name__ == "__main__": unittest.main() # run all tests diff --git a/src/toil/test/wdl/wdltoil_test_kubernetes.py b/src/toil/test/wdl/wdltoil_test_kubernetes.py new file mode 100644 index 0000000000..6772913e09 --- /dev/null +++ b/src/toil/test/wdl/wdltoil_test_kubernetes.py @@ -0,0 +1,77 @@ +import unittest + +from toil.test.provisioners.clusterTest import AbstractClusterTest +from uuid import uuid4 + +import pytest + +from toil.provisioners import cluster_factory +from toil.test import (slow, integrative) + +@integrative +@slow +@pytest.mark.timeout(600) +class WDLKubernetesClusterTest(AbstractClusterTest): + """ + Ensure WDL works on the Kubernetes batchsystem. + """ + + def __init__(self, name): + super().__init__(name) + self.clusterName = 'wdl-integration-test-' + str(uuid4()) + # t2.medium is the minimum t2 instance that permits Kubernetes + self.leaderNodeType = "t2.medium" + self.instanceTypes = ["t2.medium"] + self.clusterType = "kubernetes" + + def setUp(self) -> None: + super().setUp() + self.jobStore = f'aws:{self.awsRegion()}:wdl-test-{uuid4()}' + + def launchCluster(self) -> None: + self.createClusterUtil(args=['--leaderStorage', str(self.requestedLeaderStorage), + '--nodeTypes', ",".join(self.instanceTypes), + '-w', ",".join(self.numWorkers), + '--nodeStorage', str(self.requestedLeaderStorage)]) + + def test_wdl_kubernetes_cluster(self): + """ + Test that a wdl workflow works on a kubernetes cluster. Launches a cluster with 1 worker. This runs a wdl + workflow that performs an image pull on the worker. + :return: + """ + self.numWorkers = "1" + self.requestedLeaderStorage = 30 + # create the cluster + self.launchCluster() + # get leader + self.cluster = cluster_factory( + provisioner="aws", zone=self.zone, clusterName=self.clusterName + ) + self.leader = self.cluster.getLeader() + + url = "https://github.com/DataBiosphere/wdl-conformance-tests.git" + commit = "09b9659cd01473e836738a2e0dd205df0adb49c5" + wdl_dir = "wdl_conformance_tests" + + # get the wdl-conformance-tests repo to get WDL tasks to run + self.sshUtil([ + "bash", + "-c", + f"git clone {url} {wdl_dir} && cd {wdl_dir} && git checkout {commit}" + ]) + + # run on kubernetes batchsystem + toil_options = ['--batchSystem=kubernetes', + f"--jobstore={self.jobStore}"] + + # run WDL workflow that will run singularity + test_options = [f"tests/md5sum/md5sum.wdl", f"tests/md5sum/md5sum.json"] + self.sshUtil([ + "bash", + "-c", + f"cd {wdl_dir} && toil-wdl-runner {' '.join(test_options)} {' '.join(toil_options)}"]) + + +if __name__ == "__main__": + unittest.main() # run all tests