From 35a8ed47fb058cc1eb2e714513ffe5a00aa36340 Mon Sep 17 00:00:00 2001 From: David Callies Date: Mon, 25 Sep 2023 09:17:48 -0400 Subject: [PATCH] [omm] deplopulate __init__ and delint (#1368) --- open-media-match/.devcontainer/startup.sh | 4 +- .../src/OpenMediaMatch/__init__.py | 77 ------------------- open-media-match/src/OpenMediaMatch/app.py | 76 +++++++++++++++++- .../src/OpenMediaMatch/blueprints/curation.py | 29 ++++--- .../src/OpenMediaMatch/blueprints/hashing.py | 20 ++--- .../src/OpenMediaMatch/blueprints/matching.py | 4 +- .../OpenMediaMatch/{models.py => database.py} | 20 ++++- .../src/OpenMediaMatch/persistence.py | 38 +++++++++ .../src/OpenMediaMatch/storage/interface.py | 2 +- .../src/OpenMediaMatch/storage/mocked.py | 7 +- .../src/OpenMediaMatch/tests/test_api.py | 2 +- .../{app_resources.py => utils.py} | 15 ---- 12 files changed, 163 insertions(+), 131 deletions(-) rename open-media-match/src/OpenMediaMatch/{models.py => database.py} (51%) create mode 100644 open-media-match/src/OpenMediaMatch/persistence.py rename open-media-match/src/OpenMediaMatch/{app_resources.py => utils.py} (71%) diff --git a/open-media-match/.devcontainer/startup.sh b/open-media-match/.devcontainer/startup.sh index 48c44da9c..472a7961d 100755 --- a/open-media-match/.devcontainer/startup.sh +++ b/open-media-match/.devcontainer/startup.sh @@ -1,5 +1,5 @@ #!/bin/bash set -e export OMM_CONFIG=/workspace/.devcontainer/omm_config.py -flask --app OpenMediaMatch db upgrade --directory src/openMediaMatch/migrations -flask --app OpenMediaMatch run --debug +flask --app OpenMediaMatch.app db upgrade --directory /workspace/src/openMediaMatch/migrations +flask --app OpenMediaMatch.app run --debug diff --git a/open-media-match/src/OpenMediaMatch/__init__.py b/open-media-match/src/OpenMediaMatch/__init__.py index a873ba1cb..71ca4b12c 100644 --- a/open-media-match/src/OpenMediaMatch/__init__.py +++ b/open-media-match/src/OpenMediaMatch/__init__.py @@ -1,78 +1 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. - -import os -import sys - -import flask -import flask_migrate -import flask_sqlalchemy - -db = flask_sqlalchemy.SQLAlchemy() -migrate = flask_migrate.Migrate() - - -def create_app(): - """ - Create and configure the Flask app - """ - app = flask.Flask(__name__) - if "OMM_CONFIG" in os.environ: - app.config.from_envvar("OMM_CONFIG") - elif sys.argv[0].endswith("/flask"): # Default for flask CLI - # The devcontainer settings. If you are using the CLI outside - # the devcontainer and getting an error, just override the env - app.config.from_pyfile("/workspace/.devcontainer/omm_config.py") - else: - raise RuntimeError("No flask config given - try populating OMM_CONFIG env") - app.config.update( - SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"), - SQLALCHEMY_TRACK_MODIFICATIONS=False, - ) - app.db = db - - db.init_app(app) - migrate.init_app(app, db) - - @app.route("/") - def index(): - """ - Sanity check endpoint showing a basic status page - TODO: in development mode, this could show some useful additional info - """ - return flask.render_template( - "index.html.j2", production=app.config.get("PRODUCTION") - ) - - @app.route("/status") - def status(): - """ - Liveness/readiness check endpoint for your favourite Layer 7 load balancer - """ - return "I-AM-ALIVE\n" - - # Register Flask blueprints for whichever server roles are enabled... - # URL prefixing facilitates easy Layer 7 routing :) - # Linters complain about imports off the top level, but this is needed - # to prevent circular imports - from .blueprints import hashing, matching, curation - - if app.config.get("ROLE_HASHER", False): - app.register_blueprint(hashing.bp, url_prefix="/h") - - if app.config.get("ROLE_MATCHER", False): - app.register_blueprint(matching.bp, url_prefix="/m") - - if app.config.get("ROLE_CURATOR", False): - app.register_blueprint(curation.bp, url_prefix="/c") - - from . import models - - @app.cli.command("seed") - def seed_data(): - # TODO: This is a placeholder for where some useful seed data can be loaded; - # particularly important for development - bank = models.Bank(name="bad_stuff", enabled=True) - db.session.add(bank) - db.session.commit() - - return app diff --git a/open-media-match/src/OpenMediaMatch/app.py b/open-media-match/src/OpenMediaMatch/app.py index f61a26f0a..78061fad1 100644 --- a/open-media-match/src/OpenMediaMatch/app.py +++ b/open-media-match/src/OpenMediaMatch/app.py @@ -1,3 +1,75 @@ -from OpenMediaMatch import create_app +# Copyright (c) Meta Platforms, Inc. and affiliates. -app = create_app() +import os +import sys + +import flask +import flask_migrate + +from OpenMediaMatch import database +from OpenMediaMatch.blueprints import hashing, matching, curation + + +def create_app(): + """ + Create and configure the Flask app + """ + app = flask.Flask(__name__) + migrate = flask_migrate.Migrate() + + if "OMM_CONFIG" in os.environ: + app.config.from_envvar("OMM_CONFIG") + elif sys.argv[0].endswith("/flask"): # Default for flask CLI + # The devcontainer settings. If you are using the CLI outside + # the devcontainer and getting an error, just override the env + app.config.from_pyfile("/workspace/.devcontainer/omm_config.py") + else: + raise RuntimeError("No flask config given - try populating OMM_CONFIG env") + app.config.update( + SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"), + SQLALCHEMY_TRACK_MODIFICATIONS=False, + ) + + database.db.init_app(app) + migrate.init_app(app, database.db) + + @app.route("/") + def index(): + """ + Sanity check endpoint showing a basic status page + TODO: in development mode, this could show some useful additional info + """ + return flask.render_template( + "index.html.j2", production=app.config.get("PRODUCTION") + ) + + @app.route("/status") + def status(): + """ + Liveness/readiness check endpoint for your favourite Layer 7 load balancer + """ + return "I-AM-ALIVE\n" + + # Register Flask blueprints for whichever server roles are enabled... + # URL prefixing facilitates easy Layer 7 routing :) + # Linters complain about imports off the top level, but this is needed + # to prevent circular imports + + if app.config.get("ROLE_HASHER", False): + app.register_blueprint(hashing.bp, url_prefix="/h") + + if app.config.get("ROLE_MATCHER", False): + app.register_blueprint(matching.bp, url_prefix="/m") + + if app.config.get("ROLE_CURATOR", False): + app.register_blueprint(curation.bp, url_prefix="/c") + + @app.cli.command("seed") + def seed_data(): + # TODO: This is a placeholder for where some useful seed data can be loaded; + # particularly important for development + bank = database.Bank(name="bad_stuff", enabled=True) + database.db.session.add(bank) + database.db.session.commit() + + return app diff --git a/open-media-match/src/OpenMediaMatch/blueprints/curation.py b/open-media-match/src/OpenMediaMatch/blueprints/curation.py index 820b503aa..1d11007fd 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/curation.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/curation.py @@ -1,9 +1,8 @@ from flask import Blueprint -from flask import abort, request, current_app, jsonify +from flask import request, current_app, jsonify -from OpenMediaMatch import models +from OpenMediaMatch import database -import json bp = Blueprint("curation", __name__) @@ -11,7 +10,7 @@ @bp.route("/banks", methods=["GET"]) def banks_index(): banks = ( - current_app.db.session.execute(current_app.db.select(models.Bank)) + database.db.session.execute(current_app.db.select(database.Bank)) .scalars() .all() ) @@ -20,7 +19,7 @@ def banks_index(): @bp.route("/bank/", methods=["GET"]) def bank_show_by_id(bank_id: int): - bank = models.Bank.query.get(bank_id) + bank = database.Bank.query.get(bank_id) if not bank: return jsonify({"message": "bank not found"}), 404 return jsonify(bank) @@ -29,8 +28,8 @@ def bank_show_by_id(bank_id: int): @bp.route("/bank/", methods=["GET"]) def bank_show_by_name(bank_name: str): bank = ( - current_app.db.session.execute( - current_app.db.select(models.Bank).where(models.Bank.name == bank_name) + database.db.session.execute( + database.db.select(database.Bank).where(database.Bank.name == bank_name) ) .scalars() .all() @@ -43,16 +42,16 @@ def bank_create(): data = request.get_json() if not "name" in data: return jsonify({"message": "Field `name` is required"}), 400 - bank = models.Bank(name=data["name"], enabled=bool(data.get("enabled", True))) - current_app.db.session.add(bank) - current_app.db.session.commit() + bank = database.Bank(name=data["name"], enabled=bool(data.get("enabled", True))) + database.db.session.add(bank) + database.db.session.commit() return jsonify({"message": "Created successfully"}), 201 @bp.route("/bank/", methods=["PUT"]) def bank_update(bank_id: int): data = request.get_json() - bank = models.Bank.query.get(bank_id) + bank = database.Bank.query.get(bank_id) if not bank: return jsonify({"message": "bank not found"}), 404 @@ -61,15 +60,15 @@ def bank_update(bank_id: int): if "enabled" in data: bank.enabled = bool(data["enabled"]) - current_app.db.session.commit() + database.db.session.commit() return jsonify(bank) @bp.route("/bank/", methods=["DELETE"]) def bank_delete(bank_id: int): - bank = models.Bank.query.get(bank_id) + bank = database.Bank.query.get(bank_id) if not bank: return jsonify({"message": "bank not found"}), 404 - current_app.db.session.delete(bank) - current_app.db.session.commit() + database.db.session.delete(bank) + database.db.session.commit() return jsonify({"message": f"Bank {bank.name} ({bank.id}) deleted"}) diff --git a/open-media-match/src/OpenMediaMatch/blueprints/hashing.py b/open-media-match/src/OpenMediaMatch/blueprints/hashing.py index 67f995ffb..59b48dc6e 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/hashing.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/hashing.py @@ -17,12 +17,14 @@ from threatexchange.content_type.video import VideoContent from threatexchange.signal_type.signal_base import FileHasher, SignalType -from OpenMediaMatch import app_resources +from OpenMediaMatch.persistence import get_storage +from OpenMediaMatch.utils import abort_to_json bp = Blueprint("hashing", __name__) @bp.route("/hash") +@abort_to_json def hash_media(): """ Fetch content and return its hash. @@ -57,7 +59,7 @@ def hash_media(): return ret -def _parse_request_content_type(url_content_type: str) -> ContentType: +def _parse_request_content_type(url_content_type: str) -> t.Type[ContentType]: arg = request.args.get("content_type", "") if not arg: if url_content_type.lower().startswith("image"): @@ -71,20 +73,20 @@ def _parse_request_content_type(url_content_type: str) -> ContentType: "if you know the expected type, provide it with content_type", ) - storage = app_resources.get_storage() - content_type_config = storage.get_content_type_configs().get(arg) + content_type_config = get_storage().get_content_type_configs().get(arg) if content_type_config is None: - return {"message": f"no such content_type: '{arg}'"}, 400 + abort(400, f"no such content_type: '{arg}'") if not content_type_config.enabled: - return {"message": f"content_type {arg} is disabled"}, 400 + abort(400, f"content_type {arg} is disabled") return content_type_config.content_type -def _parse_request_signal_type(content_type: ContentType) -> t.Mapping[str, SignalType]: - storage = app_resources.get_storage() - signal_types = storage.get_enabled_signal_types_for_content_type(content_type) +def _parse_request_signal_type( + content_type: t.Type[ContentType], +) -> t.Mapping[str, t.Type[SignalType]]: + signal_types = get_storage().get_enabled_signal_types_for_content_type(content_type) if not signal_types: abort(500, "No signal types configured!") signal_type_args = request.args.get("types", None) diff --git a/open-media-match/src/OpenMediaMatch/blueprints/matching.py b/open-media-match/src/OpenMediaMatch/blueprints/matching.py index b53d0d176..b84435260 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/matching.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/matching.py @@ -7,11 +7,11 @@ from flask import Blueprint from flask import abort -from OpenMediaMatch.app_resources import ( +from OpenMediaMatch.utils import ( abort_to_json, require_request_param, - get_storage, ) +from OpenMediaMatch.persistence import get_storage bp = Blueprint("matching", __name__) diff --git a/open-media-match/src/OpenMediaMatch/models.py b/open-media-match/src/OpenMediaMatch/database.py similarity index 51% rename from open-media-match/src/OpenMediaMatch/models.py rename to open-media-match/src/OpenMediaMatch/database.py index 3801d6fcd..2e4611dde 100644 --- a/open-media-match/src/OpenMediaMatch/models.py +++ b/open-media-match/src/OpenMediaMatch/database.py @@ -1,11 +1,25 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +""" +SQLAlchemy backed relational data. + +We are trying to make all the persistent data accessed instead +through the storage interface. However, during development, that's +slower than just slinging sql on the tables, so you may see direct +references which are meant to be reaped at some future time. +""" + from dataclasses import dataclass -from OpenMediaMatch import db + +import flask_sqlalchemy + +# Initializing this at import time seems to be the only correct +# way to do this +db = flask_sqlalchemy.SQLAlchemy() @dataclass -class Bank(db.Model): +class Bank(db.Model): # type: ignore[name-defined] __tablename__ = "banks" id: int = db.Column(db.Integer, primary_key=True, autoincrement=True) name: str = db.Column(db.String(255), nullable=False) @@ -13,7 +27,7 @@ class Bank(db.Model): @dataclass -class Hash(db.Model): # Should this be Signal? +class Hash(db.Model): # type: ignore[name-defined] # Should this be Signal? __tablename__ = "hashes" id = db.Column(db.Integer, primary_key=True, autoincrement=True) enabled = db.Column(db.Boolean, nullable=False) diff --git a/open-media-match/src/OpenMediaMatch/persistence.py b/open-media-match/src/OpenMediaMatch/persistence.py new file mode 100644 index 000000000..0d4392636 --- /dev/null +++ b/open-media-match/src/OpenMediaMatch/persistence.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +A shim around persistence layer of the OMM instance. + +This includes relational data, blob storage, etc. + +It doesn't include logging (just use current_app.logger). + +We haven't made all of the hard decisions on the storage yet, and +think future deployers may change their mind about which backends to +use. We know we are going to have more than relational data, so +SQLAlchemy isn't going to be enough. Thus an even more abstract +accessor. + + +""" + +import typing as t + +from flask import g + +from OpenMediaMatch.storage.interface import IUnifiedStore +from OpenMediaMatch.storage.default import DefaultOMMStore + +T = t.TypeVar("T", bound=t.Callable[..., t.Any]) + + +def get_storage() -> IUnifiedStore: + """ + Get the storage object, which is just a wrapper around the real storage. + """ + if "storage" not in g: + # dougneal, you'll need to eventually add constructor arguments + # for this to pass in the postgres/database object. We're just + # hiding flask bits from pytx bits + g.storage = DefaultOMMStore() + return g.storage diff --git a/open-media-match/src/OpenMediaMatch/storage/interface.py b/open-media-match/src/OpenMediaMatch/storage/interface.py index 6dc68ff5d..b956007fb 100644 --- a/open-media-match/src/OpenMediaMatch/storage/interface.py +++ b/open-media-match/src/OpenMediaMatch/storage/interface.py @@ -77,7 +77,7 @@ def get_enabled_signal_types(self) -> t.Mapping[str, t.Type[SignalType]]: @t.final def get_enabled_signal_types_for_content_type( self, content_type: t.Type[ContentType] - ) -> t.Mapping[str, SignalType]: + ) -> t.Mapping[str, t.Type[SignalType]]: """Helper shortcut for getting enabled types for a piece of content""" return { k: v.signal_type diff --git a/open-media-match/src/OpenMediaMatch/storage/mocked.py b/open-media-match/src/OpenMediaMatch/storage/mocked.py index 7e54bb56e..cc81e52f8 100644 --- a/open-media-match/src/OpenMediaMatch/storage/mocked.py +++ b/open-media-match/src/OpenMediaMatch/storage/mocked.py @@ -30,10 +30,9 @@ def get_exchange_type_configs(self) -> t.Mapping[str, TSignalExchangeAPICls]: return {e.get_name(): e for e in (StaticSampleSignalExchangeAPI,)} def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]: - return { - s.get_name(): interface.SignalTypeConfig(True, s) - for s in (PdqSignal, VideoMD5Signal) - } + # Needed to bamboozle mypy into working + s_types: t.Sequence[t.Type[SignalType]] = (PdqSignal, VideoMD5Signal) + return {s.get_name(): interface.SignalTypeConfig(True, s) for s in s_types} def get_signal_type_index( self, signal_type: type[SignalType] diff --git a/open-media-match/src/OpenMediaMatch/tests/test_api.py b/open-media-match/src/OpenMediaMatch/tests/test_api.py index 57d4bf919..04a8976f3 100644 --- a/open-media-match/src/OpenMediaMatch/tests/test_api.py +++ b/open-media-match/src/OpenMediaMatch/tests/test_api.py @@ -1,7 +1,7 @@ import os import pytest -from OpenMediaMatch import create_app +from OpenMediaMatch.app import create_app @pytest.fixture() diff --git a/open-media-match/src/OpenMediaMatch/app_resources.py b/open-media-match/src/OpenMediaMatch/utils.py similarity index 71% rename from open-media-match/src/OpenMediaMatch/app_resources.py rename to open-media-match/src/OpenMediaMatch/utils.py index 1c2ce8c5d..c2b3f9d20 100644 --- a/open-media-match/src/OpenMediaMatch/app_resources.py +++ b/open-media-match/src/OpenMediaMatch/utils.py @@ -13,24 +13,9 @@ from flask import g, abort, request from werkzeug.exceptions import HTTPException -from OpenMediaMatch.storage.interface import IUnifiedStore -from OpenMediaMatch.storage.default import DefaultOMMStore - T = t.TypeVar("T", bound=t.Callable[..., t.Any]) -def get_storage() -> IUnifiedStore: - """ - Get the storage object, which is just a wrapper around the real storage. - """ - if "storage" not in g: - # dougneal, you'll need to eventually add constructor arguments - # for this to pass in the postgres/database object. We're just - # hiding flask bits from pytx bits - g.storage = DefaultOMMStore() - return g.storage - - def abort_to_json(fn: T) -> T: """ Wrap json endpoints to turn abort("message", code) to json