diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index e79e308693..bbc7d9ab72 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -5368,6 +5368,13 @@ def deploy( system_labels=system_labels, ) + def _should_enable_dedicated_endpoint(self, fast_tryout_enabled: bool) -> bool: + """Check if dedicated endpoint should be enabled for this endpoint. + + Returns True if endpoint should be a dedicated endpoint. + """ + return fast_tryout_enabled + @base.optional_sync(return_input_arg="endpoint", bind_future_to_self=False) def _deploy( self, @@ -5548,6 +5555,9 @@ def _deploy( location=self.location, credentials=self.credentials, encryption_spec_key_name=encryption_spec_key_name, + dedicated_endpoint_enabled=self._should_enable_dedicated_endpoint( + fast_tryout_enabled + ), ) else: endpoint = PrivateEndpoint.create( diff --git a/google/cloud/aiplatform/preview/models.py b/google/cloud/aiplatform/preview/models.py index 1c6b6bc972..11837cd854 100644 --- a/google/cloud/aiplatform/preview/models.py +++ b/google/cloud/aiplatform/preview/models.py @@ -1548,6 +1548,13 @@ def deploy( system_labels=system_labels, ) + def _should_enable_dedicated_endpoint(self, fast_tryout_enabled: bool) -> bool: + """Check if dedicated endpoint should be enabled for this endpoint. + + Returns True if endpoint should be a dedicated endpoint. + """ + return fast_tryout_enabled + @base.optional_sync(return_input_arg="endpoint", bind_future_to_self=False) def _deploy( self, @@ -1689,6 +1696,9 @@ def _deploy( location=self.location, credentials=self.credentials, encryption_spec_key_name=encryption_spec_key_name, + dedicated_endpoint_enabled=self._should_enable_dedicated_endpoint( + fast_tryout_enabled + ), ) else: endpoint = models.PrivateEndpoint.create( diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index e5ab109d42..edde450308 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -556,6 +556,20 @@ def get_endpoint_mock(): yield get_endpoint_mock +@pytest.fixture +def create_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "create_endpoint" + ) as create_endpoint_mock: + create_endpoint_lro_mock = mock.Mock(ga_operation.Operation) + create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint( + name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME, + display_name=test_constants.EndpointConstants._TEST_DISPLAY_NAME, + ) + create_endpoint_mock.return_value = create_endpoint_lro_mock + yield create_endpoint_mock + + @pytest.fixture def get_model_mock(): with mock.patch.object( @@ -2531,7 +2545,7 @@ def test_deploy_disable_container_logging(self, deploy_model_mock, sync): ) @pytest.mark.parametrize("sync", [True, False]) def test_preview_deploy_with_fast_tryout_enabled( - self, preview_deploy_model_mock, sync + self, preview_deploy_model_mock, create_endpoint_mock, sync ): test_model = models.Model(_TEST_ID).preview test_model._gca_resource.supported_deployment_resources_types.append( @@ -2551,6 +2565,19 @@ def test_preview_deploy_with_fast_tryout_enabled( if not sync: test_endpoint.wait() + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_MODEL_NAME + "_endpoint", + dedicated_endpoint_enabled=True, + ) + + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, + endpoint=expected_endpoint, + metadata=(), + timeout=None, + endpoint_id=None, + ) + expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, @@ -2583,7 +2610,9 @@ def test_preview_deploy_with_fast_tryout_enabled( "get_endpoint_mock", ) @pytest.mark.parametrize("sync", [True, False]) - def test_deploy_with_fast_tryout_enabled(self, deploy_model_mock, sync): + def test_deploy_with_fast_tryout_enabled( + self, deploy_model_mock, create_endpoint_mock, sync + ): test_model = models.Model(_TEST_ID) test_model._gca_resource.supported_deployment_resources_types.append( aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES @@ -2602,6 +2631,19 @@ def test_deploy_with_fast_tryout_enabled(self, deploy_model_mock, sync): if not sync: test_endpoint.wait() + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_MODEL_NAME + "_endpoint", + dedicated_endpoint_enabled=True, + ) + + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, + endpoint=expected_endpoint, + metadata=(), + timeout=None, + endpoint_id=None, + ) + expected_machine_spec = gca_machine_resources.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE,