Skip to content

Commit

Permalink
Close all sessions and free resources when a test completes
Browse files Browse the repository at this point in the history
Currently pytest hangs if a test fails, because resources are not cleaned up. Use a fixture to auto-clean during the teardown of the test.
  • Loading branch information
thodkatz committed Sep 9, 2024
1 parent cdb2c4a commit aac6768
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 33 deletions.
33 changes: 10 additions & 23 deletions tests/test_server/test_grpc/test_inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def grpc_servicer(data_store):
return inference_servicer.InferenceServicer(TorchDevicePool(), SessionManager(), data_store)


@pytest.fixture(autouse=True)
def clean(grpc_servicer):
yield
grpc_servicer.close_all_sessions()


@pytest.fixture(scope="module")
def grpc_stub_cls(grpc_channel):
return inference_pb2_grpc.InferenceStub
Expand All @@ -47,15 +53,13 @@ def method_requiring_session(self, request, grpc_stub):
def test_model_session_creation(self, grpc_stub, bioimageio_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes))
assert model.id
grpc_stub.CloseModelSession(model)

def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_explicit_model_bytes):
id_ = data_store.put(bioimageio_dummy_explicit_model_bytes.getvalue())

rq = inference_pb2.CreateModelSessionRequest(model_uri=f"upload://{id_}", deviceIds=["cpu"])
model = grpc_stub.CreateModelSession(rq)
assert model.id
grpc_stub.CloseModelSession(model)

def test_model_session_creation_using_random_uri(self, grpc_stub):
rq = inference_pb2.CreateModelSessionRequest(model_uri="randomSchema://", deviceIds=["cpu"])
Expand Down Expand Up @@ -92,36 +96,28 @@ def test_if_model_create_fails_devices_are_released(self, grpc_stub):
model_blob=inference_pb2.Blob(content=b""), deviceIds=["cpu"]
)

model = None
with pytest.raises(Exception):
model = grpc_stub.CreateModelSession(model_req)
grpc_stub.CreateModelSession(model_req)

device_by_id = self._query_devices(grpc_stub)
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status

if model:
grpc_stub.CloseModelSession(model)

def test_use_device(self, grpc_stub, bioimageio_model_bytes):
device_by_id = self._query_devices(grpc_stub)
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status

model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes, device_ids=["cpu"]))
grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes, device_ids=["cpu"]))

device_by_id = self._query_devices(grpc_stub)
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.IN_USE == device_by_id["cpu"].status

grpc_stub.CloseModelSession(model)

def test_using_same_device_fails(self, grpc_stub, bioimageio_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes, device_ids=["cpu"]))
grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes, device_ids=["cpu"]))
with pytest.raises(grpc.RpcError):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes, device_ids=["cpu"]))

grpc_stub.CloseModelSession(model)
grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes, device_ids=["cpu"]))

def test_closing_session_releases_devices(self, grpc_stub, bioimageio_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes, device_ids=["cpu"]))
Expand Down Expand Up @@ -163,8 +159,6 @@ def test_call_predict_valid_explicit(self, grpc_stub, bioimageio_dummy_explicit_
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))

grpc_stub.CloseModelSession(model)

assert len(res.tensors) == 1
assert res.tensors[0].tensorId == output_tensor_id
assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))
Expand All @@ -175,7 +169,6 @@ def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_e
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
with pytest.raises(grpc.RpcError):
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)

@pytest.mark.parametrize(
"shape",
Expand All @@ -187,7 +180,6 @@ def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioima
input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
with pytest.raises(grpc.RpcError):
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)

def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimageio_dummy_model):
model_bytes, _ = bioimageio_dummy_model
Expand All @@ -197,7 +189,6 @@ def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimageio_dummy_model
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Spec invalidTensorName doesn't exist")
grpc_stub.CloseModelSession(model)

def test_call_predict_invalid_axes(self, grpc_stub, bioimageio_dummy_model):
model_bytes, tensor_id = bioimageio_dummy_model
Expand All @@ -207,15 +198,13 @@ def test_call_predict_invalid_axes(self, grpc_stub, bioimageio_dummy_model):
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axes")
grpc_stub.CloseModelSession(model)

@pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)])
def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)

@pytest.mark.skip
def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_bytes):
Expand All @@ -227,8 +216,6 @@ def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_byte
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))

grpc_stub.CloseModelSession(model)

assert len(res.tensors) == 1
assert res.tensors[0].tensorId == output_tensor_id
assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))
32 changes: 22 additions & 10 deletions tiktorch/server/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,22 @@ def cuda_version(self) -> Optional[str]:
@abc.abstractmethod
def list_devices(self) -> List[IDevice]:
"""
List devices available on server
List devices on server
"""
...

def list_available_devices(self) -> List[IDevice]:
"""
List available devices on server
"""
return [device for device in self.list_devices() if device.status == DeviceStatus.AVAILABLE]

def list_reserved_devices(self) -> List[IDevice]:
"""
List reserved devices on server
"""
return [device for device in self.list_devices() if device.status == DeviceStatus.IN_USE]

@abc.abstractmethod
def lease(self, device_ids: List[str]) -> ILease:
"""
Expand Down Expand Up @@ -116,8 +128,8 @@ def terminate(self) -> None:

class TorchDevicePool(IDevicePool):
def __init__(self):
self.__lease_id_by_device_id = {}
self.__device_ids_by_lease_id = defaultdict(list)
self.__device_id_to_lease_id = {}
self.__lease_id_to_device_ids = defaultdict(list)
self.__lock = threading.Lock()

@property
Expand All @@ -142,7 +154,7 @@ def list_devices(self) -> List[IDevice]:
devices: List[IDevice] = []
for id_ in ids:
status = DeviceStatus.AVAILABLE
if id_ in self.__lease_id_by_device_id:
if id_ in self.__device_id_to_lease_id:
status = DeviceStatus.IN_USE

devices.append(_Device(id_=id_, status=status))
Expand All @@ -156,21 +168,21 @@ def lease(self, device_ids: List[str]) -> ILease:
with self.__lock:
lease_id = uuid.uuid4().hex
for dev_id in device_ids:
if dev_id in self.__lease_id_by_device_id:
if dev_id in self.__device_id_to_lease_id:
raise Exception(f"Device {dev_id} is already in use")

for dev_id in device_ids:
self.__lease_id_by_device_id[dev_id] = lease_id
self.__device_ids_by_lease_id[lease_id].append(dev_id)
self.__device_id_to_lease_id[dev_id] = lease_id
self.__lease_id_to_device_ids[lease_id].append(dev_id)

return _Lease(self, id_=lease_id)

def _get_lease_devices(self, lease_id: str) -> List[IDevice]:
return [_Device(id_=dev_id, status=DeviceStatus.IN_USE) for dev_id in self.__device_ids_by_lease_id[lease_id]]
return [_Device(id_=dev_id, status=DeviceStatus.IN_USE) for dev_id in self.__lease_id_to_device_ids[lease_id]]

def _release_devices(self, lease_id: str) -> None:
with self.__lock:
dev_ids = self.__device_ids_by_lease_id.pop(lease_id, [])
dev_ids = self.__lease_id_to_device_ids.pop(lease_id, [])

for id_ in dev_ids:
del self.__lease_id_by_device_id[id_]
del self.__device_id_to_lease_id[id_]
9 changes: 9 additions & 0 deletions tiktorch/server/grpc/inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inf
self.__session_manager.close_session(request.id)
return inference_pb2.Empty()

def close_all_sessions(self):
"""
Not exposed by the API
Close all sessions ensuring that all devices are not leased
"""
self.__session_manager.close_all_sessions()
assert len(self.__device_pool.list_reserved_devices()) == 0

def GetLogs(self, request: inference_pb2.Empty, context):
yield inference_pb2.LogEntry(
timestamp=int(time.time()), level=inference_pb2.LogEntry.Level.INFO, content="Sending model logs"
Expand Down
5 changes: 5 additions & 0 deletions tiktorch/server/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def close_session(self, session_id: str) -> None:

logger.debug("Closed session %s", session_id)

def close_all_sessions(self):
all_ids = tuple(self.__session_by_id.keys())
for session_id in all_ids:
self.close_session(session_id)

def __init__(self) -> None:
self.__lock = threading.Lock()
self.__session_by_id: Dict[str, Session] = {}
Expand Down

0 comments on commit aac6768

Please sign in to comment.