Skip to content

Commit

Permalink
feat: updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
1101-1 committed Sep 17, 2024
1 parent 9c561f0 commit 563b504
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 45 deletions.
70 changes: 35 additions & 35 deletions plugins/aws/fix_plugin_aws/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,41 +74,41 @@
+ waf.resources
)
regional_resources: List[Type[AwsResource]] = (
# sagemaker.resources # start with sagemaker, because it is very slow
# + acm.resources
# + apigateway.resources
# + autoscaling.resources
# + athena.resources
# + config.resources
# + cloudformation.resources
# + cloudtrail.resources
# + cloudwatch.resources
# + cognito.resources
# + dynamodb.resources
# + ec2.resources
# + efs.resources
# + ecs.resources
# + ecr.resources
# + eks.resources
# + elasticbeanstalk.resources
# + elasticache.resources
# + elb.resources
# + elbv2.resources
# + glacier.resources
# + kinesis.resources
# + kms.resources
# + lambda_.resources
# + opensearch.resources
# + rds.resources
# + secretsmanager.resources
# + service_quotas.resources
# + sns.resources
# + ssm.resources
# + sqs.resources
# + redshift.resources
# + backup.resources
# + amazonq.resources
bedrock.resources
sagemaker.resources # start with sagemaker, because it is very slow
+ acm.resources
+ apigateway.resources
+ autoscaling.resources
+ athena.resources
+ config.resources
+ cloudformation.resources
+ cloudtrail.resources
+ cloudwatch.resources
+ cognito.resources
+ dynamodb.resources
+ ec2.resources
+ efs.resources
+ ecs.resources
+ ecr.resources
+ eks.resources
+ elasticbeanstalk.resources
+ elasticache.resources
+ elb.resources
+ elbv2.resources
+ glacier.resources
+ kinesis.resources
+ kms.resources
+ lambda_.resources
+ opensearch.resources
+ rds.resources
+ secretsmanager.resources
+ service_quotas.resources
+ sns.resources
+ ssm.resources
+ sqs.resources
+ redshift.resources
+ backup.resources
+ amazonq.resources
+ bedrock.resources
)
all_resources: List[Type[AwsResource]] = global_resources + regional_resources

Expand Down
17 changes: 16 additions & 1 deletion plugins/aws/fix_plugin_aws/resource/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def service_name(cls) -> str:
@define(eq=False, slots=False)
class AwsBedrockFoundationModel(BaseAIModel, AwsResource):
kind: ClassVar[str] = "aws_bedrock_foundation_model"
kind_display: ClassVar[str] = "AWS Bedrock Foundation Model"
kind_description: ClassVar[str] = (
"AWS Bedrock Foundation Model represents the base machine learning models provided by AWS Bedrock. "
"These models are pre-trained and can be used as a starting point for various machine learning tasks."
)
kind_service: ClassVar[Optional[str]] = service_name
aws_metadata: ClassVar[Dict[str, Any]] = {"provider_link_tpl": "https://{region_id}.console.aws.amazon.com/bedrock/home?region={region_id}#/providers?model={id}"} # fmt: skip
metadata: ClassVar[Dict[str, Any]] = {"icon": "resource", "group": "ai"}
api_spec: ClassVar[AwsApiSpec] = AwsApiSpec("bedrock", "list-foundation-models", "modelSummaries")
mapping: ClassVar[Dict[str, Bender]] = {
"id": S("modelId"),
Expand Down Expand Up @@ -133,6 +141,7 @@ class AwsBedrockCustomModel(BedrockTaggable, BaseAIModel, AwsResource):
metadata: ClassVar[Dict[str, Any]] = {"icon": "resource", "group": "ai"}
reference_kinds: ClassVar[ModelReference] = {
"successors": {"default": ["aws_bedrock_model_customization_job", AwsKmsKey.kind]},
"predecessors": {"default": [AwsBedrockFoundationModel.kind]},
}
api_spec: ClassVar[AwsApiSpec] = AwsApiSpec("bedrock", "list-custom-models", "modelSummaries")
mapping: ClassVar[Dict[str, Bender]] = {
Expand Down Expand Up @@ -173,6 +182,8 @@ class AwsBedrockCustomModel(BedrockTaggable, BaseAIModel, AwsResource):
def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if job_arn := self.job_arn:
builder.add_edge(self, clazz=AwsBedrockModelCustomizationJob, id=job_arn)
if base_model_arn := self.base_model_arn:
builder.add_edge(self, reverse=True, clazz=AwsBedrockFoundationModel, arn=base_model_arn)
if model_kms_key_arn := self.model_kms_key_arn:
builder.add_edge(self, clazz=AwsKmsKey, arn=model_kms_key_arn)

Expand Down Expand Up @@ -529,7 +540,9 @@ class AwsBedrockModelCustomizationJob(BedrockTaggable, BaseAIJob, AwsResource):
aws_metadata: ClassVar[Dict[str, Any]] = {"provider_link_tpl": "https://{region_id}.console.aws.amazon.com/bedrock/home?region={region_id}#/custom-models/item/?arn={arn}"} # fmt: skip
metadata: ClassVar[Dict[str, Any]] = {"icon": "job", "group": "ai"}
reference_kinds: ClassVar[ModelReference] = {
"predecessors": {"default": [AwsEc2Subnet.kind, AwsEc2SecurityGroup.kind, AwsIamRole.kind]},
"predecessors": {
"default": [AwsEc2Subnet.kind, AwsEc2SecurityGroup.kind, AwsIamRole.kind, AwsBedrockFoundationModel.kind]
},
"successors": {"default": [AwsKmsKey.kind, AwsS3Bucket.kind]},
}
api_spec: ClassVar[AwsApiSpec] = AwsApiSpec(
Expand Down Expand Up @@ -591,6 +604,8 @@ class AwsBedrockModelCustomizationJob(BedrockTaggable, BaseAIJob, AwsResource):
def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if role_arn := self.role_arn:
builder.add_edge(self, reverse=True, clazz=AwsIamRole, arn=role_arn)
if base_model_arn := self.base_model_arn:
builder.add_edge(self, reverse=True, clazz=AwsBedrockFoundationModel, arn=base_model_arn)
if model_kms_key_arn := self.output_model_kms_key_arn:
builder.add_edge(self, clazz=AwsKmsKey, arn=model_kms_key_arn)
if output_data_config := self.output_data_config:
Expand Down
4 changes: 2 additions & 2 deletions plugins/aws/test/collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def count_kind(clazz: Type[AwsResource]) -> int:
# make sure all threads have been joined
assert len(threading.enumerate()) == 1
# ensure the correct number of nodes and edges
assert count_kind(AwsResource) == 260
assert len(account_collector.graph.edges) == 576
assert count_kind(AwsResource) == 261
assert len(account_collector.graph.edges) == 579
assert len(account_collector.graph.deferred_edges) == 2
for node in account_collector.graph.nodes:
if isinstance(node, AwsRegion):
Expand Down
5 changes: 5 additions & 0 deletions plugins/aws/test/resources/bedrock_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AwsBedrockAgent,
AwsBedrockAgentPrompt,
AwsBedrockAgentFlow,
AwsBedrockFoundationModel,
)
from test.resources import round_trip_for

Expand Down Expand Up @@ -41,3 +42,7 @@ def test_bedrock_agent_prompts() -> None:

def test_bedrock_agent_flows() -> None:
round_trip_for(AwsBedrockAgentFlow, ignore_checking_props=True)


def test_bedrock_foundation_model() -> None:
round_trip_for(AwsBedrockFoundationModel, ignore_checking_props=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"modelSummaries": [
{
"modelArn": "foo",
"modelId": "foo",
"modelName": "foo",
"providerName": "foo",
"inputModalities": [
"IMAGE",
"IMAGE",
"IMAGE"
],
"outputModalities": [
"IMAGE",
"IMAGE",
"IMAGE"
],
"responseStreamingSupported": true,
"customizationsSupported": [
"CONTINUED_PRE_TRAINING",
"CONTINUED_PRE_TRAINING",
"CONTINUED_PRE_TRAINING"
],
"inferenceTypesSupported": [
"PROVISIONED",
"PROVISIONED",
"PROVISIONED"
],
"modelLifecycle": {
"status": "LEGACY"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"description": "foo",
"knowledgeBaseId": "foo",
"name": "foo",
"status": "ACTIVE",
"updatedAt": "2024-09-17T12:11:47Z"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"knowledgeBaseSummaries": [
{
"description": "foo",
"knowledgeBaseId": "foo",
"name": "foo",
"status": "ACTIVE",
"updatedAt": "2024-09-17T12:11:47Z"
}
],
"nextToken": "foo"
}
14 changes: 7 additions & 7 deletions plugins/aws/tools/aws_model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,12 +968,12 @@ def default_imports() -> str:
# ),
],
"bedrock": [
AwsFixModel(
api_action="list-foundation-models",
result_property="modelSummaries",
result_shape="ListFoundationModelsResponse",
prefix="Bedrock",
)
# AwsFixModel(
# api_action="list-foundation-models",
# result_property="modelSummaries",
# result_shape="ListFoundationModelsResponse",
# prefix="Bedrock",
# )
],
"bedrock-agent": [
# AwsFixModel(
Expand All @@ -988,7 +988,7 @@ def default_imports() -> str:

if __name__ == "__main__":
"""print some test data"""
print(json.dumps(create_test_response("bedrock", "list-foundation-models"), indent=2))
print(json.dumps(create_test_response("bedrock-agent", "get-knowledge-base"), indent=2))

"""print the class models"""
# print(default_imports())
Expand Down

0 comments on commit 563b504

Please sign in to comment.