Skip to content

Commit

Permalink
[omm] deplopulate __init__ and delint (#1368)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dcallies authored Sep 25, 2023
1 parent 1adf457 commit 35a8ed4
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 131 deletions.
4 changes: 2 additions & 2 deletions open-media-match/.devcontainer/startup.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
set -e
export OMM_CONFIG=/workspace/.devcontainer/omm_config.py
flask --app OpenMediaMatch db upgrade --directory src/openMediaMatch/migrations
flask --app OpenMediaMatch run --debug
flask --app OpenMediaMatch.app db upgrade --directory /workspace/src/openMediaMatch/migrations
flask --app OpenMediaMatch.app run --debug
77 changes: 0 additions & 77 deletions open-media-match/src/OpenMediaMatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,78 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import os
import sys

import flask
import flask_migrate
import flask_sqlalchemy

db = flask_sqlalchemy.SQLAlchemy()
migrate = flask_migrate.Migrate()


def create_app():
"""
Create and configure the Flask app
"""
app = flask.Flask(__name__)
if "OMM_CONFIG" in os.environ:
app.config.from_envvar("OMM_CONFIG")
elif sys.argv[0].endswith("/flask"): # Default for flask CLI
# The devcontainer settings. If you are using the CLI outside
# the devcontainer and getting an error, just override the env
app.config.from_pyfile("/workspace/.devcontainer/omm_config.py")
else:
raise RuntimeError("No flask config given - try populating OMM_CONFIG env")
app.config.update(
SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"),
SQLALCHEMY_TRACK_MODIFICATIONS=False,
)
app.db = db

db.init_app(app)
migrate.init_app(app, db)

@app.route("/")
def index():
"""
Sanity check endpoint showing a basic status page
TODO: in development mode, this could show some useful additional info
"""
return flask.render_template(
"index.html.j2", production=app.config.get("PRODUCTION")
)

@app.route("/status")
def status():
"""
Liveness/readiness check endpoint for your favourite Layer 7 load balancer
"""
return "I-AM-ALIVE\n"

# Register Flask blueprints for whichever server roles are enabled...
# URL prefixing facilitates easy Layer 7 routing :)
# Linters complain about imports off the top level, but this is needed
# to prevent circular imports
from .blueprints import hashing, matching, curation

if app.config.get("ROLE_HASHER", False):
app.register_blueprint(hashing.bp, url_prefix="/h")

if app.config.get("ROLE_MATCHER", False):
app.register_blueprint(matching.bp, url_prefix="/m")

if app.config.get("ROLE_CURATOR", False):
app.register_blueprint(curation.bp, url_prefix="/c")

from . import models

@app.cli.command("seed")
def seed_data():
# TODO: This is a placeholder for where some useful seed data can be loaded;
# particularly important for development
bank = models.Bank(name="bad_stuff", enabled=True)
db.session.add(bank)
db.session.commit()

return app
76 changes: 74 additions & 2 deletions open-media-match/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,75 @@
from OpenMediaMatch import create_app
# Copyright (c) Meta Platforms, Inc. and affiliates.

app = create_app()
import os
import sys

import flask
import flask_migrate

from OpenMediaMatch import database
from OpenMediaMatch.blueprints import hashing, matching, curation


def create_app():
"""
Create and configure the Flask app
"""
app = flask.Flask(__name__)
migrate = flask_migrate.Migrate()

if "OMM_CONFIG" in os.environ:
app.config.from_envvar("OMM_CONFIG")
elif sys.argv[0].endswith("/flask"): # Default for flask CLI
# The devcontainer settings. If you are using the CLI outside
# the devcontainer and getting an error, just override the env
app.config.from_pyfile("/workspace/.devcontainer/omm_config.py")
else:
raise RuntimeError("No flask config given - try populating OMM_CONFIG env")
app.config.update(
SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"),
SQLALCHEMY_TRACK_MODIFICATIONS=False,
)

database.db.init_app(app)
migrate.init_app(app, database.db)

@app.route("/")
def index():
"""
Sanity check endpoint showing a basic status page
TODO: in development mode, this could show some useful additional info
"""
return flask.render_template(
"index.html.j2", production=app.config.get("PRODUCTION")
)

@app.route("/status")
def status():
"""
Liveness/readiness check endpoint for your favourite Layer 7 load balancer
"""
return "I-AM-ALIVE\n"

# Register Flask blueprints for whichever server roles are enabled...
# URL prefixing facilitates easy Layer 7 routing :)
# Linters complain about imports off the top level, but this is needed
# to prevent circular imports

if app.config.get("ROLE_HASHER", False):
app.register_blueprint(hashing.bp, url_prefix="/h")

if app.config.get("ROLE_MATCHER", False):
app.register_blueprint(matching.bp, url_prefix="/m")

if app.config.get("ROLE_CURATOR", False):
app.register_blueprint(curation.bp, url_prefix="/c")

@app.cli.command("seed")
def seed_data():
# TODO: This is a placeholder for where some useful seed data can be loaded;
# particularly important for development
bank = database.Bank(name="bad_stuff", enabled=True)
database.db.session.add(bank)
database.db.session.commit()

return app
29 changes: 14 additions & 15 deletions open-media-match/src/OpenMediaMatch/blueprints/curation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from flask import Blueprint
from flask import abort, request, current_app, jsonify
from flask import request, current_app, jsonify

from OpenMediaMatch import models
from OpenMediaMatch import database

import json

bp = Blueprint("curation", __name__)


@bp.route("/banks", methods=["GET"])
def banks_index():
banks = (
current_app.db.session.execute(current_app.db.select(models.Bank))
database.db.session.execute(current_app.db.select(database.Bank))
.scalars()
.all()
)
Expand All @@ -20,7 +19,7 @@ def banks_index():

@bp.route("/bank/<int:bank_id>", methods=["GET"])
def bank_show_by_id(bank_id: int):
bank = models.Bank.query.get(bank_id)
bank = database.Bank.query.get(bank_id)
if not bank:
return jsonify({"message": "bank not found"}), 404
return jsonify(bank)
Expand All @@ -29,8 +28,8 @@ def bank_show_by_id(bank_id: int):
@bp.route("/bank/<bank_name>", methods=["GET"])
def bank_show_by_name(bank_name: str):
bank = (
current_app.db.session.execute(
current_app.db.select(models.Bank).where(models.Bank.name == bank_name)
database.db.session.execute(
database.db.select(database.Bank).where(database.Bank.name == bank_name)
)
.scalars()
.all()
Expand All @@ -43,16 +42,16 @@ def bank_create():
data = request.get_json()
if not "name" in data:
return jsonify({"message": "Field `name` is required"}), 400
bank = models.Bank(name=data["name"], enabled=bool(data.get("enabled", True)))
current_app.db.session.add(bank)
current_app.db.session.commit()
bank = database.Bank(name=data["name"], enabled=bool(data.get("enabled", True)))
database.db.session.add(bank)
database.db.session.commit()
return jsonify({"message": "Created successfully"}), 201


@bp.route("/bank/<int:bank_id>", methods=["PUT"])
def bank_update(bank_id: int):
data = request.get_json()
bank = models.Bank.query.get(bank_id)
bank = database.Bank.query.get(bank_id)
if not bank:
return jsonify({"message": "bank not found"}), 404

Expand All @@ -61,15 +60,15 @@ def bank_update(bank_id: int):
if "enabled" in data:
bank.enabled = bool(data["enabled"])

current_app.db.session.commit()
database.db.session.commit()
return jsonify(bank)


@bp.route("/bank/<int:bank_id>", methods=["DELETE"])
def bank_delete(bank_id: int):
bank = models.Bank.query.get(bank_id)
bank = database.Bank.query.get(bank_id)
if not bank:
return jsonify({"message": "bank not found"}), 404
current_app.db.session.delete(bank)
current_app.db.session.commit()
database.db.session.delete(bank)
database.db.session.commit()
return jsonify({"message": f"Bank {bank.name} ({bank.id}) deleted"})
20 changes: 11 additions & 9 deletions open-media-match/src/OpenMediaMatch/blueprints/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from threatexchange.content_type.video import VideoContent
from threatexchange.signal_type.signal_base import FileHasher, SignalType

from OpenMediaMatch import app_resources
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.utils import abort_to_json

bp = Blueprint("hashing", __name__)


@bp.route("/hash")
@abort_to_json
def hash_media():
"""
Fetch content and return its hash.
Expand Down Expand Up @@ -57,7 +59,7 @@ def hash_media():
return ret


def _parse_request_content_type(url_content_type: str) -> ContentType:
def _parse_request_content_type(url_content_type: str) -> t.Type[ContentType]:
arg = request.args.get("content_type", "")
if not arg:
if url_content_type.lower().startswith("image"):
Expand All @@ -71,20 +73,20 @@ def _parse_request_content_type(url_content_type: str) -> ContentType:
"if you know the expected type, provide it with content_type",
)

storage = app_resources.get_storage()
content_type_config = storage.get_content_type_configs().get(arg)
content_type_config = get_storage().get_content_type_configs().get(arg)
if content_type_config is None:
return {"message": f"no such content_type: '{arg}'"}, 400
abort(400, f"no such content_type: '{arg}'")

if not content_type_config.enabled:
return {"message": f"content_type {arg} is disabled"}, 400
abort(400, f"content_type {arg} is disabled")

return content_type_config.content_type


def _parse_request_signal_type(content_type: ContentType) -> t.Mapping[str, SignalType]:
storage = app_resources.get_storage()
signal_types = storage.get_enabled_signal_types_for_content_type(content_type)
def _parse_request_signal_type(
content_type: t.Type[ContentType],
) -> t.Mapping[str, t.Type[SignalType]]:
signal_types = get_storage().get_enabled_signal_types_for_content_type(content_type)
if not signal_types:
abort(500, "No signal types configured!")
signal_type_args = request.args.get("types", None)
Expand Down
4 changes: 2 additions & 2 deletions open-media-match/src/OpenMediaMatch/blueprints/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from flask import Blueprint
from flask import abort

from OpenMediaMatch.app_resources import (
from OpenMediaMatch.utils import (
abort_to_json,
require_request_param,
get_storage,
)
from OpenMediaMatch.persistence import get_storage

bp = Blueprint("matching", __name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
SQLAlchemy backed relational data.
We are trying to make all the persistent data accessed instead
through the storage interface. However, during development, that's
slower than just slinging sql on the tables, so you may see direct
references which are meant to be reaped at some future time.
"""

from dataclasses import dataclass
from OpenMediaMatch import db

import flask_sqlalchemy

# Initializing this at import time seems to be the only correct
# way to do this
db = flask_sqlalchemy.SQLAlchemy()


@dataclass
class Bank(db.Model):
class Bank(db.Model): # type: ignore[name-defined]
__tablename__ = "banks"
id: int = db.Column(db.Integer, primary_key=True, autoincrement=True)
name: str = db.Column(db.String(255), nullable=False)
enabled: bool = db.Column(db.Boolean, nullable=False)


@dataclass
class Hash(db.Model): # Should this be Signal?
class Hash(db.Model): # type: ignore[name-defined] # Should this be Signal?
__tablename__ = "hashes"
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
enabled = db.Column(db.Boolean, nullable=False)
Expand Down
38 changes: 38 additions & 0 deletions open-media-match/src/OpenMediaMatch/persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
A shim around persistence layer of the OMM instance.
This includes relational data, blob storage, etc.
It doesn't include logging (just use current_app.logger).
We haven't made all of the hard decisions on the storage yet, and
think future deployers may change their mind about which backends to
use. We know we are going to have more than relational data, so
SQLAlchemy isn't going to be enough. Thus an even more abstract
accessor.
"""

import typing as t

from flask import g

from OpenMediaMatch.storage.interface import IUnifiedStore
from OpenMediaMatch.storage.default import DefaultOMMStore

T = t.TypeVar("T", bound=t.Callable[..., t.Any])


def get_storage() -> IUnifiedStore:
"""
Get the storage object, which is just a wrapper around the real storage.
"""
if "storage" not in g:
# dougneal, you'll need to eventually add constructor arguments
# for this to pass in the postgres/database object. We're just
# hiding flask bits from pytx bits
g.storage = DefaultOMMStore()
return g.storage
2 changes: 1 addition & 1 deletion open-media-match/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_enabled_signal_types(self) -> t.Mapping[str, t.Type[SignalType]]:
@t.final
def get_enabled_signal_types_for_content_type(
self, content_type: t.Type[ContentType]
) -> t.Mapping[str, SignalType]:
) -> t.Mapping[str, t.Type[SignalType]]:
"""Helper shortcut for getting enabled types for a piece of content"""
return {
k: v.signal_type
Expand Down
Loading

0 comments on commit 35a8ed4

Please sign in to comment.