Skip to content

Commit

Permalink
fix #60: Creating pipeline insertion module (#65)
Browse files Browse the repository at this point in the history
* fixes #60: correct yaml

* fixes #60: template yaml fix

* fixes #60: template lint error

* fixes #60: template fix

* fixes #60: try copilot fix

* fixes #60: Add unittest

* fixes #60: correct typo in template

* fixes #60: Correct lint error

* fixes #60: Add EOF

* fixes #60: Implement grammar correction

* fixes #60: Add environment variable check

* fixes #60: update environment variables readme

* fixes #60: remove act from devcontainer

* fixes #60: Correct Markdown lint error

* #Issue 60: Add a default key for pipeline

* issue #60: Modify tests to check error message

* issue #60: Create Fernet Key and Blob Service Client in main

* issue #60: update tests

* issue #60: add warning for image environment variable

* iisue #60: Change to image validation

* issue #60: remove legacy folder

* issue #60: Eliminate catching generic exceptions

* issue #65: Move CONSTANT to upper function

* fixes #60: Move pipeline related files into
pipelines, and modify template value for model and
pipeline

---------

Co-authored-by: Jonathan Lopez <[email protected]>
  • Loading branch information
Maxence Guindon and SonOfLope authored Apr 16, 2024
1 parent 0b4e0a4 commit 7c8f4c7
Show file tree
Hide file tree
Showing 10 changed files with 746 additions and 98 deletions.
4 changes: 2 additions & 2 deletions TESTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pipeline information.
**Preconditions:**

- [ ] Nachet backend is set up and running. Use the command `hypercorn -b :8080
app:app` to start the quartz server.
app:app` to start the quart server.
- [ ] The environment variables are all set.
- [ ] :exclamation: The frontend is not running yet

Expand Down Expand Up @@ -66,7 +66,7 @@ expected.
**Preconditions:**

- [ ] Nachet backend is set up and running. Use the command `hypercorn -b :8080
app:app` to start the quartz server.
app:app` to start the quart server.
- [ ] The environment variables are all set.
- [ ] The frontend is running.
- [ ] Start the frontend application
Expand Down
186 changes: 107 additions & 79 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
import time
import warnings

import model.inference as inference
from model import request_function

from PIL import Image, UnidentifiedImageError
from datetime import date
from dotenv import load_dotenv
from quart import Quart, request, jsonify
from quart_cors import cors
from collections import namedtuple
from cryptography.fernet import Fernet

import azure_storage.azure_storage_api as azure_storage_api
import model.inference as inference
from model import request_function

class APIErrors(Exception):
pass
Expand Down Expand Up @@ -60,24 +60,45 @@ class MaxContentLengthWarning(APIWarnings):
pass

load_dotenv()

connection_string_regex = r"^DefaultEndpointsProtocol=https?;.*;FileEndpoint=https://[a-zA-Z0-9]+\.file\.core\.windows\.net/;$"
connection_string = os.getenv("NACHET_AZURE_STORAGE_CONNECTION_STRING")
pipeline_version_regex = r"\d.\d.\d"

CONNECTION_STRING = os.getenv("NACHET_AZURE_STORAGE_CONNECTION_STRING")

FERNET_KEY = os.getenv("NACHET_BLOB_PIPELINE_DECRYPTION_KEY")
PIPELINE_VERSION = os.getenv("NACHET_BLOB_PIPELINE_VERSION")
PIPELINE_BLOB_NAME = os.getenv("NACHET_BLOB_PIPELINE_NAME")

NACHET_DATA = os.getenv("NACHET_DATA")
NACHET_MODEL = os.getenv("NACHET_MODEL")

Model = namedtuple(
'Model',
[
'entry_function',
'name',
'endpoint',
'api_key',
'inference_function',
'content_type',
'deployment_platform',
]
)

try:
VALID_EXTENSION = json.loads(os.getenv("NACHET_VALID_EXTENSION"))
VALID_DIMENSION = json.loads(os.getenv("NACHET_VALID_DIMENSION"))
except TypeError:
except (TypeError, json.decoder.JSONDecodeError):
# For testing
VALID_DIMENSION = {"width": 1920, "height": 1080}
VALID_EXTENSION = {"jpeg", "jpg", "png", "gif", "bmp", "tiff", "webp"}
warnings.warn(
f"""
NACHET_VALID_EXTENSION or NACHET_VALID_DIMENSION is not set,
using default values: {", ".join(list(VALID_EXTENSION))} and dimension: {tuple(VALID_DIMENSION.values())}
""",
ImageWarning
)

try:
MAX_CONTENT_LENGTH_MEGABYTES = int(os.getenv("NACHET_MAX_CONTENT_LENGTH"))
Expand Down Expand Up @@ -113,6 +134,51 @@ class MaxContentLengthWarning(APIWarnings):
app.config["MAX_CONTENT_LENGTH"] = MAX_CONTENT_LENGTH_MEGABYTES * 1024 * 1024


@app.before_serving
async def before_serving():
try:
# Check: do environment variables exist?
if CONNECTION_STRING is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if FERNET_KEY is None:
raise ServerError("Missing environment variable: FERNET_KEY")

if PIPELINE_VERSION is None:
raise ServerError("Missing environment variable: PIPELINE_VERSION")

if PIPELINE_BLOB_NAME is None:
raise ServerError("Missing environment variable: PIPELINE_BLOB_NAME")

if NACHET_DATA is None:
raise ServerError("Missing environment variable: NACHET_DATA")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, CONNECTION_STRING)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if not bool(re.match(pipeline_version_regex, PIPELINE_VERSION)):
raise ServerError("Incorrect environment variable: PIPELINE_VERSION")

CACHE["seeds"] = await fetch_json(NACHET_DATA, "seeds", "seeds/all.json")
CACHE["endpoints"] = await get_pipelines(
CONNECTION_STRING, PIPELINE_BLOB_NAME,
PIPELINE_VERSION, Fernet(FERNET_KEY)
)

print(
f"""Server start with current configuration:\n
date: {date.today()}
file version of pipelines: {PIPELINE_VERSION}
pipelines: {[pipeline for pipeline in CACHE["pipelines"].keys()]}\n
"""
) #TODO Transform into logging

except ServerError as e:
print(e)
raise


@app.post("/del")
async def delete_directory():
"""
Expand Down Expand Up @@ -215,10 +281,21 @@ async def image_validation():
image_base64 = data["image"]

header, encoded_image = image_base64.split(",", 1)

image_bytes = base64.b64decode(encoded_image)

image = Image.open(io.BytesIO(image_bytes))

# size check
if image.size[0] > VALID_DIMENSION["width"] and image.size[1] > VALID_DIMENSION["height"]:
raise ImageValidationError(f"invalid file size: {image.size[0]}x{image.size[1]}")

# resizable check
try:
size = (100,150)
image.thumbnail(size)
except IOError:
raise ImageValidationError("invalid file not resizable")

magic_header = magic.from_buffer(image_bytes, mime=True)
image_extension = magic_header.split("/")[1]

Expand All @@ -232,23 +309,12 @@ async def image_validation():
if header.lower() != expected_header:
raise ImageValidationError(f"invalid file header: {header}")

# size check
if image.size[0] > VALID_DIMENSION["width"] and image.size[1] > VALID_DIMENSION["height"]:
raise ImageValidationError(f"invalid file size: {image.size[0]}x{image.size[1]}")

# resizable check
try:
size = (100,150)
image.thumbnail(size)
except IOError:
raise ImageValidationError("invalid file not resizable")

validator = await azure_storage_api.generate_hash(image_bytes)
CACHE['validators'].append(validator)

return jsonify([validator]), 200

except (FileNotFoundError, ValueError, TypeError, UnidentifiedImageError, ImageValidationError) as error:
except (UnidentifiedImageError, ImageValidationError) as error:
print(error)
return jsonify([error.args[0]]), 400

Expand Down Expand Up @@ -335,14 +401,6 @@ async def inference_request():
print(error)
return jsonify(["InferenceRequestError: " + error.args[0]]), 400

except Exception as error:
print(error)
return jsonify(["Unexpected error occured"]), 500

@app.get("/coffee")
async def get_coffee():
return jsonify("Tea is great!"), 418


@app.get("/seed-data/<seed_name>")
async def get_seed_data(seed_name):
Expand All @@ -363,8 +421,9 @@ async def reload_seed_data():
try:
await fetch_json(NACHET_DATA, 'seeds', "seeds/all.json")
return jsonify(["Seed data reloaded successfully"]), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
except urllib.error.HTTPError as e:
return jsonify(
{f"An error happend when reloading the seed data: {e.args[0]}"}), 500


@app.get("/model-endpoints-metadata")
Expand Down Expand Up @@ -406,25 +465,28 @@ async def test():

return CACHE["endpoints"], 200


async def fetch_json(repo_URL, key, file_path):
"""
Fetches JSON document from a GitHub repository and caches it
"""
try:
if key != "endpoints":
json_url = os.path.join(repo_URL, file_path)
with urllib.request.urlopen(json_url) as response:
result = response.read()
result_json = json.loads(result.decode("utf-8"))
return result_json
Fetches JSON document from a GitHub repository.
except urllib.error.HTTPError as error:
raise ValueError(str(error))
except Exception as e:
raise ValueError(str(e))
Parameters:
- repo_URL (str): The URL of the GitHub repository.
- key (str): The key to identify the JSON document.
- file_path (str): The path to the JSON document in the repository.
Returns:
- dict: The JSON document as a Python dictionary.
"""
if key != "endpoints":
json_url = os.path.join(repo_URL, file_path)
with urllib.request.urlopen(json_url) as response:
result = response.read()
result_json = json.loads(result.decode("utf-8"))
return result_json


async def get_pipelines():
async def get_pipelines(connection_string, pipeline_blob_name, pipeline_version, cipher_suite):
"""
Retrieves the pipelines from the Azure storage API.
Expand All @@ -433,16 +495,15 @@ async def get_pipelines():
"""
try:
app.config["BLOB_CLIENT"] = await azure_storage_api.get_blob_client(connection_string)
result_json = await azure_storage_api.get_pipeline_info(app.config["BLOB_CLIENT"], PIPELINE_BLOB_NAME, PIPELINE_VERSION)
cipher_suite = Fernet(FERNET_KEY)
result_json = await azure_storage_api.get_pipeline_info(app.config["BLOB_CLIENT"], pipeline_blob_name, pipeline_version)
except (azure_storage_api.AzureAPIErrors) as error:
print(error)
raise ServerError("server errror: could not retrieve the pipelines") from error

models = ()
for model in result_json.get("models"):
m = Model(
request_function.get(model.get("api_call_function")),
request_function.get(model.get("endpoint_name")),
model.get("model_name"),
# To protect sensible data (API key and model endpoint), we encrypt it when
# it's pushed into the blob storage. Once we retrieve the data here in the
Expand All @@ -461,38 +522,5 @@ async def get_pipelines():
return result_json.get("pipelines")


@app.before_serving
async def before_serving():
try:
# Check: do environment variables exist?
if connection_string is None:
raise ServerError("Missing environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

if FERNET_KEY is None:
raise ServerError("Missing environment variable: FERNET_KEY")

# Check: are environment variables correct?
if not bool(re.match(connection_string_regex, connection_string)):
raise ServerError("Incorrect environment variable: NACHET_AZURE_STORAGE_CONNECTION_STRING")

CACHE["seeds"] = await fetch_json(NACHET_DATA, "seeds", "seeds/all.json")
CACHE["endpoints"] = await get_pipelines()

print(
f"""Server start with current configuration:\n
date: {date.today()}
file version of pipelines: {PIPELINE_VERSION}
pipelines: {[pipeline for pipeline in CACHE["pipelines"].keys()]}\n
"""
) #TODO Transform into logging

except ServerError as e:
print(e)
raise

except Exception as e:
print(e)
raise ServerError("Failed to retrieve data from the repository")

if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=8080)
5 changes: 4 additions & 1 deletion azure_storage/azure_storage_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ async def generate_hash(image):

async def get_blob_client(connection_string: str):
"""
given a connection string, returns the blob client object
given a connection string and a container name, mounts the container and
returns the container client as an object that can be used in other
functions. if a specified container doesnt exist, it creates one with the
provided uuid, if create_container is True
"""
try:
blob_service_client = BlobServiceClient.from_connection_string(
Expand Down
7 changes: 3 additions & 4 deletions docs/nachet-inference-documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ the `CACHE["endpoint"]` variable. This is the variable that feeds the `models`
information and metadata to the frontend.

```python
async def get_pipelines():
async def get_pipelines(connection_string, pipeline_blob_name, pipeline_version, cipher_suite):
"""
Retrieves the pipelines from the Azure storage API.
Expand All @@ -251,9 +251,8 @@ async def get_pipelines():
"""
try:
app.config["BLOB_CLIENT"] = await azure_storage_api.get_blob_client(connection_string)
result_json = await azure_storage_api.get_pipeline_info(app.config["BLOB_CLIENT"], PIPELINE_BLOB_NAME, PIPELINE_VERSION)
cipher_suite = Fernet(FERNET_KEY)
except (ConnectionStringError, PipelineNotFoundError) as error:
result_json = await azure_storage_api.get_pipeline_info(app.config["BLOB_CLIENT"], pipeline_blob_name, pipeline_version)
except (azure_storage_api.AzureAPIErrors) as error:
print(error)
raise ServerError("server errror: could not retrieve the pipelines") from error

Expand Down
Loading

0 comments on commit 7c8f4c7

Please sign in to comment.