Skip to content

Commit

Permalink
Merge pull request #2 from swiss-ai-center/1-update-version-of-packages
Browse files Browse the repository at this point in the history
Updated code with new package versions, changed condition in workflow
  • Loading branch information
andrptrc authored Jan 23, 2024
2 parents 9e2081e + 640f3c6 commit 1d64762
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 79 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Documentation: https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsuses
name: github_worflow
name: github_workflow
run-name: GitHub Workflow

env:
Expand Down Expand Up @@ -41,11 +41,11 @@ env:
# Logging level
PROD_LOG_LEVEL: ${{ vars.PROD_LOG_LEVEL }}
# Kube configuration
PROD_KUBE_CONFIG: ${{ secrets.DEV_KUBE_CONFIG }}
PROD_KUBE_CONFIG: ${{ secrets.PROD_KUBE_CONFIG }}

# Allow one concurrent deployment
concurrency:
group: github_worflow
group: github_workflow
cancel-in-progress: true

on:
Expand Down Expand Up @@ -89,7 +89,7 @@ jobs:
release:
needs: test
runs-on: ubuntu-latest
if: ${{ vars.RUN_CICD == 'true' && success() && (vars.DEPLOY_DEV == 'true' || vars.DEPLOY_PROD == 'true') }}
if: ${{ vars.RUN_CICD == 'true' && success() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/prod') && (vars.DEPLOY_DEV == 'true' || vars.DEPLOY_PROD == 'true') }}
steps:
- name: Clone repository
uses: actions/checkout@v3
Expand Down Expand Up @@ -176,7 +176,7 @@ jobs:
engine-announce-retries: ${{ env.PROD_ENGINE_ANNOUNCE_RETRIES }}
engine-announce-retry-delay: ${{ env.PROD_ENGINE_ANNOUNCE_RETRY_DELAY }}

- name: Deploy service on the Kubernetes cluster
- name: Deploy ${{ env.SERVICE_NAME }} on the Kubernetes cluster
uses: swiss-ai-center/common-code/.github/actions/execute-command-on-kubernetes-cluster@main
with:
kube-config: ${{ env.PROD_KUBE_CONFIG }}
Expand Down
88 changes: 88 additions & 0 deletions requirements-all.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
absl-py==2.1.0
aiobotocore==2.4.1
aiofiles==22.1.0
aiohttp==3.9.1
aioitertools==0.11.0
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
astunparse==1.6.3
async-timeout==4.0.3
attrs==23.2.0
botocore==1.27.59
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
coverage==7.4.0
dnspython==2.5.0
email-validator==2.1.0.post1
exceptiongroup==1.2.0
fastapi==0.108.0
flake8==5.0.4
flatbuffers==23.5.26
frozenlist==1.4.1
gast==0.5.4
google-auth==2.26.2
google-auth-oauthlib==1.2.0
google-pasta==0.2.0
grpcio==1.60.0
h11==0.14.0
h5py==3.10.0
httpcore==0.16.3
httpx==0.23.1
idna==3.6
ImageHash==4.3.1
iniconfig==2.0.0
jmespath==1.0.1
keras==2.15.0
libclang==16.0.6
Markdown==3.5.2
MarkupSafe==2.1.4
mccabe==0.7.0
ml-dtypes==0.2.0
multidict==6.0.4
numpy==1.26.3
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.2
pillow==10.2.0
pip-chill==1.0.3
pluggy==1.3.0
protobuf==4.23.4
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycodestyle==2.9.1
pydantic==2.5.3
pydantic-settings==2.1.0
pydantic_core==2.14.6
pyflakes==2.5.0
pytest==7.2.0
pytest-asyncio==0.20.3
pytest-cov==4.0.0
pytest-httpserver==1.0.6
python-dateutil==2.8.2
python-dotenv==0.21.0
PyWavelets==1.5.0
PyYAML==6.0
requests==2.31.0
requests-oauthlib==1.3.1
rfc3986==1.5.0
rsa==4.9
scipy==1.12.0
six==1.16.0
sniffio==1.3.0
starlette==0.29.0
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorflow==2.15.0
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.34.0
termcolor==2.4.0
tomli==2.0.1
typing_extensions==4.9.0
urllib3==1.26.18
uvicorn==0.19.0
Werkzeug==2.3.6
wrapt==1.14.1
yarl==1.9.4
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
common-code[test] @ git+https://github.com/swiss-ai-center/common-code.git@main
fastapi
imagehash
pip-chill
tensorflow
imagehash==4.3.1
pip-chill==1.0.3
tensorflow==2.15.0
140 changes: 70 additions & 70 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from common_code.config import get_settings
from pydantic import Field
from common_code.http_client import HttpClient
from common_code.logger.logger import get_logger
from common_code.logger.logger import get_logger, Logger
from common_code.service.controller import router as service_router
from common_code.service.service import ServiceService
from common_code.storage.service import StorageService
Expand All @@ -17,6 +16,7 @@
from common_code.service.enums import ServiceStatus
from common_code.common.enums import FieldDescriptionType, ExecutionUnitTagName, ExecutionUnitTagAcronym
from common_code.common.models import FieldDescription, ExecutionUnitTag
from contextlib import asynccontextmanager

# Imports required by the service's model
import os
Expand Down Expand Up @@ -47,9 +47,9 @@ class MyService(Service):
"""

# Any additional fields must be excluded for Pydantic to work
base_model: object = Field(exclude=True)
nsfw_model: object = Field(exclude=True)
logger: object = Field(exclude=True)
_base_model: object
_nsfw_model: object
_logger: Logger

def __init__(self):
super().__init__(
Expand All @@ -73,26 +73,26 @@ def __init__(self):
],
has_ai=True,
)
self.logger = get_logger(settings)
self._logger = get_logger(settings)
# read the ai model here
self.logger.info("Loading the base model...")
self.base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(
self._logger.info("Loading the base model...")
self._base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(
include_top=False,
weights='imagenet',
input_shape=(IMG_SIZE, IMG_SIZE, CHANNELS))
self.logger.info("Base model loaded. Recreating structure of model before loading fine-tuned weights...")
self.nsfw_model = tf.keras.Sequential([
self.base_model,
self._logger.info("Base model loaded. Recreating structure of model before loading fine-tuned weights...")
self._nsfw_model = tf.keras.Sequential([
self._base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(16),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(N_CLASSES),
tf.keras.layers.Activation('softmax')
], name='MNV2')
self.logger.info('Loading weights from file: {}'.format(WEIGHT_FILE))
self.nsfw_model.load_weights(WEIGHT_FILE)
self.logger.info('Weights loaded.')
self._logger.info('Loading weights from file: {}'.format(WEIGHT_FILE))
self._nsfw_model.load_weights(WEIGHT_FILE)
self._logger.info('Weights loaded.')

def build_score_dict(self, scores, class_names):
"""
Expand All @@ -117,19 +117,19 @@ def predict_from_image(self, image_tensor):
category scores and the list of sub-category scores
"""
image_tensor = np.array([image_tensor])
self.logger.info("Image tensor shape: {}".format(image_tensor.shape))
self._logger.info("Image tensor shape: {}".format(image_tensor.shape))
pred_sub_cat = self.nsfw_model.predict(image_tensor, verbose=0)
self.logger.info("Prediction shape: {}".format(pred_sub_cat.shape))
self.logger.info("Prediction: {}".format(pred_sub_cat))
self._logger.info("Prediction shape: {}".format(pred_sub_cat.shape))
self._logger.info("Prediction: {}".format(pred_sub_cat))
pred_cat = np.zeros((1, 2))
pred_cat[:, 0] = np.sum(pred_sub_cat[:, :4], axis=1) # do the sum of nsfw sub-categories to compute nsfw pred
pred_cat[:, 1] = np.sum(pred_sub_cat[:, 4:], axis=1) # same thing for safe
# in the end, the pred_cat is a similar output tensor as pred_sub_cat but on 2 main categories nsfw and safe
# let's use the first prediction for now (disregarding the fliped image)
scores_dict_sub_cat = self.build_score_dict(pred_sub_cat[0], SUB_CAT_NAMES)
self.logger.info("Scores sub-cat: {}".format(scores_dict_sub_cat))
self._logger.info("Scores sub-cat: {}".format(scores_dict_sub_cat))
scores_dict_cat = self.build_score_dict(pred_cat[0], CAT_NAMES)
self.logger.info("Scores cat: {}".format(scores_dict_cat))
self._logger.info("Scores cat: {}".format(scores_dict_cat))
winner_sub_cat = pred_sub_cat.argmax(axis=1)[0]
winner_cat = pred_cat.argmax(axis=1)[0]
# get the prediction as category and subcategory
Expand All @@ -144,9 +144,9 @@ def process(self, data):
image = Image.open(buff)
image = image.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
image_tensor = np.array(image)
self.logger.info("Image shape: {}".format(image_tensor.shape))
self._logger.info("Image shape: {}".format(image_tensor.shape))
image_tensor = tf.keras.applications.mobilenet.preprocess_input(image_tensor)
self.logger.info("Image shape after preprocessing: {}".format(image_tensor.shape))
self._logger.info("Image shape after preprocessing: {}".format(image_tensor.shape))
prediction_category, prediction_subcategory, scores_dict_cat, scores_dict_sub_cat = \
self.predict_from_image(image_tensor)

Expand All @@ -162,6 +162,54 @@ def process(self, data):
}


service_service: ServiceService | None = None


@asynccontextmanager
async def lifespan(app: FastAPI):
# Manual instances because startup events doesn't support Dependency Injection
# https://github.com/tiangolo/fastapi/issues/2057
# https://github.com/tiangolo/fastapi/issues/425

# Global variable
global service_service

# Startup
logger = get_logger(settings)
http_client = HttpClient()
storage_service = StorageService(logger)
my_service = MyService()
tasks_service = TasksService(logger, settings, http_client, storage_service)
service_service = ServiceService(logger, settings, http_client, tasks_service)

tasks_service.set_service(my_service)

# Start the tasks service
tasks_service.start()

async def announce():
retries = settings.engine_announce_retries
for engine_url in settings.engine_urls:
announced = False
while not announced and retries > 0:
announced = await service_service.announce_service(my_service, engine_url)
retries -= 1
if not announced:
time.sleep(settings.engine_announce_retry_delay)
if retries == 0:
logger.warning(f"Aborting service announcement after "
f"{settings.engine_announce_retries} retries")

# Announce the service to its engine
asyncio.ensure_future(announce())

yield

# Shutdown
for engine_url in settings.engine_urls:
await service_service.graceful_shutdown(my_service, engine_url)


api_description = """
This service detects nudity, sexual and hentai content in images, or if the image is 'safe for work'.
"""
Expand All @@ -172,6 +220,7 @@ def process(self, data):

# Define the FastAPI application with information
app = FastAPI(
lifespan=lifespan,
title="NSFW image detection service API.",
description=api_description,
version="0.2.1",
Expand Down Expand Up @@ -207,52 +256,3 @@ def process(self, data):
@app.get("/", include_in_schema=False)
async def root():
return RedirectResponse("/docs", status_code=301)

service_service: ServiceService | None = None


@app.on_event("startup")
async def startup_event():
# Manual instances because startup events doesn't support Dependency Injection
# https://github.com/tiangolo/fastapi/issues/2057
# https://github.com/tiangolo/fastapi/issues/425

# Global variable
global service_service

logger = get_logger(settings)
http_client = HttpClient()
storage_service = StorageService(logger)
my_service = MyService()
tasks_service = TasksService(logger, settings, http_client, storage_service)
service_service = ServiceService(logger, settings, http_client, tasks_service)

tasks_service.set_service(my_service)

# Start the tasks service
tasks_service.start()

async def announce():
retries = settings.engine_announce_retries
for engine_url in settings.engine_urls:
announced = False
while not announced and retries > 0:
announced = await service_service.announce_service(my_service, engine_url)
retries -= 1
if not announced:
time.sleep(settings.engine_announce_retry_delay)
if retries == 0:
logger.warning(f"Aborting service announcement after "
f"{settings.engine_announce_retries} retries")

# Announce the service to its engine
asyncio.ensure_future(announce())


@app.on_event("shutdown")
async def shutdown_event():
# Global variable
global service_service
my_service = MyService()
for engine_url in settings.engine_urls:
await service_service.graceful_shutdown(my_service, engine_url)

0 comments on commit 1d64762

Please sign in to comment.