diff --git a/examples/terraform-packer-example/main.tf b/examples/terraform-packer-example/main.tf index 87b65808c..c5906a76a 100644 --- a/examples/terraform-packer-example/main.tf +++ b/examples/terraform-packer-example/main.tf @@ -24,9 +24,14 @@ provider "aws" { # --------------------------------------------------------------------------------------------------------------------- resource "aws_instance" "example" { - ami = var.ami_id - instance_type = var.instance_type - user_data = data.template_file.user_data.rendered + ami = var.ami_id + instance_type = var.instance_type + + user_data = templatefile("${path.module}/user-data/user-data.sh", { + instance_text = var.instance_text + instance_port = var.instance_port + }) + vpc_security_group_ids = [aws_security_group.example.id] tags = { @@ -51,17 +56,3 @@ resource "aws_security_group" "example" { cidr_blocks = ["0.0.0.0/0"] } } - -# --------------------------------------------------------------------------------------------------------------------- -# CREATE THE USER DATA SCRIPT THAT WILL RUN DURING BOOT ON THE EC2 INSTANCE -# --------------------------------------------------------------------------------------------------------------------- - -data "template_file" "user_data" { - template = file("${path.module}/user-data/user-data.sh") - - vars = { - instance_text = var.instance_text - instance_port = var.instance_port - } -} - diff --git a/go.mod b/go.mod index 67e651deb..09cf39508 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect github.com/aws/aws-lambda-go v1.47.0 - github.com/aws/aws-sdk-go v1.44.122 + github.com/aws/aws-sdk-go v1.44.122 // indirect github.com/ghodss/yaml v1.0.0 github.com/go-errors/errors v1.0.2-0.20180813162953-d98b870cc4e0 // indirect github.com/go-sql-driver/mysql v1.8.1 @@ -48,7 +48,28 @@ require ( require ( cloud.google.com/go/cloudbuild v1.19.0 - github.com/gogo/protobuf v1.3.2 + github.com/aws/aws-sdk-go-v2 v1.32.5 + github.com/aws/aws-sdk-go-v2/config v1.28.5 + github.com/aws/aws-sdk-go-v2/credentials v1.17.46 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.39 + github.com/aws/aws-sdk-go-v2/service/acm v1.30.6 + github.com/aws/aws-sdk-go-v2/service/autoscaling v1.49.0 + github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.43.3 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.37.1 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.190.0 + github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6 + github.com/aws/aws-sdk-go-v2/service/ecs v1.50.0 + github.com/aws/aws-sdk-go-v2/service/iam v1.38.1 + github.com/aws/aws-sdk-go-v2/service/kms v1.37.6 + github.com/aws/aws-sdk-go-v2/service/lambda v1.66.1 + github.com/aws/aws-sdk-go-v2/service/rds v1.90.0 + github.com/aws/aws-sdk-go-v2/service/route53 v1.46.2 + github.com/aws/aws-sdk-go-v2/service/s3 v1.67.1 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.6 + github.com/aws/aws-sdk-go-v2/service/sns v1.33.5 + github.com/aws/aws-sdk-go-v2/service/sqs v1.37.1 + github.com/aws/aws-sdk-go-v2/service/ssm v1.55.6 + github.com/aws/aws-sdk-go-v2/service/sts v1.33.1 github.com/gonvenience/ytbx v1.4.4 github.com/homeport/dyff v1.6.0 github.com/jackc/pgx/v5 v5.7.1 @@ -77,6 +98,20 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.24 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.24 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.24 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.6 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.5 // indirect + github.com/aws/smithy-go v1.22.1 // indirect github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect @@ -99,6 +134,7 @@ require ( github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.3 // indirect + github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/gonvenience/bunt v1.3.5 // indirect diff --git a/go.sum b/go.sum index b07cf808c..114e73c88 100644 --- a/go.sum +++ b/go.sum @@ -254,6 +254,78 @@ github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1s github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= github.com/aws/aws-sdk-go v1.44.122 h1:p6mw01WBaNpbdP2xrisz5tIkcNwzj/HysobNoaAHjgo= github.com/aws/aws-sdk-go v1.44.122/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= +github.com/aws/aws-sdk-go-v2 v1.32.5 h1:U8vdWJuY7ruAkzaOdD7guwJjD06YSKmnKCJs7s3IkIo= +github.com/aws/aws-sdk-go-v2 v1.32.5/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= +github.com/aws/aws-sdk-go-v2/config v1.28.5 h1:Za41twdCXbuyyWv9LndXxZZv3QhTG1DinqlFsSuvtI0= +github.com/aws/aws-sdk-go-v2/config v1.28.5/go.mod h1:4VsPbHP8JdcdUDmbTVgNL/8w9SqOkM5jyY8ljIxLO3o= +github.com/aws/aws-sdk-go-v2/credentials v1.17.46 h1:AU7RcriIo2lXjUfHFnFKYsLCwgbz1E7Mm95ieIRDNUg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.46/go.mod h1:1FmYyLGL08KQXQ6mcTlifyFXfJVCNJTVGuQP4m0d/UA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.20 h1:sDSXIrlsFSFJtWKLQS4PUWRvrT580rrnuLydJrCQ/yA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.20/go.mod h1:WZ/c+w0ofps+/OUqMwWgnfrgzZH1DZO1RIkktICsqnY= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.39 h1:Bdepdtm7SAUxPIZj6x4qg5al04R6tZa965T/j597XxM= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.39/go.mod h1:AudGmEyVwvi3k5MVpEZP2NEVF1HqtZoMze42Uq1RTiE= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.24 h1:4usbeaes3yJnCFC7kfeyhkdkPtoRYPa/hTmCqMpKpLI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.24/go.mod h1:5CI1JemjVwde8m2WG3cz23qHKPOxbpkq0HaoreEgLIY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.24 h1:N1zsICrQglfzaBnrfM0Ys00860C+QFwu6u/5+LomP+o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.24/go.mod h1:dCn9HbJ8+K31i8IQ8EWmWj0EiIk0+vKiHNMxTTYveAg= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.24 h1:JX70yGKLj25+lMC5Yyh8wBtvB01GDilyRuJvXJ4piD0= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.24/go.mod h1:+Ln60j9SUTD0LEwnhEB0Xhg61DHqplBrbZpLgyjoEHg= +github.com/aws/aws-sdk-go-v2/service/acm v1.30.6 h1:fDg0RlN30Xf/yYzEUL/WXqhmgFsjVb/I3230oCfyI5w= +github.com/aws/aws-sdk-go-v2/service/acm v1.30.6/go.mod h1:zRR6jE3v/TcbfO8C2P+H0Z+kShiKKVaVyoIl8NQRjyg= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.49.0 h1:j3aQus6aqR1bqI6ljUpuYKrUhVqOI/JCTt1LmA1LsA0= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.49.0/go.mod h1:I1+/2m+IhnK5qEbhS3CrzjeiVloo9sItE/2K+so0fkU= +github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.43.3 h1:hKIu7ziYNid9JAuPX5TMgfEKiGyJiPO7Icdc920uLMI= +github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.43.3/go.mod h1:Qbr4yfpNqVNl69l/GEDK+8wxLf/vHi0ChoiSDzD7thU= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.37.1 h1:vucMirlM6D+RDU8ncKaSZ/5dGrXNajozVwpmWNPn2gQ= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.37.1/go.mod h1:fceORfs010mNxZbQhfqUjUeHlTwANmIT4mvHamuUaUg= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.190.0 h1:k97fGog9Tl0woxTiSIHN14Qs5ehqK6GXejUwkhJYyL0= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.190.0/go.mod h1:mzj8EEjIHSN2oZRXiw1Dd+uB4HZTl7hC8nBzX9IZMWw= +github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6 h1:zg+3FGHA0PBs0KM25qE/rOf2o5zsjNa1g/Qq83+SDI0= +github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6/go.mod h1:ZSq54Z9SIsOTf1Efwgw1msilSs4XVEfVQiP9nYVnKpM= +github.com/aws/aws-sdk-go-v2/service/ecs v1.50.0 h1:NW+6/MPclDxOWcuZZxIJSMt6cVPWVojmJ4R3HsICCsI= +github.com/aws/aws-sdk-go-v2/service/ecs v1.50.0/go.mod h1:dPTOvmjJQ1T7Q+2+Xs2KSPrMvx+p0rpyV+HsQVnUK4o= +github.com/aws/aws-sdk-go-v2/service/iam v1.38.1 h1:hfkzDZHBp9jAT4zcd5mtqckpU4E3Ax0LQaEWWk1VgN8= +github.com/aws/aws-sdk-go-v2/service/iam v1.38.1/go.mod h1:u36ahDtZcQHGmVm/r+0L1sfKX4fzLEMdCqiKRKkUMVM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.5 h1:gvZOjQKPxFXy1ft3QnEyXmT+IqneM9QAUWlM3r0mfqw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.5/go.mod h1:DLWnfvIcm9IET/mmjdxeXbBKmTCm0ZB8p1za9BVteM8= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.5 h1:3Y457U2eGukmjYjeHG6kanZpDzJADa2m0ADqnuePYVQ= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.5/go.mod h1:CfwEHGkTjYZpkQ/5PvcbEtT7AJlG68KkEvmtwU8z3/U= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5 h1:wtpJ4zcwrSbwhECWQoI/g6WM9zqCcSpHDJIWSbMLOu4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5/go.mod h1:qu/W9HXQbbQ4+1+JcZp0ZNPV31ym537ZJN+fiS7Ti8E= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.5 h1:P1doBzv5VEg1ONxnJss1Kh5ZG/ewoIE4MQtKKc6Crgg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.5/go.mod h1:NOP+euMW7W3Ukt28tAxPuoWao4rhhqJD3QEBk7oCg7w= +github.com/aws/aws-sdk-go-v2/service/kms v1.37.6 h1:CZImQdb1QbU9sGgJ9IswhVkxAcjkkD1eQTMA1KHWk+E= +github.com/aws/aws-sdk-go-v2/service/kms v1.37.6/go.mod h1:YJDdlK0zsyxVBxGU48AR/Mi8DMrGdc1E3Yij4fNrONA= +github.com/aws/aws-sdk-go-v2/service/lambda v1.66.1 h1:eJz5qvOPvA7KipzQNycPxPz7ets082W91BJKJpVRFL4= +github.com/aws/aws-sdk-go-v2/service/lambda v1.66.1/go.mod h1:guz2K3x4FKSdDaoeB+TPVgJNU9oj2gftbp5cR8ela1A= +github.com/aws/aws-sdk-go-v2/service/rds v1.90.0 h1:Lg3GkzGkgqY03qsLSXPFyxW59t/lSoXaK9SWa8EKCiI= +github.com/aws/aws-sdk-go-v2/service/rds v1.90.0/go.mod h1:h2jc7IleH3xHY7y+h8FH7WAZcz3IVLOB6/jXotIQ/qU= +github.com/aws/aws-sdk-go-v2/service/route53 v1.46.2 h1:wmt05tPp/CaRZpPV5B4SaJ5TwkHKom07/BzHoLdkY1o= +github.com/aws/aws-sdk-go-v2/service/route53 v1.46.2/go.mod h1:d+K9HESMpGb1EU9/UmmpInbGIUcAkwmcY6ZO/A3zZsw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.67.1 h1:LXLnDfjT/P6SPIaCE86xCOjJROPn4FNB2EdN68vMK5c= +github.com/aws/aws-sdk-go-v2/service/s3 v1.67.1/go.mod h1:ralv4XawHjEMaHOWnTFushl0WRqim/gQWesAMF6hTow= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.6 h1:1KDMKvOKNrpD667ORbZ/+4OgvUoaok1gg/MLzrHF9fw= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.6/go.mod h1:DmtyfCfONhOyVAJ6ZMTrDSFIeyCBlEO93Qkfhxwbxu0= +github.com/aws/aws-sdk-go-v2/service/sns v1.33.5 h1:nJDOsZumqKsejsiGKgpezFzI2oatHmQi/kKKC4wS8v4= +github.com/aws/aws-sdk-go-v2/service/sns v1.33.5/go.mod h1:SODr0Lu3lFdT0SGsGX1TzFTapwveBrT5wztVoYtppm8= +github.com/aws/aws-sdk-go-v2/service/sqs v1.37.1 h1:39WvSrVq9DD6UHkD+fx5x19P5KpRQfNdtgReDVNbelc= +github.com/aws/aws-sdk-go-v2/service/sqs v1.37.1/go.mod h1:3gwPzC9LER/BTQdQZ3r6dUktb1rSjABF1D3Sr6nS7VU= +github.com/aws/aws-sdk-go-v2/service/ssm v1.55.6 h1:mh6Osa3cjwaaVSzJ92a8x1dBh8XQ7ekKLHyhjtx5RRw= +github.com/aws/aws-sdk-go-v2/service/ssm v1.55.6/go.mod h1:l9qF25TzH95FhcIak6e4vt79KE4I7M2Nf59eMUVjj6c= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.6 h1:3zu537oLmsPfDMyjnUS2g+F2vITgy5pB74tHI+JBNoM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.6/go.mod h1:WJSZH2ZvepM6t6jwu4w/Z45Eoi75lPN7DcydSRtJg6Y= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.5 h1:K0OQAsDywb0ltlFrZm0JHPY3yZp/S9OaoLU33S7vPS8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.5/go.mod h1:ORITg+fyuMoeiQFiVGoqB3OydVTLkClw/ljbblMq6Cc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.1 h1:6SZUVRQNvExYlMLbHdlKB48x0fLbc2iVROyaNEwBHbU= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.1/go.mod h1:GqWyYCwLXnlUB1lOAXQyNSPqPLQJvmo8J0DWBzp9mtg= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d h1:xDfNPAt8lFiC1UJrqV3uuy861HCTo708pDMbjHHdCas= github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d/go.mod h1:6QX/PXZ00z/TKoufEY6K/a0k6AhaJrQKdFe6OfVXsa4= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= diff --git a/modules/aws/account.go b/modules/aws/account.go index e64bd36d0..d4d93d37b 100644 --- a/modules/aws/account.go +++ b/modules/aws/account.go @@ -1,11 +1,12 @@ package aws import ( + "context" "errors" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -26,12 +27,12 @@ func GetAccountIdE(t testing.TestingT) (string, error) { return "", err } - identity, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + identity, err := stsClient.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{}) if err != nil { return "", err } - return aws.StringValue(identity.Account), nil + return aws.ToString(identity.Account), nil } // An IAM arn is of the format arn:aws:iam::123456789012:user/test. The account id is the number after arn:aws:iam::, @@ -47,10 +48,10 @@ func extractAccountIDFromARN(arn string) (string, error) { } // NewStsClientE creates a new STS client. -func NewStsClientE(t testing.TestingT, region string) (*sts.STS, error) { +func NewStsClientE(t testing.TestingT, region string) (*sts.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return sts.New(sess), nil + return sts.NewFromConfig(*sess), nil } diff --git a/modules/aws/acm.go b/modules/aws/acm.go index 88ac5f9de..ea00a9d53 100644 --- a/modules/aws/acm.go +++ b/modules/aws/acm.go @@ -1,7 +1,9 @@ package aws import ( - "github.com/aws/aws-sdk-go/service/acm" + "context" + + "github.com/aws/aws-sdk-go-v2/service/acm" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -22,7 +24,7 @@ func GetAcmCertificateArnE(t testing.TestingT, awsRegion string, certDomainName return "", err } - result, err := acmClient.ListCertificates(&acm.ListCertificatesInput{}) + result, err := acmClient.ListCertificates(context.Background(), &acm.ListCertificatesInput{}) if err != nil { return "", err } @@ -37,7 +39,7 @@ func GetAcmCertificateArnE(t testing.TestingT, awsRegion string, certDomainName } // NewAcmClient create a new ACM client. -func NewAcmClient(t testing.TestingT, region string) *acm.ACM { +func NewAcmClient(t testing.TestingT, region string) *acm.Client { client, err := NewAcmClientE(t, region) if err != nil { t.Fatal(err) @@ -46,11 +48,11 @@ func NewAcmClient(t testing.TestingT, region string) *acm.ACM { } // NewAcmClientE creates a new ACM client. -func NewAcmClientE(t testing.TestingT, awsRegion string) (*acm.ACM, error) { +func NewAcmClientE(t testing.TestingT, awsRegion string) (*acm.Client, error) { sess, err := NewAuthenticatedSession(awsRegion) if err != nil { return nil, err } - return acm.New(sess), nil + return acm.NewFromConfig(*sess), nil } diff --git a/modules/aws/ami.go b/modules/aws/ami.go index 466980b23..de1af5868 100644 --- a/modules/aws/ami.go +++ b/modules/aws/ami.go @@ -1,12 +1,14 @@ package aws import ( + "context" "fmt" "sort" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -57,7 +59,7 @@ func GetEbsSnapshotsForAmi(t testing.TestingT, region string, ami string) []stri return snapshots } -// GetEbsSnapshotsForAmi retrieves the EBS snapshots which back the given AMI +// GetEbsSnapshotsForAmiE retrieves the EBS snapshots which back the given AMI func GetEbsSnapshotsForAmiE(t testing.TestingT, region string, ami string) ([]string, error) { logger.Default.Logf(t, "Retrieving EBS snapshots backing AMI %s", ami) ec2Client, err := NewEc2ClientE(t, region) @@ -65,9 +67,9 @@ func GetEbsSnapshotsForAmiE(t testing.TestingT, region string, ami string) ([]st return nil, err } - images, err := ec2Client.DescribeImages(&ec2.DescribeImagesInput{ - ImageIds: []*string{ - aws.String(ami), + images, err := ec2Client.DescribeImages(context.Background(), &ec2.DescribeImagesInput{ + ImageIds: []string{ + ami, }, }) if err != nil { @@ -78,7 +80,7 @@ func GetEbsSnapshotsForAmiE(t testing.TestingT, region string, ami string) ([]st for _, image := range images.Images { for _, mapping := range image.BlockDeviceMappings { if mapping.Ebs != nil && mapping.Ebs.SnapshotId != nil { - snapshots = append(snapshots, aws.StringValue(mapping.Ebs.SnapshotId)) + snapshots = append(snapshots, aws.ToString(mapping.Ebs.SnapshotId)) } } } @@ -106,18 +108,18 @@ func GetMostRecentAmiIdE(t testing.TestingT, region string, ownerId string, filt return "", err } - ec2Filters := []*ec2.Filter{} + var ec2Filters []types.Filter for name, values := range filters { - ec2Filters = append(ec2Filters, &ec2.Filter{Name: aws.String(name), Values: aws.StringSlice(values)}) + ec2Filters = append(ec2Filters, types.Filter{Name: aws.String(name), Values: values}) } input := ec2.DescribeImagesInput{ Filters: ec2Filters, IncludeDeprecated: aws.Bool(true), - Owners: []*string{aws.String(ownerId)}, + Owners: []string{ownerId}, } - out, err := ec2Client.DescribeImages(&input) + out, err := ec2Client.DescribeImages(context.Background(), &input) if err != nil { return "", err } @@ -127,11 +129,11 @@ func GetMostRecentAmiIdE(t testing.TestingT, region string, ownerId string, filt } mostRecentImage := mostRecentAMI(out.Images) - return aws.StringValue(mostRecentImage.ImageId), nil + return aws.ToString(mostRecentImage.ImageId), nil } // Image sorting code borrowed from: https://github.com/hashicorp/packer/blob/7f4112ba229309cfc0ebaa10ded2abdfaf1b22c8/builder/amazon/common/step_source_ami_info.go -type imageSort []*ec2.Image +type imageSort []types.Image func (a imageSort) Len() int { return len(a) } func (a imageSort) Swap(i, j int) { a[i], a[j] = a[j], a[i] } @@ -142,7 +144,7 @@ func (a imageSort) Less(i, j int) bool { } // mostRecentAMI returns the most recent AMI out of a slice of images. -func mostRecentAMI(images []*ec2.Image) *ec2.Image { +func mostRecentAMI(images []types.Image) types.Image { sortedImages := images sort.Sort(imageSort(sortedImages)) return sortedImages[len(sortedImages)-1] diff --git a/modules/aws/asg.go b/modules/aws/asg.go index d69c5ab5d..d88d70f4b 100644 --- a/modules/aws/asg.go +++ b/modules/aws/asg.go @@ -1,11 +1,12 @@ package aws import ( + "context" "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" "github.com/stretchr/testify/require" "github.com/gruntwork-io/terratest/modules/logger" @@ -34,8 +35,8 @@ func GetCapacityInfoForAsgE(t testing.TestingT, asgName string, awsRegion string return AsgCapacityInfo{}, err } - input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []*string{aws.String(asgName)}} - output, err := asgClient.DescribeAutoScalingGroups(&input) + input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []string{asgName}} + output, err := asgClient.DescribeAutoScalingGroups(context.Background(), &input) if err != nil { return AsgCapacityInfo{}, err } @@ -44,9 +45,9 @@ func GetCapacityInfoForAsgE(t testing.TestingT, asgName string, awsRegion string return AsgCapacityInfo{}, NewNotFoundError("ASG", asgName, awsRegion) } capacityInfo := AsgCapacityInfo{ - MinCapacity: *groups[0].MinSize, - MaxCapacity: *groups[0].MaxSize, - DesiredCapacity: *groups[0].DesiredCapacity, + MinCapacity: int64(*groups[0].MinSize), + MaxCapacity: int64(*groups[0].MaxSize), + DesiredCapacity: int64(*groups[0].DesiredCapacity), CurrentCapacity: int64(len(groups[0].Instances)), } return capacityInfo, nil @@ -68,16 +69,16 @@ func GetInstanceIdsForAsgE(t testing.TestingT, asgName string, awsRegion string) return nil, err } - input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []*string{aws.String(asgName)}} - output, err := asgClient.DescribeAutoScalingGroups(&input) + input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []string{asgName}} + output, err := asgClient.DescribeAutoScalingGroups(context.Background(), &input) if err != nil { return nil, err } - instanceIDs := []string{} + var instanceIDs []string for _, asg := range output.AutoScalingGroups { for _, instance := range asg.Instances { - instanceIDs = append(instanceIDs, aws.StringValue(instance.InstanceId)) + instanceIDs = append(instanceIDs, aws.ToString(instance.InstanceId)) } } @@ -125,7 +126,7 @@ func WaitForCapacityE( } // NewAsgClient creates an Auto Scaling Group client. -func NewAsgClient(t testing.TestingT, region string) *autoscaling.AutoScaling { +func NewAsgClient(t testing.TestingT, region string) *autoscaling.Client { client, err := NewAsgClientE(t, region) if err != nil { t.Fatal(err) @@ -134,11 +135,11 @@ func NewAsgClient(t testing.TestingT, region string) *autoscaling.AutoScaling { } // NewAsgClientE creates an Auto Scaling Group client. -func NewAsgClientE(t testing.TestingT, region string) (*autoscaling.AutoScaling, error) { +func NewAsgClientE(t testing.TestingT, region string) (*autoscaling.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return autoscaling.New(sess), nil + return autoscaling.NewFromConfig(*sess), nil } diff --git a/modules/aws/asg_test.go b/modules/aws/asg_test.go index b6b1dda7c..bda73f8fe 100644 --- a/modules/aws/asg_test.go +++ b/modules/aws/asg_test.go @@ -1,13 +1,16 @@ package aws import ( + "context" "fmt" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + autoscalingTypes "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -47,56 +50,74 @@ func TestGetInstanceIdsForAsg(t *testing.T) { assert.Equal(t, len(instanceIds), 1) } -// The following functions were adapted from the tests for cloud-nuke - -func createTestAutoScalingGroup(t *testing.T, name string, region string, desiredCount int64) { - instance := createTestEC2Instance(t, region, name) +func createTestAutoScalingGroup(t *testing.T, name string, region string, desiredCount int32) { + azs := GetAvailabilityZones(t, region) + ec2Client := NewEc2Client(t, region) + imageID := GetAmazonLinuxAmi(t, region) + template, err := ec2Client.CreateLaunchTemplate(context.Background(), &ec2.CreateLaunchTemplateInput{ + LaunchTemplateData: &types.RequestLaunchTemplateData{ + ImageId: aws.String(imageID), + InstanceType: types.InstanceType(GetRecommendedInstanceType(t, region, []string{"t2.micro, t3.micro", "t2.small", "t3.small"})), + }, + LaunchTemplateName: aws.String(name), + }) + require.NoError(t, err) asgClient := NewAsgClient(t, region) param := &autoscaling.CreateAutoScalingGroupInput{ AutoScalingGroupName: &name, - InstanceId: instance.InstanceId, - DesiredCapacity: aws.Int64(desiredCount), - MinSize: aws.Int64(1), - MaxSize: aws.Int64(3), + LaunchTemplate: &autoscalingTypes.LaunchTemplateSpecification{ + LaunchTemplateId: template.LaunchTemplate.LaunchTemplateId, + Version: aws.String("$Latest"), + }, + AvailabilityZones: azs, + DesiredCapacity: aws.Int32(desiredCount), + MinSize: aws.Int32(1), + MaxSize: aws.Int32(3), } - _, err := asgClient.CreateAutoScalingGroup(param) + _, err = asgClient.CreateAutoScalingGroup(context.Background(), param) require.NoError(t, err) - err = asgClient.WaitUntilGroupExists(&autoscaling.DescribeAutoScalingGroupsInput{ - AutoScalingGroupNames: []*string{&name}, - }) + waiter := autoscaling.NewGroupExistsWaiter(asgClient) + err = waiter.Wait(context.Background(), &autoscaling.DescribeAutoScalingGroupsInput{ + AutoScalingGroupNames: []string{name}, + }, 42*time.Minute) require.NoError(t, err) } -func createTestEC2Instance(t *testing.T, region string, name string) ec2.Instance { +func createTestEC2Instance(t *testing.T, region string, name string) types.Instance { ec2Client := NewEc2Client(t, region) imageID := GetAmazonLinuxAmi(t, region) params := &ec2.RunInstancesInput{ ImageId: aws.String(imageID), - InstanceType: aws.String(GetRecommendedInstanceType(t, region, []string{"t2.micro, t3.micro", "t2.small", "t3.small"})), - MinCount: aws.Int64(1), - MaxCount: aws.Int64(1), + InstanceType: types.InstanceType(GetRecommendedInstanceType(t, region, []string{"t2.micro, t3.micro", "t2.small", "t3.small"})), + MinCount: aws.Int32(1), + MaxCount: aws.Int32(1), } - runResult, err := ec2Client.RunInstances(params) + runResult, err := ec2Client.RunInstances(context.Background(), params) require.NoError(t, err) require.NotEqual(t, len(runResult.Instances), 0) - err = ec2Client.WaitUntilInstanceExists(&ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ - Name: aws.String("instance-id"), - Values: []*string{runResult.Instances[0].InstanceId}, + waiter := ec2.NewInstanceExistsWaiter(ec2Client) + err = waiter.Wait( + context.Background(), + &ec2.DescribeInstancesInput{ + Filters: []types.Filter{ + { + Name: aws.String("instance-id"), + Values: []string{*runResult.Instances[0].InstanceId}, + }, }, }, - }) + 42*time.Minute, + ) require.NoError(t, err) // Add test tag to the created instance - _, err = ec2Client.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{runResult.Instances[0].InstanceId}, - Tags: []*ec2.Tag{ + _, err = ec2Client.CreateTags(context.Background(), &ec2.CreateTagsInput{ + Resources: []string{*runResult.Instances[0].InstanceId}, + Tags: []types.Tag{ { Key: aws.String("Name"), Value: aws.String(name), @@ -106,17 +127,18 @@ func createTestEC2Instance(t *testing.T, region string, name string) ec2.Instanc require.NoError(t, err) // EC2 Instance must be in a running before this function returns - err = ec2Client.WaitUntilInstanceRunning(&ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ + runningWaiter := ec2.NewInstanceRunningWaiter(ec2Client) + err = runningWaiter.Wait(context.Background(), &ec2.DescribeInstancesInput{ + Filters: []types.Filter{ + { Name: aws.String("instance-id"), - Values: []*string{runResult.Instances[0].InstanceId}, + Values: []string{*runResult.Instances[0].InstanceId}, }, }, - }) + }, 42*time.Minute) require.NoError(t, err) - return *runResult.Instances[0] + return runResult.Instances[0] } func terminateEc2InstancesByName(t *testing.T, region string, names []string) { @@ -134,10 +156,18 @@ func deleteAutoScalingGroup(t *testing.T, name string, region string) { asgClient := NewAsgClient(t, region) input := &autoscaling.DeleteAutoScalingGroupInput{AutoScalingGroupName: aws.String(name)} - _, err := asgClient.DeleteAutoScalingGroup(input) + _, err := asgClient.DeleteAutoScalingGroup(context.Background(), input) require.NoError(t, err) - err = asgClient.WaitUntilGroupNotExists(&autoscaling.DescribeAutoScalingGroupsInput{ - AutoScalingGroupNames: []*string{aws.String(name)}, + + waiter := autoscaling.NewGroupNotExistsWaiter(asgClient) + err = waiter.Wait(context.Background(), &autoscaling.DescribeAutoScalingGroupsInput{ + AutoScalingGroupNames: []string{name}, + }, 40*time.Minute) + require.NoError(t, err) + + ec2Client := NewEc2Client(t, region) + _, err = ec2Client.DeleteLaunchTemplate(context.Background(), &ec2.DeleteLaunchTemplateInput{ + LaunchTemplateName: aws.String(name), }) require.NoError(t, err) } @@ -146,15 +176,15 @@ func scaleAsgToZero(t *testing.T, name string, region string) { asgClient := NewAsgClient(t, region) input := &autoscaling.UpdateAutoScalingGroupInput{ AutoScalingGroupName: aws.String(name), - DesiredCapacity: aws.Int64(0), - MinSize: aws.Int64(0), - MaxSize: aws.Int64(0), + DesiredCapacity: aws.Int32(0), + MinSize: aws.Int32(0), + MaxSize: aws.Int32(0), } - _, err := asgClient.UpdateAutoScalingGroup(input) + _, err := asgClient.UpdateAutoScalingGroup(context.Background(), input) require.NoError(t, err) WaitForCapacity(t, name, region, 40, 15*time.Second) // There is an eventual consistency bug where even though the ASG is scaled down, AWS sometimes still views a - // scaling activity so we add a 5 second pause here to work around it. + // scaling activity so we add a 5-second pause here to work around it. time.Sleep(5 * time.Second) } diff --git a/modules/aws/auth.go b/modules/aws/auth.go index f2aa6f78c..fa8ab2683 100644 --- a/modules/aws/auth.go +++ b/modules/aws/auth.go @@ -1,16 +1,18 @@ package aws import ( + "context" "fmt" "os" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/pquerna/otp/totp" ) @@ -18,9 +20,9 @@ const ( AuthAssumeRoleEnvVar = "TERRATEST_IAM_ROLE" // OS environment variable name through which Assume Role ARN may be passed for authentication ) -// NewAuthenticatedSession creates an AWS session following to standard AWS authentication workflow. +// NewAuthenticatedSession creates an AWS Config following to standard AWS authentication workflow. // If AuthAssumeIamRoleEnvVar environment variable is set, assumes IAM role specified in it. -func NewAuthenticatedSession(region string) (*session.Session, error) { +func NewAuthenticatedSession(region string) (*aws.Config, error) { if assumeRoleArn, ok := os.LookupEnv(AuthAssumeRoleEnvVar); ok { return NewAuthenticatedSessionFromRole(region, assumeRoleArn) } else { @@ -28,76 +30,58 @@ func NewAuthenticatedSession(region string) (*session.Session, error) { } } -// NewAuthenticatedSessionFromDefaultCredentials gets an AWS Session, checking that the user has credentials properly configured in their environment. -func NewAuthenticatedSessionFromDefaultCredentials(region string) (*session.Session, error) { - awsConfig := aws.NewConfig().WithRegion(region) - - sessionOptions := session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - } - - sess, err := session.NewSessionWithOptions(sessionOptions) +// NewAuthenticatedSessionFromDefaultCredentials gets an AWS Config, checking that the user has credentials properly configured in their environment. +func NewAuthenticatedSessionFromDefaultCredentials(region string) (*aws.Config, error) { + cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region)) if err != nil { - return nil, err - } - - if _, err = sess.Config.Credentials.Get(); err != nil { return nil, CredentialsError{UnderlyingErr: err} } - return sess, nil + return &cfg, nil } -// NewAuthenticatedSessionFromRole returns a new AWS Session after assuming the +// NewAuthenticatedSessionFromRole returns a new AWS Config after assuming the // role whose ARN is provided in roleARN. If the credentials are not properly // configured in the underlying environment, an error is returned. -func NewAuthenticatedSessionFromRole(region string, roleARN string) (*session.Session, error) { - sess, err := CreateAwsSessionFromRole(region, roleARN) +func NewAuthenticatedSessionFromRole(region string, roleARN string) (*aws.Config, error) { + cfg, err := NewAuthenticatedSessionFromDefaultCredentials(region) if err != nil { return nil, err } - if _, err = sess.Config.Credentials.Get(); err != nil { - return nil, CredentialsError{UnderlyingErr: err} - } - - return sess, nil -} + client := sts.NewFromConfig(*cfg) -// CreateAwsSessionFromRole returns a new AWS session after assuming the role -// whose ARN is provided in roleARN. -func CreateAwsSessionFromRole(region string, roleARN string) (*session.Session, error) { - sess, err := session.NewSession(aws.NewConfig().WithRegion(region)) + roleProvider := stscreds.NewAssumeRoleProvider(client, roleARN) + retrieve, err := roleProvider.Retrieve(context.Background()) if err != nil { - return nil, err + return nil, CredentialsError{UnderlyingErr: err} } - sess = AssumeRole(sess, roleARN) - return sess, err -} -// AssumeRole mutates the provided session by obtaining new credentials by -// assuming the role provided in roleARN. -func AssumeRole(sess *session.Session, roleARN string) *session.Session { - sess.Config.Credentials = stscreds.NewCredentials(sess, roleARN) - return sess + return &aws.Config{ + Region: region, + Credentials: aws.NewCredentialsCache(credentials.StaticCredentialsProvider{ + Value: retrieve, + }), + }, nil } -// CreateAwsSessionWithCreds creates a new AWS session using explicit credentials. This is useful if you want to create an IAM User dynamically and -// create an AWS session authenticated as the new IAM User. -func CreateAwsSessionWithCreds(region string, accessKeyID string, secretAccessKey string) (*session.Session, error) { - creds := CreateAwsCredentials(accessKeyID, secretAccessKey) - return session.NewSession(aws.NewConfig().WithRegion(region).WithCredentials(creds)) +// CreateAwsSessionWithCreds creates a new AWS Config using explicit credentials. This is useful if you want to create an IAM User dynamically and +// create an AWS Config authenticated as the new IAM User. +func CreateAwsSessionWithCreds(region string, accessKeyID string, secretAccessKey string) (*aws.Config, error) { + return &aws.Config{ + Region: region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")), + }, nil } -// CreateAwsSessionWithMfa creates a new AWS session authenticated using an MFA token retrieved using the given STS client and MFA Device. -func CreateAwsSessionWithMfa(region string, stsClient *sts.STS, mfaDevice *iam.VirtualMFADevice) (*session.Session, error) { +// CreateAwsSessionWithMfa creates a new AWS Config authenticated using an MFA token retrieved using the given STS client and MFA Device. +func CreateAwsSessionWithMfa(region string, stsClient *sts.Client, mfaDevice *types.VirtualMFADevice) (*aws.Config, error) { tokenCode, err := GetTimeBasedOneTimePassword(mfaDevice) if err != nil { return nil, err } - output, err := stsClient.GetSessionToken(&sts.GetSessionTokenInput{ + output, err := stsClient.GetSessionToken(context.Background(), &sts.GetSessionTokenInput{ SerialNumber: mfaDevice.SerialNumber, TokenCode: aws.String(tokenCode), }) @@ -109,29 +93,14 @@ func CreateAwsSessionWithMfa(region string, stsClient *sts.STS, mfaDevice *iam.V secretAccessKey := *output.Credentials.SecretAccessKey sessionToken := *output.Credentials.SessionToken - creds := CreateAwsCredentialsWithSessionToken(accessKeyID, secretAccessKey, sessionToken) - return session.NewSession(aws.NewConfig().WithRegion(region).WithCredentials(creds)) -} - -// CreateAwsCredentials creates an AWS Credentials configuration with specific AWS credentials. -func CreateAwsCredentials(accessKeyID string, secretAccessKey string) *credentials.Credentials { - creds := credentials.Value{AccessKeyID: accessKeyID, SecretAccessKey: secretAccessKey} - return credentials.NewStaticCredentialsFromCreds(creds) -} - -// CreateAwsCredentialsWithSessionToken creates an AWS Credentials configuration with temporary AWS credentials by including a session token (used for -// authenticating with MFA). -func CreateAwsCredentialsWithSessionToken(accessKeyID, secretAccessKey, sessionToken string) *credentials.Credentials { - creds := credentials.Value{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - SessionToken: sessionToken, - } - return credentials.NewStaticCredentialsFromCreds(creds) + return &aws.Config{ + Region: region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, sessionToken)), + }, nil } // GetTimeBasedOneTimePassword gets a One-Time Password from the given mfaDevice. Per the RFC 6238 standard, this value will be different every 30 seconds. -func GetTimeBasedOneTimePassword(mfaDevice *iam.VirtualMFADevice) (string, error) { +func GetTimeBasedOneTimePassword(mfaDevice *types.VirtualMFADevice) (string, error) { base32StringSeed := string(mfaDevice.Base32StringSeed) otp, err := totp.GenerateCode(base32StringSeed, time.Now()) @@ -143,8 +112,8 @@ func GetTimeBasedOneTimePassword(mfaDevice *iam.VirtualMFADevice) (string, error } // ReadPasswordPolicyMinPasswordLength returns the minimal password length. -func ReadPasswordPolicyMinPasswordLength(iamClient *iam.IAM) (int, error) { - output, err := iamClient.GetAccountPasswordPolicy(&iam.GetAccountPasswordPolicyInput{}) +func ReadPasswordPolicyMinPasswordLength(iamClient *iam.Client) (int, error) { + output, err := iamClient.GetAccountPasswordPolicy(context.Background(), &iam.GetAccountPasswordPolicyInput{}) if err != nil { return -1, err } diff --git a/modules/aws/cloudwatch.go b/modules/aws/cloudwatch.go index d5af76e28..f24783e11 100644 --- a/modules/aws/cloudwatch.go +++ b/modules/aws/cloudwatch.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudwatchlogs" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -22,7 +24,7 @@ func GetCloudWatchLogEntriesE(t testing.TestingT, awsRegion string, logStreamNam return nil, err } - output, err := client.GetLogEvents(&cloudwatchlogs.GetLogEventsInput{ + output, err := client.GetLogEvents(context.Background(), &cloudwatchlogs.GetLogEventsInput{ LogGroupName: aws.String(logGroupName), LogStreamName: aws.String(logStreamName), }) @@ -31,7 +33,7 @@ func GetCloudWatchLogEntriesE(t testing.TestingT, awsRegion string, logStreamNam return nil, err } - entries := []string{} + var entries []string for _, event := range output.Events { entries = append(entries, *event.Message) } @@ -40,7 +42,7 @@ func GetCloudWatchLogEntriesE(t testing.TestingT, awsRegion string, logStreamNam } // NewCloudWatchLogsClient creates a new CloudWatch Logs client. -func NewCloudWatchLogsClient(t testing.TestingT, region string) *cloudwatchlogs.CloudWatchLogs { +func NewCloudWatchLogsClient(t testing.TestingT, region string) *cloudwatchlogs.Client { client, err := NewCloudWatchLogsClientE(t, region) if err != nil { t.Fatal(err) @@ -49,10 +51,10 @@ func NewCloudWatchLogsClient(t testing.TestingT, region string) *cloudwatchlogs. } // NewCloudWatchLogsClientE creates a new CloudWatch Logs client. -func NewCloudWatchLogsClientE(t testing.TestingT, region string) (*cloudwatchlogs.CloudWatchLogs, error) { +func NewCloudWatchLogsClientE(t testing.TestingT, region string) (*cloudwatchlogs.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return cloudwatchlogs.New(sess), nil + return cloudwatchlogs.NewFromConfig(*sess), nil } diff --git a/modules/aws/dynamodb.go b/modules/aws/dynamodb.go index 447b17ece..cbd44e4b8 100644 --- a/modules/aws/dynamodb.go +++ b/modules/aws/dynamodb.go @@ -1,23 +1,26 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" ) // GetDynamoDbTableTags fetches resource tags of a specified dynamoDB table. This will fail the test if there are any errors -func GetDynamoDbTableTags(t testing.TestingT, region string, tableName string) []*dynamodb.Tag { +func GetDynamoDbTableTags(t testing.TestingT, region string, tableName string) []types.Tag { tags, err := GetDynamoDbTableTagsE(t, region, tableName) require.NoError(t, err) return tags } // GetDynamoDbTableTagsE fetches resource tags of a specified dynamoDB table. -func GetDynamoDbTableTagsE(t testing.TestingT, region string, tableName string) ([]*dynamodb.Tag, error) { +func GetDynamoDbTableTagsE(t testing.TestingT, region string, tableName string) ([]types.Tag, error) { table := GetDynamoDBTable(t, region, tableName) - out, err := NewDynamoDBClient(t, region).ListTagsOfResource(&dynamodb.ListTagsOfResourceInput{ + out, err := NewDynamoDBClient(t, region).ListTagsOfResource(context.Background(), &dynamodb.ListTagsOfResourceInput{ ResourceArn: table.TableArn, }) if err != nil { @@ -27,15 +30,15 @@ func GetDynamoDbTableTagsE(t testing.TestingT, region string, tableName string) } // GetDynamoDBTableTimeToLive fetches information about the TTL configuration of a specified dynamoDB table. This will fail the test if there are any errors. -func GetDynamoDBTableTimeToLive(t testing.TestingT, region string, tableName string) *dynamodb.TimeToLiveDescription { +func GetDynamoDBTableTimeToLive(t testing.TestingT, region string, tableName string) *types.TimeToLiveDescription { ttl, err := GetDynamoDBTableTimeToLiveE(t, region, tableName) require.NoError(t, err) return ttl } // GetDynamoDBTableTimeToLiveE fetches information about the TTL configuration of a specified dynamoDB table. -func GetDynamoDBTableTimeToLiveE(t testing.TestingT, region string, tableName string) (*dynamodb.TimeToLiveDescription, error) { - out, err := NewDynamoDBClient(t, region).DescribeTimeToLive(&dynamodb.DescribeTimeToLiveInput{ +func GetDynamoDBTableTimeToLiveE(t testing.TestingT, region string, tableName string) (*types.TimeToLiveDescription, error) { + out, err := NewDynamoDBClient(t, region).DescribeTimeToLive(context.Background(), &dynamodb.DescribeTimeToLiveInput{ TableName: aws.String(tableName), }) if err != nil { @@ -45,15 +48,15 @@ func GetDynamoDBTableTimeToLiveE(t testing.TestingT, region string, tableName st } // GetDynamoDBTable fetches information about the specified dynamoDB table. This will fail the test if there are any errors. -func GetDynamoDBTable(t testing.TestingT, region string, tableName string) *dynamodb.TableDescription { +func GetDynamoDBTable(t testing.TestingT, region string, tableName string) *types.TableDescription { table, err := GetDynamoDBTableE(t, region, tableName) require.NoError(t, err) return table } // GetDynamoDBTableE fetches information about the specified dynamoDB table. -func GetDynamoDBTableE(t testing.TestingT, region string, tableName string) (*dynamodb.TableDescription, error) { - out, err := NewDynamoDBClient(t, region).DescribeTable(&dynamodb.DescribeTableInput{ +func GetDynamoDBTableE(t testing.TestingT, region string, tableName string) (*types.TableDescription, error) { + out, err := NewDynamoDBClient(t, region).DescribeTable(context.Background(), &dynamodb.DescribeTableInput{ TableName: aws.String(tableName), }) if err != nil { @@ -63,17 +66,17 @@ func GetDynamoDBTableE(t testing.TestingT, region string, tableName string) (*dy } // NewDynamoDBClient creates a DynamoDB client. -func NewDynamoDBClient(t testing.TestingT, region string) *dynamodb.DynamoDB { +func NewDynamoDBClient(t testing.TestingT, region string) *dynamodb.Client { client, err := NewDynamoDBClientE(t, region) require.NoError(t, err) return client } // NewDynamoDBClientE creates a DynamoDB client. -func NewDynamoDBClientE(t testing.TestingT, region string) (*dynamodb.DynamoDB, error) { +func NewDynamoDBClientE(t testing.TestingT, region string) (*dynamodb.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return dynamodb.New(sess), nil + return dynamodb.NewFromConfig(*sess), nil } diff --git a/modules/aws/ebs.go b/modules/aws/ebs.go index 5f8e2e1fd..5bcd2dee8 100644 --- a/modules/aws/ebs.go +++ b/modules/aws/ebs.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -15,7 +17,7 @@ func DeleteEbsSnapshot(t testing.TestingT, region string, snapshot string) { } } -// DeleteEbsSnapshot deletes the given EBS snapshot +// DeleteEbsSnapshotE deletes the given EBS snapshot func DeleteEbsSnapshotE(t testing.TestingT, region string, snapshot string) error { logger.Default.Logf(t, "Deleting EBS snapshot %s", snapshot) ec2Client, err := NewEc2ClientE(t, region) @@ -23,7 +25,7 @@ func DeleteEbsSnapshotE(t testing.TestingT, region string, snapshot string) erro return err } - _, err = ec2Client.DeleteSnapshot(&ec2.DeleteSnapshotInput{ + _, err = ec2Client.DeleteSnapshot(context.Background(), &ec2.DeleteSnapshotInput{ SnapshotId: aws.String(snapshot), }) return err diff --git a/modules/aws/ec2-syslog.go b/modules/aws/ec2-syslog.go index 43acbd47a..f622a465c 100644 --- a/modules/aws/ec2-syslog.go +++ b/modules/aws/ec2-syslog.go @@ -1,18 +1,12 @@ package aws import ( - "encoding/base64" "fmt" - "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/gruntwork-io/terratest/modules/logger" - "github.com/gruntwork-io/terratest/modules/retry" "github.com/gruntwork-io/terratest/modules/testing" ) -// (Deprecated) See the FetchContentsOfFileFromInstance method for a more powerful solution. +// GetSyslogForInstance (Deprecated) See the FetchContentsOfFileFromInstance method for a more powerful solution. // // GetSyslogForInstance gets the syslog for the Instance with the given ID in the given region. This should be available ~1 minute after an // Instance boots and is very useful for debugging boot-time issues, such as an error in User Data. @@ -24,57 +18,19 @@ func GetSyslogForInstance(t testing.TestingT, instanceID string, awsRegion strin return out } -// (Deprecated) See the FetchContentsOfFileFromInstanceE method for a more powerful solution. +// GetSyslogForInstanceE (Deprecated) See the FetchContentsOfFileFromInstanceE method for a more powerful solution. // // GetSyslogForInstanceE gets the syslog for the Instance with the given ID in the given region. This should be available ~1 minute after an // Instance boots and is very useful for debugging boot-time issues, such as an error in User Data. func GetSyslogForInstanceE(t testing.TestingT, instanceID string, region string) (string, error) { - description := fmt.Sprintf("Fetching syslog for Instance %s in %s", instanceID, region) - maxRetries := 120 - timeBetweenRetries := 5 * time.Second - - logger.Default.Logf(t, "%s", description) - - client, err := NewEc2ClientE(t, region) - if err != nil { - return "", err - } - - input := ec2.GetConsoleOutputInput{ - InstanceId: aws.String(instanceID), - } - - syslogB64, err := retry.DoWithRetryE(t, description, maxRetries, timeBetweenRetries, func() (string, error) { - out, err := client.GetConsoleOutput(&input) - if err != nil { - return "", err - } - - syslog := aws.StringValue(out.Output) - if syslog == "" { - return "", fmt.Errorf("Syslog is not yet available for instance %s in %s", instanceID, region) - } - - return syslog, nil - }) - - if err != nil { - return "", err - } - - syslogBytes, err := base64.StdEncoding.DecodeString(syslogB64) - if err != nil { - return "", err - } - - return string(syslogBytes), nil + return "", fmt.Errorf("(Deprecated) use FetchContentsOfFileFromInstanceE method instead") } -// (Deprecated) See the FetchContentsOfFilesFromAsg method for a more powerful solution. +// GetSyslogForInstancesInAsg (Deprecated) See the FetchContentsOfFilesFromAsg method for a more powerful solution. // // GetSyslogForInstancesInAsg gets the syslog for each of the Instances in the given ASG in the given region. These logs should be available ~1 // minute after the Instance boots and are very useful for debugging boot-time issues, such as an error in User Data. -// Returns a map of Instance Id -> Syslog for that Instance. +// Returns a map of Instance ID -> Syslog for that Instance. func GetSyslogForInstancesInAsg(t testing.TestingT, asgName string, awsRegion string) map[string]string { out, err := GetSyslogForInstancesInAsgE(t, asgName, awsRegion) if err != nil { @@ -83,27 +39,11 @@ func GetSyslogForInstancesInAsg(t testing.TestingT, asgName string, awsRegion st return out } -// (Deprecated) See the FetchContentsOfFilesFromAsgE method for a more powerful solution. +// GetSyslogForInstancesInAsgE (Deprecated) See the FetchContentsOfFilesFromAsgE method for a more powerful solution. // // GetSyslogForInstancesInAsgE gets the syslog for each of the Instances in the given ASG in the given region. These logs should be available ~1 // minute after the Instance boots and are very useful for debugging boot-time issues, such as an error in User Data. -// Returns a map of Instance Id -> Syslog for that Instance. +// Returns a map of Instance ID -> Syslog for that Instance. func GetSyslogForInstancesInAsgE(t testing.TestingT, asgName string, awsRegion string) (map[string]string, error) { - logger.Default.Logf(t, "Fetching syslog for each Instance in ASG %s in %s", asgName, awsRegion) - - instanceIDs, err := GetEc2InstanceIdsByTagE(t, awsRegion, "aws:autoscaling:groupName", asgName) - if err != nil { - return nil, err - } - - logs := map[string]string{} - for _, id := range instanceIDs { - syslog, err := GetSyslogForInstanceE(t, id, awsRegion) - if err != nil { - return nil, err - } - logs[id] = syslog - } - - return logs, nil + return nil, fmt.Errorf("(Deprecated) use FetchContentsOfFilesFromAsgE method instead") } diff --git a/modules/aws/ec2.go b/modules/aws/ec2.go index 62886a67d..99bc26a16 100644 --- a/modules/aws/ec2.go +++ b/modules/aws/ec2.go @@ -1,10 +1,12 @@ package aws import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -44,17 +46,17 @@ func GetPrivateIpsOfEc2Instances(t testing.TestingT, instanceIDs []string, awsRe func GetPrivateIpsOfEc2InstancesE(t testing.TestingT, instanceIDs []string, awsRegion string) (map[string]string, error) { ec2Client := NewEc2Client(t, awsRegion) // TODO: implement pagination for cases that extend beyond limit (1000 instances) - input := ec2.DescribeInstancesInput{InstanceIds: aws.StringSlice(instanceIDs)} - output, err := ec2Client.DescribeInstances(&input) + input := ec2.DescribeInstancesInput{InstanceIds: instanceIDs} + output, err := ec2Client.DescribeInstances(context.Background(), &input) if err != nil { return nil, err } ips := map[string]string{} - for _, reserveration := range output.Reservations { - for _, instance := range reserveration.Instances { - ips[aws.StringValue(instance.InstanceId)] = aws.StringValue(instance.PrivateIpAddress) + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + ips[aws.ToString(instance.InstanceId)] = aws.ToString(instance.PrivateIpAddress) } } @@ -98,17 +100,17 @@ func GetPrivateHostnamesOfEc2InstancesE(t testing.TestingT, instanceIDs []string return nil, err } // TODO: implement pagination for cases that extend beyond limit (1000 instances) - input := ec2.DescribeInstancesInput{InstanceIds: aws.StringSlice(instanceIDs)} - output, err := ec2Client.DescribeInstances(&input) + input := ec2.DescribeInstancesInput{InstanceIds: instanceIDs} + output, err := ec2Client.DescribeInstances(context.Background(), &input) if err != nil { return nil, err } hostnames := map[string]string{} - for _, reserveration := range output.Reservations { - for _, instance := range reserveration.Instances { - hostnames[aws.StringValue(instance.InstanceId)] = aws.StringValue(instance.PrivateDnsName) + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + hostnames[aws.ToString(instance.InstanceId)] = aws.ToString(instance.PrivateDnsName) } } @@ -149,17 +151,17 @@ func GetPublicIpsOfEc2Instances(t testing.TestingT, instanceIDs []string, awsReg func GetPublicIpsOfEc2InstancesE(t testing.TestingT, instanceIDs []string, awsRegion string) (map[string]string, error) { ec2Client := NewEc2Client(t, awsRegion) // TODO: implement pagination for cases that extend beyond limit (1000 instances) - input := ec2.DescribeInstancesInput{InstanceIds: aws.StringSlice(instanceIDs)} - output, err := ec2Client.DescribeInstances(&input) + input := ec2.DescribeInstancesInput{InstanceIds: instanceIDs} + output, err := ec2Client.DescribeInstances(context.Background(), &input) if err != nil { return nil, err } ips := map[string]string{} - for _, reserveration := range output.Reservations { - for _, instance := range reserveration.Instances { - ips[aws.StringValue(instance.InstanceId)] = aws.StringValue(instance.PublicIpAddress) + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + ips[aws.ToString(instance.InstanceId)] = aws.ToString(instance.PublicIpAddress) } } @@ -189,7 +191,7 @@ func GetEc2InstanceIdsByFilters(t testing.TestingT, region string, ec2Filters ma return out } -// GetEc2InstanceIdsByFilters returns all the IDs of EC2 instances in the given region which match to EC2 filter list +// GetEc2InstanceIdsByFiltersE returns all the IDs of EC2 instances in the given region which match to EC2 filter list // as per https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/#DescribeInstancesInput. func GetEc2InstanceIdsByFiltersE(t testing.TestingT, region string, ec2Filters map[string][]string) ([]string, error) { client, err := NewEc2ClientE(t, region) @@ -197,19 +199,19 @@ func GetEc2InstanceIdsByFiltersE(t testing.TestingT, region string, ec2Filters m return nil, err } - ec2FilterList := []*ec2.Filter{} + var ec2FilterList []types.Filter for name, values := range ec2Filters { - ec2FilterList = append(ec2FilterList, &ec2.Filter{Name: aws.String(name), Values: aws.StringSlice(values)}) + ec2FilterList = append(ec2FilterList, types.Filter{Name: aws.String(name), Values: values}) } // TODO: implement pagination for cases that extend beyond limit (1000 instances) - output, err := client.DescribeInstances(&ec2.DescribeInstancesInput{Filters: ec2FilterList}) + output, err := client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{Filters: ec2FilterList}) if err != nil { return nil, err } - instanceIDs := []string{} + var instanceIDs []string for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { @@ -235,19 +237,19 @@ func GetTagsForEc2InstanceE(t testing.TestingT, region string, instanceID string } input := ec2.DescribeTagsInput{ - Filters: []*ec2.Filter{ + Filters: []types.Filter{ { Name: aws.String("resource-type"), - Values: aws.StringSlice([]string{"instance"}), + Values: []string{"instance"}, }, { Name: aws.String("resource-id"), - Values: aws.StringSlice([]string{instanceID}), + Values: []string{instanceID}, }, }, } - out, err := client.DescribeTags(&input) + out, err := client.DescribeTags(context.Background(), &input) if err != nil { return nil, err } @@ -255,7 +257,7 @@ func GetTagsForEc2InstanceE(t testing.TestingT, region string, instanceID string tags := map[string]string{} for _, tag := range out.Tags { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -275,7 +277,7 @@ func DeleteAmiE(t testing.TestingT, region string, imageID string) error { return err } - _, err = client.DeregisterImage(&ec2.DeregisterImageInput{ImageId: aws.String(imageID)}) + _, err = client.DeregisterImage(context.Background(), &ec2.DeregisterImageInput{ImageId: aws.String(imageID)}) return err } @@ -291,16 +293,16 @@ func AddTagsToResourceE(t testing.TestingT, region string, resource string, tags return err } - var awsTags []*ec2.Tag + var awsTags []types.Tag for key, value := range tags { - awsTags = append(awsTags, &ec2.Tag{ + awsTags = append(awsTags, types.Tag{ Key: aws.String(key), Value: aws.String(value), }) } - _, err = client.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{aws.String(resource)}, + _, err = client.CreateTags(context.Background(), &ec2.CreateTagsInput{ + Resources: []string{resource}, Tags: awsTags, }) @@ -321,9 +323,9 @@ func TerminateInstanceE(t testing.TestingT, region string, instanceID string) er return err } - _, err = client.TerminateInstances(&ec2.TerminateInstancesInput{ - InstanceIds: []*string{ - aws.String(instanceID), + _, err = client.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{ + InstanceIds: []string{ + instanceID, }, }) @@ -344,7 +346,7 @@ func GetAmiPubliclyAccessibleE(t testing.TestingT, awsRegion string, amiID strin return false, err } for _, launchPermission := range launchPermissions { - if aws.StringValue(launchPermission.Group) == "all" { + if string(launchPermission.Group) == "all" { return true, nil } } @@ -360,30 +362,30 @@ func GetAccountsWithLaunchPermissionsForAmi(t testing.TestingT, awsRegion string // GetAccountsWithLaunchPermissionsForAmiE returns list of accounts that the AMI is shared with func GetAccountsWithLaunchPermissionsForAmiE(t testing.TestingT, awsRegion string, amiID string) ([]string, error) { - accountIDs := []string{} + var accountIDs []string launchPermissions, err := GetLaunchPermissionsForAmiE(t, awsRegion, amiID) if err != nil { return accountIDs, err } for _, launchPermission := range launchPermissions { - if aws.StringValue(launchPermission.UserId) != "" { - accountIDs = append(accountIDs, aws.StringValue(launchPermission.UserId)) + if aws.ToString(launchPermission.UserId) != "" { + accountIDs = append(accountIDs, aws.ToString(launchPermission.UserId)) } } return accountIDs, nil } // GetLaunchPermissionsForAmiE returns launchPermissions as configured in AWS -func GetLaunchPermissionsForAmiE(t testing.TestingT, awsRegion string, amiID string) ([]*ec2.LaunchPermission, error) { +func GetLaunchPermissionsForAmiE(t testing.TestingT, awsRegion string, amiID string) ([]types.LaunchPermission, error) { client := NewEc2Client(t, awsRegion) input := &ec2.DescribeImageAttributeInput{ - Attribute: aws.String("launchPermission"), + Attribute: types.ImageAttributeNameLaunchPermission, ImageId: aws.String(amiID), } - output, err := client.DescribeImageAttribute(input) + output, err := client.DescribeImageAttribute(context.Background(), input) if err != nil { - return []*ec2.LaunchPermission{}, err + return []types.LaunchPermission{}, err } return output.LaunchPermissions, nil } @@ -422,7 +424,7 @@ func GetRecommendedInstanceTypeE(t testing.TestingT, region string, instanceType // AZs. If you have code that needs to run on a "small" instance across all AZs in many different regions, you can // use this function to automatically figure out which instance type you should use. // This function expects an authenticated EC2 client from the AWS SDK Go library. -func GetRecommendedInstanceTypeWithClientE(t testing.TestingT, ec2Client *ec2.EC2, instanceTypeOptions []string) (string, error) { +func GetRecommendedInstanceTypeWithClientE(t testing.TestingT, ec2Client *ec2.Client, instanceTypeOptions []string) (string, error) { availabilityZones, err := getAllAvailabilityZonesE(ec2Client) if err != nil { return "", err @@ -439,7 +441,7 @@ func GetRecommendedInstanceTypeWithClientE(t testing.TestingT, ec2Client *ec2.EC // pickRecommendedInstanceTypeE returns the first instance type from instanceTypeOptions that is available in all the // AZs in availabilityZones based on the availability data in instanceTypeOfferings. If none of the instance types are // available in all AZs, this function returns an error. -func pickRecommendedInstanceTypeE(availabilityZones []string, instanceTypeOfferings []*ec2.InstanceTypeOffering, instanceTypeOptions []string) (string, error) { +func pickRecommendedInstanceTypeE(availabilityZones []string, instanceTypeOfferings []types.InstanceTypeOffering, instanceTypeOptions []string) (string, error) { // O(n^3) for the win! for _, instanceType := range instanceTypeOptions { if instanceTypeExistsInAllAzs(instanceType, availabilityZones, instanceTypeOfferings) { @@ -450,9 +452,9 @@ func pickRecommendedInstanceTypeE(availabilityZones []string, instanceTypeOfferi return "", NoInstanceTypeError{InstanceTypeOptions: instanceTypeOptions, Azs: availabilityZones} } -// instanceTypeExistsInAllAzs returns true if the given inistance type exists in all the given availabilityZones based +// instanceTypeExistsInAllAzs returns true if the given instance type exists in all the given availabilityZones based // on the availability data in instanceTypeOfferings -func instanceTypeExistsInAllAzs(instanceType string, availabilityZones []string, instanceTypeOfferings []*ec2.InstanceTypeOffering) bool { +func instanceTypeExistsInAllAzs(instanceType string, availabilityZones []string, instanceTypeOfferings []types.InstanceTypeOffering) bool { if len(availabilityZones) == 0 || len(instanceTypeOfferings) == 0 { return false } @@ -468,9 +470,9 @@ func instanceTypeExistsInAllAzs(instanceType string, availabilityZones []string, // hasOffering returns true if the given availability zone and instance type are one of the offerings in // instanceTypeOfferings -func hasOffering(instanceTypeOfferings []*ec2.InstanceTypeOffering, availabilityZone string, instanceType string) bool { +func hasOffering(instanceTypeOfferings []types.InstanceTypeOffering, availabilityZone string, instanceType string) bool { for _, offering := range instanceTypeOfferings { - if aws.StringValue(offering.InstanceType) == instanceType && aws.StringValue(offering.Location) == availabilityZone { + if string(offering.InstanceType) == instanceType && aws.ToString(offering.Location) == availabilityZone { return true } } @@ -480,18 +482,18 @@ func hasOffering(instanceTypeOfferings []*ec2.InstanceTypeOffering, availability // getInstanceTypeOfferingsE returns the instance types from the given list that are available in the region configured // in the given EC2 client -func getInstanceTypeOfferingsE(client *ec2.EC2, instanceTypeOptions []string) ([]*ec2.InstanceTypeOffering, error) { +func getInstanceTypeOfferingsE(client *ec2.Client, instanceTypeOptions []string) ([]types.InstanceTypeOffering, error) { input := ec2.DescribeInstanceTypeOfferingsInput{ - LocationType: aws.String(ec2.LocationTypeAvailabilityZone), - Filters: []*ec2.Filter{ + LocationType: types.LocationTypeAvailabilityZone, + Filters: []types.Filter{ { Name: aws.String("instance-type"), - Values: aws.StringSlice(instanceTypeOptions), + Values: instanceTypeOptions, }, }, } - out, err := client.DescribeInstanceTypeOfferings(&input) + out, err := client.DescribeInstanceTypeOfferings(context.Background(), &input) if err != nil { return nil, err } @@ -500,17 +502,17 @@ func getInstanceTypeOfferingsE(client *ec2.EC2, instanceTypeOptions []string) ([ } // getAllAvailabilityZonesE returns all the available AZs in the region configured in the given EC2 client -func getAllAvailabilityZonesE(client *ec2.EC2) ([]string, error) { +func getAllAvailabilityZonesE(client *ec2.Client) ([]string, error) { input := ec2.DescribeAvailabilityZonesInput{ - Filters: []*ec2.Filter{ + Filters: []types.Filter{ { Name: aws.String("state"), - Values: aws.StringSlice([]string{"available"}), + Values: []string{"available"}, }, }, } - out, err := client.DescribeAvailabilityZones(&input) + out, err := client.DescribeAvailabilityZones(context.Background(), &input) if err != nil { return nil, err } @@ -518,25 +520,25 @@ func getAllAvailabilityZonesE(client *ec2.EC2) ([]string, error) { var azs []string for _, az := range out.AvailabilityZones { - azs = append(azs, aws.StringValue(az.ZoneName)) + azs = append(azs, aws.ToString(az.ZoneName)) } return azs, nil } // NewEc2Client creates an EC2 client. -func NewEc2Client(t testing.TestingT, region string) *ec2.EC2 { +func NewEc2Client(t testing.TestingT, region string) *ec2.Client { client, err := NewEc2ClientE(t, region) require.NoError(t, err) return client } // NewEc2ClientE creates an EC2 client. -func NewEc2ClientE(t testing.TestingT, region string) (*ec2.EC2, error) { +func NewEc2ClientE(t testing.TestingT, region string) (*ec2.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ec2.New(sess), nil + return ec2.NewFromConfig(*sess), nil } diff --git a/modules/aws/ec2_test.go b/modules/aws/ec2_test.go index 236e55bbb..2d5c9e729 100644 --- a/modules/aws/ec2_test.go +++ b/modules/aws/ec2_test.go @@ -5,9 +5,8 @@ import ( "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -56,7 +55,7 @@ func TestGetRecommendedInstanceType(t *testing.T) { t.Run(fmt.Sprintf("%s-%s", testCase.region, strings.Join(testCase.instanceTypeOptions, "-")), func(t *testing.T) { t.Parallel() instanceType := GetRecommendedInstanceType(t, testCase.region, testCase.instanceTypeOptions) - // We could hard-code the expected result (e.g., as of July, 2020, we expect eu-west-1 to return t2.micro + // We could hard-code the expected result (e.g., as of July 2020, we expect eu-west-1 to return t2.micro // and ap-northeast-2 to return t3.micro), but the result will likely change over time, so to avoid a // brittle test, we simply check that we get _one_ result. Combined with the unit test below, this hopefully // is enough to be confident this function works correctly. @@ -69,7 +68,7 @@ func TestPickRecommendedInstanceTypeHappyPath(t *testing.T) { testCases := []struct { name string availabilityZones []string - instanceTypeOfferings []*ec2.InstanceTypeOffering + instanceTypeOfferings []types.InstanceTypeOffering instanceTypeOptions []string expected string }{ @@ -136,7 +135,7 @@ func TestPickRecommendedInstanceTypeErrors(t *testing.T) { testCases := []struct { name string availabilityZones []string - instanceTypeOfferings []*ec2.InstanceTypeOffering + instanceTypeOfferings []types.InstanceTypeOffering instanceTypeOptions []string }{ { @@ -184,15 +183,15 @@ func TestPickRecommendedInstanceTypeErrors(t *testing.T) { } } -func offerings(offerings map[string][]string) []*ec2.InstanceTypeOffering { - var out []*ec2.InstanceTypeOffering +func offerings(offerings map[string][]string) []types.InstanceTypeOffering { + var out []types.InstanceTypeOffering for az, instanceTypes := range offerings { for _, instanceType := range instanceTypes { - offering := &ec2.InstanceTypeOffering{ - InstanceType: aws.String(instanceType), + offering := types.InstanceTypeOffering{ + InstanceType: types.InstanceType(instanceType), Location: aws.String(az), - LocationType: aws.String(ec2.LocationTypeAvailabilityZone), + LocationType: types.LocationTypeAvailabilityZone, } out = append(out, offering) } diff --git a/modules/aws/ecr.go b/modules/aws/ecr.go index 71f03a6a2..40d4da749 100644 --- a/modules/aws/ecr.go +++ b/modules/aws/ecr.go @@ -1,10 +1,12 @@ package aws import ( + "context" goerrors "errors" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecr" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecr" + "github.com/aws/aws-sdk-go-v2/service/ecr/types" "github.com/gruntwork-io/go-commons/errors" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" @@ -12,16 +14,16 @@ import ( ) // CreateECRRepo creates a new ECR Repository. This will fail the test and stop execution if there is an error. -func CreateECRRepo(t testing.TestingT, region string, name string) *ecr.Repository { +func CreateECRRepo(t testing.TestingT, region string, name string) *types.Repository { repo, err := CreateECRRepoE(t, region, name) require.NoError(t, err) return repo } // CreateECRRepoE creates a new ECR Repository. -func CreateECRRepoE(t testing.TestingT, region string, name string) (*ecr.Repository, error) { +func CreateECRRepoE(t testing.TestingT, region string, name string) (*types.Repository, error) { client := NewECRClient(t, region) - resp, err := client.CreateRepository(&ecr.CreateRepositoryInput{RepositoryName: aws.String(name)}) + resp, err := client.CreateRepository(context.Background(), &ecr.CreateRepositoryInput{RepositoryName: aws.String(name)}) if err != nil { return nil, err } @@ -30,7 +32,7 @@ func CreateECRRepoE(t testing.TestingT, region string, name string) (*ecr.Reposi // GetECRRepo gets an ECR repository by name. This will fail the test and stop execution if there is an error. // An error occurs if a repository with the given name does not exist in the given region. -func GetECRRepo(t testing.TestingT, region string, name string) *ecr.Repository { +func GetECRRepo(t testing.TestingT, region string, name string) *types.Repository { repo, err := GetECRRepoE(t, region, name) require.NoError(t, err) return repo @@ -38,35 +40,35 @@ func GetECRRepo(t testing.TestingT, region string, name string) *ecr.Repository // GetECRRepoE gets an ECR Repository by name. // An error occurs if a repository with the given name does not exist in the given region. -func GetECRRepoE(t testing.TestingT, region string, name string) (*ecr.Repository, error) { +func GetECRRepoE(t testing.TestingT, region string, name string) (*types.Repository, error) { client := NewECRClient(t, region) - repositoryNames := []*string{aws.String(name)} - resp, err := client.DescribeRepositories(&ecr.DescribeRepositoriesInput{RepositoryNames: repositoryNames}) + repositoryNames := []string{name} + resp, err := client.DescribeRepositories(context.Background(), &ecr.DescribeRepositoriesInput{RepositoryNames: repositoryNames}) if err != nil { return nil, err } if len(resp.Repositories) != 1 { - return nil, errors.WithStackTrace(goerrors.New(("An unexpected condition occurred. Please file an issue at github.com/gruntwork-io/terratest"))) + return nil, errors.WithStackTrace(goerrors.New("an unexpected condition occurred. Please file an issue at github.com/gruntwork-io/terratest")) } - return resp.Repositories[0], nil + return &resp.Repositories[0], nil } // DeleteECRRepo will force delete the ECR repo by deleting all images prior to deleting the ECR repository. // This will fail the test and stop execution if there is an error. -func DeleteECRRepo(t testing.TestingT, region string, repo *ecr.Repository) { +func DeleteECRRepo(t testing.TestingT, region string, repo *types.Repository) { err := DeleteECRRepoE(t, region, repo) require.NoError(t, err) } // DeleteECRRepoE will force delete the ECR repo by deleting all images prior to deleting the ECR repository. -func DeleteECRRepoE(t testing.TestingT, region string, repo *ecr.Repository) error { +func DeleteECRRepoE(t testing.TestingT, region string, repo *types.Repository) error { client := NewECRClient(t, region) - resp, err := client.ListImages(&ecr.ListImagesInput{RepositoryName: repo.RepositoryName}) + resp, err := client.ListImages(context.Background(), &ecr.ListImagesInput{RepositoryName: repo.RepositoryName}) if err != nil { return err } if len(resp.ImageIds) > 0 { - _, err = client.BatchDeleteImage(&ecr.BatchDeleteImageInput{ + _, err = client.BatchDeleteImage(context.Background(), &ecr.BatchDeleteImageInput{ RepositoryName: repo.RepositoryName, ImageIds: resp.ImageIds, }) @@ -75,7 +77,7 @@ func DeleteECRRepoE(t testing.TestingT, region string, repo *ecr.Repository) err } } - _, err = client.DeleteRepository(&ecr.DeleteRepositoryInput{RepositoryName: repo.RepositoryName}) + _, err = client.DeleteRepository(context.Background(), &ecr.DeleteRepositoryInput{RepositoryName: repo.RepositoryName}) if err != nil { return err } @@ -84,33 +86,33 @@ func DeleteECRRepoE(t testing.TestingT, region string, repo *ecr.Repository) err // NewECRClient returns a client for the Elastic Container Registry. This will fail the test and // stop execution if there is an error. -func NewECRClient(t testing.TestingT, region string) *ecr.ECR { +func NewECRClient(t testing.TestingT, region string) *ecr.Client { sess, err := NewECRClientE(t, region) require.NoError(t, err) return sess } -// NewECRClient returns a client for the Elastic Container Registry. -func NewECRClientE(t testing.TestingT, region string) (*ecr.ECR, error) { +// NewECRClientE returns a client for the Elastic Container Registry. +func NewECRClientE(t testing.TestingT, region string) (*ecr.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ecr.New(sess), nil + return ecr.NewFromConfig(*sess), nil } // GetECRRepoLifecyclePolicy gets the policies for the given ECR repository. // This will fail the test and stop execution if there is an error. -func GetECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *ecr.Repository) string { +func GetECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *types.Repository) string { policy, err := GetECRRepoLifecyclePolicyE(t, region, repo) require.NoError(t, err) return policy } // GetECRRepoLifecyclePolicyE gets the policies for the given ECR repository. -func GetECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Repository) (string, error) { +func GetECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *types.Repository) (string, error) { client := NewECRClient(t, region) - resp, err := client.GetLifecyclePolicy(&ecr.GetLifecyclePolicyInput{RepositoryName: repo.RepositoryName}) + resp, err := client.GetLifecyclePolicy(context.Background(), &ecr.GetLifecyclePolicyInput{RepositoryName: repo.RepositoryName}) if err != nil { return "", err } @@ -119,13 +121,13 @@ func GetECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Rep // PutECRRepoLifecyclePolicy puts the given policy for the given ECR repository. // This will fail the test and stop execution if there is an error. -func PutECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *ecr.Repository, policy string) { +func PutECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *types.Repository, policy string) { err := PutECRRepoLifecyclePolicyE(t, region, repo, policy) require.NoError(t, err) } -// PutEcrRepoLifecyclePolicy puts the given policy for the given ECR repository. -func PutECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Repository, policy string) error { +// PutECRRepoLifecyclePolicyE puts the given policy for the given ECR repository. +func PutECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *types.Repository, policy string) error { logger.Default.Logf(t, "Applying policy for repository %s in %s", *repo.RepositoryName, region) client, err := NewECRClientE(t, region) @@ -138,6 +140,6 @@ func PutECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Rep LifecyclePolicyText: aws.String(policy), } - _, err = client.PutLifecyclePolicy(input) + _, err = client.PutLifecyclePolicy(context.Background(), input) return err } diff --git a/modules/aws/ecr_test.go b/modules/aws/ecr_test.go index 30a4c342a..bc0be1d7b 100644 --- a/modules/aws/ecr_test.go +++ b/modules/aws/ecr_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,11 +20,11 @@ func TestEcrRepo(t *testing.T) { defer DeleteECRRepo(t, region, repo1) require.NoError(t, err) - assert.Equal(t, ecrRepoName, aws.StringValue(repo1.RepositoryName)) + assert.Equal(t, ecrRepoName, aws.ToString(repo1.RepositoryName)) repo2, err := GetECRRepoE(t, region, ecrRepoName) require.NoError(t, err) - assert.Equal(t, ecrRepoName, aws.StringValue(repo2.RepositoryName)) + assert.Equal(t, ecrRepoName, aws.ToString(repo2.RepositoryName)) } func TestGetEcrRepoLifecyclePolicyError(t *testing.T) { @@ -36,7 +36,7 @@ func TestGetEcrRepoLifecyclePolicyError(t *testing.T) { defer DeleteECRRepo(t, region, repo1) require.NoError(t, err) - assert.Equal(t, ecrRepoName, aws.StringValue(repo1.RepositoryName)) + assert.Equal(t, ecrRepoName, aws.ToString(repo1.RepositoryName)) _, err = GetECRRepoLifecyclePolicyE(t, region, repo1) require.Error(t, err) diff --git a/modules/aws/ecs.go b/modules/aws/ecs.go index 81658738e..29b463c76 100644 --- a/modules/aws/ecs.go +++ b/modules/aws/ecs.go @@ -1,29 +1,31 @@ package aws import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" ) // GetEcsCluster fetches information about specified ECS cluster. -func GetEcsCluster(t testing.TestingT, region string, name string) *ecs.Cluster { +func GetEcsCluster(t testing.TestingT, region string, name string) *types.Cluster { cluster, err := GetEcsClusterE(t, region, name) require.NoError(t, err) return cluster } // GetEcsClusterE fetches information about specified ECS cluster. -func GetEcsClusterE(t testing.TestingT, region string, name string) (*ecs.Cluster, error) { - return GetEcsClusterWithIncludeE(t, region, name, []string{}) +func GetEcsClusterE(t testing.TestingT, region string, name string) (*types.Cluster, error) { + return GetEcsClusterWithIncludeE(t, region, name, []types.ClusterField{}) } // GetEcsClusterWithInclude fetches extended information about specified ECS cluster. // The `include` parameter specifies a list of `ecs.ClusterField*` constants, such as `ecs.ClusterFieldTags`. -func GetEcsClusterWithInclude(t testing.TestingT, region string, name string, include []string) *ecs.Cluster { +func GetEcsClusterWithInclude(t testing.TestingT, region string, name string, include []types.ClusterField) *types.Cluster { clusterInfo, err := GetEcsClusterWithIncludeE(t, region, name, include) require.NoError(t, err) return clusterInfo @@ -31,52 +33,53 @@ func GetEcsClusterWithInclude(t testing.TestingT, region string, name string, in // GetEcsClusterWithIncludeE fetches extended information about specified ECS cluster. // The `include` parameter specifies a list of `ecs.ClusterField*` constants, such as `ecs.ClusterFieldTags`. -func GetEcsClusterWithIncludeE(t testing.TestingT, region string, name string, include []string) (*ecs.Cluster, error) { +func GetEcsClusterWithIncludeE(t testing.TestingT, region string, name string, include []types.ClusterField) (*types.Cluster, error) { client, err := NewEcsClientE(t, region) if err != nil { return nil, err } + input := &ecs.DescribeClustersInput{ - Clusters: []*string{ - aws.String(name), + Clusters: []string{ + name, }, - Include: aws.StringSlice(include), + Include: include, } - output, err := client.DescribeClusters(input) + output, err := client.DescribeClusters(context.Background(), input) if err != nil { return nil, err } numClusters := len(output.Clusters) if numClusters != 1 { - return nil, fmt.Errorf("Expected to find 1 ECS cluster named '%s' in region '%v', but found '%d'", + return nil, fmt.Errorf("expected to find 1 ECS cluster named '%s' in region '%v', but found '%d'", name, region, numClusters) } - return output.Clusters[0], nil + return &output.Clusters[0], nil } // GetDefaultEcsClusterE fetches information about default ECS cluster. -func GetDefaultEcsClusterE(t testing.TestingT, region string) (*ecs.Cluster, error) { +func GetDefaultEcsClusterE(t testing.TestingT, region string) (*types.Cluster, error) { return GetEcsClusterE(t, region, "default") } // GetDefaultEcsCluster fetches information about default ECS cluster. -func GetDefaultEcsCluster(t testing.TestingT, region string) *ecs.Cluster { +func GetDefaultEcsCluster(t testing.TestingT, region string) *types.Cluster { return GetEcsCluster(t, region, "default") } // CreateEcsCluster creates ECS cluster in the given region under the given name. -func CreateEcsCluster(t testing.TestingT, region string, name string) *ecs.Cluster { +func CreateEcsCluster(t testing.TestingT, region string, name string) *types.Cluster { cluster, err := CreateEcsClusterE(t, region, name) require.NoError(t, err) return cluster } // CreateEcsClusterE creates ECS cluster in the given region under the given name. -func CreateEcsClusterE(t testing.TestingT, region string, name string) (*ecs.Cluster, error) { +func CreateEcsClusterE(t testing.TestingT, region string, name string) (*types.Cluster, error) { client := NewEcsClient(t, region) - cluster, err := client.CreateCluster(&ecs.CreateClusterInput{ + cluster, err := client.CreateCluster(context.Background(), &ecs.CreateClusterInput{ ClusterName: aws.String(name), }) if err != nil { @@ -85,33 +88,33 @@ func CreateEcsClusterE(t testing.TestingT, region string, name string) (*ecs.Clu return cluster.Cluster, nil } -func DeleteEcsCluster(t testing.TestingT, region string, cluster *ecs.Cluster) { +func DeleteEcsCluster(t testing.TestingT, region string, cluster *types.Cluster) { err := DeleteEcsClusterE(t, region, cluster) require.NoError(t, err) } // DeleteEcsClusterE deletes existing ECS cluster in the given region. -func DeleteEcsClusterE(t testing.TestingT, region string, cluster *ecs.Cluster) error { +func DeleteEcsClusterE(t testing.TestingT, region string, cluster *types.Cluster) error { client := NewEcsClient(t, region) - _, err := client.DeleteCluster(&ecs.DeleteClusterInput{ + _, err := client.DeleteCluster(context.Background(), &ecs.DeleteClusterInput{ Cluster: aws.String(*cluster.ClusterName), }) return err } // GetEcsService fetches information about specified ECS service. -func GetEcsService(t testing.TestingT, region string, clusterName string, serviceName string) *ecs.Service { +func GetEcsService(t testing.TestingT, region string, clusterName string, serviceName string) *types.Service { service, err := GetEcsServiceE(t, region, clusterName, serviceName) require.NoError(t, err) return service } // GetEcsServiceE fetches information about specified ECS service. -func GetEcsServiceE(t testing.TestingT, region string, clusterName string, serviceName string) (*ecs.Service, error) { - output, err := NewEcsClient(t, region).DescribeServices(&ecs.DescribeServicesInput{ +func GetEcsServiceE(t testing.TestingT, region string, clusterName string, serviceName string) (*types.Service, error) { + output, err := NewEcsClient(t, region).DescribeServices(context.Background(), &ecs.DescribeServicesInput{ Cluster: aws.String(clusterName), - Services: []*string{ - aws.String(serviceName), + Services: []string{ + serviceName, }, }) if err != nil { @@ -121,22 +124,22 @@ func GetEcsServiceE(t testing.TestingT, region string, clusterName string, servi numServices := len(output.Services) if numServices != 1 { return nil, fmt.Errorf( - "Expected to find 1 ECS service named '%s' in cluster '%s' in region '%v', but found '%d'", + "expected to find 1 ECS service named '%s' in cluster '%s' in region '%v', but found '%d'", serviceName, clusterName, region, numServices) } - return output.Services[0], nil + return &output.Services[0], nil } // GetEcsTaskDefinition fetches information about specified ECS task definition. -func GetEcsTaskDefinition(t testing.TestingT, region string, taskDefinition string) *ecs.TaskDefinition { +func GetEcsTaskDefinition(t testing.TestingT, region string, taskDefinition string) *types.TaskDefinition { task, err := GetEcsTaskDefinitionE(t, region, taskDefinition) require.NoError(t, err) return task } // GetEcsTaskDefinitionE fetches information about specified ECS task definition. -func GetEcsTaskDefinitionE(t testing.TestingT, region string, taskDefinition string) (*ecs.TaskDefinition, error) { - output, err := NewEcsClient(t, region).DescribeTaskDefinition(&ecs.DescribeTaskDefinitionInput{ +func GetEcsTaskDefinitionE(t testing.TestingT, region string, taskDefinition string) (*types.TaskDefinition, error) { + output, err := NewEcsClient(t, region).DescribeTaskDefinition(context.Background(), &ecs.DescribeTaskDefinitionInput{ TaskDefinition: aws.String(taskDefinition), }) if err != nil { @@ -146,17 +149,17 @@ func GetEcsTaskDefinitionE(t testing.TestingT, region string, taskDefinition str } // NewEcsClient creates en ECS client. -func NewEcsClient(t testing.TestingT, region string) *ecs.ECS { +func NewEcsClient(t testing.TestingT, region string) *ecs.Client { client, err := NewEcsClientE(t, region) require.NoError(t, err) return client } // NewEcsClientE creates an ECS client. -func NewEcsClientE(t testing.TestingT, region string) (*ecs.ECS, error) { +func NewEcsClientE(t testing.TestingT, region string) (*ecs.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ecs.New(sess), nil + return ecs.NewFromConfig(*sess), nil } diff --git a/modules/aws/ecs_test.go b/modules/aws/ecs_test.go index 3f7ffecb8..d9537888a 100644 --- a/modules/aws/ecs_test.go +++ b/modules/aws/ecs_test.go @@ -1,10 +1,12 @@ package aws import ( + "context" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -30,13 +32,13 @@ func TestEcsClusterWithInclude(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) clusterName := "terratest-" + random.UniqueId() - tags := []*ecs.Tag{&ecs.Tag{ + tags := []types.Tag{{ Key: aws.String("test-tag"), Value: aws.String("hello-world"), }} client := NewEcsClient(t, region) - c1, err := client.CreateCluster(&ecs.CreateClusterInput{ + c1, err := client.CreateCluster(context.Background(), &ecs.CreateClusterInput{ ClusterName: aws.String(clusterName), Tags: tags, }) @@ -44,19 +46,19 @@ func TestEcsClusterWithInclude(t *testing.T) { defer DeleteEcsCluster(t, region, c1.Cluster) - assert.Equal(t, clusterName, aws.StringValue(c1.Cluster.ClusterName)) + assert.Equal(t, clusterName, aws.ToString(c1.Cluster.ClusterName)) - c2, err := GetEcsClusterWithIncludeE(t, region, clusterName, []string{ecs.ClusterFieldTags}) + c2, err := GetEcsClusterWithIncludeE(t, region, clusterName, []types.ClusterField{types.ClusterFieldTags}) assert.NoError(t, err) - assert.Equal(t, clusterName, aws.StringValue(c2.ClusterName)) + assert.Equal(t, clusterName, aws.ToString(c2.ClusterName)) assert.Equal(t, tags, c2.Tags) assert.Empty(t, c2.Statistics) - c3, err := GetEcsClusterWithIncludeE(t, region, clusterName, []string{ecs.ClusterFieldStatistics}) + c3, err := GetEcsClusterWithIncludeE(t, region, clusterName, []types.ClusterField{types.ClusterFieldStatistics}) assert.NoError(t, err) - assert.Equal(t, clusterName, aws.StringValue(c3.ClusterName)) + assert.Equal(t, clusterName, aws.ToString(c3.ClusterName)) assert.NotEmpty(t, c3.Statistics) assert.Empty(t, c3.Tags) } diff --git a/modules/aws/iam.go b/modules/aws/iam.go index 5aaa716b3..83cf08954 100644 --- a/modules/aws/iam.go +++ b/modules/aws/iam.go @@ -1,10 +1,12 @@ package aws import ( + "context" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -25,7 +27,7 @@ func GetIamCurrentUserNameE(t testing.TestingT) (string, error) { return "", err } - resp, err := iamClient.GetUser(&iam.GetUserInput{}) + resp, err := iamClient.GetUser(context.Background(), &iam.GetUserInput{}) if err != nil { return "", err } @@ -49,7 +51,7 @@ func GetIamCurrentUserArnE(t testing.TestingT) (string, error) { return "", err } - resp, err := iamClient.GetUser(&iam.GetUserInput{}) + resp, err := iamClient.GetUser(context.Background(), &iam.GetUserInput{}) if err != nil { return "", err } @@ -58,7 +60,7 @@ func GetIamCurrentUserArnE(t testing.TestingT) (string, error) { } // CreateMfaDevice creates an MFA device using the given IAM client. -func CreateMfaDevice(t testing.TestingT, iamClient *iam.IAM, deviceName string) *iam.VirtualMFADevice { +func CreateMfaDevice(t testing.TestingT, iamClient *iam.Client, deviceName string) *types.VirtualMFADevice { mfaDevice, err := CreateMfaDeviceE(t, iamClient, deviceName) if err != nil { t.Fatal(err) @@ -67,10 +69,10 @@ func CreateMfaDevice(t testing.TestingT, iamClient *iam.IAM, deviceName string) } // CreateMfaDeviceE creates an MFA device using the given IAM client. -func CreateMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, deviceName string) (*iam.VirtualMFADevice, error) { +func CreateMfaDeviceE(t testing.TestingT, iamClient *iam.Client, deviceName string) (*types.VirtualMFADevice, error) { logger.Default.Logf(t, "Creating an MFA device called %s", deviceName) - output, err := iamClient.CreateVirtualMFADevice(&iam.CreateVirtualMFADeviceInput{ + output, err := iamClient.CreateVirtualMFADevice(context.Background(), &iam.CreateVirtualMFADeviceInput{ VirtualMFADeviceName: aws.String(deviceName), }) if err != nil { @@ -86,7 +88,7 @@ func CreateMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, deviceName string) // EnableMfaDevice enables a newly created MFA Device by supplying the first two one-time passwords, so that it can be used for future // logins by the given IAM User. -func EnableMfaDevice(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.VirtualMFADevice) { +func EnableMfaDevice(t testing.TestingT, iamClient *iam.Client, mfaDevice *types.VirtualMFADevice) { err := EnableMfaDeviceE(t, iamClient, mfaDevice) if err != nil { t.Fatal(err) @@ -95,8 +97,8 @@ func EnableMfaDevice(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.Virt // EnableMfaDeviceE enables a newly created MFA Device by supplying the first two one-time passwords, so that it can be used for future // logins by the given IAM User. -func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.VirtualMFADevice) error { - logger.Default.Logf(t, "Enabling MFA device %s", aws.StringValue(mfaDevice.SerialNumber)) +func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.Client, mfaDevice *types.VirtualMFADevice) error { + logger.Default.Logf(t, "Enabling MFA device %s", aws.ToString(mfaDevice.SerialNumber)) iamUserName, err := GetIamCurrentUserArnE(t) if err != nil { @@ -116,7 +118,7 @@ func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.Vir return err } - _, err = iamClient.EnableMFADevice(&iam.EnableMFADeviceInput{ + _, err = iamClient.EnableMFADevice(context.Background(), &iam.EnableMFADeviceInput{ AuthenticationCode1: aws.String(authCode1), AuthenticationCode2: aws.String(authCode2), SerialNumber: mfaDevice.SerialNumber, @@ -127,14 +129,14 @@ func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.Vir return err } - logger.Default.Logf(t, "Waiting for MFA Device enablement to propagate.") + logger.Log(t, "Waiting for MFA Device enablement to propagate.") time.Sleep(10 * time.Second) return nil } // NewIamClient creates a new IAM client. -func NewIamClient(t testing.TestingT, region string) *iam.IAM { +func NewIamClient(t testing.TestingT, region string) *iam.Client { client, err := NewIamClientE(t, region) if err != nil { t.Fatal(err) @@ -143,10 +145,10 @@ func NewIamClient(t testing.TestingT, region string) *iam.IAM { } // NewIamClientE creates a new IAM client. -func NewIamClientE(t testing.TestingT, region string) (*iam.IAM, error) { +func NewIamClientE(t testing.TestingT, region string) (*iam.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return iam.New(sess), nil + return iam.NewFromConfig(*sess), nil } diff --git a/modules/aws/keypair.go b/modules/aws/keypair.go index 6cbe8fc3f..741cef10f 100644 --- a/modules/aws/keypair.go +++ b/modules/aws/keypair.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/ssh" "github.com/gruntwork-io/terratest/modules/testing" @@ -57,7 +59,7 @@ func ImportEC2KeyPairE(t testing.TestingT, region string, name string, keyPair * PublicKeyMaterial: []byte(keyPair.PublicKey), } - _, err = client.ImportKeyPair(params) + _, err = client.ImportKeyPair(context.Background(), params) if err != nil { return nil, err } @@ -86,6 +88,6 @@ func DeleteEC2KeyPairE(t testing.TestingT, keyPair *Ec2Keypair) error { KeyName: aws.String(keyPair.Name), } - _, err = client.DeleteKeyPair(params) + _, err = client.DeleteKeyPair(context.Background(), params) return err } diff --git a/modules/aws/keypair_test.go b/modules/aws/keypair_test.go index d3f04f2b5..8b9b494c4 100644 --- a/modules/aws/keypair_test.go +++ b/modules/aws/keypair_test.go @@ -1,12 +1,12 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -32,10 +32,10 @@ func keyPairExists(t *testing.T, keyPair *Ec2Keypair) bool { client := NewEc2Client(t, keyPair.Region) input := ec2.DescribeKeyPairsInput{ - KeyNames: aws.StringSlice([]string{keyPair.Name}), + KeyNames: []string{keyPair.Name}, } - out, err := client.DescribeKeyPairs(&input) + out, err := client.DescribeKeyPairs(context.Background(), &input) if err != nil { if strings.Contains(err.Error(), "InvalidKeyPair.NotFound") { return false diff --git a/modules/aws/kms.go b/modules/aws/kms.go index 07cfd1fa8..d10442971 100644 --- a/modules/aws/kms.go +++ b/modules/aws/kms.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kms" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -24,7 +26,7 @@ func GetCmkArnE(t testing.TestingT, region string, cmkID string) (string, error) return "", err } - result, err := kmsClient.DescribeKey(&kms.DescribeKeyInput{ + result, err := kmsClient.DescribeKey(context.Background(), &kms.DescribeKeyInput{ KeyId: aws.String(cmkID), }) @@ -36,7 +38,7 @@ func GetCmkArnE(t testing.TestingT, region string, cmkID string) (string, error) } // NewKmsClient creates a KMS client. -func NewKmsClient(t testing.TestingT, region string) *kms.KMS { +func NewKmsClient(t testing.TestingT, region string) *kms.Client { client, err := NewKmsClientE(t, region) if err != nil { t.Fatal(err) @@ -45,11 +47,11 @@ func NewKmsClient(t testing.TestingT, region string) *kms.KMS { } // NewKmsClientE creates a KMS client. -func NewKmsClientE(t testing.TestingT, region string) (*kms.KMS, error) { +func NewKmsClientE(t testing.TestingT, region string) (*kms.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return kms.New(sess), nil + return kms.NewFromConfig(*sess), nil } diff --git a/modules/aws/lambda.go b/modules/aws/lambda.go index 5630613ca..3ffbec3d4 100644 --- a/modules/aws/lambda.go +++ b/modules/aws/lambda.go @@ -1,11 +1,13 @@ package aws import ( + "context" "encoding/json" "errors" "fmt" - "github.com/aws/aws-sdk-go/service/lambda" + "github.com/aws/aws-sdk-go-v2/service/lambda" + "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" ) @@ -59,7 +61,7 @@ type LambdaOutput struct { // The HTTP status code for a successful request is in the 200 range. // For RequestResponse invocation type, the status code is 200. // For the DryRun invocation type, the status code is 204. - StatusCode *int64 + StatusCode int32 } // InvokeFunction invokes a lambda function. @@ -89,14 +91,14 @@ func InvokeFunctionE(t testing.TestingT, region, functionName string, payload in invokeInput.Payload = payloadJson } - out, err := lambdaClient.Invoke(invokeInput) + out, err := lambdaClient.Invoke(context.Background(), invokeInput) require.NoError(t, err) if err != nil { return nil, err } if out.FunctionError != nil { - return out.Payload, &FunctionError{Message: *out.FunctionError, StatusCode: *out.StatusCode, Payload: out.Payload} + return out.Payload, &FunctionError{Message: *out.FunctionError, StatusCode: out.StatusCode, Payload: out.Payload} } return out.Payload, nil @@ -123,7 +125,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, } // Verify the InvocationType is one of the allowed values and report - // an error if it's not. By default the InvocationType will be + // an error if it's not. By default, the InvocationType will be // "RequestResponse". invocationType, err := input.InvocationType.Value() if err != nil { @@ -132,7 +134,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, invokeInput := &lambda.InvokeInput{ FunctionName: &functionName, - InvocationType: &invocationType, + InvocationType: types.InvocationType(invocationType), } if input.Payload != nil { @@ -143,7 +145,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, invokeInput.Payload = payloadJson } - out, err := lambdaClient.Invoke(invokeInput) + out, err := lambdaClient.Invoke(context.Background(), invokeInput) if err != nil { return nil, err } @@ -165,7 +167,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, type FunctionError struct { Message string - StatusCode int64 + StatusCode int32 Payload []byte } @@ -174,18 +176,18 @@ func (err *FunctionError) Error() string { } // NewLambdaClient creates a new Lambda client. -func NewLambdaClient(t testing.TestingT, region string) *lambda.Lambda { +func NewLambdaClient(t testing.TestingT, region string) *lambda.Client { client, err := NewLambdaClientE(t, region) require.NoError(t, err) return client } // NewLambdaClientE creates a new Lambda client. -func NewLambdaClientE(t testing.TestingT, region string) (*lambda.Lambda, error) { +func NewLambdaClientE(t testing.TestingT, region string) (*lambda.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return lambda.New(sess), nil + return lambda.NewFromConfig(*sess), nil } diff --git a/modules/aws/rds.go b/modules/aws/rds.go index ab0967de4..d7a8b1e9b 100644 --- a/modules/aws/rds.go +++ b/modules/aws/rds.go @@ -1,11 +1,13 @@ package aws import ( + "context" "database/sql" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + "github.com/aws/aws-sdk-go-v2/service/rds/types" _ "github.com/go-sql-driver/mysql" "github.com/gruntwork-io/terratest/modules/testing" _ "github.com/jackc/pgx/v5/stdlib" @@ -28,11 +30,11 @@ func GetAddressOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion return "", err } - return aws.StringValue(dbInstance.Endpoint.Address), nil + return aws.ToString(dbInstance.Endpoint.Address), nil } // GetPortOfRdsInstance gets the address of the given RDS Instance in the given region. -func GetPortOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) int64 { +func GetPortOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) int32 { port, err := GetPortOfRdsInstanceE(t, dbInstanceID, awsRegion) if err != nil { t.Fatal(err) @@ -41,7 +43,7 @@ func GetPortOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion str } // GetPortOfRdsInstanceE gets the address of the given RDS Instance in the given region. -func GetPortOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) (int64, error) { +func GetPortOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) (int32, error) { dbInstance, err := GetRdsInstanceDetailsE(t, dbInstanceID, awsRegion) if err != nil { return -1, err @@ -51,7 +53,7 @@ func GetPortOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion st } // GetWhetherSchemaExistsInRdsMySqlInstance checks whether the specified schema/table name exists in the RDS instance -func GetWhetherSchemaExistsInRdsMySqlInstance(t testing.TestingT, dbUrl string, dbPort int64, dbUsername string, dbPassword string, expectedSchemaName string) bool { +func GetWhetherSchemaExistsInRdsMySqlInstance(t testing.TestingT, dbUrl string, dbPort int32, dbUsername string, dbPassword string, expectedSchemaName string) bool { output, err := GetWhetherSchemaExistsInRdsMySqlInstanceE(t, dbUrl, dbPort, dbUsername, dbPassword, expectedSchemaName) if err != nil { t.Fatal(err) @@ -60,7 +62,7 @@ func GetWhetherSchemaExistsInRdsMySqlInstance(t testing.TestingT, dbUrl string, } // GetWhetherSchemaExistsInRdsMySqlInstanceE checks whether the specified schema/table name exists in the RDS instance -func GetWhetherSchemaExistsInRdsMySqlInstanceE(t testing.TestingT, dbUrl string, dbPort int64, dbUsername string, dbPassword string, expectedSchemaName string) (bool, error) { +func GetWhetherSchemaExistsInRdsMySqlInstanceE(t testing.TestingT, dbUrl string, dbPort int32, dbUsername string, dbPassword string, expectedSchemaName string) (bool, error) { connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/", dbUsername, dbPassword, dbUrl, dbPort) db, connErr := sql.Open("mysql", connectionString) if connErr != nil { @@ -80,7 +82,7 @@ func GetWhetherSchemaExistsInRdsMySqlInstanceE(t testing.TestingT, dbUrl string, } // GetWhetherSchemaExistsInRdsPostgresInstance checks whether the specified schema/table name exists in the RDS instance -func GetWhetherSchemaExistsInRdsPostgresInstance(t testing.TestingT, dbUrl string, dbPort int64, dbUsername string, dbPassword string, expectedSchemaName string) bool { +func GetWhetherSchemaExistsInRdsPostgresInstance(t testing.TestingT, dbUrl string, dbPort int32, dbUsername string, dbPassword string, expectedSchemaName string) bool { output, err := GetWhetherSchemaExistsInRdsPostgresInstanceE(t, dbUrl, dbPort, dbUsername, dbPassword, expectedSchemaName) if err != nil { t.Fatal(err) @@ -89,7 +91,7 @@ func GetWhetherSchemaExistsInRdsPostgresInstance(t testing.TestingT, dbUrl strin } // GetWhetherSchemaExistsInRdsPostgresInstanceE checks whether the specified schema/table name exists in the RDS instance -func GetWhetherSchemaExistsInRdsPostgresInstanceE(t testing.TestingT, dbUrl string, dbPort int64, dbUsername string, dbPassword string, expectedSchemaName string) (bool, error) { +func GetWhetherSchemaExistsInRdsPostgresInstanceE(t testing.TestingT, dbUrl string, dbPort int32, dbUsername string, dbPassword string, expectedSchemaName string) (bool, error) { connectionString := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", dbUrl, dbPort, dbUsername, dbPassword, expectedSchemaName) db, connErr := sql.Open("pgx", connectionString) @@ -122,8 +124,8 @@ func GetParameterValueForParameterOfRdsInstance(t testing.TestingT, parameterNam func GetParameterValueForParameterOfRdsInstanceE(t testing.TestingT, parameterName string, dbInstanceID string, awsRegion string) (string, error) { output := GetAllParametersOfRdsInstance(t, dbInstanceID, awsRegion) for _, parameter := range output { - if aws.StringValue(parameter.ParameterName) == parameterName { - return aws.StringValue(parameter.ParameterValue), nil + if aws.ToString(parameter.ParameterName) == parameterName { + return aws.ToString(parameter.ParameterValue), nil } } return "", ParameterForDbInstanceNotFound{ParameterName: parameterName, DbInstanceID: dbInstanceID, AwsRegion: awsRegion} @@ -143,10 +145,10 @@ func GetOptionSettingForOfRdsInstanceE(t testing.TestingT, optionName string, op optionGroupName := GetOptionGroupNameOfRdsInstance(t, dbInstanceID, awsRegion) options := GetOptionsOfOptionGroup(t, optionGroupName, awsRegion) for _, option := range options { - if aws.StringValue(option.OptionName) == optionName { + if aws.ToString(option.OptionName) == optionName { for _, optionSetting := range option.OptionSettings { - if aws.StringValue(optionSetting.Name) == optionSettingName { - return aws.StringValue(optionSetting.Value), nil + if aws.ToString(optionSetting.Name) == optionSettingName { + return aws.ToString(optionSetting.Value), nil } } } @@ -169,11 +171,11 @@ func GetOptionGroupNameOfRdsInstanceE(t testing.TestingT, dbInstanceID string, a if err != nil { return "", err } - return aws.StringValue(dbInstance.OptionGroupMemberships[0].OptionGroupName), nil + return aws.ToString(dbInstance.OptionGroupMemberships[0].OptionGroupName), nil } // GetOptionsOfOptionGroup gets the options of the option group specified -func GetOptionsOfOptionGroup(t testing.TestingT, optionGroupName string, awsRegion string) []*rds.Option { +func GetOptionsOfOptionGroup(t testing.TestingT, optionGroupName string, awsRegion string) []types.Option { output, err := GetOptionsOfOptionGroupE(t, optionGroupName, awsRegion) if err != nil { t.Fatal(err) @@ -182,18 +184,18 @@ func GetOptionsOfOptionGroup(t testing.TestingT, optionGroupName string, awsRegi } // GetOptionsOfOptionGroupE gets the options of the option group specified -func GetOptionsOfOptionGroupE(t testing.TestingT, optionGroupName string, awsRegion string) ([]*rds.Option, error) { +func GetOptionsOfOptionGroupE(t testing.TestingT, optionGroupName string, awsRegion string) ([]types.Option, error) { rdsClient := NewRdsClient(t, awsRegion) input := rds.DescribeOptionGroupsInput{OptionGroupName: aws.String(optionGroupName)} - output, err := rdsClient.DescribeOptionGroups(&input) + output, err := rdsClient.DescribeOptionGroups(context.Background(), &input) if err != nil { - return []*rds.Option{}, err + return []types.Option{}, err } return output.OptionGroupsList[0].Options, nil } // GetAllParametersOfRdsInstance gets all the parameters defined in the parameter group for the RDS instance in the given region. -func GetAllParametersOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) []*rds.Parameter { +func GetAllParametersOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) []types.Parameter { parameters, err := GetAllParametersOfRdsInstanceE(t, dbInstanceID, awsRegion) if err != nil { t.Fatal(err) @@ -202,36 +204,36 @@ func GetAllParametersOfRdsInstance(t testing.TestingT, dbInstanceID string, awsR } // GetAllParametersOfRdsInstanceE gets all the parameters defined in the parameter group for the RDS instance in the given region. -func GetAllParametersOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) ([]*rds.Parameter, error) { +func GetAllParametersOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) ([]types.Parameter, error) { dbInstance, dbInstanceErr := GetRdsInstanceDetailsE(t, dbInstanceID, awsRegion) if dbInstanceErr != nil { - return []*rds.Parameter{}, dbInstanceErr + return []types.Parameter{}, dbInstanceErr } - parameterGroupName := aws.StringValue(dbInstance.DBParameterGroups[0].DBParameterGroupName) + parameterGroupName := aws.ToString(dbInstance.DBParameterGroups[0].DBParameterGroupName) rdsClient := NewRdsClient(t, awsRegion) input := rds.DescribeDBParametersInput{DBParameterGroupName: aws.String(parameterGroupName)} - output, err := rdsClient.DescribeDBParameters(&input) + output, err := rdsClient.DescribeDBParameters(context.Background(), &input) if err != nil { - return []*rds.Parameter{}, err + return []types.Parameter{}, err } return output.Parameters, nil } // GetRdsInstanceDetailsE gets the details of a single DB instance whose identifier is passed. -func GetRdsInstanceDetailsE(t testing.TestingT, dbInstanceID string, awsRegion string) (*rds.DBInstance, error) { +func GetRdsInstanceDetailsE(t testing.TestingT, dbInstanceID string, awsRegion string) (*types.DBInstance, error) { rdsClient := NewRdsClient(t, awsRegion) input := rds.DescribeDBInstancesInput{DBInstanceIdentifier: aws.String(dbInstanceID)} - output, err := rdsClient.DescribeDBInstances(&input) + output, err := rdsClient.DescribeDBInstances(context.Background(), &input) if err != nil { return nil, err } - return output.DBInstances[0], nil + return &output.DBInstances[0], nil } // NewRdsClient creates an RDS client. -func NewRdsClient(t testing.TestingT, region string) *rds.RDS { +func NewRdsClient(t testing.TestingT, region string) *rds.Client { client, err := NewRdsClientE(t, region) if err != nil { t.Fatal(err) @@ -240,18 +242,18 @@ func NewRdsClient(t testing.TestingT, region string) *rds.RDS { } // NewRdsClientE creates an RDS client. -func NewRdsClientE(t testing.TestingT, region string) (*rds.RDS, error) { +func NewRdsClientE(t testing.TestingT, region string) (*rds.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return rds.New(sess), nil + return rds.NewFromConfig(*sess), nil } // GetRecommendedRdsInstanceType takes in a list of RDS instance types (e.g., "db.t2.micro", "db.t3.micro") and returns the // first instance type in the list that is available in the given region and for the given database engine type. -// If none of the instances provided are avaiable for your combination of region and database engine, this function will exit with an error. +// If none of the instances provided are available for your combination of region and database engine, this function will exit with an error. func GetRecommendedRdsInstanceType(t testing.TestingT, region string, engine string, engineVersion string, instanceTypeOptions []string) string { out, err := GetRecommendedRdsInstanceTypeE(t, region, engine, engineVersion, instanceTypeOptions) require.NoError(t, err) @@ -260,7 +262,7 @@ func GetRecommendedRdsInstanceType(t testing.TestingT, region string, engine str // GetRecommendedRdsInstanceTypeE takes in a list of RDS instance types (e.g., "db.t2.micro", "db.t3.micro") and returns the // first instance type in the list that is available in the given region and for the given database engine type. -// If none of the instances provided are avaiable for your combination of region and database engine, this function will return an error. +// If none of the instances provided are available for your combination of region and database engine, this function will return an error. func GetRecommendedRdsInstanceTypeE(t testing.TestingT, region string, engine string, engineVersion string, instanceTypeOptions []string) (string, error) { client, err := NewRdsClientE(t, region) if err != nil { @@ -271,9 +273,9 @@ func GetRecommendedRdsInstanceTypeE(t testing.TestingT, region string, engine st // GetRecommendedRdsInstanceTypeWithClientE takes in a list of RDS instance types (e.g., "db.t2.micro", "db.t3.micro") and returns the // first instance type in the list that is available in the given region and for the given database engine type. -// If none of the instances provided are avaiable for your combination of region and database engine, this function will return an error. +// If none of the instances provided are available for your combination of region and database engine, this function will return an error. // This function expects an authenticated RDS client from the AWS SDK Go library. -func GetRecommendedRdsInstanceTypeWithClientE(t testing.TestingT, rdsClient *rds.RDS, engine string, engineVersion string, instanceTypeOptions []string) (string, error) { +func GetRecommendedRdsInstanceTypeWithClientE(t testing.TestingT, rdsClient *rds.Client, engine string, engineVersion string, instanceTypeOptions []string) (string, error) { for _, instanceTypeOption := range instanceTypeOptions { instanceTypeExists, err := instanceTypeExistsForEngineAndRegionE(rdsClient, engine, engineVersion, instanceTypeOption) if err != nil { @@ -289,14 +291,14 @@ func GetRecommendedRdsInstanceTypeWithClientE(t testing.TestingT, rdsClient *rds // instanceTypeExistsForEngineAndRegionE returns a boolean that represents whether the provided instance type (e.g. db.t2.micro) exists for the given region and db engine type // This function will return an error if the RDS AWS SDK call fails. -func instanceTypeExistsForEngineAndRegionE(client *rds.RDS, engine string, engineVersion string, instanceType string) (bool, error) { +func instanceTypeExistsForEngineAndRegionE(client *rds.Client, engine string, engineVersion string, instanceType string) (bool, error) { input := rds.DescribeOrderableDBInstanceOptionsInput{ Engine: aws.String(engine), EngineVersion: aws.String(engineVersion), DBInstanceClass: aws.String(instanceType), } - out, err := client.DescribeOrderableDBInstanceOptions(&input) + out, err := client.DescribeOrderableDBInstanceOptions(context.Background(), &input) if err != nil { return false, err } @@ -326,7 +328,7 @@ func GetValidEngineVersionE(t testing.TestingT, region string, engine string, ma Engine: aws.String(engine), EngineVersion: aws.String(majorVersion), } - out, err := client.DescribeDBEngineVersions(&input) + out, err := client.DescribeDBEngineVersions(context.Background(), &input) if err != nil || len(out.DBEngineVersions) == 0 { return "", err } diff --git a/modules/aws/region.go b/modules/aws/region.go index 8a02c9a67..1cfe8e73f 100644 --- a/modules/aws/region.go +++ b/modules/aws/region.go @@ -1,12 +1,13 @@ package aws import ( + "context" "fmt" "os" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/gruntwork-io/terratest/modules/collections" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/random" @@ -113,14 +114,14 @@ func GetAllAwsRegionsE(t testing.TestingT) ([]string, error) { return nil, err } - out, err := ec2Client.DescribeRegions(&ec2.DescribeRegionsInput{}) + out, err := ec2Client.DescribeRegions(context.Background(), &ec2.DescribeRegionsInput{}) if err != nil { return nil, err } - regions := []string{} + var regions []string for _, region := range out.Regions { - regions = append(regions, aws.StringValue(region.RegionName)) + regions = append(regions, aws.ToString(region.RegionName)) } return regions, nil @@ -146,14 +147,14 @@ func GetAvailabilityZonesE(t testing.TestingT, region string) ([]string, error) return nil, err } - resp, err := ec2Client.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{}) + resp, err := ec2Client.DescribeAvailabilityZones(context.Background(), &ec2.DescribeAvailabilityZonesInput{}) if err != nil { return nil, err } var out []string for _, availabilityZone := range resp.AvailabilityZones { - out = append(out, aws.StringValue(availabilityZone.ZoneName)) + out = append(out, aws.ToString(availabilityZone.ZoneName)) } return out, nil @@ -168,7 +169,7 @@ func GetRegionsForService(t testing.TestingT, serviceName string) []string { return out } -// GetRegionsForService gets all AWS regions in which a service is available and returns errors. +// GetRegionsForServiceE gets all AWS regions in which a service is available and returns errors. // See https://docs.aws.amazon.com/systems-manager/latest/userguide/parameter-store-public-parameters-global-infrastructure.html func GetRegionsForServiceE(t testing.TestingT, serviceName string) ([]string, error) { // These values are available in any region, defaulting to us-east-1 since it's the oldest @@ -179,12 +180,11 @@ func GetRegionsForServiceE(t testing.TestingT, serviceName string) ([]string, er } paramPath := "/aws/service/global-infrastructure/services/%s/regions" - req, resp := ssmClient.GetParametersByPathRequest(&ssm.GetParametersByPathInput{ + resp, err := ssmClient.GetParametersByPath(context.Background(), &ssm.GetParametersByPathInput{ Path: aws.String(fmt.Sprintf(paramPath, serviceName)), }) - ssmErr := req.Send() - if ssmErr != nil { + if err != nil { return nil, err } diff --git a/modules/aws/route53.go b/modules/aws/route53.go index abb7a56ba..30ddd0166 100644 --- a/modules/aws/route53.go +++ b/modules/aws/route53.go @@ -1,17 +1,19 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/service/route53" - "github.com/gogo/protobuf/proto" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/route53" + "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/stretchr/testify/require" ) // GetRoute53Record returns a Route 53 Record -func GetRoute53Record(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) *route53.ResourceRecordSet { +func GetRoute53Record(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) *types.ResourceRecordSet { r, err := GetRoute53RecordE(t, hostedZoneID, recordName, recordType, awsRegion) require.NoError(t, err) @@ -19,35 +21,33 @@ func GetRoute53Record(t *testing.T, hostedZoneID, recordName, recordType, awsReg } // GetRoute53RecordE returns a Route 53 Record -func GetRoute53RecordE(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) (record *route53.ResourceRecordSet, err error) { +func GetRoute53RecordE(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) (*types.ResourceRecordSet, error) { route53Client, err := NewRoute53ClientE(t, awsRegion) if err != nil { return nil, err } - o, err := route53Client.ListResourceRecordSets(&route53.ListResourceRecordSetsInput{ + o, err := route53Client.ListResourceRecordSets(context.Background(), &route53.ListResourceRecordSetsInput{ HostedZoneId: &hostedZoneID, StartRecordName: &recordName, - StartRecordType: &recordType, - MaxItems: proto.String("1"), + StartRecordType: types.RRType(recordType), + MaxItems: aws.Int32(1), }) if err != nil { - return + return nil, err } - for _, record = range o.ResourceRecordSets { + + for _, record := range o.ResourceRecordSets { if strings.EqualFold(recordName+".", *record.Name) { - break + return &record, nil } - record = nil - } - if record == nil { - err = fmt.Errorf("record not found") } - return + + return nil, fmt.Errorf("record not found") } -// NewRoute53ClientE creates a route 53 client. -func NewRoute53Client(t *testing.T, region string) *route53.Route53 { +// NewRoute53Client creates a route 53 client. +func NewRoute53Client(t *testing.T, region string) *route53.Client { c, err := NewRoute53ClientE(t, region) require.NoError(t, err) @@ -55,11 +55,11 @@ func NewRoute53Client(t *testing.T, region string) *route53.Route53 { } // NewRoute53ClientE creates a route 53 client. -func NewRoute53ClientE(t *testing.T, region string) (*route53.Route53, error) { +func NewRoute53ClientE(t *testing.T, region string) (*route53.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return route53.New(sess), nil + return route53.NewFromConfig(*sess), nil } diff --git a/modules/aws/route53_test.go b/modules/aws/route53_test.go index 3f98d3183..5c048c1f4 100644 --- a/modules/aws/route53_test.go +++ b/modules/aws/route53_test.go @@ -1,13 +1,14 @@ package aws import ( + "context" "fmt" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/route53" - "github.com/gogo/protobuf/proto" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/route53" + "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,35 +20,35 @@ func TestRoute53Record(t *testing.T) { require.NoError(t, err) domain := fmt.Sprintf("terratest%dexample.com", time.Now().UnixNano()) - hostedZone, err := c.CreateHostedZone(&route53.CreateHostedZoneInput{ + hostedZone, err := c.CreateHostedZone(context.Background(), &route53.CreateHostedZoneInput{ Name: aws.String(domain), CallerReference: aws.String(fmt.Sprint(time.Now().UnixNano())), }) require.NoError(t, err) t.Cleanup(func() { - _, err := c.DeleteHostedZone(&route53.DeleteHostedZoneInput{ + _, err := c.DeleteHostedZone(context.Background(), &route53.DeleteHostedZoneInput{ Id: hostedZone.HostedZone.Id, }) require.NoError(t, err) }) recordName := fmt.Sprintf("record.%s", domain) - resourceRecordSet := &route53.ResourceRecordSet{ + resourceRecordSet := &types.ResourceRecordSet{ Name: &recordName, - Type: aws.String("A"), + Type: types.RRTypeA, TTL: aws.Int64(60), - ResourceRecords: []*route53.ResourceRecord{ + ResourceRecords: []types.ResourceRecord{ { Value: aws.String("127.0.0.1"), }, }, } - _, err = c.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{ + _, err = c.ChangeResourceRecordSets(context.Background(), &route53.ChangeResourceRecordSetsInput{ HostedZoneId: hostedZone.HostedZone.Id, - ChangeBatch: &route53.ChangeBatch{ - Changes: []*route53.Change{ + ChangeBatch: &types.ChangeBatch{ + Changes: []types.Change{ { - Action: proto.String("CREATE"), + Action: types.ChangeActionCreate, ResourceRecordSet: resourceRecordSet, }, }, @@ -55,12 +56,12 @@ func TestRoute53Record(t *testing.T) { }) require.NoError(t, err) t.Cleanup(func() { - _, err := c.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{ + _, err := c.ChangeResourceRecordSets(context.Background(), &route53.ChangeResourceRecordSetsInput{ HostedZoneId: hostedZone.HostedZone.Id, - ChangeBatch: &route53.ChangeBatch{ - Changes: []*route53.Change{ + ChangeBatch: &types.ChangeBatch{ + Changes: []types.Change{ { - Action: proto.String("DELETE"), + Action: types.ChangeActionDelete, ResourceRecordSet: resourceRecordSet, }, }, @@ -70,10 +71,10 @@ func TestRoute53Record(t *testing.T) { }) t.Run("ExistingRecord", func(t *testing.T) { - route53Record := GetRoute53Record(t, *hostedZone.HostedZone.Id, recordName, *resourceRecordSet.Type, region) + route53Record := GetRoute53Record(t, *hostedZone.HostedZone.Id, recordName, string(resourceRecordSet.Type), region) require.NotNil(t, route53Record) assert.Equal(t, recordName+".", *route53Record.Name) - assert.Equal(t, *resourceRecordSet.Type, *route53Record.Type) + assert.Equal(t, resourceRecordSet.Type, route53Record.Type) assert.Equal(t, "127.0.0.1", *route53Record.ResourceRecords[0].Value) }) diff --git a/modules/aws/s3.go b/modules/aws/s3.go index 17a464de0..518d141fe 100644 --- a/modules/aws/s3.go +++ b/modules/aws/s3.go @@ -2,12 +2,14 @@ package aws import ( "bytes" + "context" "fmt" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -28,13 +30,13 @@ func FindS3BucketWithTagE(t testing.TestingT, awsRegion string, key string, valu return "", err } - resp, err := s3Client.ListBuckets(&s3.ListBucketsInput{}) + resp, err := s3Client.ListBuckets(context.Background(), &s3.ListBucketsInput{}) if err != nil { return "", err } for _, bucket := range resp.Buckets { - tagResponse, err := s3Client.GetBucketTagging(&s3.GetBucketTaggingInput{Bucket: bucket.Name}) + tagResponse, err := s3Client.GetBucketTagging(context.Background(), &s3.GetBucketTaggingInput{Bucket: bucket.Name}) if err != nil { if strings.Contains(err.Error(), "NoSuchBucket") { @@ -77,7 +79,7 @@ func GetS3BucketTagsE(t testing.TestingT, awsRegion string, bucket string) (map[ return nil, err } - out, err := s3Client.GetBucketTagging(&s3.GetBucketTaggingInput{ + out, err := s3Client.GetBucketTagging(context.Background(), &s3.GetBucketTaggingInput{ Bucket: &bucket, }) if err != nil { @@ -86,7 +88,7 @@ func GetS3BucketTagsE(t testing.TestingT, awsRegion string, bucket string) (map[ tags := map[string]string{} for _, tag := range out.TagSet { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -107,7 +109,7 @@ func GetS3ObjectContentsE(t testing.TestingT, awsRegion string, bucket string, k return "", err } - res, err := s3Client.GetObject(&s3.GetObjectInput{ + res, err := s3Client.GetObject(context.Background(), &s3.GetObjectInput{ Bucket: &bucket, Key: &key, }) @@ -144,21 +146,27 @@ func CreateS3BucketE(t testing.TestingT, region string, name string) error { } params := &s3.CreateBucketInput{ - Bucket: aws.String(name), - // https://github.com/aws/aws-sdk-go/blob/v1.44.122/service/s3/api.go#L41646 - ObjectOwnership: aws.String(s3.ObjectOwnershipObjectWriter), + Bucket: aws.String(name), + ObjectOwnership: types.ObjectOwnershipObjectWriter, } - _, err = s3Client.CreateBucket(params) + + if region != "us-east-1" { + params.CreateBucketConfiguration = &types.CreateBucketConfiguration{ + LocationConstraint: types.BucketLocationConstraint(region), + } + } + + _, err = s3Client.CreateBucket(context.Background(), params) return err } -// PutS3BucketPolicy applies an IAM resource policy to a given S3 bucket to create it's bucket policy +// PutS3BucketPolicy applies an IAM resource policy to a given S3 bucket to create its bucket policy func PutS3BucketPolicy(t testing.TestingT, region string, bucketName string, policyJSONString string) { err := PutS3BucketPolicyE(t, region, bucketName, policyJSONString) require.NoError(t, err) } -// PutS3BucketPolicyE applies an IAM resource policy to a given S3 bucket to create it's bucket policy +// PutS3BucketPolicyE applies an IAM resource policy to a given S3 bucket to create its bucket policy func PutS3BucketPolicyE(t testing.TestingT, region string, bucketName string, policyJSONString string) error { logger.Default.Logf(t, "Applying bucket policy for bucket %s in %s", bucketName, region) @@ -172,7 +180,7 @@ func PutS3BucketPolicyE(t testing.TestingT, region string, bucketName string, po Policy: aws.String(policyJSONString), } - _, err = s3Client.PutBucketPolicy(input) + _, err = s3Client.PutBucketPolicy(context.Background(), input) return err } @@ -193,13 +201,13 @@ func PutS3BucketVersioningE(t testing.TestingT, region string, bucketName string input := &s3.PutBucketVersioningInput{ Bucket: aws.String(bucketName), - VersioningConfiguration: &s3.VersioningConfiguration{ - MFADelete: aws.String("Disabled"), - Status: aws.String("Enabled"), + VersioningConfiguration: &types.VersioningConfiguration{ + MFADelete: types.MFADeleteDisabled, + Status: types.BucketVersioningStatusEnabled, }, } - _, err = s3Client.PutBucketVersioning(input) + _, err = s3Client.PutBucketVersioning(context.Background(), input) return err } @@ -221,7 +229,7 @@ func DeleteS3BucketE(t testing.TestingT, region string, name string) error { params := &s3.DeleteBucketInput{ Bucket: aws.String(name), } - _, err = s3Client.DeleteBucket(params) + _, err = s3Client.DeleteBucket(context.Background(), params) return err } @@ -246,53 +254,53 @@ func EmptyS3BucketE(t testing.TestingT, region string, name string) error { for { // Requesting a batch of objects from s3 bucket - bucketObjects, err := s3Client.ListObjectVersions(params) + bucketObjects, err := s3Client.ListObjectVersions(context.Background(), params) if err != nil { return err } - //Checks if the bucket is already empty + // Checks if the bucket is already empty if len((*bucketObjects).Versions) == 0 { logger.Default.Logf(t, "Bucket %s is already empty", name) return nil } - //creating an array of pointers of ObjectIdentifier - objectsToDelete := make([]*s3.ObjectIdentifier, 0, 1000) + // creating an array of pointers of ObjectIdentifier + objectsToDelete := make([]types.ObjectIdentifier, 0, 1000) for _, object := range (*bucketObjects).Versions { - obj := s3.ObjectIdentifier{ + obj := types.ObjectIdentifier{ Key: object.Key, VersionId: object.VersionId, } - objectsToDelete = append(objectsToDelete, &obj) + objectsToDelete = append(objectsToDelete, obj) } for _, object := range (*bucketObjects).DeleteMarkers { - obj := s3.ObjectIdentifier{ + obj := types.ObjectIdentifier{ Key: object.Key, VersionId: object.VersionId, } - objectsToDelete = append(objectsToDelete, &obj) + objectsToDelete = append(objectsToDelete, obj) } - //Creating JSON payload for bulk delete - deleteArray := s3.Delete{Objects: objectsToDelete} + // Creating JSON payload for bulk delete + deleteArray := types.Delete{Objects: objectsToDelete} deleteParams := &s3.DeleteObjectsInput{ Bucket: aws.String(name), Delete: &deleteArray, } - //Running the Bulk delete job (limit 1000) - _, err = s3Client.DeleteObjects(deleteParams) + // Running the Bulk delete job (limit 1000) + _, err = s3Client.DeleteObjects(context.Background(), deleteParams) if err != nil { return err } - if *(*bucketObjects).IsTruncated { //if there are more objects in the bucket, IsTruncated = true + if *(*bucketObjects).IsTruncated { // if there are more objects in the bucket, IsTruncated = true // params.Marker = (*deleteParams).Delete.Objects[len((*deleteParams).Delete.Objects)-1].Key params.KeyMarker = bucketObjects.NextKeyMarker logger.Default.Logf(t, "Requesting next batch | %s", *(params.KeyMarker)) - } else { //if all objects in the bucket have been cleaned up. + } else { // if all objects in the bucket have been cleaned up. break } } @@ -316,7 +324,7 @@ func GetS3BucketLoggingTargetE(t testing.TestingT, awsRegion string, bucket stri return "", err } - res, err := s3Client.GetBucketLogging(&s3.GetBucketLoggingInput{ + res, err := s3Client.GetBucketLogging(context.Background(), &s3.GetBucketLoggingInput{ Bucket: &bucket, }) @@ -328,7 +336,7 @@ func GetS3BucketLoggingTargetE(t testing.TestingT, awsRegion string, bucket stri return "", S3AccessLoggingNotEnabledErr{bucket, awsRegion} } - return aws.StringValue(res.LoggingEnabled.TargetBucket), nil + return aws.ToString(res.LoggingEnabled.TargetBucket), nil } // GetS3BucketLoggingTargetPrefix fetches the given bucket's logging object prefix and returns it as a string @@ -347,7 +355,7 @@ func GetS3BucketLoggingTargetPrefixE(t testing.TestingT, awsRegion string, bucke return "", err } - res, err := s3Client.GetBucketLogging(&s3.GetBucketLoggingInput{ + res, err := s3Client.GetBucketLogging(context.Background(), &s3.GetBucketLoggingInput{ Bucket: &bucket, }) @@ -359,7 +367,7 @@ func GetS3BucketLoggingTargetPrefixE(t testing.TestingT, awsRegion string, bucke return "", S3AccessLoggingNotEnabledErr{bucket, awsRegion} } - return aws.StringValue(res.LoggingEnabled.TargetPrefix), nil + return aws.ToString(res.LoggingEnabled.TargetPrefix), nil } // GetS3BucketVersioning fetches the given bucket's versioning configuration status and returns it as a string @@ -377,14 +385,14 @@ func GetS3BucketVersioningE(t testing.TestingT, awsRegion string, bucket string) return "", err } - res, err := s3Client.GetBucketVersioning(&s3.GetBucketVersioningInput{ + res, err := s3Client.GetBucketVersioning(context.Background(), &s3.GetBucketVersioningInput{ Bucket: &bucket, }) if err != nil { return "", err } - return aws.StringValue(res.Status), nil + return string(res.Status), nil } // GetS3BucketPolicy fetches the given bucket's resource policy and returns it as a string @@ -402,14 +410,14 @@ func GetS3BucketPolicyE(t testing.TestingT, awsRegion string, bucket string) (st return "", err } - res, err := s3Client.GetBucketPolicy(&s3.GetBucketPolicyInput{ + res, err := s3Client.GetBucketPolicy(context.Background(), &s3.GetBucketPolicyInput{ Bucket: &bucket, }) if err != nil { return "", err } - return aws.StringValue(res.Policy), nil + return aws.ToString(res.Policy), nil } // AssertS3BucketExists checks if the given S3 bucket exists in the given region and fail the test if it does not. @@ -428,7 +436,7 @@ func AssertS3BucketExistsE(t testing.TestingT, region string, name string) error params := &s3.HeadBucketInput{ Bucket: aws.String(name), } - _, err = s3Client.HeadBucket(params) + _, err = s3Client.HeadBucket(context.Background(), params) return err } @@ -471,7 +479,7 @@ func AssertS3BucketPolicyExistsE(t testing.TestingT, region string, bucketName s } // NewS3Client creates an S3 client. -func NewS3Client(t testing.TestingT, region string) *s3.S3 { +func NewS3Client(t testing.TestingT, region string) *s3.Client { client, err := NewS3ClientE(t, region) require.NoError(t, err) @@ -479,30 +487,30 @@ func NewS3Client(t testing.TestingT, region string) *s3.S3 { } // NewS3ClientE creates an S3 client. -func NewS3ClientE(t testing.TestingT, region string) (*s3.S3, error) { +func NewS3ClientE(t testing.TestingT, region string) (*s3.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return s3.New(sess), nil + return s3.NewFromConfig(*sess), nil } // NewS3Uploader creates an S3 Uploader. -func NewS3Uploader(t testing.TestingT, region string) *s3manager.Uploader { +func NewS3Uploader(t testing.TestingT, region string) *manager.Uploader { uploader, err := NewS3UploaderE(t, region) require.NoError(t, err) return uploader } // NewS3UploaderE creates an S3 Uploader. -func NewS3UploaderE(t testing.TestingT, region string) (*s3manager.Uploader, error) { +func NewS3UploaderE(t testing.TestingT, region string) (*manager.Uploader, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return s3manager.NewUploader(sess), nil + return manager.NewUploader(s3.NewFromConfig(*sess)), nil } // S3AccessLoggingNotEnabledErr is a custom error that occurs when acess logging hasn't been enabled on the S3 Bucket diff --git a/modules/aws/s3_test.go b/modules/aws/s3_test.go index dafde4843..fc00ef87f 100644 --- a/modules/aws/s3_test.go +++ b/modules/aws/s3_test.go @@ -2,15 +2,16 @@ package aws import ( + "context" "fmt" "math/rand" "strconv" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" @@ -22,7 +23,7 @@ func TestCreateAndDestroyS3Bucket(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) id := random.UniqueId() - logger.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) + logger.Default.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) s3BucketName := "gruntwork-terratest-" + strings.ToLower(id) @@ -35,7 +36,7 @@ func TestAssertS3BucketExistsNoFalseNegative(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) s3BucketName := "gruntwork-terratest-" + strings.ToLower(random.UniqueId()) - logger.Logf(t, "Random values selected. Region = %s, s3BucketName = %s\n", region, s3BucketName) + logger.Default.Logf(t, "Random values selected. Region = %s, s3BucketName = %s\n", region, s3BucketName) CreateS3Bucket(t, region, s3BucketName) defer DeleteS3Bucket(t, region, s3BucketName) @@ -48,10 +49,10 @@ func TestAssertS3BucketExistsNoFalsePositive(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) s3BucketName := "gruntwork-terratest-" + strings.ToLower(random.UniqueId()) - logger.Logf(t, "Random values selected. Region = %s, s3BucketName = %s\n", region, s3BucketName) + logger.Default.Logf(t, "Random values selected. Region = %s, s3BucketName = %s\n", region, s3BucketName) // We elect not to create the S3 bucket to confirm that our function correctly reports it doesn't exist. - //aws.CreateS3Bucket(region, s3BucketName) + // aws.CreateS3Bucket(region, s3BucketName) err := AssertS3BucketExistsE(t, region, s3BucketName) if err == nil { @@ -64,7 +65,7 @@ func TestAssertS3BucketVersioningEnabled(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) s3BucketName := "gruntwork-terratest-" + strings.ToLower(random.UniqueId()) - logger.Logf(t, "Random values selected. Region = %s, s3BucketName = %s\n", region, s3BucketName) + logger.Default.Logf(t, "Random values selected. Region = %s, s3BucketName = %s\n", region, s3BucketName) CreateS3Bucket(t, region, s3BucketName) defer DeleteS3Bucket(t, region, s3BucketName) @@ -76,10 +77,9 @@ func TestAssertS3BucketVersioningEnabled(t *testing.T) { func TestEmptyS3Bucket(t *testing.T) { t.Parallel() - // region := GetRandomStableRegion(t, nil, nil) - region := "us-east-1" + region := GetRandomStableRegion(t, nil, nil) id := random.UniqueId() - logger.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) + logger.Default.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) s3BucketName := "gruntwork-terratest-" + strings.ToLower(id) @@ -100,7 +100,7 @@ func TestEmptyS3BucketVersioned(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) id := random.UniqueId() - logger.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) + logger.Default.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) s3BucketName := "gruntwork-terratest-" + strings.ToLower(id) @@ -114,13 +114,13 @@ func TestEmptyS3BucketVersioned(t *testing.T) { versionInput := &s3.PutBucketVersioningInput{ Bucket: aws.String(s3BucketName), - VersioningConfiguration: &s3.VersioningConfiguration{ - MFADelete: aws.String("Disabled"), - Status: aws.String("Enabled"), + VersioningConfiguration: &types.VersioningConfiguration{ + MFADelete: types.MFADeleteDisabled, + Status: types.BucketVersioningStatusEnabled, }, } - _, err = s3Client.PutBucketVersioning(versionInput) + _, err = s3Client.PutBucketVersioning(context.Background(), versionInput) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func TestAssertS3BucketPolicyExists(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) id := random.UniqueId() - logger.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) + logger.Default.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) s3BucketName := "gruntwork-terratest-" + strings.ToLower(id) exampleBucketPolicy := fmt.Sprintf(`{"Version":"2012-10-17","Statement":[{"Effect":"Deny","Principal":{"AWS":["*"]},"Action":"s3:Get*","Resource":"arn:aws:s3:::%s/*","Condition":{"Bool":{"aws:SecureTransport":"false"}}}]}`, s3BucketName) @@ -152,7 +152,7 @@ func TestGetS3BucketTags(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) id := random.UniqueId() - logger.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) + logger.Default.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) s3BucketName := "gruntwork-terratest-" + strings.ToLower(id) CreateS3Bucket(t, region, s3BucketName) @@ -163,10 +163,10 @@ func TestGetS3BucketTags(t *testing.T) { t.Fatal(err) } - _, err = s3Client.PutBucketTagging(&s3.PutBucketTaggingInput{ + _, err = s3Client.PutBucketTagging(context.Background(), &s3.PutBucketTaggingInput{ Bucket: &s3BucketName, - Tagging: &s3.Tagging{ - TagSet: []*s3.Tag{ + Tagging: &types.Tagging{ + TagSet: []types.Tag{ { Key: aws.String("Key1"), Value: aws.String("Value1"), @@ -188,9 +188,9 @@ func TestGetS3BucketTags(t *testing.T) { assert.True(t, actualTags["NonExistentKey"] == "") } -func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName string) { +func testEmptyBucket(t *testing.T, s3Client *s3.Client, region string, s3BucketName string) { expectedFileCount := rand.Intn(1000) - logger.Logf(t, "Uploading %s files to bucket %s", strconv.Itoa(expectedFileCount), s3BucketName) + logger.Default.Logf(t, "Uploading %s files to bucket %s", strconv.Itoa(expectedFileCount), s3BucketName) deleted := 0 @@ -199,7 +199,7 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName key := fmt.Sprintf("test-%s", strconv.Itoa(i)) body := strings.NewReader("This is the body") - params := &s3manager.UploadInput{ + params := &s3.PutObjectInput{ Bucket: aws.String(s3BucketName), Key: &key, Body: body, @@ -207,14 +207,14 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName uploader := NewS3Uploader(t, region) - _, err := uploader.Upload(params) + _, err := uploader.Upload(context.Background(), params) if err != nil { t.Fatal(err) } // Delete the first 10 files to be able to test if all files, including delete markers are deleted if i < 10 { - _, err := s3Client.DeleteObject(&s3.DeleteObjectInput{ + _, err := s3Client.DeleteObject(context.Background(), &s3.DeleteObjectInput{ Bucket: aws.String(s3BucketName), Key: aws.String(key), }) @@ -225,21 +225,21 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName } if i != 0 && i%100 == 0 { - logger.Logf(t, "Uploaded %s files to bucket %s successfully", strconv.Itoa(i), s3BucketName) + logger.Default.Logf(t, "Uploaded %s files to bucket %s successfully", strconv.Itoa(i), s3BucketName) } } - logger.Logf(t, "Uploaded %s files to bucket %s successfully", strconv.Itoa(expectedFileCount), s3BucketName) + logger.Default.Logf(t, "Uploaded %s files to bucket %s successfully", strconv.Itoa(expectedFileCount), s3BucketName) // verify bucket contains 1 file now listObjectsParams := &s3.ListObjectsV2Input{ Bucket: aws.String(s3BucketName), } - logger.Logf(t, "Verifying %s files were uploaded to bucket %s", strconv.Itoa(expectedFileCount), s3BucketName) + logger.Default.Logf(t, "Verifying %s files were uploaded to bucket %s", strconv.Itoa(expectedFileCount), s3BucketName) actualCount := 0 for { - bucketObjects, err := s3Client.ListObjectsV2(listObjectsParams) + bucketObjects, err := s3Client.ListObjectsV2(context.Background(), listObjectsParams) if err != nil { t.Fatal(err) } @@ -256,12 +256,12 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName require.Equal(t, expectedFileCount-deleted, actualCount) - //empty bucket - logger.Logf(t, "Emptying bucket %s", s3BucketName) + // empty bucket + logger.Default.Logf(t, "Emptying bucket %s", s3BucketName) EmptyS3Bucket(t, region, s3BucketName) // verify the bucket is empty - bucketObjects, err := s3Client.ListObjectsV2(listObjectsParams) + bucketObjects, err := s3Client.ListObjectsV2(context.Background(), listObjectsParams) if err != nil { t.Fatal(err) } diff --git a/modules/aws/secretsmanager.go b/modules/aws/secretsmanager.go index 6dcc23f14..a3eb8d2e8 100644 --- a/modules/aws/secretsmanager.go +++ b/modules/aws/secretsmanager.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/secretsmanager" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -21,7 +23,7 @@ func CreateSecretStringWithDefaultKeyE(t testing.TestingT, awsRegion, descriptio client := NewSecretsManagerClient(t, awsRegion) - secret, err := client.CreateSecret(&secretsmanager.CreateSecretInput{ + secret, err := client.CreateSecret(context.Background(), &secretsmanager.CreateSecretInput{ Description: aws.String(description), Name: aws.String(name), SecretString: aws.String(secretString), @@ -31,7 +33,7 @@ func CreateSecretStringWithDefaultKeyE(t testing.TestingT, awsRegion, descriptio return "", err } - return aws.StringValue(secret.ARN), nil + return aws.ToString(secret.ARN), nil } // GetSecretValue takes the friendly name or ARN of a secret and returns the plaintext value @@ -47,29 +49,29 @@ func GetSecretValueE(t testing.TestingT, awsRegion, id string) (string, error) { client := NewSecretsManagerClient(t, awsRegion) - secret, err := client.GetSecretValue(&secretsmanager.GetSecretValueInput{ + secret, err := client.GetSecretValue(context.Background(), &secretsmanager.GetSecretValueInput{ SecretId: aws.String(id), }) if err != nil { return "", err } - return aws.StringValue(secret.SecretString), nil + return aws.ToString(secret.SecretString), nil } -// UpdateSecretString updates a secret in Secrets Manager to a new string value +// PutSecretString updates a secret in Secrets Manager to a new string value func PutSecretString(t testing.TestingT, awsRegion, id string, secretString string) { err := PutSecretStringE(t, awsRegion, id, secretString) require.NoError(t, err) } -// UpdateSecretStringE updates a secret in Secrets Manager to a new string value +// PutSecretStringE updates a secret in Secrets Manager to a new string value func PutSecretStringE(t testing.TestingT, awsRegion, id string, secretString string) error { logger.Default.Logf(t, "Updating secret with ID %s", id) client := NewSecretsManagerClient(t, awsRegion) - _, err := client.PutSecretValue(&secretsmanager.PutSecretValueInput{ + _, err := client.PutSecretValue(context.Background(), &secretsmanager.PutSecretValueInput{ SecretId: aws.String(id), SecretString: aws.String(secretString), }) @@ -77,19 +79,19 @@ func PutSecretStringE(t testing.TestingT, awsRegion, id string, secretString str return err } -// DeleteSecret deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30 day recovery window. +// DeleteSecret deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30-day recovery window. func DeleteSecret(t testing.TestingT, awsRegion, id string, forceDelete bool) { err := DeleteSecretE(t, awsRegion, id, forceDelete) require.NoError(t, err) } -// DeleteSecretE deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30 day recovery window. +// DeleteSecretE deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30-day recovery window. func DeleteSecretE(t testing.TestingT, awsRegion, id string, forceDelete bool) error { logger.Default.Logf(t, "Deleting secret with ID %s", id) client := NewSecretsManagerClient(t, awsRegion) - _, err := client.DeleteSecret(&secretsmanager.DeleteSecretInput{ + _, err := client.DeleteSecret(context.Background(), &secretsmanager.DeleteSecretInput{ ForceDeleteWithoutRecovery: aws.Bool(forceDelete), SecretId: aws.String(id), }) @@ -98,18 +100,18 @@ func DeleteSecretE(t testing.TestingT, awsRegion, id string, forceDelete bool) e } // NewSecretsManagerClient creates a new SecretsManager client. -func NewSecretsManagerClient(t testing.TestingT, region string) *secretsmanager.SecretsManager { +func NewSecretsManagerClient(t testing.TestingT, region string) *secretsmanager.Client { client, err := NewSecretsManagerClientE(t, region) require.NoError(t, err) return client } // NewSecretsManagerClientE creates a new SecretsManager client. -func NewSecretsManagerClientE(t testing.TestingT, region string) (*secretsmanager.SecretsManager, error) { +func NewSecretsManagerClientE(t testing.TestingT, region string) (*secretsmanager.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return secretsmanager.New(sess), nil + return secretsmanager.NewFromConfig(*sess), nil } diff --git a/modules/aws/sns.go b/modules/aws/sns.go index 4deb3b8c9..2243c927d 100644 --- a/modules/aws/sns.go +++ b/modules/aws/sns.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sns" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -29,12 +31,12 @@ func CreateSnsTopicE(t testing.TestingT, region string, snsTopicName string) (st Name: &snsTopicName, } - output, err := snsClient.CreateTopic(createTopicInput) + output, err := snsClient.CreateTopic(context.Background(), createTopicInput) if err != nil { return "", err } - return aws.StringValue(output.TopicArn), err + return aws.ToString(output.TopicArn), err } // DeleteSNSTopic deletes an SNS Topic. @@ -58,12 +60,12 @@ func DeleteSNSTopicE(t testing.TestingT, region string, snsTopicArn string) erro TopicArn: aws.String(snsTopicArn), } - _, err = snsClient.DeleteTopic(deleteTopicInput) + _, err = snsClient.DeleteTopic(context.Background(), deleteTopicInput) return err } // NewSnsClient creates a new SNS client. -func NewSnsClient(t testing.TestingT, region string) *sns.SNS { +func NewSnsClient(t testing.TestingT, region string) *sns.Client { client, err := NewSnsClientE(t, region) if err != nil { t.Fatal(err) @@ -72,11 +74,11 @@ func NewSnsClient(t testing.TestingT, region string) *sns.SNS { } // NewSnsClientE creates a new SNS client. -func NewSnsClientE(t testing.TestingT, region string) (*sns.SNS, error) { +func NewSnsClientE(t testing.TestingT, region string) (*sns.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return sns.New(sess), nil + return sns.NewFromConfig(*sess), nil } diff --git a/modules/aws/sns_test.go b/modules/aws/sns_test.go index 3445e7e96..a442c934b 100644 --- a/modules/aws/sns_test.go +++ b/modules/aws/sns_test.go @@ -1,12 +1,13 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -29,7 +30,7 @@ func snsTopicExists(t *testing.T, region string, arn string) bool { input := sns.GetTopicAttributesInput{TopicArn: aws.String(arn)} - if _, err := snsClient.GetTopicAttributes(&input); err != nil { + if _, err := snsClient.GetTopicAttributes(context.Background(), &input); err != nil { if strings.Contains(err.Error(), "NotFound") { return false } diff --git a/modules/aws/sqs.go b/modules/aws/sqs.go index 8c11c1d2e..729fd0d0a 100644 --- a/modules/aws/sqs.go +++ b/modules/aws/sqs.go @@ -1,12 +1,14 @@ package aws import ( + "context" "fmt" "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/google/uuid" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" @@ -37,7 +39,7 @@ func CreateRandomQueueE(t testing.TestingT, awsRegion string, prefix string) (st channelName := fmt.Sprintf("%s-%s", prefix, channel.String()) - queue, err := sqsClient.CreateQueue(&sqs.CreateQueueInput{ + queue, err := sqsClient.CreateQueue(context.Background(), &sqs.CreateQueueInput{ QueueName: aws.String(channelName), }) @@ -45,7 +47,7 @@ func CreateRandomQueueE(t testing.TestingT, awsRegion string, prefix string) (st return "", err } - return aws.StringValue(queue.QueueUrl), nil + return aws.ToString(queue.QueueUrl), nil } // CreateRandomFifoQueue creates a new FIFO SQS queue with a random name that starts with the given prefix and return the queue URL. @@ -73,11 +75,11 @@ func CreateRandomFifoQueueE(t testing.TestingT, awsRegion string, prefix string) channelName := fmt.Sprintf("%s-%s.fifo", prefix, channel.String()) - queue, err := sqsClient.CreateQueue(&sqs.CreateQueueInput{ + queue, err := sqsClient.CreateQueue(context.Background(), &sqs.CreateQueueInput{ QueueName: aws.String(channelName), - Attributes: map[string]*string{ - "ContentBasedDeduplication": aws.String("true"), - "FifoQueue": aws.String("true"), + Attributes: map[string]string{ + "ContentBasedDeduplication": "true", + "FifoQueue": "true", }, }) @@ -85,7 +87,7 @@ func CreateRandomFifoQueueE(t testing.TestingT, awsRegion string, prefix string) return "", err } - return aws.StringValue(queue.QueueUrl), nil + return aws.ToString(queue.QueueUrl), nil } // DeleteQueue deletes the SQS queue with the given URL. @@ -105,7 +107,7 @@ func DeleteQueueE(t testing.TestingT, awsRegion string, queueURL string) error { return err } - _, err = sqsClient.DeleteQueue(&sqs.DeleteQueueInput{ + _, err = sqsClient.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{ QueueUrl: aws.String(queueURL), }) @@ -129,7 +131,7 @@ func DeleteMessageFromQueueE(t testing.TestingT, awsRegion string, queueURL stri return err } - _, err = sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ + _, err = sqsClient.DeleteMessage(context.Background(), &sqs.DeleteMessageInput{ ReceiptHandle: &receipt, QueueUrl: &queueURL, }) @@ -154,7 +156,7 @@ func SendMessageToQueueE(t testing.TestingT, awsRegion string, queueURL string, return err } - res, err := sqsClient.SendMessage(&sqs.SendMessageInput{ + res, err := sqsClient.SendMessage(context.Background(), &sqs.SendMessageInput{ MessageBody: &message, QueueUrl: &queueURL, }) @@ -167,12 +169,12 @@ func SendMessageToQueueE(t testing.TestingT, awsRegion string, queueURL string, return err } - logger.Default.Logf(t, "Message id %s sent to queue %s", aws.StringValue(res.MessageId), queueURL) + logger.Default.Logf(t, "Message id %s sent to queue %s", aws.ToString(res.MessageId), queueURL) return nil } -// SendMessageToFifoQueue sends the given message to the FIFO SQS queue with the given URL. +// SendMessageFifoToQueue sends the given message to the FIFO SQS queue with the given URL. func SendMessageFifoToQueue(t testing.TestingT, awsRegion string, queueURL string, message string, messageGroupID string) { err := SendMessageToFifoQueueE(t, awsRegion, queueURL, message, messageGroupID) if err != nil { @@ -189,7 +191,7 @@ func SendMessageToFifoQueueE(t testing.TestingT, awsRegion string, queueURL stri return err } - res, err := sqsClient.SendMessage(&sqs.SendMessageInput{ + res, err := sqsClient.SendMessage(context.Background(), &sqs.SendMessageInput{ MessageBody: &message, QueueUrl: &queueURL, MessageGroupId: &messageGroupID, @@ -203,7 +205,7 @@ func SendMessageToFifoQueueE(t testing.TestingT, awsRegion string, queueURL stri return err } - logger.Default.Logf(t, "Message id %s sent to queue %s", aws.StringValue(res.MessageId), queueURL) + logger.Default.Logf(t, "Message id %s sent to queue %s", aws.ToString(res.MessageId), queueURL) return nil } @@ -232,12 +234,12 @@ func WaitForQueueMessage(t testing.TestingT, awsRegion string, queueURL string, for i := 0; i < cycles; i++ { logger.Default.Logf(t, "Waiting for message on %s (%ss)", queueURL, strconv.Itoa(i*cycleLength)) - result, err := sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{ - QueueUrl: aws.String(queueURL), - AttributeNames: aws.StringSlice([]string{"SentTimestamp"}), - MaxNumberOfMessages: aws.Int64(1), - MessageAttributeNames: aws.StringSlice([]string{"All"}), - WaitTimeSeconds: aws.Int64(int64(cycleLength)), + result, err := sqsClient.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{ + QueueUrl: aws.String(queueURL), + MessageSystemAttributeNames: []types.MessageSystemAttributeName{types.MessageSystemAttributeNameSentTimestamp}, + MaxNumberOfMessages: int32(1), + MessageAttributeNames: []string{"All"}, + WaitTimeSeconds: int32(cycleLength), }) if err != nil { @@ -254,7 +256,7 @@ func WaitForQueueMessage(t testing.TestingT, awsRegion string, queueURL string, } // NewSqsClient creates a new SQS client. -func NewSqsClient(t testing.TestingT, region string) *sqs.SQS { +func NewSqsClient(t testing.TestingT, region string) *sqs.Client { client, err := NewSqsClientE(t, region) if err != nil { t.Fatal(err) @@ -263,13 +265,13 @@ func NewSqsClient(t testing.TestingT, region string) *sqs.SQS { } // NewSqsClientE creates a new SQS client. -func NewSqsClientE(t testing.TestingT, region string) (*sqs.SQS, error) { +func NewSqsClientE(t testing.TestingT, region string) (*sqs.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return sqs.New(sess), nil + return sqs.NewFromConfig(*sess), nil } // ReceiveMessageTimeout is an error that occurs if receiving a message times out. diff --git a/modules/aws/sqs_test.go b/modules/aws/sqs_test.go index 6200e8879..5f975d500 100644 --- a/modules/aws/sqs_test.go +++ b/modules/aws/sqs_test.go @@ -1,12 +1,13 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -71,7 +72,7 @@ func queueExists(t *testing.T, region string, url string) bool { input := sqs.GetQueueAttributesInput{QueueUrl: aws.String(url)} - if _, err := sqsClient.GetQueueAttributes(&input); err != nil { + if _, err := sqsClient.GetQueueAttributes(context.Background(), &input); err != nil { if strings.Contains(err.Error(), "NonExistentQueue") { return false } diff --git a/modules/aws/ssm.go b/modules/aws/ssm.go index 460b85501..679a40197 100644 --- a/modules/aws/ssm.go +++ b/modules/aws/ssm.go @@ -1,11 +1,14 @@ package aws import ( + "context" + "errors" "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/retry" "github.com/gruntwork-io/terratest/modules/testing" @@ -29,9 +32,9 @@ func GetParameterE(t testing.TestingT, awsRegion string, keyName string) (string return GetParameterWithClientE(t, ssmClient, keyName) } -// GetParameterE retrieves the latest version of SSM Parameter at keyName with decryption with the ability to provide the SSM client. -func GetParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string) (string, error) { - resp, err := client.GetParameter(&ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)}) +// GetParameterWithClientE retrieves the latest version of SSM Parameter at keyName with decryption with the ability to provide the SSM client. +func GetParameterWithClientE(t testing.TestingT, client *ssm.Client, keyName string) (string, error) { + resp, err := client.GetParameter(context.Background(), &ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)}) if err != nil { return "", err } @@ -56,14 +59,19 @@ func PutParameterE(t testing.TestingT, awsRegion string, keyName string, keyDesc return PutParameterWithClientE(t, ssmClient, keyName, keyDescription, keyValue) } -// PutParameterE creates new version of SSM Parameter at keyName with keyValue as SecureString with the ability to provide the SSM client. -func PutParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string, keyDescription string, keyValue string) (int64, error) { - resp, err := client.PutParameter(&ssm.PutParameterInput{Name: aws.String(keyName), Description: aws.String(keyDescription), Value: aws.String(keyValue), Type: aws.String("SecureString")}) +// PutParameterWithClientE creates new version of SSM Parameter at keyName with keyValue as SecureString with the ability to provide the SSM client. +func PutParameterWithClientE(t testing.TestingT, client *ssm.Client, keyName string, keyDescription string, keyValue string) (int64, error) { + resp, err := client.PutParameter(context.Background(), &ssm.PutParameterInput{ + Name: aws.String(keyName), + Description: aws.String(keyDescription), + Value: aws.String(keyValue), + Type: types.ParameterTypeSecureString, + }) if err != nil { return 0, err } - return *resp.Version, nil + return resp.Version, nil } // DeleteParameter deletes all versions of SSM Parameter at keyName. @@ -81,9 +89,9 @@ func DeleteParameterE(t testing.TestingT, awsRegion string, keyName string) erro return DeleteParameterWithClientE(t, ssmClient, keyName) } -// DeleteParameterE deletes all versions of SSM Parameter at keyName with the ability to provide the SSM client. -func DeleteParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string) error { - _, err := client.DeleteParameter(&ssm.DeleteParameterInput{Name: aws.String(keyName)}) +// DeleteParameterWithClientE deletes all versions of SSM Parameter at keyName with the ability to provide the SSM client. +func DeleteParameterWithClientE(t testing.TestingT, client *ssm.Client, keyName string) error { + _, err := client.DeleteParameter(context.Background(), &ssm.DeleteParameterInput{Name: aws.String(keyName)}) if err != nil { return err } @@ -91,21 +99,21 @@ func DeleteParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName str return nil } -// NewSsmClient creates a SSM client. -func NewSsmClient(t testing.TestingT, region string) *ssm.SSM { +// NewSsmClient creates an SSM client. +func NewSsmClient(t testing.TestingT, region string) *ssm.Client { client, err := NewSsmClientE(t, region) require.NoError(t, err) return client } // NewSsmClientE creates an SSM client. -func NewSsmClientE(t testing.TestingT, region string) (*ssm.SSM, error) { +func NewSsmClientE(t testing.TestingT, region string) (*ssm.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ssm.New(sess), nil + return ssm.NewFromConfig(*sess), nil } // WaitForSsmInstanceE waits until the instance get registered to the SSM inventory. @@ -117,23 +125,23 @@ func WaitForSsmInstanceE(t testing.TestingT, awsRegion, instanceID string, timeo return WaitForSsmInstanceWithClientE(t, client, instanceID, timeout) } -// WaitForSsmInstanceE waits until the instance get registered to the SSM inventory with the ability to provide the SSM client. -func WaitForSsmInstanceWithClientE(t testing.TestingT, client *ssm.SSM, instanceID string, timeout time.Duration) error { +// WaitForSsmInstanceWithClientE waits until the instance get registered to the SSM inventory with the ability to provide the SSM client. +func WaitForSsmInstanceWithClientE(t testing.TestingT, client *ssm.Client, instanceID string, timeout time.Duration) error { timeBetweenRetries := 2 * time.Second maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds()) description := fmt.Sprintf("Waiting for %s to appear in the SSM inventory", instanceID) input := &ssm.GetInventoryInput{ - Filters: []*ssm.InventoryFilter{ + Filters: []types.InventoryFilter{ { Key: aws.String("AWS:InstanceInformation.InstanceId"), - Type: aws.String("Equal"), - Values: aws.StringSlice([]string{instanceID}), + Type: types.InventoryQueryOperatorTypeEqual, + Values: []string{instanceID}, }, }, } _, err := retry.DoWithRetryE(t, description, maxRetries, timeBetweenRetries, func() (string, error) { - resp, err := client.GetInventory(input) + resp, err := client.GetInventory(context.Background(), input) if err != nil { return "", err @@ -173,7 +181,7 @@ func CheckSsmCommandE(t testing.TestingT, awsRegion, instanceID, command string, } // CheckSSMCommandWithClientE checks that you can run the given command on the given instance through AWS SSM with the ability to provide the SSM client. Returns the result and an error if one occurs. -func CheckSSMCommandWithClientE(t testing.TestingT, client *ssm.SSM, instanceID, command string, timeout time.Duration) (*CommandOutput, error) { +func CheckSSMCommandWithClientE(t testing.TestingT, client *ssm.Client, instanceID, command string, timeout time.Duration) (*CommandOutput, error) { return CheckSSMCommandWithClientWithDocumentE(t, client, instanceID, command, "AWS-RunShellScript", timeout) } @@ -197,19 +205,22 @@ func CheckSsmCommandWithDocumentE(t testing.TestingT, awsRegion, instanceID, com } // CheckSSMCommandWithClientWithDocumentE checks that you can run the given command on the given instance through AWS SSM with the ability to provide the SSM client with specified Command Doc type. Returns the result and an error if one occurs. -func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.SSM, instanceID, command string, commandDocName string, timeout time.Duration) (*CommandOutput, error) { +func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.Client, instanceID, command string, commandDocName string, timeout time.Duration) (*CommandOutput, error) { timeBetweenRetries := 2 * time.Second maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds()) - resp, err := client.SendCommand(&ssm.SendCommandInput{ - Comment: aws.String("Terratest SSM"), - DocumentName: aws.String(commandDocName), - InstanceIds: aws.StringSlice([]string{instanceID}), - Parameters: map[string][]*string{ - "commands": aws.StringSlice([]string{command}), + resp, err := client.SendCommand( + context.Background(), + &ssm.SendCommandInput{ + Comment: aws.String("Terratest SSM"), + DocumentName: aws.String(commandDocName), + InstanceIds: []string{instanceID}, + Parameters: map[string][]string{ + "commands": {command}, + }, }, - }) + ) if err != nil { return nil, err } @@ -225,7 +236,7 @@ func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.SSM, result := &CommandOutput{} _, err = retry.DoWithRetryableErrorsE(t, description, retryableErrors, maxRetries, timeBetweenRetries, func() (string, error) { - resp, err := client.GetCommandInvocation(&ssm.GetCommandInvocationInput{ + resp, err := client.GetCommandInvocation(context.Background(), &ssm.GetCommandInvocationInput{ CommandId: resp.Command.CommandId, InstanceId: &instanceID, }) @@ -234,25 +245,26 @@ func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.SSM, return "", err } - result.Stderr = aws.StringValue(resp.StandardErrorContent) - result.Stdout = aws.StringValue(resp.StandardOutputContent) - result.ExitCode = aws.Int64Value(resp.ResponseCode) + result.Stderr = aws.ToString(resp.StandardErrorContent) + result.Stdout = aws.ToString(resp.StandardOutputContent) + result.ExitCode = int64(resp.ResponseCode) - status := aws.StringValue(resp.Status) + status := resp.Status - if status == ssm.CommandInvocationStatusSuccess { + if status == types.CommandInvocationStatusSuccess { return "", nil } - if status == ssm.CommandInvocationStatusFailed { - return "", fmt.Errorf(aws.StringValue(resp.StatusDetails)) + if status == types.CommandInvocationStatusFailed { + return "", fmt.Errorf(aws.ToString(resp.StatusDetails)) } return "", fmt.Errorf("bad status: %s", status) }) if err != nil { - if actualErr, ok := err.(retry.FatalError); ok { + var actualErr retry.FatalError + if errors.As(err, &actualErr) { return result, actualErr.Underlying } return result, fmt.Errorf("unexpected error: %v", err) diff --git a/modules/aws/ssm_test.go b/modules/aws/ssm_test.go index 2c865f468..0b8ff8cd2 100644 --- a/modules/aws/ssm_test.go +++ b/modules/aws/ssm_test.go @@ -17,9 +17,9 @@ func TestParameterIsFound(t *testing.T) { expectedValue := fmt.Sprintf("test-value-%s", random.UniqueId()) expectedDescription := fmt.Sprintf("test-description-%s", random.UniqueId()) version := PutParameter(t, awsRegion, expectedName, expectedDescription, expectedValue) - logger.Logf(t, "Created parameter with version %d", version) + logger.Default.Logf(t, "Created parameter with version %d", version) keyValue := GetParameter(t, awsRegion, expectedName) - logger.Logf(t, "Found key with name %s", expectedName) + logger.Default.Logf(t, "Found key with name %s", expectedName) assert.Equal(t, expectedValue, keyValue) } @@ -29,10 +29,10 @@ func TestParameterIsDeleted(t *testing.T) { expectedValue := fmt.Sprintf("test-value-%s", random.UniqueId()) expectedDescription := fmt.Sprintf("test-description-%s", random.UniqueId()) version := PutParameter(t, awsRegion, expectedName, expectedDescription, expectedValue) - logger.Logf(t, "Created parameter with version %d", version) + logger.Default.Logf(t, "Created parameter with version %d", version) DeleteParameter(t, awsRegion, expectedName) - logger.Logf(t, "Deleted paramter %s", expectedName) + logger.Default.Logf(t, "Deleted paramter %s", expectedName) actualValue, err := GetParameterE(t, awsRegion, expectedName) assert.Equal(t, actualValue, "") diff --git a/modules/aws/vpc.go b/modules/aws/vpc.go index 5ee142a3f..8091a3353 100644 --- a/modules/aws/vpc.go +++ b/modules/aws/vpc.go @@ -1,12 +1,14 @@ package aws import ( + "context" "fmt" "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/random" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -50,45 +52,45 @@ func GetDefaultVpc(t testing.TestingT, region string) *Vpc { // GetDefaultVpcE fetches information about the default VPC in the given region. func GetDefaultVpcE(t testing.TestingT, region string) (*Vpc, error) { - defaultVpcFilter := ec2.Filter{Name: aws.String(isDefaultFilterName), Values: []*string{aws.String(isDefaultFilterValue)}} - vpcs, err := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region) + defaultVpcFilter := types.Filter{Name: aws.String(isDefaultFilterName), Values: []string{isDefaultFilterValue}} + vpcs, err := GetVpcsE(t, []types.Filter{defaultVpcFilter}, region) numVpcs := len(vpcs) if numVpcs != 1 { - return nil, fmt.Errorf("Expected to find one default VPC in region %s but found %s", region, strconv.Itoa(numVpcs)) + return nil, fmt.Errorf("expected to find one default VPC in region %s but found %s", region, strconv.Itoa(numVpcs)) } return vpcs[0], err } -// GetVpcById fetches information about a VPC with given Id in the given region. +// GetVpcById fetches information about a VPC with given ID in the given region. func GetVpcById(t testing.TestingT, vpcId string, region string) *Vpc { vpc, err := GetVpcByIdE(t, vpcId, region) require.NoError(t, err) return vpc } -// GetVpcByIdE fetches information about a VPC with given Id in the given region. +// GetVpcByIdE fetches information about a VPC with given ID in the given region. func GetVpcByIdE(t testing.TestingT, vpcId string, region string) (*Vpc, error) { - vpcIdFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcId}} - vpcs, err := GetVpcsE(t, []*ec2.Filter{&vpcIdFilter}, region) + vpcIdFilter := types.Filter{Name: aws.String(vpcIDFilterName), Values: []string{vpcId}} + vpcs, err := GetVpcsE(t, []types.Filter{vpcIdFilter}, region) numVpcs := len(vpcs) if numVpcs != 1 { - return nil, fmt.Errorf("Expected to find one VPC with ID %s in region %s but found %s", vpcId, region, strconv.Itoa(numVpcs)) + return nil, fmt.Errorf("expected to find one VPC with ID %s in region %s but found %s", vpcId, region, strconv.Itoa(numVpcs)) } return vpcs[0], err } -// GetVpcsE fetches informations about VPCs from given regions limited by filters -func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, error) { +// GetVpcsE fetches information about VPCs from given regions limited by filters +func GetVpcsE(t testing.TestingT, filters []types.Filter, region string) ([]*Vpc, error) { client, err := NewEc2ClientE(t, region) if err != nil { return nil, err } - vpcs, err := client.DescribeVpcs(&ec2.DescribeVpcsInput{Filters: filters}) + vpcs, err := client.DescribeVpcs(context.Background(), &ec2.DescribeVpcsInput{Filters: filters}) if err != nil { return nil, err } @@ -97,13 +99,13 @@ func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, retVal := make([]*Vpc, numVpcs) for i, vpc := range vpcs.Vpcs { - vpcIdFilter := generateVpcIdFilter(aws.StringValue(vpc.VpcId)) - subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIdFilter}) + vpcIdFilter := generateVpcIdFilter(aws.ToString(vpc.VpcId)) + subnets, err := GetSubnetsForVpcE(t, region, []types.Filter{vpcIdFilter}) if err != nil { return nil, err } - tags, err := GetTagsForVpcE(t, aws.StringValue(vpc.VpcId), region) + tags, err := GetTagsForVpcE(t, aws.ToString(vpc.VpcId), region) if err != nil { return nil, err } @@ -125,7 +127,7 @@ func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, }() retVal[i] = &Vpc{ - Id: aws.StringValue(vpc.VpcId), + Id: aws.ToString(vpc.VpcId), Name: FindVpcName(vpc), Subnets: subnets, Tags: tags, @@ -140,7 +142,7 @@ func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, // FindVpcName extracts the VPC name from its tags (if any). Fall back to "Default" if it's the default VPC or empty string // otherwise. -func FindVpcName(vpc *ec2.Vpc) string { +func FindVpcName(vpc types.Vpc) string { for _, tag := range vpc.Tags { if *tag.Key == "Name" { return *tag.Value @@ -157,7 +159,7 @@ func FindVpcName(vpc *ec2.Vpc) string { // GetSubnetsForVpc gets the subnets in the specified VPC. func GetSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet { vpcIDFilter := generateVpcIdFilter(vpcID) - subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter}) + subnets, err := GetSubnetsForVpcE(t, region, []types.Filter{vpcIDFilter}) if err != nil { t.Fatal(err) } @@ -167,11 +169,11 @@ func GetSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet // GetAzDefaultSubnetsForVpc gets the default az subnets in the specified VPC. func GetAzDefaultSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet { vpcIDFilter := generateVpcIdFilter(vpcID) - defaultForAzFilter := ec2.Filter{ + defaultForAzFilter := types.Filter{ Name: aws.String(defaultForAzFilterName), - Values: []*string{aws.String("true")}, + Values: []string{"true"}, } - subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter, &defaultForAzFilter}) + subnets, err := GetSubnetsForVpcE(t, region, []types.Filter{vpcIDFilter, defaultForAzFilter}) if err != nil { t.Fatal(err) } @@ -179,27 +181,27 @@ func GetAzDefaultSubnetsForVpc(t testing.TestingT, vpcID string, region string) } // generateVpcIdFilter is a helper method to generate vpc id filter -func generateVpcIdFilter(vpcID string) ec2.Filter { - return ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcID}} +func generateVpcIdFilter(vpcID string) types.Filter { + return types.Filter{Name: aws.String(vpcIDFilterName), Values: []string{vpcID}} } // GetSubnetsForVpcE gets the subnets in the specified VPC. -func GetSubnetsForVpcE(t testing.TestingT, region string, filters []*ec2.Filter) ([]Subnet, error) { +func GetSubnetsForVpcE(t testing.TestingT, region string, filters []types.Filter) ([]Subnet, error) { client, err := NewEc2ClientE(t, region) if err != nil { return nil, err } - subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: filters}) + subnetOutput, err := client.DescribeSubnets(context.Background(), &ec2.DescribeSubnetsInput{Filters: filters}) if err != nil { return nil, err } - subnets := []Subnet{} + var subnets []Subnet for _, ec2Subnet := range subnetOutput.Subnets { subnetTags := GetTagsForSubnet(t, *ec2Subnet.SubnetId, region) - subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone), DefaultForAz: aws.BoolValue(ec2Subnet.DefaultForAz), Tags: subnetTags} + subnet := Subnet{Id: aws.ToString(ec2Subnet.SubnetId), AvailabilityZone: aws.ToString(ec2Subnet.AvailabilityZone), DefaultForAz: aws.ToBool(ec2Subnet.DefaultForAz), Tags: subnetTags} subnets = append(subnets, subnet) } @@ -219,14 +221,14 @@ func GetTagsForVpcE(t testing.TestingT, vpcID string, region string) (map[string client, err := NewEc2ClientE(t, region) require.NoError(t, err) - vpcResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(vpcResourceTypeFilterValue)}} - vpcResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&vpcID}} - tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&vpcResourceTypeFilter, &vpcResourceIdFilter}}) + vpcResourceTypeFilter := types.Filter{Name: aws.String(resourceIdFilterName), Values: []string{vpcResourceTypeFilterValue}} + vpcResourceIdFilter := types.Filter{Name: aws.String(resourceTypeFilterName), Values: []string{vpcID}} + tagsOutput, err := client.DescribeTags(context.Background(), &ec2.DescribeTagsInput{Filters: []types.Filter{vpcResourceTypeFilter, vpcResourceIdFilter}}) require.NoError(t, err) tags := map[string]string{} for _, tag := range tagsOutput.Tags { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -244,12 +246,12 @@ func GetDefaultSubnetIDsForVpcE(t testing.TestingT, vpc Vpc) ([]string, error) { if vpc.Name != defaultVPCName { // You cannot create a default subnet in a nondefault VPC // https://docs.aws.amazon.com/vpc/latest/userguide/default-vpc.html - return nil, fmt.Errorf("Only default VPCs have default subnets but VPC with id %s is not default VPC", vpc.Id) + return nil, fmt.Errorf("only default VPCs have default subnets but VPC with id %s is not default VPC", vpc.Id) } - subnetIDs := []string{} + var subnetIDs []string numSubnets := len(vpc.Subnets) if numSubnets == 0 { - return nil, fmt.Errorf("Expected to find at least one subnet in vpc with ID %s but found zero", vpc.Id) + return nil, fmt.Errorf("expected to find at least one subnet in vpc with ID %s but found zero", vpc.Id) } for _, subnet := range vpc.Subnets { @@ -273,14 +275,14 @@ func GetTagsForSubnetE(t testing.TestingT, subnetId string, region string) (map[ client, err := NewEc2ClientE(t, region) require.NoError(t, err) - subnetResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(subnetResourceTypeFilterValue)}} - subnetResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&subnetId}} - tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&subnetResourceTypeFilter, &subnetResourceIdFilter}}) + subnetResourceTypeFilter := types.Filter{Name: aws.String(resourceIdFilterName), Values: []string{subnetResourceTypeFilterValue}} + subnetResourceIdFilter := types.Filter{Name: aws.String(resourceTypeFilterName), Values: []string{subnetId}} + tagsOutput, err := client.DescribeTags(context.Background(), &ec2.DescribeTagsInput{Filters: []types.Filter{subnetResourceTypeFilter, subnetResourceIdFilter}}) require.NoError(t, err) tags := map[string]string{} for _, tag := range tagsOutput.Tags { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -297,9 +299,9 @@ func IsPublicSubnet(t testing.TestingT, subnetId string, region string) bool { func IsPublicSubnetE(t testing.TestingT, subnetId string, region string) (bool, error) { subnetIdFilterName := "association.subnet-id" - subnetIdFilter := ec2.Filter{ + subnetIdFilter := types.Filter{ Name: &subnetIdFilterName, - Values: []*string{&subnetId}, + Values: []string{subnetId}, } client, err := NewEc2ClientE(t, region) @@ -307,7 +309,7 @@ func IsPublicSubnetE(t testing.TestingT, subnetId string, region string) (bool, return false, err } - rts, err := client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{&subnetIdFilter}}) + rts, err := client.DescribeRouteTables(context.Background(), &ec2.DescribeRouteTablesInput{Filters: []types.Filter{subnetIdFilter}}) if err != nil { return false, err } @@ -322,7 +324,7 @@ func IsPublicSubnetE(t testing.TestingT, subnetId string, region string) (bool, for _, rt := range rts.RouteTables { for _, r := range rt.Routes { - if strings.HasPrefix(aws.StringValue(r.GatewayId), "igw-") { + if strings.HasPrefix(aws.ToString(r.GatewayId), "igw-") { return true, nil } } @@ -341,28 +343,28 @@ func getImplicitRouteTableForSubnetE(t testing.TestingT, subnetId string, region return nil, err } - subnetFilter := ec2.Filter{ + subnetFilter := types.Filter{ Name: &subnetFilterName, - Values: []*string{&subnetId}, + Values: []string{subnetId}, } - subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&subnetFilter}}) + subnetOutput, err := client.DescribeSubnets(context.Background(), &ec2.DescribeSubnetsInput{Filters: []types.Filter{subnetFilter}}) if err != nil { return nil, err } numSubnets := len(subnetOutput.Subnets) if numSubnets != 1 { - return nil, fmt.Errorf("Expected to find one subnet with id %s but found %s", subnetId, strconv.Itoa(numSubnets)) + return nil, fmt.Errorf("expected to find one subnet with id %s but found %s", subnetId, strconv.Itoa(numSubnets)) } - mainRouteFilter := ec2.Filter{ + mainRouteFilter := types.Filter{ Name: &mainRouteFilterName, - Values: []*string{&mainRouteFilterValue}, + Values: []string{mainRouteFilterValue}, } - vpcFilter := ec2.Filter{ + vpcFilter := types.Filter{ Name: aws.String(vpcIDFilterName), - Values: []*string{subnetOutput.Subnets[0].VpcId}, + Values: []string{*subnetOutput.Subnets[0].VpcId}, } - return client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{&mainRouteFilter, &vpcFilter}}) + return client.DescribeRouteTables(context.Background(), &ec2.DescribeRouteTablesInput{Filters: []types.Filter{mainRouteFilter, vpcFilter}}) } // GetRandomPrivateCidrBlock gets a random CIDR block from the range of acceptable private IP addresses per RFC 1918 diff --git a/modules/aws/vpc_test.go b/modules/aws/vpc_test.go index 8a060efc8..8d04320f5 100644 --- a/modules/aws/vpc_test.go +++ b/modules/aws/vpc_test.go @@ -1,13 +1,15 @@ package aws import ( + "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) func TestGetDefaultVpc(t *testing.T) { @@ -42,8 +44,8 @@ func TestGetVpcsE(t *testing.T) { isDefaultFilterName := "isDefault" isDefaultFilterValue := "true" - defaultVpcFilter := ec2.Filter{Name: &isDefaultFilterName, Values: []*string{&isDefaultFilterValue}} - vpcs, _ := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region) + defaultVpcFilter := types.Filter{Name: &isDefaultFilterName, Values: []string{isDefaultFilterValue}} + vpcs, _ := GetVpcsE(t, []types.Filter{defaultVpcFilter}, region) require.Equal(t, len(vpcs), 1) assert.NotEmpty(t, vpcs[0].Name) @@ -164,7 +166,7 @@ func TestGetDefaultAzSubnets(t *testing.T) { vpc := GetDefaultVpc(t, region) // Note: cannot know exact list of default azs aheard of time, but we know that - //it must be greater than 0 for default vpc. + // it must be greater than 0 for default vpc. subnets := GetAzDefaultSubnetsForVpc(t, vpc.Id, region) assert.NotZero(t, len(subnets)) } @@ -172,16 +174,16 @@ func TestGetDefaultAzSubnets(t *testing.T) { func createPublicRoute(t *testing.T, vpcId string, routeTableId string, region string) { ec2Client := NewEc2Client(t, region) - createIGWOut, igerr := ec2Client.CreateInternetGateway(&ec2.CreateInternetGatewayInput{}) + createIGWOut, igerr := ec2Client.CreateInternetGateway(context.Background(), &ec2.CreateInternetGatewayInput{}) require.NoError(t, igerr) - _, aigerr := ec2Client.AttachInternetGateway(&ec2.AttachInternetGatewayInput{ + _, aigerr := ec2Client.AttachInternetGateway(context.Background(), &ec2.AttachInternetGatewayInput{ InternetGatewayId: createIGWOut.InternetGateway.InternetGatewayId, VpcId: aws.String(vpcId), }) require.NoError(t, aigerr) - _, err := ec2Client.CreateRoute(&ec2.CreateRouteInput{ + _, err := ec2Client.CreateRoute(context.Background(), &ec2.CreateRouteInput{ RouteTableId: aws.String(routeTableId), DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: createIGWOut.InternetGateway.InternetGatewayId, @@ -190,10 +192,10 @@ func createPublicRoute(t *testing.T, vpcId string, routeTableId string, region s require.NoError(t, err) } -func createRouteTable(t *testing.T, vpcId string, region string) ec2.RouteTable { +func createRouteTable(t *testing.T, vpcId string, region string) types.RouteTable { ec2Client := NewEc2Client(t, region) - createRouteTableOutput, err := ec2Client.CreateRouteTable(&ec2.CreateRouteTableInput{ + createRouteTableOutput, err := ec2Client.CreateRouteTable(context.Background(), &ec2.CreateRouteTableInput{ VpcId: aws.String(vpcId), }) @@ -201,16 +203,16 @@ func createRouteTable(t *testing.T, vpcId string, region string) ec2.RouteTable return *createRouteTableOutput.RouteTable } -func createSubnet(t *testing.T, vpcId string, routeTableId string, region string) ec2.Subnet { +func createSubnet(t *testing.T, vpcId string, routeTableId string, region string) types.Subnet { ec2Client := NewEc2Client(t, region) - createSubnetOutput, err := ec2Client.CreateSubnet(&ec2.CreateSubnetInput{ + createSubnetOutput, err := ec2Client.CreateSubnet(context.Background(), &ec2.CreateSubnetInput{ CidrBlock: aws.String("10.10.1.0/24"), VpcId: aws.String(vpcId), }) require.NoError(t, err) - _, err = ec2Client.AssociateRouteTable(&ec2.AssociateRouteTableInput{ + _, err = ec2Client.AssociateRouteTable(context.Background(), &ec2.AssociateRouteTableInput{ RouteTableId: aws.String(routeTableId), SubnetId: aws.String(*createSubnetOutput.Subnet.SubnetId), }) @@ -219,10 +221,10 @@ func createSubnet(t *testing.T, vpcId string, routeTableId string, region string return *createSubnetOutput.Subnet } -func createVpc(t *testing.T, region string) ec2.Vpc { +func createVpc(t *testing.T, region string) types.Vpc { ec2Client := NewEc2Client(t, region) - createVpcOutput, err := ec2Client.CreateVpc(&ec2.CreateVpcInput{ + createVpcOutput, err := ec2Client.CreateVpc(context.Background(), &ec2.CreateVpcInput{ CidrBlock: aws.String("10.10.0.0/16"), }) @@ -234,29 +236,29 @@ func deleteRouteTables(t *testing.T, vpcId string, region string) { ec2Client := NewEc2Client(t, region) vpcIDFilterName := "vpc-id" - vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}} + vpcIDFilter := types.Filter{Name: &vpcIDFilterName, Values: []string{vpcId}} // "You can't delete the main route table." mainRTFilterName := "association.main" mainRTFilterValue := "false" - notMainRTFilter := ec2.Filter{Name: &mainRTFilterName, Values: []*string{&mainRTFilterValue}} + notMainRTFilter := types.Filter{Name: &mainRTFilterName, Values: []string{mainRTFilterValue}} - filters := []*ec2.Filter{&vpcIDFilter, ¬MainRTFilter} + filters := []types.Filter{vpcIDFilter, notMainRTFilter} - rtOutput, err := ec2Client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: filters}) + rtOutput, err := ec2Client.DescribeRouteTables(context.Background(), &ec2.DescribeRouteTablesInput{Filters: filters}) require.NoError(t, err) for _, rt := range rtOutput.RouteTables { // "You must disassociate the route table from any subnets before you can delete it." for _, assoc := range rt.Associations { - _, disassocErr := ec2Client.DisassociateRouteTable(&ec2.DisassociateRouteTableInput{ + _, disassocErr := ec2Client.DisassociateRouteTable(context.Background(), &ec2.DisassociateRouteTableInput{ AssociationId: assoc.RouteTableAssociationId, }) require.NoError(t, disassocErr) } - _, err := ec2Client.DeleteRouteTable(&ec2.DeleteRouteTableInput{ + _, err := ec2Client.DeleteRouteTable(context.Background(), &ec2.DeleteRouteTableInput{ RouteTableId: rt.RouteTableId, }) require.NoError(t, err) @@ -266,13 +268,13 @@ func deleteRouteTables(t *testing.T, vpcId string, region string) { func deleteSubnets(t *testing.T, vpcId string, region string) { ec2Client := NewEc2Client(t, region) vpcIDFilterName := "vpc-id" - vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}} + vpcIDFilter := types.Filter{Name: &vpcIDFilterName, Values: []string{vpcId}} - subnetsOutput, err := ec2Client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIDFilter}}) + subnetsOutput, err := ec2Client.DescribeSubnets(context.Background(), &ec2.DescribeSubnetsInput{Filters: []types.Filter{vpcIDFilter}}) require.NoError(t, err) for _, subnet := range subnetsOutput.Subnets { - _, err := ec2Client.DeleteSubnet(&ec2.DeleteSubnetInput{ + _, err := ec2Client.DeleteSubnet(context.Background(), &ec2.DeleteSubnetInput{ SubnetId: subnet.SubnetId, }) require.NoError(t, err) @@ -282,20 +284,20 @@ func deleteSubnets(t *testing.T, vpcId string, region string) { func deleteInternetGateways(t *testing.T, vpcId string, region string) { ec2Client := NewEc2Client(t, region) vpcIDFilterName := "attachment.vpc-id" - vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}} + vpcIDFilter := types.Filter{Name: &vpcIDFilterName, Values: []string{vpcId}} - igwOutput, err := ec2Client.DescribeInternetGateways(&ec2.DescribeInternetGatewaysInput{Filters: []*ec2.Filter{&vpcIDFilter}}) + igwOutput, err := ec2Client.DescribeInternetGateways(context.Background(), &ec2.DescribeInternetGatewaysInput{Filters: []types.Filter{vpcIDFilter}}) require.NoError(t, err) for _, igw := range igwOutput.InternetGateways { - _, detachErr := ec2Client.DetachInternetGateway(&ec2.DetachInternetGatewayInput{ + _, detachErr := ec2Client.DetachInternetGateway(context.Background(), &ec2.DetachInternetGatewayInput{ InternetGatewayId: igw.InternetGatewayId, VpcId: aws.String(vpcId), }) require.NoError(t, detachErr) - _, err := ec2Client.DeleteInternetGateway(&ec2.DeleteInternetGatewayInput{ + _, err := ec2Client.DeleteInternetGateway(context.Background(), &ec2.DeleteInternetGatewayInput{ InternetGatewayId: igw.InternetGatewayId, }) require.NoError(t, err) @@ -309,7 +311,7 @@ func deleteVpc(t *testing.T, vpcId string, region string) { deleteSubnets(t, vpcId, region) deleteInternetGateways(t, vpcId, region) - _, err := ec2Client.DeleteVpc(&ec2.DeleteVpcInput{ + _, err := ec2Client.DeleteVpc(context.Background(), &ec2.DeleteVpcInput{ VpcId: aws.String(vpcId), }) require.NoError(t, err) diff --git a/test/packer_basic_example_test.go b/test/packer_basic_example_test.go index 2952c941b..8ba867c8f 100644 --- a/test/packer_basic_example_test.go +++ b/test/packer_basic_example_test.go @@ -1,13 +1,15 @@ package test import ( + "context" "fmt" "os" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" terratest_aws "github.com/gruntwork-io/terratest/modules/aws" "github.com/gruntwork-io/terratest/modules/packer" "github.com/gruntwork-io/terratest/modules/random" @@ -76,10 +78,9 @@ func TestPackerBasicExample(t *testing.T) { assert.Contains(t, accountsWithLaunchPermissions, requestingAccount) // website::tag::3::Check AMI's properties. - // Check if AMI is public - MakeAmiPublic(t, amiID, ec2Client) + // Check if AMI is private amiIsPublic := terratest_aws.GetAmiPubliclyAccessible(t, awsRegion, amiID) - assert.True(t, amiIsPublic) + assert.False(t, amiIsPublic) } // An example of how to test the Packer template in examples/packer-basic-example using Terratest @@ -110,7 +111,7 @@ func TestPackerBasicExampleWithVarFile(t *testing.T) { // The path to where the Packer template is located Template: "../examples/packer-basic-example/build.pkr.hcl", - // Variable file to to pass to our Packer build using -var-file option + // Variable file to pass to our Packer build using -var-file option VarFiles: []string{ varFile.Name(), }, @@ -144,10 +145,9 @@ func TestPackerBasicExampleWithVarFile(t *testing.T) { assert.NotContains(t, accountsWithLaunchPermissions, randomAccount) assert.Contains(t, accountsWithLaunchPermissions, requestingAccount) - // Check if AMI is public - MakeAmiPublic(t, amiID, ec2Client) + // Check if AMI is private amiIsPublic := terratest_aws.GetAmiPubliclyAccessible(t, awsRegion, amiID) - assert.True(t, amiIsPublic) + assert.False(t, amiIsPublic) } func TestPackerMultipleConcurrentAmis(t *testing.T) { @@ -196,35 +196,18 @@ func TestPackerMultipleConcurrentAmis(t *testing.T) { } } -func ShareAmi(t *testing.T, amiID string, accountID string, ec2Client *ec2.EC2) { +func ShareAmi(t *testing.T, amiID string, accountID string, ec2Client *ec2.Client) { input := &ec2.ModifyImageAttributeInput{ ImageId: aws.String(amiID), - LaunchPermission: &ec2.LaunchPermissionModifications{ - Add: []*ec2.LaunchPermission{ + LaunchPermission: &types.LaunchPermissionModifications{ + Add: []types.LaunchPermission{ { UserId: aws.String(accountID), }, }, }, } - _, err := ec2Client.ModifyImageAttribute(input) - if err != nil { - t.Fatal(err) - } -} - -func MakeAmiPublic(t *testing.T, amiID string, ec2Client *ec2.EC2) { - input := &ec2.ModifyImageAttributeInput{ - ImageId: aws.String(amiID), - LaunchPermission: &ec2.LaunchPermissionModifications{ - Add: []*ec2.LaunchPermission{ - { - Group: aws.String("all"), - }, - }, - }, - } - _, err := ec2Client.ModifyImageAttribute(input) + _, err := ec2Client.ModifyImageAttribute(context.Background(), input) if err != nil { t.Fatal(err) } diff --git a/test/terraform_aws_dynamodb_example_test.go b/test/terraform_aws_dynamodb_example_test.go index c3b5ff3d9..5d97fab84 100644 --- a/test/terraform_aws_dynamodb_example_test.go +++ b/test/terraform_aws_dynamodb_example_test.go @@ -4,8 +4,8 @@ import ( "fmt" "testing" - awsSDK "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + awsSDK "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/gruntwork-io/terratest/modules/aws" "github.com/gruntwork-io/terratest/modules/random" "github.com/gruntwork-io/terratest/modules/terraform" @@ -22,11 +22,11 @@ func TestTerraformAwsDynamoDBExample(t *testing.T) { // Set up expected values to be checked later expectedTableName := fmt.Sprintf("terratest-aws-dynamodb-example-table-%s", random.UniqueId()) expectedKmsKeyArn := aws.GetCmkArn(t, awsRegion, "alias/aws/dynamodb") - expectedKeySchema := []*dynamodb.KeySchemaElement{ - {AttributeName: awsSDK.String("userId"), KeyType: awsSDK.String("HASH")}, - {AttributeName: awsSDK.String("department"), KeyType: awsSDK.String("RANGE")}, + expectedKeySchema := []types.KeySchemaElement{ + {AttributeName: awsSDK.String("userId"), KeyType: types.KeyTypeHash}, + {AttributeName: awsSDK.String("department"), KeyType: types.KeyTypeRange}, } - expectedTags := []*dynamodb.Tag{ + expectedTags := []types.Tag{ {Key: awsSDK.String("Environment"), Value: awsSDK.String("production")}, } @@ -52,18 +52,18 @@ func TestTerraformAwsDynamoDBExample(t *testing.T) { // Look up the DynamoDB table by name table := aws.GetDynamoDBTable(t, awsRegion, expectedTableName) - assert.Equal(t, "ACTIVE", awsSDK.StringValue(table.TableStatus)) + assert.Equal(t, "ACTIVE", string(table.TableStatus)) assert.ElementsMatch(t, expectedKeySchema, table.KeySchema) // Verify server-side encryption configuration - assert.Equal(t, expectedKmsKeyArn, awsSDK.StringValue(table.SSEDescription.KMSMasterKeyArn)) - assert.Equal(t, "ENABLED", awsSDK.StringValue(table.SSEDescription.Status)) - assert.Equal(t, "KMS", awsSDK.StringValue(table.SSEDescription.SSEType)) + assert.Equal(t, expectedKmsKeyArn, awsSDK.ToString(table.SSEDescription.KMSMasterKeyArn)) + assert.Equal(t, "ENABLED", string(table.SSEDescription.Status)) + assert.Equal(t, "KMS", string(table.SSEDescription.SSEType)) // Verify TTL configuration ttl := aws.GetDynamoDBTableTimeToLive(t, awsRegion, expectedTableName) - assert.Equal(t, "expires", awsSDK.StringValue(ttl.AttributeName)) - assert.Equal(t, "ENABLED", awsSDK.StringValue(ttl.TimeToLiveStatus)) + assert.Equal(t, "expires", awsSDK.ToString(ttl.AttributeName)) + assert.Equal(t, "ENABLED", string(ttl.TimeToLiveStatus)) // Verify resource tags tags := aws.GetDynamoDbTableTags(t, awsRegion, expectedTableName) diff --git a/test/terraform_aws_ecs_example_test.go b/test/terraform_aws_ecs_example_test.go index c2bba03d1..6b195526c 100644 --- a/test/terraform_aws_ecs_example_test.go +++ b/test/terraform_aws_ecs_example_test.go @@ -4,11 +4,12 @@ import ( "fmt" "testing" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/terratest/modules/aws" "github.com/gruntwork-io/terratest/modules/random" "github.com/gruntwork-io/terratest/modules/terraform" - awsSDK "github.com/aws/aws-sdk-go/aws" + awsSDK "github.com/aws/aws-sdk-go-v2/aws" "github.com/stretchr/testify/assert" ) @@ -48,18 +49,18 @@ func TestTerraformAwsEcsExample(t *testing.T) { // Look up the ECS cluster by name cluster := aws.GetEcsCluster(t, awsRegion, expectedClusterName) - assert.Equal(t, int64(1), awsSDK.Int64Value(cluster.ActiveServicesCount)) + assert.Equal(t, int32(1), cluster.ActiveServicesCount) // Look up the ECS service by name service := aws.GetEcsService(t, awsRegion, expectedClusterName, expectedServiceName) - assert.Equal(t, int64(0), awsSDK.Int64Value(service.DesiredCount)) - assert.Equal(t, "FARGATE", awsSDK.StringValue(service.LaunchType)) + assert.Equal(t, int32(0), service.DesiredCount) + assert.Equal(t, types.LaunchTypeFargate, service.LaunchType) // Look up the ECS task definition by ARN task := aws.GetEcsTaskDefinition(t, awsRegion, taskDefinition) - assert.Equal(t, "256", awsSDK.StringValue(task.Cpu)) - assert.Equal(t, "512", awsSDK.StringValue(task.Memory)) - assert.Equal(t, "awsvpc", awsSDK.StringValue(task.NetworkMode)) + assert.Equal(t, "256", awsSDK.ToString(task.Cpu)) + assert.Equal(t, "512", awsSDK.ToString(task.Memory)) + assert.Equal(t, types.NetworkModeAwsvpc, task.NetworkMode) } diff --git a/test/terraform_aws_lambda_example_test.go b/test/terraform_aws_lambda_example_test.go index 08c510166..5de0f8a77 100644 --- a/test/terraform_aws_lambda_example_test.go +++ b/test/terraform_aws_lambda_example_test.go @@ -136,7 +136,7 @@ func TestTerraformAwsLambdaWithParamsExample(t *testing.T) { // With "DryRun", there's no message in the output, but there is // a status code which will have a value of 204 for a successful // invocation. - assert.Equal(t, int(*out.StatusCode), 204) + assert.Equal(t, int(out.StatusCode), 204) // Invoke the function, this time causing the Lambda to error and // capturing the error. diff --git a/test/terraform_aws_rds_example_test.go b/test/terraform_aws_rds_example_test.go index c69e9f92f..efac235ea 100644 --- a/test/terraform_aws_rds_example_test.go +++ b/test/terraform_aws_rds_example_test.go @@ -20,7 +20,7 @@ func TestTerraformAwsRdsExample(t *testing.T) { majorEngineVersion string engineFamily string licenseModel string - schemaCheck func(t *testing.T, dbUrl string, dbPort int64, dbUsername string, dbPassword string, expectedSchemaName string) bool + schemaCheck func(t *testing.T, dbUrl string, dbPort int32, dbUsername string, dbPassword string, expectedSchemaName string) bool expectedOptins map[struct { opName string setName string @@ -33,7 +33,7 @@ func TestTerraformAwsRdsExample(t *testing.T) { majorEngineVersion: "5.7", engineFamily: "mysql5.7", licenseModel: "general-public-license", - schemaCheck: func(t *testing.T, dbUrl string, dbPort int64, dbUsername, dbPassword, expectedSchemaName string) bool { + schemaCheck: func(t *testing.T, dbUrl string, dbPort int32, dbUsername, dbPassword, expectedSchemaName string) bool { return aws.GetWhetherSchemaExistsInRdsMySqlInstance(t, dbUrl, dbPort, dbUsername, dbPassword, expectedSchemaName) }, expectedOptins: map[struct { @@ -53,7 +53,7 @@ func TestTerraformAwsRdsExample(t *testing.T) { majorEngineVersion: "13", engineFamily: "postgres13", licenseModel: "postgresql-license", - schemaCheck: func(t *testing.T, dbUrl string, dbPort int64, dbUsername, dbPassword, expectedSchemaName string) bool { + schemaCheck: func(t *testing.T, dbUrl string, dbPort int32, dbUsername, dbPassword, expectedSchemaName string) bool { return aws.GetWhetherSchemaExistsInRdsPostgresInstance(t, dbUrl, dbPort, dbUsername, dbPassword, expectedSchemaName) }, }, @@ -67,7 +67,7 @@ func TestTerraformAwsRdsExample(t *testing.T) { // Give this RDS Instance a unique ID for a name tag so we can distinguish it from any other RDS Instance running // in your AWS account expectedName := fmt.Sprintf("terratest-aws-rds-example-%s", strings.ToLower(random.UniqueId())) - expectedPort := int64(3306) + expectedPort := int32(3306) expectedDatabaseName := "terratest" username := "username" password := "password" diff --git a/test/terraform_packer_example_test.go b/test/terraform_packer_example_test.go index 3afd72162..081257935 100644 --- a/test/terraform_packer_example_test.go +++ b/test/terraform_packer_example_test.go @@ -8,7 +8,6 @@ import ( "github.com/gruntwork-io/terratest/modules/aws" http_helper "github.com/gruntwork-io/terratest/modules/http-helper" - "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/packer" "github.com/gruntwork-io/terratest/modules/random" "github.com/gruntwork-io/terratest/modules/terraform" @@ -37,13 +36,6 @@ func TestTerraformPackerExample(t *testing.T) { undeployUsingTerraform(t, workingDir) }) - // At the end of the test, fetch the most recent syslog entries from each Instance. This can be useful for - // debugging issues without having to manually SSH to the server. - defer test_structure.RunTestStage(t, "logs", func() { - awsRegion := test_structure.LoadString(t, workingDir, "awsRegion") - fetchSyslogForInstance(t, awsRegion, workingDir) - }) - // Build the AMI for the web app test_structure.RunTestStage(t, "build_ami", func() { // Pick a random AWS region to test in. This helps ensure your code works in all regions. @@ -156,18 +148,6 @@ func undeployUsingTerraform(t *testing.T, workingDir string) { terraform.Destroy(t, terraformOptions) } -// Fetch the most recent syslogs for the instance. This is a handy way to see what happened on the Instance as part of -// your test log output, without having to re-run the test and manually SSH to the Instance. -func fetchSyslogForInstance(t *testing.T, awsRegion string, workingDir string) { - // Load the Terraform Options saved by the earlier deploy_terraform stage - terraformOptions := test_structure.LoadTerraformOptions(t, workingDir) - - instanceID := terraform.OutputRequired(t, terraformOptions, "instance_id") - logs := aws.GetSyslogForInstance(t, instanceID, awsRegion) - - logger.Logf(t, "Most recent syslog for Instance %s:\n\n%s\n", instanceID, logs) -} - // Validate the web server has been deployed and is working func validateInstanceRunningWebServer(t *testing.T, workingDir string) { // Load the Terraform Options saved by the earlier deploy_terraform stage