Skip to content

Commit

Permalink
fixes #51: Add get pipeline info successful test
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxence Guindon committed Feb 28, 2024
1 parent 3cddb43 commit 4a4a731
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 65 deletions.
26 changes: 15 additions & 11 deletions tests/test_azure_storage_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import unittest
import asyncio
from unittest.mock import patch, Mock, MagicMock
from azure_storage.azure_storage_api import (
mount_container,
Expand All @@ -9,8 +11,6 @@
GetBlobError,
)

import asyncio


class TestMountContainerFunction(unittest.TestCase):
@patch("azure.storage.blob.BlobServiceClient.from_connection_string")
Expand Down Expand Up @@ -148,38 +148,42 @@ def test_get_blob_unsuccessful(self, MockFromConnectionString):

class testGetPipeline(unittest.TestCase):
@patch("azure.storage.blob.BlobServiceClient.from_connection_string")
def test_get_pipeline_info_successful(self, MockFromConnectionString):
mock_blob_name = "test_blob"
mock_blob_content = b"v1"
def test_get_pipeline_info_successful(self, MockFromConnectionString,):

mock_blob_content = b'''{
"name": "test_blob.json",
"version": "v1"
}'''

mock_blob = Mock()
mock_blob.readall.return_value = mock_blob_content

mock_blob_client = MagicMock()
mock_blob_client = Mock()
mock_blob_client.configure_mock(name="test_blob.json")
mock_blob_client.download_blob.return_value = mock_blob

mock_container_client = MagicMock()
mock_container_client.exists.return_value = True
mock_container_client.list_blobs.return_value = [mock_blob_client]
mock_container_client.get_blob_client.return_value = mock_blob_client

mock_blob_service_client = MockFromConnectionString.return_value
mock_blob_service_client.get_container_client.return_value = (
mock_container_client
)

connection_string = "test_connection_string"
mock_blob_name = "test_blob"
mock_version = b"v1"
mock_version = "v1"

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(
get_pipeline_info(connection_string, mock_blob_name, mock_version)
)

print(result == mock_blob_content)
print(result == json.loads(mock_blob_content))

self.assertEqual(result, mock_blob_content)
self.assertEqual(result, json.loads(mock_blob_content))

@patch("azure.storage.blob.BlobServiceClient.from_connection_string")
def test_get_pipeline_info_unsuccessful(self, MockFromConnectionString):
Expand Down
54 changes: 0 additions & 54 deletions tests/test_inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,60 +16,6 @@
# missing one or more blob environment variable
# Return value of the function (JSON) in the following format:

result_json = {
'filename': 'tmp/tmp_file_name',
'boxes': [
{'box': {
'topX': 0.0,
'topY': 0.0,
'bottomX': 0.0,
'bottomY': 0.0
},
'label': 'label_name',
'score': 0.999
}
],
}

# or

result_json = {
'filename': 'tmp/tmp_file_name',
'boxes': [
{'box': {
'topX': 0.0,
'topY': 0.0,
'bottomX': 0.0,
'bottomY': 0.0
},
'label': 'label_name',
'score': 0.999,
'all_result': [{
{
'label': "seed_name",
'score': 0.999
},
{
'label': "seed_name",
'score': 0.002
},
{
'label': "seed_name",
'score': 0.002
},
{
'label': "seed_name",
'score': 0.002
},
{
'label': "seed_name",
'score': 0.002
}
}]
}
],
}

class TestInferenceRequest(unittest.TestCase):
def setUp(self):
self.app = app.app
Expand Down

0 comments on commit 4a4a731

Please sign in to comment.