Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
clearml committed Oct 3, 2024
1 parent ffcda55 commit 9726964
Showing 1 changed file with 42 additions and 56 deletions.
98 changes: 42 additions & 56 deletions clearml/automation/aws_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,34 @@
Task.add_requirements("boto3")
except ImportError as err:
raise ImportError(
"AwsAutoScaler requires 'boto3' package, it was not found\n"
"install with: pip install boto3"
"AwsAutoScaler requires 'boto3' package, it was not found\n" "install with: pip install boto3"
) from err


@attr.s
class AWSDriver(CloudDriver):
"""AWS Driver"""
aws_access_key_id = attr.ib(validator=instance_of(str), default='')
aws_secret_access_key = attr.ib(validator=instance_of(str), default='')
aws_session_token = attr.ib(validator=instance_of(str), default='')
aws_region = attr.ib(validator=instance_of(str), default='')

aws_access_key_id = attr.ib(validator=instance_of(str), default="")
aws_secret_access_key = attr.ib(validator=instance_of(str), default="")
aws_session_token = attr.ib(validator=instance_of(str), default="")
aws_region = attr.ib(validator=instance_of(str), default="")
use_credentials_chain = attr.ib(validator=instance_of(bool), default=False)
use_iam_instance_profile = attr.ib(validator=instance_of(bool), default=False)
iam_arn = attr.ib(validator=instance_of(str), default='')
iam_name = attr.ib(validator=instance_of(str), default='')
iam_arn = attr.ib(validator=instance_of(str), default="")
iam_name = attr.ib(validator=instance_of(str), default="")

@classmethod
def from_config(cls, config):
obj = super().from_config(config)
obj.aws_access_key_id = config['hyper_params'].get('cloud_credentials_key')
obj.aws_secret_access_key = config['hyper_params'].get('cloud_credentials_secret')
obj.aws_session_token = config['hyper_params'].get('cloud_credentials_token')
obj.aws_region = config['hyper_params'].get('cloud_credentials_region')
obj.use_credentials_chain = config['hyper_params'].get('use_credentials_chain', False)
obj.use_iam_instance_profile = config['hyper_params'].get('use_iam_instance_profile', False)
obj.iam_arn = config['hyper_params'].get('iam_arn')
obj.iam_name = config['hyper_params'].get('iam_name')
obj.aws_access_key_id = config["hyper_params"].get("cloud_credentials_key")
obj.aws_secret_access_key = config["hyper_params"].get("cloud_credentials_secret")
obj.aws_session_token = config["hyper_params"].get("cloud_credentials_token")
obj.aws_region = config["hyper_params"].get("cloud_credentials_region")
obj.use_credentials_chain = config["hyper_params"].get("use_credentials_chain", False)
obj.use_iam_instance_profile = config["hyper_params"].get("use_iam_instance_profile", False)
obj.iam_arn = config["hyper_params"].get("iam_arn")
obj.iam_name = config["hyper_params"].get("iam_name")
return obj

def __attrs_post_init__(self):
Expand All @@ -60,7 +60,7 @@ def spin_up_worker(self, resource_conf, worker_prefix, queue_name, task_id):
launch_specification = ConfigFactory.from_dict(
{
"ImageId": resource_conf["ami_id"],
"Monitoring": {'Enabled': bool(resource_conf.get('enable_monitoring', False))},
"Monitoring": {"Enabled": bool(resource_conf.get("enable_monitoring", False))},
"InstanceType": resource_conf["instance_type"],
}
)
Expand All @@ -70,9 +70,7 @@ def spin_up_worker(self, resource_conf, worker_prefix, queue_name, task_id):
launch_specification["BlockDeviceMappings"] = [
{
"DeviceName": resource_conf["ebs_device_name"],
"Ebs": {
"SnapshotId": resource_conf["ebs_snapshot_id"]
}
"Ebs": {"SnapshotId": resource_conf["ebs_snapshot_id"]},
}
]
elif resource_conf.get("ebs_device_name"):
Expand All @@ -81,8 +79,8 @@ def spin_up_worker(self, resource_conf, worker_prefix, queue_name, task_id):
"DeviceName": resource_conf["ebs_device_name"],
"Ebs": {
"VolumeSize": resource_conf.get("ebs_volume_size", 80),
"VolumeType": resource_conf.get("ebs_volume_type", "gp3")
}
"VolumeType": resource_conf.get("ebs_volume_type", "gp3"),
},
}
]

Expand All @@ -91,45 +89,33 @@ def spin_up_worker(self, resource_conf, worker_prefix, queue_name, task_id):
elif resource_conf.get("availability_zone", None):
launch_specification["Placement"] = {"AvailabilityZone": resource_conf["availability_zone"]}
else:
raise Exception('subnet_id or availability_zone must to be specified in the config')
raise Exception("subnet_id or availability_zone must to be specified in the config")
if resource_conf.get("key_name", None):
launch_specification["KeyName"] = resource_conf["key_name"]
if resource_conf.get("security_group_ids", None):
launch_specification["SecurityGroupIds"] = resource_conf[
"security_group_ids"
]
launch_specification["SecurityGroupIds"] = resource_conf["security_group_ids"]
# Adding iam role - you can have Arn OR Name, not both, Arn getting priority
if self.iam_arn:
launch_specification["IamInstanceProfile"] = {
'Arn': self.iam_arn,
"Arn": self.iam_arn,
}
elif self.iam_name:
launch_specification["IamInstanceProfile"] = {
'Name': self.iam_name
}
launch_specification["IamInstanceProfile"] = {"Name": self.iam_name}

if resource_conf["is_spot"]:
# Create a request for a spot instance in AWS
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode(
"ascii"
)
encoded_user_data = base64.b64encode(user_data.encode("ascii")).decode("ascii")
launch_specification["UserData"] = encoded_user_data
ConfigTree.merge_configs(
launch_specification, resource_conf.get("extra_configurations", {})
)
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))

instances = ec2.request_spot_instances(
LaunchSpecification=launch_specification
)
instances = ec2.request_spot_instances(LaunchSpecification=launch_specification)

# Wait until spot request is fulfilled
request_id = instances["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
waiter = ec2.get_waiter("spot_instance_request_fulfilled")
waiter.wait(SpotInstanceRequestIds=[request_id])
# Get the instance object for later use
response = ec2.describe_spot_instance_requests(
SpotInstanceRequestIds=[request_id]
)
response = ec2.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id])
instance_id = response["SpotInstanceRequests"][0]["InstanceId"]

else:
Expand All @@ -140,9 +126,7 @@ def spin_up_worker(self, resource_conf, worker_prefix, queue_name, task_id):
UserData=user_data,
InstanceInitiatedShutdownBehavior="terminate",
)
ConfigTree.merge_configs(
launch_specification, resource_conf.get("extra_configurations", {})
)
ConfigTree.merge_configs(launch_specification, resource_conf.get("extra_configurations", {}))

instances = ec2.run_instances(**launch_specification)

Expand All @@ -165,30 +149,32 @@ def spin_down_worker(self, instance_id):

def creds(self):
creds = {
'region_name': self.aws_region or None,
"region_name": self.aws_region or None,
}

if not self.use_credentials_chain:
creds.update({
'aws_secret_access_key': self.aws_secret_access_key or None,
'aws_access_key_id': self.aws_access_key_id or None,
'aws_session_token': self.aws_session_token or None,
})
creds.update(
{
"aws_secret_access_key": self.aws_secret_access_key or None,
"aws_access_key_id": self.aws_access_key_id or None,
"aws_session_token": self.aws_session_token or None,
}
)
return creds

def instance_id_command(self):
return 'curl http://169.254.169.254/latest/meta-data/instance-id'
return "curl http://169.254.169.254/latest/meta-data/instance-id"

def instance_type_key(self):
return 'instance_type'
return "instance_type"

def kind(self):
return 'AWS'
return "AWS"

def console_log(self, instance_id):
ec2 = boto3.client("ec2", **self.creds())
try:
out = ec2.get_console_output(InstanceId=instance_id)
return out.get('Output', '')
return out.get("Output", "")
except ClientError as err:
return 'error: cannot get logs for {}:\n{}'.format(instance_id, err)
return "error: cannot get logs for {}:\n{}".format(instance_id, err)

0 comments on commit 9726964

Please sign in to comment.