Skip to content

Commit

Permalink
chore: when fast_tryout_enabled is set, auto endpoint creation will c…
Browse files Browse the repository at this point in the history
…reate dedicated endopoint

PiperOrigin-RevId: 699249704
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 22, 2024
1 parent a56e4dd commit 1487846
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
10 changes: 10 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5368,6 +5368,13 @@ def deploy(
system_labels=system_labels,
)

def _should_enable_dedicated_endpoint(self, fast_tryout_enabled: bool) -> bool:
"""Check if dedicated endpoint should be enabled for this endpoint.
Returns True if endpoint should be a dedicated endpoint.
"""
return fast_tryout_enabled

@base.optional_sync(return_input_arg="endpoint", bind_future_to_self=False)
def _deploy(
self,
Expand Down Expand Up @@ -5548,6 +5555,9 @@ def _deploy(
location=self.location,
credentials=self.credentials,
encryption_spec_key_name=encryption_spec_key_name,
dedicated_endpoint_enabled=self._should_enable_dedicated_endpoint(
fast_tryout_enabled
),
)
else:
endpoint = PrivateEndpoint.create(
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/aiplatform/preview/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,13 @@ def deploy(
system_labels=system_labels,
)

def _should_enable_dedicated_endpoint(self, fast_tryout_enabled: bool) -> bool:
"""Check if dedicated endpoint should be enabled for this endpoint.
Returns True if endpoint should be a dedicated endpoint.
"""
return fast_tryout_enabled

@base.optional_sync(return_input_arg="endpoint", bind_future_to_self=False)
def _deploy(
self,
Expand Down Expand Up @@ -1689,6 +1696,9 @@ def _deploy(
location=self.location,
credentials=self.credentials,
encryption_spec_key_name=encryption_spec_key_name,
dedicated_endpoint_enabled=self._should_enable_dedicated_endpoint(
fast_tryout_enabled
),
)
else:
endpoint = models.PrivateEndpoint.create(
Expand Down
46 changes: 44 additions & 2 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,20 @@ def get_endpoint_mock():
yield get_endpoint_mock


@pytest.fixture
def create_endpoint_mock():
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "create_endpoint"
) as create_endpoint_mock:
create_endpoint_lro_mock = mock.Mock(ga_operation.Operation)
create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint(
name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
display_name=test_constants.EndpointConstants._TEST_DISPLAY_NAME,
)
create_endpoint_mock.return_value = create_endpoint_lro_mock
yield create_endpoint_mock


@pytest.fixture
def get_model_mock():
with mock.patch.object(
Expand Down Expand Up @@ -2531,7 +2545,7 @@ def test_deploy_disable_container_logging(self, deploy_model_mock, sync):
)
@pytest.mark.parametrize("sync", [True, False])
def test_preview_deploy_with_fast_tryout_enabled(
self, preview_deploy_model_mock, sync
self, preview_deploy_model_mock, create_endpoint_mock, sync
):
test_model = models.Model(_TEST_ID).preview
test_model._gca_resource.supported_deployment_resources_types.append(
Expand All @@ -2551,6 +2565,19 @@ def test_preview_deploy_with_fast_tryout_enabled(
if not sync:
test_endpoint.wait()

expected_endpoint = gca_endpoint.Endpoint(
display_name=_TEST_MODEL_NAME + "_endpoint",
dedicated_endpoint_enabled=True,
)

create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
metadata=(),
timeout=None,
endpoint_id=None,
)

expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec(
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
Expand Down Expand Up @@ -2583,7 +2610,9 @@ def test_preview_deploy_with_fast_tryout_enabled(
"get_endpoint_mock",
)
@pytest.mark.parametrize("sync", [True, False])
def test_deploy_with_fast_tryout_enabled(self, deploy_model_mock, sync):
def test_deploy_with_fast_tryout_enabled(
self, deploy_model_mock, create_endpoint_mock, sync
):
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
Expand All @@ -2602,6 +2631,19 @@ def test_deploy_with_fast_tryout_enabled(self, deploy_model_mock, sync):
if not sync:
test_endpoint.wait()

expected_endpoint = gca_endpoint.Endpoint(
display_name=_TEST_MODEL_NAME + "_endpoint",
dedicated_endpoint_enabled=True,
)

create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
metadata=(),
timeout=None,
endpoint_id=None,
)

expected_machine_spec = gca_machine_resources.MachineSpec(
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
Expand Down

0 comments on commit 1487846

Please sign in to comment.