Skip to content

Commit

Permalink
Fix test for upi standard transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
tiopramayudi committed Mar 5, 2024
1 parent 09a8e71 commit b6a7b99
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions python/sdk/test/pyfunc_upi_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import xgboost as xgb
from caraml.upi.utils import df_to_table, table_to_df
from caraml.upi.v1 import type_pb2, upi_pb2, upi_pb2_grpc, variable_pb2
from google.protobuf.json_format import MessageToDict, ParseDict
from merlin.deployment_mode import DeploymentMode
from merlin.endpoint import Status
from merlin.model import ModelType, PyFuncModel
Expand Down Expand Up @@ -63,7 +64,13 @@ class SimpleForwarder(PyFuncModel):
target_name = "probability"

def infer(self, request: dict, **kwargs):
return request
upi_request = upi_pb2.PredictValuesRequest()
ParseDict(request, upi_request)
upi_response = upi_pb2.PredictValuesResponse(
prediction_result_table=upi_request.prediction_table,
target_name=self.target_name,
)
return MessageToDict(upi_response)

def upiv1_infer(
self, request: upi_pb2.PredictValuesRequest, context: grpc.ServicerContext
Expand Down Expand Up @@ -107,12 +114,14 @@ def test_deploy(integration_test_url, project_name, use_google_oauth, requests):

@pytest.mark.pyfunc
@pytest.mark.integration
@pytest.mark.parametrize("upi_http_predictor", [False, True])
def test_pyfunc_with_standard_transformer(
integration_test_url, project_name, use_google_oauth, requests
integration_test_url, project_name, upi_http_predictor, use_google_oauth, requests
):
merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth)
merlin.set_project(project_name)
merlin.set_model("pyfunc-upi-std", ModelType.PYFUNC)
model_name = "pyfunc-upi-std-http" if upi_http_predictor else "pyfunc-upi-std"
merlin.set_model(model_name, ModelType.PYFUNC)

undeploy_all_version()
with merlin.new_model_version() as v:
Expand All @@ -125,10 +134,12 @@ def test_pyfunc_with_standard_transformer(
transformer_config_path = os.path.join(
"test/transformer", "upi_standard_transformer_no_feast.yaml"
)

upi_http_enabled_val = "true" if upi_http_predictor else "false"
transformer = StandardTransformer(
config_file=transformer_config_path,
enabled=True,
env_vars={"PREDICTOR_UPI_HTTP_ENABLED": "true"},
env_vars={"PREDICTOR_UPI_HTTP_ENABLED": upi_http_enabled_val},
)
endpoint = merlin.deploy(v, transformer=transformer, protocol=Protocol.UPI_V1)

Expand Down

0 comments on commit b6a7b99

Please sign in to comment.