diff --git a/alembic/versions/20230810_0df58829fc1a_add_discovery_service_tables.py b/alembic/versions/20230810_0df58829fc1a_add_discovery_service_tables.py new file mode 100644 index 0000000000..70b14f1dcc --- /dev/null +++ b/alembic/versions/20230810_0df58829fc1a_add_discovery_service_tables.py @@ -0,0 +1,151 @@ +"""Add discovery service tables + +Revision ID: 0df58829fc1a +Revises: 2f1a51aa0ee8 +Create Date: 2023-08-10 15:49:36.784169+00:00 + +""" +import sqlalchemy as sa + +from alembic import op +from api.discovery.opds_registration import OpdsRegistrationService +from core.migration.migrate_external_integration import ( + _migrate_external_integration, + get_configuration_settings, + get_integrations, + get_library_for_integration, +) +from core.migration.util import drop_enum, pg_update_enum + +# revision identifiers, used by Alembic. +revision = "0df58829fc1a" +down_revision = "2f1a51aa0ee8" +branch_labels = None +depends_on = None + +old_goals_enum = [ + "PATRON_AUTH_GOAL", + "LICENSE_GOAL", +] + +new_goals_enum = old_goals_enum + ["DISCOVERY_GOAL"] + + +def upgrade() -> None: + op.create_table( + "discovery_service_registrations", + sa.Column( + "status", + sa.Enum("SUCCESS", "FAILURE", name="registrationstatus"), + nullable=False, + ), + sa.Column( + "stage", + sa.Enum("TESTING", "PRODUCTION", name="registrationstage"), + nullable=False, + ), + sa.Column("web_client", sa.Unicode(), nullable=True), + sa.Column("short_name", sa.Unicode(), nullable=True), + sa.Column("shared_secret", sa.Unicode(), nullable=True), + sa.Column("integration_id", sa.Integer(), nullable=False), + sa.Column("library_id", sa.Integer(), nullable=False), + sa.Column("vendor_id", sa.Unicode(), nullable=True), + sa.ForeignKeyConstraint( + ["integration_id"], ["integration_configurations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["library_id"], ["libraries.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("integration_id", "library_id"), + ) + pg_update_enum( + op, + "integration_configurations", + "goal", + "goals", + old_goals_enum, + new_goals_enum, + ) + + # Migrate data + connection = op.get_bind() + external_integrations = get_integrations(connection, "discovery") + for external_integration in external_integrations: + # This should always be the case, but we want to make sure + assert external_integration.protocol == "OPDS Registration" + + # Create the settings and library settings dicts from the configurationsettings + settings_dict, library_settings, self_test_result = get_configuration_settings( + connection, external_integration + ) + + # Write the configurationsettings into the integration_configurations table + integration_configuration_id = _migrate_external_integration( + connection, + external_integration, + OpdsRegistrationService, + "DISCOVERY_GOAL", + settings_dict, + self_test_result, + ) + + # Get the libraries that are associated with this external integration + interation_libraries = get_library_for_integration( + connection, external_integration.id + ) + + vendor_id = settings_dict.get("vendor_id") + + # Write the library settings into the discovery_service_registrations table + for library in interation_libraries: + library_id = library.library_id + library_settings_dict = library_settings[library_id] + + status = library_settings_dict.get("library-registration-status") + if status is None: + status = "FAILURE" + else: + status = status.upper() + + stage = library_settings_dict.get("library-registration-stage") + if stage is None: + stage = "TESTING" + else: + stage = stage.upper() + + web_client = library_settings_dict.get("library-registration-web-client") + short_name = library_settings_dict.get("username") + shared_secret = library_settings_dict.get("password") + + connection.execute( + "insert into discovery_service_registrations " + "(status, stage, web_client, short_name, shared_secret, integration_id, library_id, vendor_id) " + "values (%s, %s, %s, %s, %s, %s, %s, %s)", + ( + status, + stage, + web_client, + short_name, + shared_secret, + integration_configuration_id, + library_id, + vendor_id, + ), + ) + + +def downgrade() -> None: + connection = op.get_bind() + connection.execute( + "DELETE from integration_configurations where goal = %s", "DISCOVERY_GOAL" + ) + + op.drop_table("discovery_service_registrations") + drop_enum(op, "registrationstatus") + drop_enum(op, "registrationstage") + pg_update_enum( + op, + "integration_configurations", + "goal", + "goals", + new_goals_enum, + old_goals_enum, + ) diff --git a/api/admin/controller/discovery_service_library_registrations.py b/api/admin/controller/discovery_service_library_registrations.py index 85f21ec0a0..61ab9047a0 100644 --- a/api/admin/controller/discovery_service_library_registrations.py +++ b/api/admin/controller/discovery_service_library_registrations.py @@ -1,69 +1,84 @@ +from __future__ import annotations + import json +from typing import Any, Dict import flask from flask import Response, url_for from flask_babel import lazy_gettext as _ +from sqlalchemy import select +from sqlalchemy.orm import Session -from api.admin.controller.settings import SettingsController +from api.admin.controller.base import AdminPermissionsControllerMixin from api.admin.problem_details import MISSING_SERVICE, NO_SUCH_LIBRARY -from api.registration.registry import Registration, RemoteRegistry -from core.model import ExternalIntegration, Library, get_one -from core.util.http import HTTP -from core.util.problem_detail import ProblemDetail +from api.controller import CirculationManager +from api.discovery.opds_registration import OpdsRegistrationService +from api.integration.registry.discovery import DiscoveryRegistry +from core.integration.goals import Goals +from core.model import IntegrationConfiguration, Library, get_one +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStage, +) +from core.problem_details import INVALID_INPUT +from core.util.problem_detail import ProblemDetail, ProblemError -class DiscoveryServiceLibraryRegistrationsController(SettingsController): +class DiscoveryServiceLibraryRegistrationsController(AdminPermissionsControllerMixin): """List the libraries that have been registered with a specific - RemoteRegistry, and allow the admin to register a library with - a RemoteRegistry. - - :param registration_class: Mock class to use instead of Registration. + OpdsRegistrationService, and allow the admin to register a library with + a OpdsRegistrationService. """ - def __init__(self, manager): - super().__init__(manager) - self.goal = ExternalIntegration.DISCOVERY_GOAL + def __init__(self, manager: CirculationManager): + self._db: Session = manager._db + self.goal = Goals.DISCOVERY_GOAL + self.registry = DiscoveryRegistry() def process_discovery_service_library_registrations( self, - registration_class=None, - do_get=HTTP.debuggable_get, - do_post=HTTP.debuggable_post, - ): - registration_class = registration_class or Registration + ) -> Response | Dict[str, Any] | ProblemDetail: self.require_system_admin() - if flask.request.method == "GET": - return self.process_get(do_get) - else: - return self.process_post(registration_class, do_get, do_post) + try: + if flask.request.method == "GET": + return self.process_get() + else: + return self.process_post() + except ProblemError as e: + self._db.rollback() + return e.problem_detail - def process_get(self, do_get=HTTP.debuggable_get): + def process_get(self) -> Dict[str, Any]: """Make a list of all discovery services, each with the list of libraries registered with that service and the status of the registration.""" services = [] - for registry in RemoteRegistry.for_protocol_and_goal( - self._db, ExternalIntegration.OPDS_REGISTRATION, self.goal - ): - result = registry.fetch_registration_document(do_get=do_get) - if isinstance(result, ProblemDetail): - # Unlike most cases like this, a ProblemDetail doesn't + integration_query = select(IntegrationConfiguration).where( + IntegrationConfiguration.goal == self.goal, + IntegrationConfiguration.protocol + == self.registry.get_protocol(OpdsRegistrationService), + ) + integrations = self._db.scalars(integration_query).all() + for integration in integrations: + registry = OpdsRegistrationService.for_integration(self._db, integration) + try: + access_problem = None + ( + terms_of_service_link, + terms_of_service_html, + ) = registry.fetch_registration_document() + except ProblemError as e: + # Unlike most cases like this, a ProblemError doesn't # mean the whole request is ruined -- just that one of # the discovery services isn't working. Turn the # ProblemDetail into a JSON object and return it for # handling on the client side. - access_problem = json.loads(result.response[0]) + access_problem = json.loads(e.problem_detail.response[0]) terms_of_service_link = terms_of_service_html = None - else: - access_problem = None - terms_of_service_link, terms_of_service_html = result - libraries = [] - for registration in registry.registrations: - library_info = self.get_library_info(registration) - if library_info: - libraries.append(library_info) + + libraries = [self.get_library_info(r) for r in registry.registrations] services.append( dict( @@ -77,59 +92,77 @@ def process_get(self, do_get=HTTP.debuggable_get): return dict(library_registrations=services) - def get_library_info(self, registration): + def get_library_info( + self, registration: DiscoveryServiceRegistration + ) -> Dict[str, str]: """Find the relevant information about the library which the user is trying to register""" - library = registration.library - library_info = dict(short_name=library.short_name) - status = registration.status_field.value - stage_field = registration.stage_field.value - if stage_field: - library_info["stage"] = stage_field + library_info = {"short_name": str(registration.library.short_name)} + status = registration.status + stage = registration.stage + if stage: + library_info["stage"] = stage.value if status: - library_info["status"] = status - return library_info + library_info["status"] = status.value - def look_up_registry(self, integration_id): - """Find the RemoteRegistry that the user is trying to register the library with, + return library_info + + def look_up_registry(self, integration_id: int) -> OpdsRegistrationService: + """Find the OpdsRegistrationService that the user is trying to register the library with, and check that it actually exists.""" - registry = RemoteRegistry.for_integration_id( - self._db, integration_id, self.goal - ) + registry = OpdsRegistrationService.for_integration(self._db, integration_id) if not registry: - return MISSING_SERVICE + raise ProblemError(problem_detail=MISSING_SERVICE) return registry - def look_up_library(self, library_short_name): + def look_up_library(self, library_short_name: str) -> Library: """Find the library the user is trying to register, and check that it actually exists.""" library = get_one(self._db, Library, short_name=library_short_name) if not library: - return NO_SUCH_LIBRARY + raise ProblemError(problem_detail=NO_SUCH_LIBRARY) return library - def process_post(self, registration_class, do_get, do_post): - """Attempt to register a library with a RemoteRegistry.""" + def process_post(self) -> Response: + """Attempt to register a library with a OpdsRegistrationService.""" - integration_id = flask.request.form.get("integration_id") + integration_id = flask.request.form.get("integration_id", type=int) library_short_name = flask.request.form.get("library_short_name") - stage = ( - flask.request.form.get("registration_stage") or Registration.TESTING_STAGE - ) + stage_string = flask.request.form.get("registration_stage") + if integration_id is None: + raise ProblemError( + problem_detail=INVALID_INPUT.detailed( + "Missing required parameter 'integration_id'" + ) + ) registry = self.look_up_registry(integration_id) - if isinstance(registry, ProblemDetail): - return registry + if library_short_name is None: + raise ProblemError( + problem_detail=INVALID_INPUT.detailed( + "Missing required parameter 'library_short_name'" + ) + ) library = self.look_up_library(library_short_name) - if isinstance(library, ProblemDetail): - return library - registration = registration_class(registry, library) - registered = registration.push(stage, url_for, do_get=do_get, do_post=do_post) - if isinstance(registered, ProblemDetail): - return registered + if stage_string is None: + raise ProblemError( + problem_detail=INVALID_INPUT.detailed( + "Missing required parameter 'registration_stage'" + ) + ) + try: + stage = RegistrationStage(stage_string) + except ValueError: + raise ProblemError( + problem_detail=INVALID_INPUT.detailed( + f"'{stage_string}' is not a valid registration stage" + ) + ) + + registry.register_library(library, stage, url_for) return Response(str(_("Success")), 200) diff --git a/api/admin/controller/discovery_services.py b/api/admin/controller/discovery_services.py index e52757ada7..ad2b9b3eda 100644 --- a/api/admin/controller/discovery_services.py +++ b/api/admin/controller/discovery_services.py @@ -1,153 +1,135 @@ +from typing import Union + import flask from flask import Response -from flask_babel import lazy_gettext as _ +from sqlalchemy import and_, select -from api.admin.controller.settings import SettingsController +from api.admin.controller.base import AdminPermissionsControllerMixin +from api.admin.controller.integration_settings import IntegrationSettingsController +from api.admin.form_data import ProcessFormData from api.admin.problem_details import ( - CANNOT_CHANGE_PROTOCOL, INCOMPLETE_CONFIGURATION, - MISSING_SERVICE, + INTEGRATION_URL_ALREADY_IN_USE, NO_PROTOCOL_FOR_NEW_SERVICE, + UNKNOWN_PROTOCOL, +) +from api.discovery.opds_registration import OpdsRegistrationService +from api.integration.registry.discovery import DiscoveryRegistry +from core.model import ( + IntegrationConfiguration, + json_serializer, + site_configuration_has_changed, ) -from api.registration.registry import RemoteRegistry -from core.model import ExternalIntegration, get_one_or_create -from core.util.problem_detail import ProblemDetail - - -class DiscoveryServicesController(SettingsController): - def __init__(self, manager): - super().__init__(manager) - self.opds_registration = ExternalIntegration.OPDS_REGISTRATION - self.protocols = [ - { - "name": self.opds_registration, - "sitewide": True, - "settings": [ - { - "key": ExternalIntegration.URL, - "label": _("URL"), - "required": True, - "format": "url", - }, - ], - "supports_registration": True, - "supports_staging": True, - } - ] - self.goal = ExternalIntegration.DISCOVERY_GOAL - - def process_discovery_services(self): +from core.util.problem_detail import ProblemDetail, ProblemError + + +class DiscoveryServicesController( + IntegrationSettingsController[OpdsRegistrationService], + AdminPermissionsControllerMixin, +): + def default_registry(self) -> DiscoveryRegistry: + return DiscoveryRegistry() + + def process_discovery_services(self) -> Union[Response, ProblemDetail]: self.require_system_admin() if flask.request.method == "GET": return self.process_get() else: return self.process_post() - def process_get(self): - registries = list( - RemoteRegistry.for_protocol_and_goal( - self._db, self.opds_registration, self.goal - ) - ) - if not registries: - # There are no registries at all. Set up the default - # library registry. + def process_get(self) -> Response: + if len(self.configured_services) == 0: self.set_up_default_registry() - services = self._get_integration_info(self.goal, self.protocols) - return dict( - discovery_services=services, - protocols=self.protocols, + return Response( + json_serializer( + { + "discovery_services": self.configured_services, + "protocols": list(self.protocols.values()), + } + ), + status=200, + mimetype="application/json", ) - def set_up_default_registry(self): + def set_up_default_registry(self) -> None: """Set up the default library registry; no other registries exist yet.""" - - service, is_new = get_one_or_create( - self._db, - ExternalIntegration, - protocol=self.opds_registration, - goal=self.goal, + protocol = self.registry.get_protocol(OpdsRegistrationService) + assert protocol is not None + default_registry = self.create_new_service( + name=OpdsRegistrationService.DEFAULT_LIBRARY_REGISTRY_NAME, + protocol=protocol, ) - if is_new: - service.url = RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL - service.name = RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_NAME - - def process_post(self): - name = flask.request.form.get("name") - protocol = flask.request.form.get("protocol") - fields = {"name": name, "protocol": protocol} - form_field_error = self.validate_form_fields(**fields) - if form_field_error: - return form_field_error - - id = flask.request.form.get("id") - is_new = False - if id: - # Find an existing service in order to edit it - service = self.look_up_service_from_registry(protocol, id) - else: - service, is_new = self._create_integration( - self.protocols, protocol, self.goal - ) - - if isinstance(service, ProblemDetail): - self._db.rollback() - return service - - name_error = self.check_name_unique(service, name) - if name_error: + settings = OpdsRegistrationService.settings_class()( + url=OpdsRegistrationService.DEFAULT_LIBRARY_REGISTRY_URL + ) + default_registry.settings_dict = settings.dict() + + def process_post(self) -> Union[Response, ProblemDetail]: + try: + form_data = flask.request.form + protocol = form_data.get("protocol", None, str) + id = form_data.get("id", None, int) + name = form_data.get("name", None, str) + + if protocol is None and id is None: + raise ProblemError(NO_PROTOCOL_FOR_NEW_SERVICE) + + if protocol is None or protocol not in self.registry: + self.log.warning(f"Unknown service protocol: {protocol}") + raise ProblemError(UNKNOWN_PROTOCOL) + + if id is not None: + # Find an existing service to edit + service = self.get_existing_service(id, name, protocol) + response_code = 200 + else: + # Create a new service + if name is None: + raise ProblemError(INCOMPLETE_CONFIGURATION) + service = self.create_new_service(name, protocol) + response_code = 201 + + impl_cls = self.registry[protocol] + settings_class = impl_cls.settings_class() + validated_settings = ProcessFormData.get_settings(settings_class, form_data) + service.settings_dict = validated_settings.dict() + + # Make sure that the URL of the service is unique. + self.check_url_unique(service, validated_settings.url) + + # Trigger a site configuration change + site_configuration_has_changed(self._db) + + except ProblemError as e: self._db.rollback() - return name_error + return e.problem_detail - url = flask.request.form.get("url") - url_not_unique = self.check_url_unique(service, url, protocol, self.goal) - if url_not_unique: - self._db.rollback() - return url_not_unique + return Response(str(service.id), response_code) - protocol_error = self.set_protocols(service, protocol) - if protocol_error: + def process_delete(self, service_id: int) -> Union[Response, ProblemDetail]: + self.require_system_admin() + try: + return self.delete_service(service_id) + except ProblemError as e: self._db.rollback() - return protocol_error - - service.name = name - - if is_new: - return Response(str(service.id), 201) - else: - return Response(str(service.id), 200) - - def validate_form_fields(self, **fields): - """The 'name' and 'protocol' fields cannot be blank, and the protocol must - be selected from the list of recognized protocols. The URL must be valid.""" - - name = fields.get("name") - protocol = fields.get("protocol") - if not name: - return INCOMPLETE_CONFIGURATION - if not protocol: - return NO_PROTOCOL_FOR_NEW_SERVICE - - error = self.validate_protocol() - if error: - return error - - wrong_format = self.validate_formats() - if wrong_format: - return wrong_format - - def look_up_service_from_registry(self, protocol, id): - """Find an existing service, and make sure that the user is not trying to edit - its protocol.""" - - registry = RemoteRegistry.for_integration_id(self._db, id, self.goal) - if not registry: - return MISSING_SERVICE - service = registry.integration - if protocol != service.protocol: - return CANNOT_CHANGE_PROTOCOL - return service - - def process_delete(self, service_id): - return self._delete_integration(service_id, ExternalIntegration.DISCOVERY_GOAL) + return e.problem_detail + + def check_url_unique(self, service: IntegrationConfiguration, url: str) -> None: + """Check that the URL of the service is unique. + + :raises ProblemDetail: If the URL is not unique. + """ + + existing_service = self._db.scalars( + select(IntegrationConfiguration).where( + and_( + IntegrationConfiguration.goal == service.goal, + IntegrationConfiguration.protocol == service.protocol, + IntegrationConfiguration.settings_dict.contains({"url": url}), + IntegrationConfiguration.id != service.id, + ) + ) + ).one_or_none() + if existing_service: + raise ProblemError(problem_detail=INTEGRATION_URL_ALREADY_IN_USE) diff --git a/api/admin/controller/integration_settings.py b/api/admin/controller/integration_settings.py new file mode 100644 index 0000000000..b1e8840f2f --- /dev/null +++ b/api/admin/controller/integration_settings.py @@ -0,0 +1,310 @@ +import json +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, NamedTuple, Optional, Type, TypeVar + +import flask +from flask import Response + +from api.admin.problem_details import ( + CANNOT_CHANGE_PROTOCOL, + INTEGRATION_NAME_ALREADY_IN_USE, + MISSING_SERVICE, + NO_SUCH_LIBRARY, +) +from api.controller import CirculationManager +from core.integration.base import ( + HasIntegrationConfiguration, + HasLibraryIntegrationConfiguration, +) +from core.integration.registry import IntegrationRegistry +from core.integration.settings import BaseSettings +from core.model import ( + IntegrationConfiguration, + IntegrationLibraryConfiguration, + Library, + create, + get_one, +) +from core.problem_details import INTERNAL_SERVER_ERROR, INVALID_INPUT +from core.util.cache import memoize +from core.util.problem_detail import ProblemError + +T = TypeVar("T", bound=HasIntegrationConfiguration) + + +class UpdatedLibrarySettingsTuple(NamedTuple): + integration: IntegrationLibraryConfiguration + settings: Dict[str, Any] + + +class ChangedLibrariesTuple(NamedTuple): + new: List[UpdatedLibrarySettingsTuple] + updated: List[UpdatedLibrarySettingsTuple] + removed: List[IntegrationLibraryConfiguration] + + +class IntegrationSettingsController(ABC, Generic[T]): + def __init__( + self, + manager: CirculationManager, + registry: Optional[IntegrationRegistry[T]] = None, + ): + self._db = manager._db + self.registry = registry or self.default_registry() + self.log = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}") + + @abstractmethod + def default_registry(self) -> IntegrationRegistry[T]: + """ + Return the IntegrationRegistry for the controller's goal. + """ + ... + + @memoize(ttls=1800) + def _cached_protocols(self) -> Dict[str, Dict[str, Any]]: + """Cached result for integration implementations""" + protocols = {} + for name, api in self.registry: + protocol = { + "name": name, + "label": api.label(), + "description": api.description(), + "settings": api.settings_class().configuration_form(self._db), + } + if issubclass(api, HasLibraryIntegrationConfiguration): + protocol[ + "library_settings" + ] = api.library_settings_class().configuration_form(self._db) + protocol.update(api.protocol_details(self._db)) + protocols[name] = protocol + return protocols + + @property + def protocols(self) -> Dict[str, Dict[str, Any]]: + """Use a property for implementations to allow expiring cached results""" + return self._cached_protocols() + + @property + def configured_services(self) -> List[Dict[str, Any]]: + """Return a list of all currently configured services for the controller's goal.""" + configured_services = [] + for service in ( + self._db.query(IntegrationConfiguration) + .filter(IntegrationConfiguration.goal == self.registry.goal) + .order_by(IntegrationConfiguration.name) + ): + if service.protocol not in self.registry: + self.log.warning( + f"Unknown protocol: {service.protocol} for goal {self.registry.goal}" + ) + continue + + service_info = { + "id": service.id, + "name": service.name, + "protocol": service.protocol, + "settings": service.settings_dict, + } + + api = self.registry[service.protocol] + if issubclass(api, HasLibraryIntegrationConfiguration): + libraries = [] + for library_settings in service.library_configurations: + library_info = {"short_name": library_settings.library.short_name} + library_info.update(library_settings.settings_dict) + libraries.append(library_info) + service_info["libraries"] = libraries + + configured_services.append(service_info) + return configured_services + + def get_existing_service( + self, service_id: int, name: Optional[str], protocol: str + ) -> IntegrationConfiguration: + """ + Query for an existing service to edit. + + Raises ProblemError if the service doesn't exist, or if the protocol + doesn't match. If the name is provided, the service will be renamed if + necessary and a ProblemError will be raised if the name is already in + use. + """ + service: Optional[IntegrationConfiguration] = get_one( + self._db, + IntegrationConfiguration, + id=service_id, + goal=self.registry.goal, + ) + if service is None: + raise ProblemError(MISSING_SERVICE) + if service.protocol != protocol: + raise ProblemError(CANNOT_CHANGE_PROTOCOL) + if name is not None and service.name != name: + service_with_name = get_one(self._db, IntegrationConfiguration, name=name) + if service_with_name is not None: + raise ProblemError(INTEGRATION_NAME_ALREADY_IN_USE) + service.name = name + + return service + + def create_new_service(self, name: str, protocol: str) -> IntegrationConfiguration: + """ + Create a new service. + + Returns the new IntegrationConfiguration on success and raises a ProblemError + on any errors. + """ + # Create a new service + service_with_name = get_one(self._db, IntegrationConfiguration, name=name) + if service_with_name is not None: + raise ProblemError(INTEGRATION_NAME_ALREADY_IN_USE) + + new_service, _ = create( + self._db, + IntegrationConfiguration, + protocol=protocol, + goal=self.registry.goal, + name=name, + ) + if not new_service: + raise ProblemError( + INTERNAL_SERVER_ERROR.detailed( + f"Could not create the '{self.registry.goal.value}' integration." + ) + ) + return new_service + + def get_library(self, short_name: str) -> Library: + """ + Get a library by its short name. + """ + library: Optional[Library] = get_one(self._db, Library, short_name=short_name) + if library is None: + raise ProblemError( + NO_SUCH_LIBRARY.detailed( + f"You attempted to add the integration to {short_name}, but it does not exist.", + ) + ) + return library + + def create_library_settings( + self, service: IntegrationConfiguration, short_name: str + ) -> IntegrationLibraryConfiguration: + """ + Create a new IntegrationLibraryConfiguration for the given IntegrationConfiguration and library. + """ + library = self.get_library(short_name) + library_settings, _ = create( + self._db, + IntegrationLibraryConfiguration, + library=library, + parent_id=service.id, + ) + if not library_settings: + raise ProblemError( + INTERNAL_SERVER_ERROR.detailed( + "Could not create the library configuration" + ) + ) + return library_settings + + def get_changed_libraries( + self, service: IntegrationConfiguration, libraries_data: str + ) -> ChangedLibrariesTuple: + """ + Return a tuple of lists of libraries that have had their library settings + added, updated, or removed. + """ + libraries = json.loads(libraries_data) + existing_library_settings = { + c.library.short_name: c for c in service.library_configurations + } + submitted_library_settings = {l.get("short_name"): l for l in libraries} + + removed = [ + existing_library_settings[library] + for library in existing_library_settings.keys() + - submitted_library_settings.keys() + ] + updated = [ + UpdatedLibrarySettingsTuple( + integration=existing_library_settings[library], + settings=submitted_library_settings[library], + ) + for library in existing_library_settings.keys() + & submitted_library_settings.keys() + if library and self.get_library(library) + ] + new = [ + UpdatedLibrarySettingsTuple( + integration=self.create_library_settings(service, library), + settings=submitted_library_settings[library], + ) + for library in submitted_library_settings.keys() + - existing_library_settings.keys() + ] + return ChangedLibrariesTuple(new=new, updated=updated, removed=removed) + + def process_deleted_libraries( + self, removed: List[IntegrationLibraryConfiguration] + ) -> None: + """ + Delete any IntegrationLibraryConfigurations that were removed. + """ + for library_integration in removed: + self._db.delete(library_integration) + + def process_updated_libraries( + self, + libraries: List[UpdatedLibrarySettingsTuple], + settings_class: Type[BaseSettings], + ) -> None: + """ + Update the settings for any IntegrationLibraryConfigurations that were updated or added. + """ + for integration, settings in libraries: + validated_settings = settings_class(**settings) + integration.settings_dict = validated_settings.dict() + + def process_libraries( + self, + service: IntegrationConfiguration, + libraries_data: str, + settings_class: Type[BaseSettings], + ) -> None: + """ + Process the library settings for a service. This will create new + IntegrationLibraryConfigurations for any libraries that don't have one, + update the settings for any that do, and delete any that were removed. + """ + new, updated, removed = self.get_changed_libraries(service, libraries_data) + + self.process_deleted_libraries(removed) + self.process_updated_libraries(new, settings_class) + self.process_updated_libraries(updated, settings_class) + + def delete_service(self, service_id: int) -> Response: + """ + Delete a service. + + Returns a Response on success suitable to return to the frontend + and raises a ProblemError on any errors. + """ + if flask.request.method != "DELETE": + raise ProblemError( + problem_detail=INVALID_INPUT.detailed( + "Method not allowed for this endpoint" + ) + ) + + integration = get_one( + self._db, + IntegrationConfiguration, + id=service_id, + goal=self.registry.goal, + ) + if not integration: + raise ProblemError(problem_detail=MISSING_SERVICE) + self._db.delete(integration) + return Response("Deleted", 200) diff --git a/api/admin/controller/metadata_services.py b/api/admin/controller/metadata_services.py index 719d3df6ca..24e0003cf0 100644 --- a/api/admin/controller/metadata_services.py +++ b/api/admin/controller/metadata_services.py @@ -10,7 +10,6 @@ from api.novelist import NoveListAPI from api.nyt import NYTBestSellerAPI from core.model import ExternalIntegration, get_one -from core.util.http import HTTP from core.util.problem_detail import ProblemDetail @@ -63,7 +62,7 @@ def find_protocol_class(self, integration): "No metadata self-test class for protocol %s" % integration.protocol ) - def process_post(self, do_get=HTTP.debuggable_get, do_post=HTTP.debuggable_post): + def process_post(self): name = flask.request.form.get("name") protocol = flask.request.form.get("protocol") url = flask.request.form.get("url") diff --git a/api/admin/controller/patron_auth_services.py b/api/admin/controller/patron_auth_services.py index 417ba72844..ec58c28e5e 100644 --- a/api/admin/controller/patron_auth_services.py +++ b/api/admin/controller/patron_auth_services.py @@ -1,70 +1,35 @@ -import json -import logging -from itertools import chain -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import List, Set, Type, Union import flask from flask import Response -from flask_babel import lazy_gettext as _ from api.admin.controller.base import AdminPermissionsControllerMixin +from api.admin.controller.integration_settings import ( + IntegrationSettingsController, + UpdatedLibrarySettingsTuple, +) from api.admin.form_data import ProcessFormData from api.admin.problem_details import * from api.authentication.base import AuthenticationProvider from api.authentication.basic import BasicAuthenticationProvider -from api.controller import CirculationManager from api.integration.registry.patron_auth import PatronAuthRegistry from core.integration.goals import Goals from core.integration.registry import IntegrationRegistry from core.integration.settings import BaseSettings -from core.model import ( - Library, - create, - get_one, - json_serializer, - site_configuration_has_changed, -) +from core.model import json_serializer, site_configuration_has_changed from core.model.integration import ( IntegrationConfiguration, IntegrationLibraryConfiguration, ) -from core.util.cache import memoize from core.util.problem_detail import ProblemDetail, ProblemError -class PatronAuthServicesController(AdminPermissionsControllerMixin): - def __init__( - self, - manager: CirculationManager, - auth_registry: Optional[IntegrationRegistry[AuthenticationProvider]] = None, - ): - self._db = manager._db - self.registry = auth_registry if auth_registry else PatronAuthRegistry() - self.type = _("patron authentication service") - self.log = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}") - self._apis = None - - @memoize(ttls=1800) - def _cached_protocols(self) -> Dict[str, Dict[str, Any]]: - """Cached result for integration implementations""" - protocols = {} - for name, api in self.registry: - - protocols[name] = { - "name": name, - "label": api.label(), - "description": api.description(), - "settings": api.settings_class().configuration_form(self._db), - "library_settings": api.library_settings_class().configuration_form( - self._db - ), - } - return protocols - - @property - def protocols(self) -> Dict[str, Dict[str, Any]]: - """Use a property for implementations to allow expiring cached results""" - return self._cached_protocols() +class PatronAuthServicesController( + IntegrationSettingsController[AuthenticationProvider], + AdminPermissionsControllerMixin, +): + def default_registry(self) -> IntegrationRegistry[AuthenticationProvider]: + return PatronAuthRegistry() @property def basic_auth_protocols(self) -> Set[str]: @@ -74,36 +39,6 @@ def basic_auth_protocols(self) -> Set[str]: if issubclass(api, BasicAuthenticationProvider) } - @property - def configured_services(self) -> List[Dict[str, Any]]: - configured_services = [] - for service in ( - self._db.query(IntegrationConfiguration) - .filter(IntegrationConfiguration.goal == Goals.PATRON_AUTH_GOAL) - .order_by(IntegrationConfiguration.name) - ): - if service.protocol not in self.registry: - self.log.warning( - f"Unknown patron authentication service implementation: {service.protocol}" - ) - continue - - libraries = [] - for library_settings in service.library_configurations: - library_info = {"short_name": library_settings.library.short_name} - library_info.update(library_settings.settings_dict) - libraries.append(library_info) - - service_info = { - "id": service.id, - "name": service.name, - "protocol": service.protocol, - "settings": service.settings_dict, - "libraries": libraries, - } - configured_services.append(service_info) - return configured_services - def process_patron_auth_services(self) -> Union[Response, ProblemDetail]: self.require_system_admin() @@ -124,126 +59,6 @@ def process_get(self) -> Response: mimetype="application/json", ) - def get_existing_service( - self, service_id: int, name: Optional[str], protocol: str - ) -> IntegrationConfiguration: - # Find an existing service to edit - auth_service: Optional[IntegrationConfiguration] = get_one( - self._db, - IntegrationConfiguration, - id=service_id, - goal=Goals.PATRON_AUTH_GOAL, - ) - if auth_service is None: - raise ProblemError(MISSING_SERVICE) - if auth_service.protocol != protocol: - raise ProblemError(CANNOT_CHANGE_PROTOCOL) - if name is not None and auth_service.name != name: - service_with_name = get_one(self._db, IntegrationConfiguration, name=name) - if service_with_name is not None: - raise ProblemError(INTEGRATION_NAME_ALREADY_IN_USE) - auth_service.name = name - - return auth_service - - def create_new_service(self, name: str, protocol: str) -> IntegrationConfiguration: - # Create a new service - service_with_name = get_one(self._db, IntegrationConfiguration, name=name) - if service_with_name is not None: - raise ProblemError(INTEGRATION_NAME_ALREADY_IN_USE) - - auth_service, _ = create( - self._db, - IntegrationConfiguration, - protocol=protocol, - goal=Goals.PATRON_AUTH_GOAL, - name=name, - ) - if not auth_service: - raise ProblemError( - INTERNAL_SERVER_ERROR.detailed( - "Could not create the Authentication integration." - ) - ) - return auth_service - - def remove_library_settings( - self, library_settings: IntegrationLibraryConfiguration - ) -> None: - self._db.delete(library_settings) - - def get_library(self, short_name: str) -> Library: - library: Optional[Library] = get_one(self._db, Library, short_name=short_name) - if library is None: - raise ProblemError( - NO_SUCH_LIBRARY.detailed( - f"You attempted to add the integration to {short_name}, but it does not exist.", - ) - ) - return library - - def create_library_settings( - self, auth_service: IntegrationConfiguration, short_name: str - ) -> IntegrationLibraryConfiguration: - library = self.get_library(short_name) - library_settings, _ = create( - self._db, - IntegrationLibraryConfiguration, - library=library, - parent_id=auth_service.id, - ) - if not library_settings: - raise ProblemError( - INTERNAL_SERVER_ERROR.detailed( - "Could not create the library configuration" - ) - ) - return library_settings - - def process_libraries( - self, - auth_service: IntegrationConfiguration, - libraries_data: str, - settings_class: Type[BaseSettings], - ) -> None: - # Update libraries - libraries = json.loads(libraries_data) - existing_library_settings = { - c.library.short_name: c for c in auth_service.library_configurations - } - submitted_library_settings = {l.get("short_name"): l for l in libraries} - - removed = [ - existing_library_settings[library] - for library in existing_library_settings.keys() - - submitted_library_settings.keys() - ] - updated = [ - (existing_library_settings[library], submitted_library_settings[library]) - for library in existing_library_settings.keys() - & submitted_library_settings.keys() - if library and self.get_library(library) - ] - new = [ - ( - self.create_library_settings(auth_service, library), - submitted_library_settings[library], - ) - for library in submitted_library_settings.keys() - - existing_library_settings.keys() - ] - - # Remove libraries that are no longer configured - for library_settings in removed: - self.remove_library_settings(library_settings) - - # Update new and existing libraries settings - for integration, settings in chain(new, updated): - validated_settings = settings_class(**settings) - integration.settings_dict = validated_settings.dict() - # Make sure library doesn't have multiple auth basic auth services - self.check_library_integrations(integration.library) - def process_post(self) -> Union[Response, ProblemDetail]: try: form_data = flask.request.form @@ -293,8 +108,12 @@ def process_post(self) -> Union[Response, ProblemDetail]: return Response(str(auth_service.id), response_code) - def check_library_integrations(self, library: Library) -> None: + def library_integration_validation( + self, integration: IntegrationLibraryConfiguration + ) -> None: """Check that the library didn't end up with multiple basic auth services.""" + + library = integration.library basic_auth_integrations = ( self._db.query(IntegrationConfiguration) .join(IntegrationLibraryConfiguration) @@ -313,18 +132,19 @@ def check_library_integrations(self, library: Library) -> None: ) ) + def process_updated_libraries( + self, + libraries: List[UpdatedLibrarySettingsTuple], + settings_class: Type[BaseSettings], + ) -> None: + super().process_updated_libraries(libraries, settings_class) + for integration, _ in libraries: + self.library_integration_validation(integration) + def process_delete(self, service_id: int) -> Union[Response, ProblemDetail]: - if flask.request.method != "DELETE": - return INVALID_INPUT.detailed(_("Method not allowed for this endpoint")) # type: ignore[no-any-return] self.require_system_admin() - - integration = get_one( - self._db, - IntegrationConfiguration, - id=service_id, - goal=Goals.PATRON_AUTH_GOAL, - ) - if not integration: - return MISSING_SERVICE - self._db.delete(integration) - return Response(str(_("Deleted")), 200) + try: + return self.delete_service(service_id) + except ProblemError as e: + self._db.rollback() + return e.problem_detail diff --git a/api/admin/controller/reset_password.py b/api/admin/controller/reset_password.py index 06844aa2d6..00a6391c93 100644 --- a/api/admin/controller/reset_password.py +++ b/api/admin/controller/reset_password.py @@ -195,11 +195,11 @@ def reset_password( def _response_with_message_and_redirect_button( self, - message: str, + message: Optional[str], redirect_button_link: str, redirect_button_text: str, is_error: bool = False, - status_code: int = 200, + status_code: Optional[int] = 200, ) -> Response: style = error_style if is_error else body_style diff --git a/api/admin/exceptions.py b/api/admin/exceptions.py index b23e12b5b0..37386cae95 100644 --- a/api/admin/exceptions.py +++ b/api/admin/exceptions.py @@ -1,16 +1,19 @@ -from .problem_details import * +from typing import Any + +from api.admin.problem_details import ADMIN_NOT_AUTHORIZED +from core.util.problem_detail import ProblemDetail class AdminNotAuthorized(Exception): status_code = 403 - def __init__(self, *args: object) -> None: + def __init__(self, *args: Any) -> None: self.message = None if len(args) > 0: self.message = args[0] super().__init__(*args) - def as_problem_detail_document(self, debug=False): + def as_problem_detail_document(self, debug=False) -> ProblemDetail: return ( ADMIN_NOT_AUTHORIZED.detailed(self.message) if self.message diff --git a/api/admin/validator.py b/api/admin/validator.py index 8751f09839..2582497bf1 100644 --- a/api/admin/validator.py +++ b/api/admin/validator.py @@ -2,7 +2,7 @@ from flask_babel import lazy_gettext as _ -from api.admin.exceptions import * +from api.admin.problem_details import INVALID_EMAIL, INVALID_NUMBER, INVALID_URL class Validator: diff --git a/api/adobe_vendor_id.py b/api/adobe_vendor_id.py index fd9723c272..7ccd920d36 100644 --- a/api/adobe_vendor_id.py +++ b/api/adobe_vendor_id.py @@ -10,21 +10,21 @@ import jwt from jwt.algorithms import HMACAlgorithm from jwt.exceptions import InvalidIssuedAtError +from sqlalchemy import select from sqlalchemy.orm import Query from sqlalchemy.orm.session import Session -from api.registration.constants import RegistrationConstants -from core.model import ( - ConfigurationSetting, - Credential, - DataSource, - ExternalIntegration, - Library, - Patron, +from core.integration.goals import Goals +from core.model import Credential, DataSource, IntegrationConfiguration, Library, Patron +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStatus, ) from core.util.datetime_helpers import datetime_utc, utc_now from .config import CannotLoadConfiguration +from .discovery.opds_registration import OpdsRegistrationService +from .integration.registry.discovery import DiscoveryRegistry if sys.version_info >= (3, 11): from typing import Self @@ -99,8 +99,6 @@ def __init__( self.short_token_signer = HMACAlgorithm(HMACAlgorithm.SHA256) self.short_token_signing_key = self.short_token_signer.prepare_key(self.secret) - VENDOR_ID_KEY = "vendor_id" - @classmethod def from_config( cls, library: Library, _db: Optional[Session] = None @@ -124,60 +122,41 @@ def from_config( # Use a version of the library library = _db.merge(library, load=False) - # Try to find an external integration with a configured Vendor ID. - integrations = ( - _db.query(ExternalIntegration) - .outerjoin(ExternalIntegration.libraries) - .filter( - ExternalIntegration.protocol == ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.goal == ExternalIntegration.DISCOVERY_GOAL, - Library.id == library.id, + # Find the first registration that has a vendor ID. + protocol = DiscoveryRegistry().get_protocol(OpdsRegistrationService) + registration = _db.scalars( + select(DiscoveryServiceRegistration) + .join(IntegrationConfiguration) + .where( + DiscoveryServiceRegistration.library == library, + DiscoveryServiceRegistration.vendor_id != None, + DiscoveryServiceRegistration.status == RegistrationStatus.SUCCESS, + IntegrationConfiguration.protocol == protocol, + IntegrationConfiguration.goal == Goals.DISCOVERY_GOAL, ) - ) + ).first() - for possible_integration in integrations: - vendor_id = ConfigurationSetting.for_externalintegration( - cls.VENDOR_ID_KEY, possible_integration - ).value - registration_status = ( - ConfigurationSetting.for_library_and_externalintegration( - _db, - RegistrationConstants.LIBRARY_REGISTRATION_STATUS, - library, - possible_integration, - ).value - ) - if ( - vendor_id - and registration_status == RegistrationConstants.SUCCESS_STATUS - ): - integration = possible_integration - break - else: + if registration is None: + # No vendor ID is configured for this library. return None library_uri = library.settings.website + vendor_id = registration.vendor_id + short_name = registration.short_name + shared_secret = registration.shared_secret - vendor_id = integration.setting(cls.VENDOR_ID_KEY).value - library_short_name = ConfigurationSetting.for_library_and_externalintegration( - _db, ExternalIntegration.USERNAME, library, integration - ).value - secret = ConfigurationSetting.for_library_and_externalintegration( - _db, ExternalIntegration.PASSWORD, library, integration - ).value - - if not vendor_id or not library_uri or not library_short_name or not secret: + if not vendor_id or not library_uri or not short_name or not shared_secret: raise CannotLoadConfiguration( "Short Client Token configuration is incomplete. " "vendor_id (%s), username (%s), password (%s) and " "Library website_url (%s) must all be defined." - % (vendor_id, library_uri, library_short_name, secret) + % (vendor_id, library_uri, short_name, shared_secret) ) - if "|" in library_short_name: + if "|" in short_name: raise CannotLoadConfiguration( "Library short name cannot contain the pipe character." ) - return cls(vendor_id, library_uri, library_short_name, secret) + return cls(vendor_id, library_uri, short_name, shared_secret) @classmethod def adobe_relevant_credentials(self, patron: Patron) -> Query[Credential]: diff --git a/api/controller.py b/api/controller.py index 0749b4b3ac..420408384c 100644 --- a/api/controller.py +++ b/api/controller.py @@ -18,6 +18,7 @@ from flask_babel import lazy_gettext as _ from lxml import etree from pydantic import ValidationError +from sqlalchemy import select from sqlalchemy.orm import eagerload from sqlalchemy.orm.exc import NoResultFound @@ -74,6 +75,7 @@ DuplicateDeviceTokenError, InvalidTokenTypeError, ) +from core.model.discovery_service_registration import DiscoveryServiceRegistration from core.opds import AcquisitionFeed, NavigationFacets, NavigationFeed from core.opds2 import AcquisitonFeedOPDS2 from core.opensearch import OpenSearchDocument @@ -369,13 +371,13 @@ def get_domain(url): if domain: patron_web_domains.add(domain) - from api.registration.registry import Registration - - for setting in self._db.query(ConfigurationSetting).filter( - ConfigurationSetting.key == Registration.LIBRARY_REGISTRATION_WEB_CLIENT - ): - if setting.value: - patron_web_domains.add(get_domain(setting.value)) + domains = self._db.execute( + select(DiscoveryServiceRegistration.web_client).where( + DiscoveryServiceRegistration.web_client != None + ) + ).all() + for row in domains: + patron_web_domains.add(get_domain(row.web_client)) self.patron_web_domains = patron_web_domains self.setup_configuration_dependent_controllers() diff --git a/api/discovery/opds_registration.py b/api/discovery/opds_registration.py new file mode 100644 index 0000000000..49cb6bae08 --- /dev/null +++ b/api/discovery/opds_registration.py @@ -0,0 +1,528 @@ +from __future__ import annotations + +import base64 +import json +import sys +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, + overload, +) + +from Crypto.Cipher.PKCS1_OAEP import PKCS1OAEP_Cipher +from flask_babel import lazy_gettext as _ +from html_sanitizer import Sanitizer +from pydantic import HttpUrl +from requests import Response +from sqlalchemy import select +from sqlalchemy.orm.session import Session + +from api.config import Configuration +from api.problem_details import * +from core.integration.base import HasIntegrationConfiguration +from core.integration.goals import Goals +from core.integration.settings import BaseSettings, ConfigurationFormItem, FormField +from core.model import IntegrationConfiguration, Library, get_one, get_one_or_create +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStage, + RegistrationStatus, +) +from core.util.http import HTTP +from core.util.problem_detail import ProblemError + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + + +class OpdsRegistrationServiceSettings(BaseSettings): + url: HttpUrl = FormField( + ..., + form=ConfigurationFormItem( + label=_("URL"), + required=True, + ), + ) + + +class OpdsRegistrationService(HasIntegrationConfiguration): + """A circulation manager's view of a remote service that supports + the OPDS Directory Registration Protocol: + + https://github.com/NYPL-Simplified/Simplified/wiki/OPDS-Directory-Registration-Protocol + + In practical terms, this is a library registry (which has + DISCOVERY_GOAL and wants to help patrons find their libraries). + """ + + DEFAULT_LIBRARY_REGISTRY_URL = "https://registry.thepalaceproject.org" + DEFAULT_LIBRARY_REGISTRY_NAME = "Palace Library Registry" + + OPDS_2_TYPE = "application/opds+json" + + def __init__( + self, + integration: IntegrationConfiguration, + settings: OpdsRegistrationServiceSettings, + ) -> None: + """Constructor.""" + self.integration = integration + self.settings = settings + + @classmethod + def label(cls) -> str: + """Get the label of this integration.""" + return "OPDS Registration" + + @classmethod + def description(cls) -> str: + """Get the description of this integration.""" + return "Register your library for discovery in the app with a library registry." + + @classmethod + def protocol_details(cls, db: Session) -> dict[str, Any]: + return { + "sitewide": True, + "supports_registration": True, + "supports_staging": True, + } + + @classmethod + def settings_class(cls) -> Type[OpdsRegistrationServiceSettings]: + """Get the settings for this integration.""" + return OpdsRegistrationServiceSettings + + @classmethod + @overload + def for_integration(cls, _db: Session, integration: int) -> Optional[Self]: + ... + + @classmethod + @overload + def for_integration( + cls, _db: Session, integration: IntegrationConfiguration + ) -> Self: + ... + + @classmethod + def for_integration( + cls, _db: Session, integration: int | IntegrationConfiguration + ) -> Optional[Self]: + """ + Find a OpdsRegistrationService object configured by the given IntegrationConfiguration ID. + """ + if isinstance(integration, int): + integration_obj = get_one(_db, IntegrationConfiguration, id=integration) + else: + integration_obj = integration + if integration_obj is None: + return None + + settings = cls.settings_class().construct(**integration_obj.settings_dict) + return cls(integration_obj, settings) + + @staticmethod + def get_request(url: str) -> Response: + return HTTP.debuggable_get(url) + + @staticmethod + def post_request( + url: str, payload: Union[str, Dict[str, Any]], **kwargs: Any + ) -> Response: + return HTTP.debuggable_post(url, payload, **kwargs) + + @classmethod + def for_protocol_goal_and_url( + cls, _db: Session, protocol: str, goal: Goals, url: str + ) -> Optional[Self]: + """Get a LibraryRegistry for the given protocol, goal, and + URL. Create the corresponding ExternalIntegration if necessary. + """ + settings = cls.settings_class().construct(url=url) # type: ignore[arg-type] + query = select(IntegrationConfiguration).where( + IntegrationConfiguration.goal == goal, + IntegrationConfiguration.protocol == protocol, + IntegrationConfiguration.settings_dict.contains(settings.dict()), + ) + integration = _db.scalars(query).one_or_none() + if not integration: + return None + return cls(integration, settings) + + @property + def registrations(self) -> List[DiscoveryServiceRegistration]: + """Find all of this site's registrations with this OpdsRegistrationService. + + :yield: A sequence of Registration objects. + """ + session = Session.object_session(self.integration) + return session.scalars( + select(DiscoveryServiceRegistration).where( + DiscoveryServiceRegistration.integration_id == self.integration.id, + ) + ).all() + + def fetch_catalog( + self, + ) -> Tuple[str, str]: + """Fetch the root catalog for this OpdsRegistrationService. + + :return: A ProblemDetail if there's a problem communicating + with the service or parsing the catalog; otherwise a 2-tuple + (registration URL, Adobe vendor ID). + """ + catalog_url = self.settings.url + response = self.get_request(catalog_url) + return self._extract_catalog_information(response) + + @classmethod + def _extract_catalog_information(cls, response: Response) -> Tuple[str, str]: + """From an OPDS catalog, extract information that's essential to + kickstarting the OPDS Directory Registration Protocol. + + :param response: A requests-style Response object. + + :return A ProblemDetail if there's a problem accessing the + catalog; otherwise a 2-tuple (registration URL, Adobe vendor + ID). + """ + catalog, links = cls._extract_links(response) + if catalog: + vendor_id = catalog.get("metadata", {}).get("adobe_vendor_id") + else: + vendor_id = None + register_url = None + for link in links: + if link.get("rel") == "register": + register_url = link.get("href") + break + if not register_url: + raise ProblemError( + problem_detail=REMOTE_INTEGRATION_FAILED.detailed( + _( + "The service at %(url)s did not provide a register link.", + url=response.url, + ) + ) + ) + return register_url, vendor_id + + def fetch_registration_document( + self, + ) -> Tuple[Optional[str], Optional[str]]: + """Fetch a discovery service's registration document and extract + useful information from it. + + :return: A ProblemDetail if there's a problem accessing the + service; otherwise, a 2-tuple (terms_of_service_link, + terms_of_service_html), containing information about the + Terms of Service that govern a circulation manager's + registration with the discovery service. + """ + registration_url, vendor_id = self.fetch_catalog() + response = self.get_request(registration_url) + return self._extract_registration_information(response) + + @classmethod + def _extract_registration_information( + cls, response: Response + ) -> Tuple[Optional[str], Optional[str]]: + """From an OPDS registration document, extract information that's + useful to kickstarting the OPDS Directory Registration Protocol. + + The registration document is completely optional, so an + invalid or unintelligible document is treated the same as a + missing document. + + :return: A 2-tuple (terms_of_service_link, + terms_of_service_html), containing information about the + Terms of Service that govern a circulation manager's + registration with the discovery service. If the + registration document is missing or malformed, both values + will be None. + """ + tos_link = None + tos_html = None + try: + catalog, links = cls._extract_links(response) + except ProblemError: + return None, None + for link in links: + if link.get("rel") != "terms-of-service": + continue + url = link.get("href") or "" + is_http = any( + [url.startswith(protocol + "://") for protocol in ("http", "https")] + ) + if is_http and not tos_link: + tos_link = url + elif url.startswith("data:") and not tos_html: + try: + tos_html = cls._decode_data_url(url) + except Exception as e: + tos_html = None + return tos_link, tos_html + + @classmethod + def _extract_links( + cls, response: Response + ) -> Tuple[Optional[Dict[str, Any]], List[Dict[str, str]]]: + """Parse an OPDS 2 feed out of a Requests response object. + + :return: A 2-tuple (parsed_catalog, links), + with `links` being a list of dictionaries, each containing + one OPDS link. + """ + # The response must contain an OPDS 2 catalog. + type = response.headers.get("Content-Type") + if not (type and type.startswith(cls.OPDS_2_TYPE)): + raise ProblemError( + problem_detail=REMOTE_INTEGRATION_FAILED.detailed( + _("The service at %(url)s did not return OPDS.", url=response.url) + ) + ) + + catalog = response.json() + links = catalog.get("links", []) + return catalog, links + + @classmethod + def _decode_data_url(cls, url: str) -> str: + """Convert a data: URL to a string of sanitized HTML. + + :raise ValueError: If the data: URL is invalid, in an + unexpected format, or does not have a supported media type. + :return: A string. + """ + if not url.startswith("data:"): + raise ValueError("Not a data: URL: %s" % url) + parts = url.split(",") + if len(parts) != 2: + raise ValueError("Invalid data: URL: %s" % url) + header, encoded = parts + if not header.endswith(";base64"): + raise ValueError("data: URL not base64-encoded: %s" % url) + media_type = header[len("data:") : -len(";base64")] + if not any(media_type.startswith(x) for x in ("text/html", "text/plain")): + raise ValueError("Unsupported media type in data: URL: %s" % media_type) + html = base64.b64decode(encoded.encode("utf-8")).decode("utf-8") + return Sanitizer().sanitize(html) # type: ignore[no-any-return] + + def register_library( + self, + library: Library, + stage: RegistrationStage, + url_for: Callable[..., str], + ) -> Literal[True]: + """Attempt to register a library with a OpdsRegistrationService. + + NOTE: This method is designed to be used in a + controller. Other callers may use this method, but they must be + able to render a ProblemDetail when there's a failure. + + NOTE: The application server must be running when this method + is called, because part of the OPDS Directory Registration + Protocol is the remote server retrieving the library's + Authentication For OPDS document. + + :param stage: Either TESTING_STAGE or PRODUCTION_STAGE + :param url_for: Flask url_for() or equivalent, used to generate URLs + for the application server. + + :return: Raise a ProblemError if there was a problem; otherwise True. + """ + db = Session.object_session(library) + registration, _ = get_one_or_create( + db, + DiscoveryServiceRegistration, + library=library, + integration=self.integration, + ) + + # Assume that the registration will fail. + # + # TODO: If a registration has previously succeeded, failure to + # re-register probably means a maintenance of the status quo, + # not a change of success to failure. But we don't have any way + # of being sure. + registration.status = RegistrationStatus.FAILURE + + # If the library has no private key, we can't register it. This should never + # happen because the column isn't nullable. We add an assertion here just in + # case, so we get a stack trace if it does happen. + assert library.private_key is not None + cipher = Configuration.cipher(library.private_key) + + # Before we can start the registration protocol, we must fetch + # the remote catalog's URL and extract the link to the + # registration resource that kicks off the protocol. + register_url, vendor_id = self.fetch_catalog() + + if vendor_id: + registration.vendor_id = vendor_id + + # Build the document we'll be sending to the registration URL. + payload = self._create_registration_payload(library, stage, url_for) + headers = self._create_registration_headers(registration) + + # Send the document. + response = self._send_registration_request(register_url, headers, payload) + catalog = json.loads(response.content) + + # Process the result. + return self._process_registration_result(registration, catalog, cipher, stage) + + @staticmethod + def _create_registration_payload( + library: Library, + stage: RegistrationStage, + url_for: Callable[..., str], + ) -> Dict[str, str]: + """Collect the key-value pairs to be sent when kicking off the + registration protocol. + + :param library: The library to be registered. + :param stage: The registrant's opinion about what stage this + registration should be in. + :param url_for: An implementation of Flask url_for. + + :return: A dictionary suitable for passing into requests.post. + """ + auth_document_url = url_for( + "authentication_document", + library_short_name=library.short_name, + _external=True, + ) + payload = dict(url=auth_document_url, stage=stage.value) + + # Find the email address the administrator should use if they notice + # a problem with the way the library is using an integration. + contact = Configuration.configuration_contact_uri(library) + if contact: + payload["contact"] = contact + return payload + + @staticmethod + def _create_registration_headers( + registration: DiscoveryServiceRegistration, + ) -> Dict[str, str]: + shared_secret = registration.shared_secret + headers = {} + if shared_secret: + headers["Authorization"] = f"Bearer {shared_secret}" + return headers + + @classmethod + def _send_registration_request( + cls, + register_url: str, + headers: Dict[str, str], + payload: Dict[str, str], + ) -> Response: + """Send the request that actually kicks off the OPDS Directory + Registration Protocol. + + :return: A requests-like Response object or raise a ProblemError on failure. + """ + response = cls.post_request( + register_url, + headers=headers, + payload=payload, + timeout=60, + allowed_response_codes=["2xx", "3xx"], + ) + return response + + @classmethod + def _decrypt_shared_secret( + cls, cipher: PKCS1OAEP_Cipher, cipher_text: str + ) -> bytes: + """Attempt to decrypt an encrypted shared secret. + + :param cipher: A Cipher object. + + :param shared_secret: A byte string. + + :return: The decrypted shared secret, as a bytestring, or + raise as ProblemError if it could not be decrypted. + """ + try: + shared_secret = cipher.decrypt(base64.b64decode(cipher_text)) + except ValueError: + raise ProblemError( + problem_detail=SHARED_SECRET_DECRYPTION_ERROR.detailed( + f"Could not decrypt shared secret: '{cipher_text}'" + ) + ) + return shared_secret + + @classmethod + def _process_registration_result( + cls, + registration: DiscoveryServiceRegistration, + catalog: Dict[str, Any] | Any, + cipher: PKCS1OAEP_Cipher, + desired_stage: RegistrationStage, + ) -> Literal[True]: + """We just sent out a registration request and got an OPDS catalog + in return. Process that catalog. + + :param catalog: A dictionary derived from an OPDS 2 catalog. + :param cipher: A Cipher object. + :param desired_stage: Our opinion, as communicated to the + server, about whether this library is ready to go into + production. + """ + # Since every library has a public key, the catalog should have provided + # credentials for future authenticated communication, + # e.g. through Short Client Tokens or authenticated API + # requests. + if not isinstance(catalog, dict): + raise ProblemError( + problem_detail=INTEGRATION_ERROR.detailed( + f"Remote service served '{catalog}', which I can't make sense of as an OPDS document.", + ) + ) + metadata: Dict[str, str] = catalog.get("metadata", {}) + short_name = metadata.get("short_name") + encrypted_shared_secret = metadata.get("shared_secret") + links = catalog.get("links", []) + + web_client_url = None + for link in links: + if link.get("rel") == "self" and link.get("type") == "text/html": + web_client_url = link.get("href") + break + + if short_name: + registration.short_name = short_name + if encrypted_shared_secret: + # NOTE: we can only store Unicode data in the + # ConfigurationSetting.value, so this requires that the + # shared secret encoded as UTF-8. This works for the + # library registry product, which uses a long string of + # hex digits as its shared secret. + registration.shared_secret = cls._decrypt_shared_secret( + cipher, encrypted_shared_secret + ).decode("utf-8") + + # We have successfully completed the registration. + registration.status = RegistrationStatus.SUCCESS + + # Our opinion about the proper stage of this library was successfully + # communicated to the registry. + registration.stage = desired_stage + + # Store the web client URL as a ConfigurationSetting. + registration.web_client = web_client_url + + return True diff --git a/api/discovery/registration_script.py b/api/discovery/registration_script.py new file mode 100644 index 0000000000..4d75cd7b5f --- /dev/null +++ b/api/discovery/registration_script.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from argparse import ArgumentParser +from typing import Callable, List, Literal, Optional + +from flask import url_for +from sqlalchemy.orm import Session + +from api.config import Configuration +from api.controller import CirculationManager +from api.discovery.opds_registration import OpdsRegistrationService +from api.integration.registry.discovery import DiscoveryRegistry +from api.util.flask import PalaceFlask +from core.integration.goals import Goals +from core.model import ConfigurationSetting, Library, get_one +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStage, +) +from core.scripts import LibraryInputScript +from core.util.problem_detail import ProblemDetail, ProblemError + + +class LibraryRegistrationScript(LibraryInputScript): + """Register local libraries with a remote library registry.""" + + @classmethod + def arg_parser(cls, _db: Session) -> ArgumentParser: # type: ignore[override] + parser = LibraryInputScript.arg_parser(_db) + parser.add_argument( + "--registry-url", + help="Register libraries with the given registry.", + default=OpdsRegistrationService.DEFAULT_LIBRARY_REGISTRY_URL, + ) + parser.add_argument( + "--stage", + help="Register these libraries in the 'testing' stage or the 'production' stage.", + choices=[stage.value for stage in RegistrationStage], + ) + return parser # type: ignore[no-any-return] + + def do_run( + self, + cmd_args: Optional[List[str]] = None, + manager: Optional[CirculationManager] = None, + ) -> PalaceFlask | Literal[False]: + parsed = self.parse_command_line(self._db, cmd_args) + + url = parsed.registry_url + protocol = DiscoveryRegistry().get_protocol(OpdsRegistrationService) + registry = OpdsRegistrationService.for_protocol_goal_and_url( + self._db, protocol, Goals.DISCOVERY_GOAL, url # type: ignore[arg-type] + ) + if registry is None: + self.log.error(f'No OPDS Registration service found for "{url}"') + return False + + try: + stage = RegistrationStage(parsed.stage) if parsed.stage else None + except ValueError: + self.log.error( + f'Invalid registration stage "{parsed.stage}". ' + f'Must be one of {", ".join([stage.value for stage in RegistrationStage])}.' + ) + return False + + # Set up an application context so we have access to url_for. + from api.app import app + + app.manager = manager or CirculationManager(self._db) + base_url = ConfigurationSetting.sitewide( + self._db, Configuration.BASE_URL_KEY + ).value + ctx = app.test_request_context(base_url=base_url) + ctx.push() + for library in parsed.libraries: + if not stage: + # Check if the library has already been registered. + registration = get_one( + self._db, + DiscoveryServiceRegistration, + library=library, + integration=registry.integration, + ) + if registration and registration.stage is not None: + library_stage = registration.stage + else: + # Don't know what stage to register this library in, so it defaults to test. + library_stage = RegistrationStage.TESTING + else: + library_stage = stage + + self.process_library(registry, library, library_stage, url_for) + ctx.pop() + + # For testing purposes, return the application object that was + # created. + return app + + def process_library( # type: ignore[override] + self, + registry: OpdsRegistrationService, + library: Library, + stage: RegistrationStage, + url_for: Callable[..., str], + ) -> bool | ProblemDetail: + """Push one Library's registration to the given OpdsRegistrationService.""" + + self.log.info("Processing library %r", library.short_name) + self.log.info("Registering with %s as %s", registry.settings.url, stage.value) + try: + registry.register_library(library, stage, url_for) + except ProblemError as e: + data, status_code, headers = e.problem_detail.response + self.log.exception( + "Could not complete registration. Problem detail document: %r" % data + ) + return e.problem_detail + except Exception as e: + self.log.exception(f"Exception during registration: {e}") + return False + + self.log.info("Success.") + return True diff --git a/api/integration/registry/discovery.py b/api/integration/registry/discovery.py new file mode 100644 index 0000000000..5178b97c30 --- /dev/null +++ b/api/integration/registry/discovery.py @@ -0,0 +1,10 @@ +from api.discovery.opds_registration import OpdsRegistrationService +from core.integration.goals import Goals +from core.integration.registry import IntegrationRegistry + + +class DiscoveryRegistry(IntegrationRegistry[OpdsRegistrationService]): + def __init__(self) -> None: + super().__init__(Goals.DISCOVERY_GOAL) + + self.register(OpdsRegistrationService, canonical="OPDS Registration") diff --git a/api/marc.py b/api/marc.py index b43cc76eb8..1ad5e8c549 100644 --- a/api/marc.py +++ b/api/marc.py @@ -3,10 +3,12 @@ import urllib.request from pymarc import Field, Subfield +from sqlalchemy import select from core.config import Configuration from core.marc import Annotator, MARCExporter from core.model import ConfigurationSetting, Session +from core.model.discovery_service_registration import DiscoveryServiceRegistration class LibraryAnnotator(Annotator): @@ -68,16 +70,14 @@ def add_web_client_urls(self, record, library, identifier, integration=None): if marc_setting: settings.append(marc_setting) - from api.registration.registry import Registration - settings += [ - s.value - for s in _db.query(ConfigurationSetting).filter( - ConfigurationSetting.key - == Registration.LIBRARY_REGISTRATION_WEB_CLIENT, - ConfigurationSetting.library_id == library.id, - ) - if s.value + s.web_client + for s in _db.execute( + select(DiscoveryServiceRegistration.web_client).where( + DiscoveryServiceRegistration.library == library, + DiscoveryServiceRegistration.web_client != None, + ) + ).all() ] qualified_identifier = urllib.parse.quote( diff --git a/api/registration/__init__.py b/api/registration/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/registration/constants.py b/api/registration/constants.py deleted file mode 100644 index 0356efd114..0000000000 --- a/api/registration/constants.py +++ /dev/null @@ -1,23 +0,0 @@ -class RegistrationConstants: - """Constants used for library registration.""" - - # A library registration attempt may succeed or fail. - LIBRARY_REGISTRATION_STATUS = "library-registration-status" - SUCCESS_STATUS = "success" - FAILURE_STATUS = "failure" - - # A library may be registered in a 'testing' stage or a - # 'production' stage. This represents the _library's_ opinion - # about whether the integration is ready for production. The - # library won't actually be in production (whatever that means for - # a given integration) until the _remote_ also thinks it should. - LIBRARY_REGISTRATION_STAGE = "library-registration-stage" - TESTING_STAGE = "testing" - PRODUCTION_STAGE = "production" - VALID_REGISTRATION_STAGES = [TESTING_STAGE, PRODUCTION_STAGE] - - # A registry may provide access to a web client. If so, we'll store - # the URL so we can enable CORS headers in requests from that client, - # and use it in MARC records so the library's main catalog can link - # to it. - LIBRARY_REGISTRATION_WEB_CLIENT = "library-registration-web-client" diff --git a/api/registration/registry.py b/api/registration/registry.py deleted file mode 100644 index c5315515a9..0000000000 --- a/api/registration/registry.py +++ /dev/null @@ -1,596 +0,0 @@ -import base64 -import json -import logging - -import feedparser -from flask import url_for -from flask_babel import lazy_gettext as _ -from html_sanitizer import Sanitizer -from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.orm.session import Session - -from api.adobe_vendor_id import AuthdataUtility -from api.config import Configuration -from api.controller import CirculationManager -from api.problem_details import * -from core.model import ConfigurationSetting, ExternalIntegration, create, get_one -from core.scripts import LibraryInputScript -from core.util.http import HTTP -from core.util.problem_detail import JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE -from core.util.problem_detail import ProblemDetail - -from .constants import RegistrationConstants - - -class RemoteRegistry: - """A circulation manager's view of a remote service that supports - the OPDS Directory Registration Protocol: - - https://github.com/NYPL-Simplified/Simplified/wiki/OPDS-Directory-Registration-Protocol - - In practical terms, this may be a library registry (which has - DISCOVERY_GOAL and wants to help patrons find their libraries) or - it may be a shared ODL collection (which has LICENSE_GOAL). - """ - - DEFAULT_LIBRARY_REGISTRY_URL = "https://registry.thepalaceproject.org" - DEFAULT_LIBRARY_REGISTRY_NAME = "Palace Library Registry" - - OPDS_1_PREFIX = "application/atom+xml;profile=opds-catalog" - OPDS_2_TYPE = "application/opds+json" - - def __init__(self, integration): - """Constructor.""" - self.integration = integration - - @classmethod - def for_integration_id(cls, _db, integration_id, goal): - """Find a LibraryRegistry object configured - by the given ExternalIntegration ID. - - :param goal: The ExternalIntegration's .goal must be this goal. - """ - integration = get_one(_db, ExternalIntegration, goal=goal, id=integration_id) - if not integration: - return None - return cls(integration) - - @classmethod - def for_protocol_and_goal(cls, _db, protocol, goal): - """Find all LibraryRegistry objects with the given protocol and goal.""" - for i in _db.query(ExternalIntegration).filter( - ExternalIntegration.goal == goal, - ExternalIntegration.protocol == protocol, - ): - yield cls(i) - - @classmethod - def for_protocol_goal_and_url(cls, _db, protocol, goal, url): - """Get a LibraryRegistry for the given protocol, goal, and - URL. Create the corresponding ExternalIntegration if necessary. - """ - try: - integration = ExternalIntegration.with_setting_value( - _db, protocol, goal, ExternalIntegration.URL, url - ).one() - except NoResultFound: - integration = None - if not integration: - integration, is_new = create( - _db, ExternalIntegration, protocol=protocol, goal=goal - ) - integration.setting(ExternalIntegration.URL).value = url - return cls(integration) - - @property - def registrations(self): - """Find all of this site's successful registrations with - this RemoteRegistry. - - :yield: A sequence of Registration objects. - """ - for x in self.integration.libraries: - yield Registration(self, x) - - def fetch_catalog(self, catalog_url=None, do_get=HTTP.debuggable_get): - """Fetch the root catalog for this RemoteRegistry. - - :return: A ProblemDetail if there's a problem communicating - with the service or parsing the catalog; otherwise a 2-tuple - (registration URL, Adobe vendor ID). - """ - catalog_url = catalog_url or self.integration.url - response = do_get(catalog_url) - if isinstance(response, ProblemDetail): - return response - return self._extract_catalog_information(response) - - @classmethod - def _extract_catalog_information(cls, response): - """From an OPDS catalog, extract information that's essential to - kickstarting the OPDS Directory Registration Protocol. - - :param response: A requests-style Response object. - - :return A ProblemDetail if there's a problem accessing the - catalog; otherwise a 2-tuple (registration URL, Adobe vendor - ID). - """ - result = cls._extract_links(response) - if isinstance(result, ProblemDetail): - return result - catalog, links = result - if catalog: - vendor_id = catalog.get("metadata", {}).get("adobe_vendor_id") - else: - vendor_id = None - register_url = None - for link in links: - if link.get("rel") == "register": - register_url = link.get("href") - break - if not register_url: - return REMOTE_INTEGRATION_FAILED.detailed( - _( - "The service at %(url)s did not provide a register link.", - url=response.url, - ) - ) - return register_url, vendor_id - - def fetch_registration_document(self, do_get=HTTP.debuggable_get): - """Fetch a discovery service's registration document and extract - useful information from it. - - :return: A ProblemDetail if there's a problem accessing the - service; otherwise, a 2-tuple (terms_of_service_link, - terms_of_service_html), containing information about the - Terms of Service that govern a circulation manager's - registration with the discovery service. - """ - catalog = self.fetch_catalog(do_get=do_get) - if isinstance(catalog, ProblemDetail): - return catalog - registration_url, vendor_id = catalog - - response = do_get(registration_url) - if isinstance(response, ProblemDetail): - return response - ( - terms_of_service_link, - terms_of_service_html, - ) = self._extract_registration_information(response) - return terms_of_service_link, terms_of_service_html - - @classmethod - def _extract_registration_information(cls, response): - """From an OPDS registration document, extract information that's - useful to kickstarting the OPDS Directory Registration Protocol. - - The registration document is completely optional, so an - invalid or unintelligible document is treated the same as a - missing document. - - :return: A 2-tuple (terms_of_service_link, - terms_of_service_html), containing information about the - Terms of Service that govern a circulation manager's - registration with the discovery service. If the - registration document is missing or malformed, both values - will be None. - """ - tos_link = None - tos_html = None - result = cls._extract_links(response) - if isinstance(result, ProblemDetail): - return None, None - catalog, links = result - for link in links: - if link.get("rel") != "terms-of-service": - continue - url = link.get("href") - is_http = any( - [url.startswith(protocol + "://") for protocol in ("http", "https")] - ) - if is_http and not tos_link: - tos_link = url - elif url.startswith("data:") and not tos_html: - try: - tos_html = cls._decode_data_url(url) - except Exception as e: - tos_html = None - return tos_link, tos_html - - @classmethod - def _extract_links(cls, response): - """Parse an OPDS 1 or OPDS feed out of a Requests response object. - - :return: A 2-tuple (parsed_catalog, links), - with `links` being a list of dictionaries, each containing - one OPDS link. - """ - # The response must contain either an OPDS 2 catalog or an OPDS 1 feed. - type = response.headers.get("Content-Type") - if type and type.startswith(cls.OPDS_2_TYPE): - # This is an OPDS 2 catalog. - catalog = json.loads(response.content) - links = catalog.get("links", []) - elif type and type.startswith(cls.OPDS_1_PREFIX): - # This is an OPDS 1 feed. - feed = feedparser.parse(response.content) - links = feed.get("feed", {}).get("links", []) - catalog = None - else: - return REMOTE_INTEGRATION_FAILED.detailed( - _("The service at %(url)s did not return OPDS.", url=response.url) - ) - return catalog, links - - @classmethod - def _decode_data_url(cls, url): - """Convert a data: URL to a string of sanitized HTML. - - :raise ValueError: If the data: URL is invalid, in an - unexpected format, or does not have a supported media type. - :return: A string. - """ - if not url.startswith("data:"): - raise ValueError("Not a data: URL: %s" % url) - parts = url.split(",") - if len(parts) != 2: - raise ValueError("Invalid data: URL: %s" % url) - header, encoded = parts - if not header.endswith(";base64"): - raise ValueError("data: URL not base64-encoded: %s" % url) - media_type = header[len("data:") : -len(";base64")] - if not any(media_type.startswith(x) for x in ("text/html", "text/plain")): - raise ValueError("Unsupported media type in data: URL: %s" % media_type) - html = base64.b64decode(encoded.encode("utf-8")).decode("utf-8") - return Sanitizer().sanitize(html) - - -class Registration(RegistrationConstants): - """A library's registration for a particular registry. - - The registration does not correspond to one specific data model - object -- it's a relationship between a Library and an - ExternalIntegration, and a set of ConfigurationSettings that - configure the relationship between the two. - """ - - def __init__(self, registry, library): - self.registry = registry - self.integration = self.registry.integration - self.library = library - self._db = Session.object_session(self.integration) - - if not library in self.integration.libraries: - self.integration.libraries.append(library) - - # Find or create all the ConfigurationSettings that configure - # this relationship between library and registry. - # Has the registration succeeded? (Initial value: no.) - self.status_field = self.setting( - self.LIBRARY_REGISTRATION_STATUS, self.FAILURE_STATUS - ) - - # Does the library want to be in the testing or production stage? - # (Initial value: testing.) - self.stage_field = self.setting( - self.LIBRARY_REGISTRATION_STAGE, self.TESTING_STAGE - ) - - # If the registry provides a web client for the library, it will - # be stored in this setting. - self.web_client_field = self.setting(self.LIBRARY_REGISTRATION_WEB_CLIENT) - - def setting(self, key, default_value=None): - """Find or create a ConfigurationSetting that configures this - relationship between library and registry. - - :param key: Name of the ConfigurationSetting. - :return: A 2-tuple (ConfigurationSetting, is_new) - """ - setting = ConfigurationSetting.for_library_and_externalintegration( - self._db, key, self.library, self.integration - ) - if setting.value is None and default_value is not None: - setting.value = default_value - return setting - - def push( - self, - stage, - url_for, - catalog_url=None, - do_get=HTTP.debuggable_get, - do_post=HTTP.debuggable_post, - ): - """Attempt to register a library with a RemoteRegistry. - - NOTE: This method is designed to be used in a - controller. Other callers may use this method, but they must be - able to render a ProblemDetail when there's a failure. - - NOTE: The application server must be running when this method - is called, because part of the OPDS Directory Registration - Protocol is the remote server retrieving the library's - Authentication For OPDS document. - - :param stage: Either TESTING_STAGE or PRODUCTION_STAGE - :param url_for: Flask url_for() or equivalent, used to generate URLs - for the application server. - :param do_get: Mockable method to make a GET request. - :param do_post: Mockable method to make a POST request. - - :return: A ProblemDetail if there was a problem; otherwise True. - """ - # Assume that the registration will fail. - # - # TODO: If a registration has previously succeeded, failure to - # re-register probably means a maintenance of the status quo, - # not a change of success to failure. But we don't have any way - # of being sure. - self.status_field.value = self.FAILURE_STATUS - - if stage not in self.VALID_REGISTRATION_STAGES: - return INVALID_INPUT.detailed( - _("%r is not a valid registration stage") % stage - ) - - cipher = Configuration.cipher(self.library.private_key) - - # Before we can start the registration protocol, we must fetch - # the remote catalog's URL and extract the link to the - # registration resource that kicks off the protocol. - result = self.registry.fetch_catalog(catalog_url, do_get) - if isinstance(result, ProblemDetail): - return result - register_url, vendor_id = result - - # Store the vendor id as a ConfigurationSetting on the integration - # -- it'll be the same value for all libraries. - if vendor_id: - ConfigurationSetting.for_externalintegration( - AuthdataUtility.VENDOR_ID_KEY, self.integration - ).value = vendor_id - - # Build the document we'll be sending to the registration URL. - payload = self._create_registration_payload(url_for, stage) - - if isinstance(payload, ProblemDetail): - return payload - - headers = self._create_registration_headers() - if isinstance(headers, ProblemDetail): - return headers - - # Send the document. - response = self._send_registration_request( - register_url, headers, payload, do_post - ) - - if isinstance(response, ProblemDetail): - return response - catalog = json.loads(response.content) - - # Process the result. - return self._process_registration_result(catalog, cipher, stage) - - def _create_registration_payload(self, url_for, stage): - """Collect the key-value pairs to be sent when kicking off the - registration protocol. - - :param url_for: An implementation of Flask url_for. - :param state: The registrant's opinion about what stage this - registration should be in. - :return: A dictionary suitable for passing into requests.post. - """ - auth_document_url = url_for( - "authentication_document", - library_short_name=self.library.short_name, - _external=True, - ) - payload = dict(url=auth_document_url, stage=stage) - - # Find the email address the administrator should use if they notice - # a problem with the way the library is using an integration. - contact = Configuration.configuration_contact_uri(self.library) - if contact: - payload["contact"] = contact - return payload - - def _create_registration_headers(self): - shared_secret = self.setting(ExternalIntegration.PASSWORD).value - headers = {} - if shared_secret: - headers["Authorization"] = "Bearer %s" % shared_secret - return headers - - @classmethod - def _send_registration_request(cls, register_url, headers, payload, do_post): - """Send the request that actually kicks off the OPDS Directory - Registration Protocol. - - :return: Either a ProblemDetail or a requests-like Response object. - """ - # Allow 400 and 401 so we can provide a more useful error message. - response = do_post( - register_url, - headers=headers, - payload=payload, - timeout=60, - allowed_response_codes=["2xx", "3xx", "400", "401"], - ) - if response.status_code in [400, 401]: - if response.headers.get("Content-Type") == PROBLEM_DETAIL_JSON_MEDIA_TYPE: - problem = json.loads(response.content) - return INTEGRATION_ERROR.detailed( - _( - 'Remote service returned: "%(problem)s"', - problem=problem.get("detail"), - ) - ) - else: - return INTEGRATION_ERROR.detailed( - _( - 'Remote service returned: "%(problem)s"', - problem=response.content.decode("utf-8"), - ) - ) - return response - - @classmethod - def _decrypt_shared_secret(cls, cipher, shared_secret): - """Attempt to decrypt an encrypted shared secret. - - :param cipher: A Cipher object. - - :param shared_secret: A byte string. - - :return: The decrypted shared secret, as a bytestring, or - a ProblemDetail if it could not be decrypted. - """ - try: - shared_secret = cipher.decrypt(base64.b64decode(shared_secret)) - except ValueError as e: - return SHARED_SECRET_DECRYPTION_ERROR.detailed( - _("Could not decrypt shared secret %s") % shared_secret - ) - return shared_secret - - def _process_registration_result(self, catalog, cipher, desired_stage): - """We just sent out a registration request and got an OPDS catalog - in return. Process that catalog. - - :param catalog: A dictionary derived from an OPDS 2 catalog. - :param cipher: A Cipher object. - :param desired_stage: Our opinion, as communicated to the - server, about whether this library is ready to go into - production. - """ - # Since every library has a public key, the catalog should have provided - # credentials for future authenticated communication, - # e.g. through Short Client Tokens or authenticated API - # requests. - if not isinstance(catalog, dict): - return INTEGRATION_ERROR.detailed( - _( - "Remote service served %(representation)r, which I can't make sense of as an OPDS document.", - representation=catalog, - ) - ) - metadata = catalog.get("metadata", {}) - short_name = metadata.get("short_name") - shared_secret = metadata.get("shared_secret") - links = catalog.get("links", []) - - web_client_url = None - for link in links: - if link.get("rel") == "self" and link.get("type") == "text/html": - web_client_url = link.get("href") - break - - if short_name: - setting = self.setting(ExternalIntegration.USERNAME) - setting.value = short_name - if shared_secret: - shared_secret = self._decrypt_shared_secret(cipher, shared_secret) - if isinstance(shared_secret, ProblemDetail): - return shared_secret - - setting = self.setting(ExternalIntegration.PASSWORD) - - # NOTE: we can only store Unicode data in the - # ConfigurationSetting.value, so this requires that the - # shared secret encoded as UTF-8. This works for the - # library registry product, which uses a long string of - # hex digits as its shared secret. - setting.value = shared_secret.decode("utf8") - - # We have successfully completed the registration. - self.status_field.value = self.SUCCESS_STATUS - - # Our opinion about the proper stage of this library was succesfully - # communicated to the registry. - self.stage_field.value = desired_stage - - # Store the web client URL as a ConfigurationSetting. - if web_client_url: - self.web_client_field.value = web_client_url - - return True - - -class LibraryRegistrationScript(LibraryInputScript): - """Register local libraries with a remote library registry.""" - - PROTOCOL = ExternalIntegration.OPDS_REGISTRATION - GOAL = ExternalIntegration.DISCOVERY_GOAL - - @classmethod - def arg_parser(cls, _db): - parser = LibraryInputScript.arg_parser(_db) - parser.add_argument( - "--registry-url", - help="Register libraries with the given registry.", - default=RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL, - ) - parser.add_argument( - "--stage", - help="Register these libraries in the 'testing' stage or the 'production' stage.", - choices=(Registration.TESTING_STAGE, Registration.PRODUCTION_STAGE), - ) - return parser - - def do_run(self, cmd_args=None, manager=None): - parser = self.arg_parser(self._db) - parsed = self.parse_command_line(self._db, cmd_args) - - url = parsed.registry_url - registry = RemoteRegistry.for_protocol_goal_and_url( - self._db, self.PROTOCOL, self.GOAL, url - ) - stage = parsed.stage - - # Set up an application context so we have access to url_for. - from api.app import app - - app.manager = manager or CirculationManager(self._db) - base_url = ConfigurationSetting.sitewide( - self._db, Configuration.BASE_URL_KEY - ).value - ctx = app.test_request_context(base_url=base_url) - ctx.push() - for library in parsed.libraries: - registration = Registration(registry, library) - library_stage = stage or registration.stage_field.value - self.process_library(registration, library_stage, url_for) - ctx.pop() - - # For testing purposes, return the application object that was - # created. - return app - - def process_library(self, registration, stage, url_for): - """Push one Library's registration to the given RemoteRegistry.""" - - logger = logging.getLogger( - "Registration of library %r" % registration.library.short_name - ) - logger.info( - "Registering with %s as %s", registration.registry.integration.url, stage - ) - try: - result = registration.push(stage, url_for) - except Exception as e: - logger.error("Exception during registration", exc_info=e) - return False - if isinstance(result, ProblemDetail): - data, status_code, headers = result.response - logger.error( - "Could not complete registration. Problem detail document: %r" % data - ) - return result - else: - logger.info("Success.") - return result diff --git a/api/saml/controller.py b/api/saml/controller.py index c33dd82f82..7323a2913e 100644 --- a/api/saml/controller.py +++ b/api/saml/controller.py @@ -107,7 +107,6 @@ def _error_uri(self, redirect_uri, problem_detail): problem_detail.status_code, problem_detail.title, problem_detail.detail, - problem_detail.instance, problem_detail.debug_message, ) params = {self.ERROR: problem_detail_json} diff --git a/api/selftest.py b/api/selftest.py index b5291bc776..3f26f58ff9 100644 --- a/api/selftest.py +++ b/api/selftest.py @@ -31,7 +31,7 @@ class _NoValidLibrarySelfTestPatron(BaseError): detail (optional) -- additional explanation of the error """ - def __init__(self, message: str, *, detail: Optional[str] = None): + def __init__(self, message: Optional[str], *, detail: Optional[str] = None): super().__init__(message=message) self.message = message self.detail = detail @@ -97,6 +97,8 @@ def _determine_self_test_patron( # If we get here, then we have failed to find a valid test patron # and will raise an exception. + message: Optional[str] + detail: Optional[str] if patron is None: message = "Library has no test patron configured." detail = ( diff --git a/bin/configuration/register_library b/bin/configuration/register_library index c8729b8fab..b0a28149b6 100755 --- a/bin/configuration/register_library +++ b/bin/configuration/register_library @@ -6,6 +6,6 @@ import sys bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..", "..") sys.path.append(os.path.abspath(package_dir)) -from api.registration.registry import LibraryRegistrationScript +from api.discovery.registration_script import LibraryRegistrationScript LibraryRegistrationScript().run() diff --git a/core/integration/base.py b/core/integration/base.py index b505c5d42f..80a2e53742 100644 --- a/core/integration/base.py +++ b/core/integration/base.py @@ -1,7 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Type +from typing import Any, Type + +from sqlalchemy.orm import Session from core.integration.settings import BaseSettings @@ -25,6 +27,15 @@ def settings_class(cls) -> Type[BaseSettings]: """Get the settings for this integration""" ... + @classmethod + def protocol_details(cls, db: Session) -> dict[str, Any]: + """Add any additional details about this protocol to be + returned to the admin interface. + + The default implementation returns an empty dict. + """ + return {} + class HasLibraryIntegrationConfiguration(HasIntegrationConfiguration, ABC): @classmethod diff --git a/core/integration/goals.py b/core/integration/goals.py index 5f4f2c24a1..b7326f2ce8 100644 --- a/core/integration/goals.py +++ b/core/integration/goals.py @@ -8,3 +8,4 @@ class Goals(Enum): PATRON_AUTH_GOAL = "patron_auth" LICENSE_GOAL = "licenses" + DISCOVERY_GOAL = "discovery" diff --git a/core/migration/util.py b/core/migration/util.py new file mode 100644 index 0000000000..0085ce6a93 --- /dev/null +++ b/core/migration/util.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any, List + +import sqlalchemy as sa + + +def pg_update_enum( + op: Any, + table: str, + column: str, + enum_name: str, + old_values: List[str], + new_values: List[str], +) -> None: + """ + Alembic migration helper function to update an enum type. + + Alembic performs its updates within a transaction, and Postgres does not allow + the addition of new enum values within a transaction. In order to be able to + update an enum within a tranaction this function creates a temporary enum type + and uses it to update the column. It then drops the old enum type and creates + the new enum type. Finally, it updates the column to use the new enum type and + drops the temporary enum type. + + This is a cleaned up version of the code from: + https://stackoverflow.com/questions/14845203/altering-an-enum-field-using-alembic/45615354#45615354 + """ + + # Create SA Enum objects for the enums + tmp_enum_name = f"_tmp_{enum_name}" + tmp_enum = sa.Enum(*new_values, name=tmp_enum_name) + old_enum = sa.Enum(*old_values, name=enum_name) + new_enum = sa.Enum(*new_values, name=enum_name) + + # Create the tmp enum type + tmp_enum.create(op.get_bind()) + + # Alter the column to use the tmp enum type + op.alter_column( + table, + column, + type_=tmp_enum, + postgresql_using=f"{column}::text::{tmp_enum_name}", + ) + + # Drop the old enum type + old_enum.drop(op.get_bind()) + + # Create the new enum type + new_enum.create(op.get_bind()) + + # Alter the column to use the new enum type + op.alter_column( + table, column, type_=new_enum, postgresql_using=f"{column}::text::{enum_name}" + ) + + # Drop the tmp enum type + tmp_enum.drop(op.get_bind()) + + +def drop_enum(op: Any, enum_name: str, checkfirst: bool = True) -> None: + """ + Alembic migration helper function to drop an enum type. + """ + sa.Enum(name=enum_name).drop(op.get_bind(), checkfirst=checkfirst) diff --git a/core/model/__init__.py b/core/model/__init__.py index f709131633..c08d7e3d95 100644 --- a/core/model/__init__.py +++ b/core/model/__init__.py @@ -3,16 +3,15 @@ import json import logging import os -import warnings -from typing import Any, Dict, Generator, List, Literal, Tuple, Type, TypeVar, Union +from typing import Any, Generator, List, Literal, Tuple, Type, TypeVar from contextlib2 import contextmanager from psycopg2.extensions import adapt as sqlescape from psycopg2.extras import NumericRange from pydantic.json import pydantic_encoder from sqlalchemy import create_engine -from sqlalchemy.engine import Connection, Engine -from sqlalchemy.exc import IntegrityError, SAWarning +from sqlalchemy.engine import Connection +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound @@ -558,6 +557,7 @@ def _bulk_operation(self): from .customlist import CustomList, CustomListEntry from .datasource import DataSource from .devicetokens import DeviceToken +from .discovery_service_registration import DiscoveryServiceRegistration from .edition import Edition from .hassessioncache import HasSessionCache from .identifier import Equivalency, Identifier diff --git a/core/model/configuration.py b/core/model/configuration.py index 227276b4bb..264d7c3dd2 100644 --- a/core/model/configuration.py +++ b/core/model/configuration.py @@ -161,10 +161,6 @@ class ExternalIntegration(Base): # Adobe Vendor ID, which manage access to DRM-dependent content. DRM_GOAL = "drm" - # These integrations are associated with external services that - # help patrons find libraries. - DISCOVERY_GOAL = "discovery" - # These integrations are associated with external services that # collect logs of server-side events. LOGGING_GOAL = "logging" @@ -235,9 +231,6 @@ class ExternalIntegration(Base): # Integrations with SEARCH_GOAL OPENSEARCH = "Opensearch" - # Integrations with DISCOVERY_GOAL - OPDS_REGISTRATION = "OPDS Registration" - # Integrations with ANALYTICS_GOAL GOOGLE_ANALYTICS = "Google Analytics" diff --git a/core/model/discovery_service_registration.py b/core/model/discovery_service_registration.py new file mode 100644 index 0000000000..e556aea5c7 --- /dev/null +++ b/core/model/discovery_service_registration.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING + +from sqlalchemy import Column +from sqlalchemy import Enum as AlchemyEnum +from sqlalchemy import ForeignKey, Integer, Unicode +from sqlalchemy.orm import Mapped, relationship + +from core.model import Base + +if TYPE_CHECKING: + from core.model import IntegrationConfiguration, Library + + +class RegistrationStage(Enum): + """The stage of a library's registration with a discovery service.""" + + TESTING = "testing" + PRODUCTION = "production" + + +class RegistrationStatus(Enum): + """The status of a library's registration with a discovery service.""" + + SUCCESS = "success" + FAILURE = "failure" + + +class DiscoveryServiceRegistration(Base): + """A library's registration with a discovery service.""" + + __tablename__ = "discovery_service_registrations" + + status = Column( + AlchemyEnum(RegistrationStatus), + default=RegistrationStatus.FAILURE, + nullable=False, + ) + stage = Column( + AlchemyEnum(RegistrationStage), + default=RegistrationStage.TESTING, + nullable=False, + ) + web_client = Column(Unicode) + + short_name = Column(Unicode) + shared_secret = Column(Unicode) + + # The IntegrationConfiguration this registration is associated with. + integration_id = Column( + Integer, + ForeignKey("integration_configurations.id", ondelete="CASCADE"), + nullable=False, + primary_key=True, + ) + integration: Mapped[IntegrationConfiguration] = relationship( + "IntegrationConfiguration" + ) + + # The Library this registration is associated with. + library_id = Column( + Integer, + ForeignKey("libraries.id", ondelete="CASCADE"), + nullable=False, + primary_key=True, + ) + library: Mapped[Library] = relationship("Library") + + vendor_id = Column(Unicode) diff --git a/core/selftest.py b/core/selftest.py index 779937ba99..0b870bbd43 100644 --- a/core/selftest.py +++ b/core/selftest.py @@ -261,7 +261,7 @@ def run_test( def test_failure( cls, name: str, - message: Union[str, Exception], + message: Union[Optional[str], Exception], debug_message: Optional[str] = None, ) -> SelfTestResult: """Create a SelfTestResult for a known failure. diff --git a/core/util/http.py b/core/util/http.py index dab0c74c60..b52e16ef85 100644 --- a/core/util/http.py +++ b/core/util/http.py @@ -1,5 +1,6 @@ import logging -from typing import Callable, Optional +from json import JSONDecodeError +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urlparse import requests @@ -13,6 +14,7 @@ from core.problem_details import INTEGRATION_ERROR from .problem_detail import JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE +from .problem_detail import ProblemError class RemoteIntegrationException(IntegrationException): @@ -381,14 +383,16 @@ def series(cls, status_code): return "%sxx" % (int(status_code) // 100) @classmethod - def debuggable_get(cls, url: str, **kwargs): + def debuggable_get(cls, url: str, **kwargs: Any) -> Response: """Make a GET request that returns a detailed problem detail document on error. """ return cls.debuggable_request("GET", url, **kwargs) @classmethod - def debuggable_post(cls, url: str, payload, **kwargs): + def debuggable_post( + cls, url: str, payload: Union[str, Dict[str, Any]], **kwargs: Any + ) -> Response: """Make a POST request that returns a detailed problem detail document on error. """ @@ -401,10 +405,10 @@ def debuggable_request( http_method: str, url: str, make_request_with: Optional[Callable[..., Response]] = None, - **kwargs, + **kwargs: Any, ) -> Response: - """Make a request that returns a detailed problem detail document on - error, rather than a generic "an integration error occured" + """Make a request that raises a ProblemError with a detailed problem detail + document on error, rather than a generic "an integration error occurred" message. :param http_method: HTTP method to use when making the request. @@ -430,52 +434,55 @@ def debuggable_request( def process_debuggable_response( cls, url: str, - response, - disallowed_response_codes=None, - allowed_response_codes=None, - expected_encoding="utf-8", - ): + response: Response, + allowed_response_codes: Optional[List[Union[str, int]]] = None, + disallowed_response_codes: Optional[List[Union[str, int]]] = None, + expected_encoding: str = "utf-8", + ) -> Response: """If there was a problem with an integration request, - return an appropriate ProblemDetail. Otherwise, return the + raise ProblemError with an appropriate ProblemDetail. Otherwise, return the response to the original request. :param response: A Response object from the requests library. - :param expected_encoding: Typically we expect HTTP responses to be UTF-8 + :param expected_encoding: Typically, we expect HTTP responses to be UTF-8 encoded, but for certain requests we can change the encoding type. """ allowed_response_codes = allowed_response_codes or ["2xx", "3xx"] - allowed_response_codes = list(map(str, allowed_response_codes)) + allowed_response_codes_str = list(map(str, allowed_response_codes)) + disallowed_response_codes = disallowed_response_codes or [] + disallowed_response_codes_str = list(map(str, disallowed_response_codes)) + code = response.status_code series = cls.series(code) - if str(code) in allowed_response_codes or series in allowed_response_codes: - # Whether or not it looks like there's been a problem, + if ( + str(code) in allowed_response_codes_str + or series in allowed_response_codes_str + ): + # Whether it looks like there's been a problem, # we've been told to let this response code through. return response content_type = response.headers.get("Content-Type") - response_content = response.content - if response_content and isinstance(response_content, bytes): - try: - response_content = response_content.decode(expected_encoding) - except Exception as e: - return RequestNetworkException(url, e) if content_type == PROBLEM_DETAIL_JSON_MEDIA_TYPE: # The server returned a problem detail document. Wrap it # in a new document that represents the integration # failure. - problem = INTEGRATION_ERROR.detailed( - _("Remote service returned a problem detail document: %r") - % (response_content) - ) - problem.debug_message = response_content - return problem + try: + problem_detail = INTEGRATION_ERROR.detailed( + f"Remote service returned a problem detail document: '{response.text}'" + ) + problem_detail.debug_message = response.text + raise ProblemError(problem_detail=problem_detail) + except JSONDecodeError: + # Failed to decode the problem detail document, we just fall through + # and raise the generic integration error. + pass + # There's been a problem. Return the message we got from the # server, verbatim. - return INTEGRATION_ERROR.detailed( - _("%s response from integration server: %r") - % ( - response.status_code, - response_content, + raise ProblemError( + problem_detail=INTEGRATION_ERROR.detailed( + f'{response.status_code} response from integration server: "{response.text}"' ) ) diff --git a/core/util/problem_detail.py b/core/util/problem_detail.py index 9635518d86..e2eed52e2c 100644 --- a/core/util/problem_detail.py +++ b/core/util/problem_detail.py @@ -2,9 +2,11 @@ As per http://datatracker.ietf.org/doc/draft-ietf-appsawg-http-problem/ """ +from __future__ import annotations + import json as j import logging -from typing import Optional +from typing import Dict, Optional, Tuple from flask_babel import LazyString from pydantic import BaseModel @@ -14,12 +16,16 @@ JSON_MEDIA_TYPE = "application/api-problem+json" -def json(type, status, title, detail=None, instance=None, debug_message=None): +def json( + type: str, + status: Optional[int], + title: Optional[str], + detail: Optional[str] = None, + debug_message: Optional[str] = None, +) -> str: d = dict(type=type, title=str(title), status=status) if detail: d["detail"] = str(detail) - if instance: - d["instance"] = instance if debug_message: d["debug_message"] = debug_message return j.dumps(d) @@ -41,22 +47,20 @@ class ProblemDetail: def __init__( self, - uri, - status_code=None, - title=None, - detail=None, - instance=None, - debug_message=None, + uri: str, + status_code: Optional[int] = None, + title: Optional[str] = None, + detail: Optional[str] = None, + debug_message: Optional[str] = None, ): self.uri = uri self.title = title self.status_code = status_code self.detail = detail - self.instance = instance self.debug_message = debug_message @property - def response(self): + def response(self) -> Tuple[str, int, Dict[str, str]]: """Create a Flask-style response.""" return ( json( @@ -64,7 +68,6 @@ def response(self): self.status_code, self.title, self.detail, - self.instance, self.debug_message, ), self.status_code or 400, @@ -72,9 +75,13 @@ def response(self): ) def detailed( - self, detail, status_code=None, title=None, instance=None, debug_message=None - ): - """Create a ProblemDetail for a more specific occurance of an existing + self, + detail: str, + status_code: Optional[int] = None, + title: Optional[str] = None, + debug_message: Optional[str] = None, + ) -> ProblemDetail: + """Create a ProblemDetail for a more specific occurrence of an existing ProblemDetail. The detailed error message will be shown to patrons. @@ -92,13 +99,16 @@ def detailed( status_code or self.status_code, title or self.title, detail, - instance, debug_message, ) def with_debug( - self, debug_message, detail=None, status_code=None, title=None, instance=None - ): + self, + debug_message: str, + detail: Optional[str] = None, + status_code: Optional[int] = None, + title: Optional[str] = None, + ) -> ProblemDetail: """Insert debugging information into a ProblemDetail. The original ProblemDetail's error message will be shown to @@ -110,17 +120,15 @@ def with_debug( status_code or self.status_code, title or self.title, detail or self.detail, - instance or self.instance, debug_message, ) - def __repr__(self): - return " str: + return " None: """Initialize a new instance of ProblemError class. :param problem_detail: ProblemDetail object diff --git a/pyproject.toml b/pyproject.toml index 6e503ae694..d53471d147 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,12 +67,16 @@ disallow_untyped_defs = true module = [ "api.admin.announcement_list_validator", "api.admin.config", + "api.admin.controller.discovery_service_library_registrations", + "api.admin.controller.discovery_services", + "api.admin.controller.integration_settings", "api.admin.controller.library_settings", "api.admin.controller.patron_auth_services", "api.admin.form_data", "api.admin.model.dashboard_statistics", "api.adobe_vendor_id", "api.circulation", + "api.discovery.*", "api.integration.*", "core.integration.*", "core.model.announcements", @@ -83,6 +87,7 @@ module = [ "core.settings.*", "core.util.authentication_for_opds", "core.util.cache", + "core.util.problem_detail", "tests.fixtures.authenticator", "tests.migration.*", ] diff --git a/tests/api/admin/controller/test_analytics_services.py b/tests/api/admin/controller/test_analytics_services.py index 59efcd8d1c..b7ffd79db6 100644 --- a/tests/api/admin/controller/test_analytics_services.py +++ b/tests/api/admin/controller/test_analytics_services.py @@ -4,7 +4,17 @@ import pytest from werkzeug.datastructures import ImmutableMultiDict -from api.admin.exceptions import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + CANNOT_CHANGE_PROTOCOL, + INCOMPLETE_CONFIGURATION, + INTEGRATION_NAME_ALREADY_IN_USE, + MISSING_ANALYTICS_NAME, + MISSING_SERVICE, + NO_PROTOCOL_FOR_NEW_SERVICE, + NO_SUCH_LIBRARY, + UNKNOWN_PROTOCOL, +) from api.google_analytics_provider import GoogleAnalyticsProvider from core.local_analytics_provider import LocalAnalyticsProvider from core.model import ( diff --git a/tests/api/admin/controller/test_catalog_services.py b/tests/api/admin/controller/test_catalog_services.py index 3dea1b7cd9..fda8a836a3 100644 --- a/tests/api/admin/controller/test_catalog_services.py +++ b/tests/api/admin/controller/test_catalog_services.py @@ -4,7 +4,15 @@ import pytest from werkzeug.datastructures import ImmutableMultiDict -from api.admin.exceptions import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + CANNOT_CHANGE_PROTOCOL, + INTEGRATION_NAME_ALREADY_IN_USE, + MISSING_INTEGRATION, + MISSING_SERVICE, + MULTIPLE_SERVICES_FOR_LIBRARY, + UNKNOWN_PROTOCOL, +) from core.marc import MARCExporter from core.model import ( AdminRole, diff --git a/tests/api/admin/controller/test_collections.py b/tests/api/admin/controller/test_collections.py index 1208d2895b..812f63e710 100644 --- a/tests/api/admin/controller/test_collections.py +++ b/tests/api/admin/controller/test_collections.py @@ -4,7 +4,22 @@ import pytest from werkzeug.datastructures import ImmutableMultiDict -from api.admin.exceptions import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + CANNOT_CHANGE_PROTOCOL, + CANNOT_DELETE_COLLECTION_WITH_CHILDREN, + COLLECTION_NAME_ALREADY_IN_USE, + INCOMPLETE_CONFIGURATION, + INTEGRATION_GOAL_CONFLICT, + MISSING_COLLECTION, + MISSING_COLLECTION_NAME, + MISSING_PARENT, + MISSING_SERVICE, + NO_PROTOCOL_FOR_NEW_SERVICE, + NO_SUCH_LIBRARY, + PROTOCOL_DOES_NOT_SUPPORT_PARENTS, + UNKNOWN_PROTOCOL, +) from api.selftest import HasCollectionSelfTests from core.model import ( Admin, diff --git a/tests/api/admin/controller/test_discovery_services.py b/tests/api/admin/controller/test_discovery_services.py index 3f7f0e7a6f..2abf538c39 100644 --- a/tests/api/admin/controller/test_discovery_services.py +++ b/tests/api/admin/controller/test_discovery_services.py @@ -1,19 +1,30 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import flask import pytest +from flask import Response from werkzeug.datastructures import ImmutableMultiDict -from api.admin.exceptions import ( +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( INCOMPLETE_CONFIGURATION, INTEGRATION_NAME_ALREADY_IN_USE, INTEGRATION_URL_ALREADY_IN_USE, MISSING_SERVICE, NO_PROTOCOL_FOR_NEW_SERVICE, UNKNOWN_PROTOCOL, - AdminNotAuthorized, ) -from api.registration.registry import RemoteRegistry -from core.model import AdminRole, ExternalIntegration, create, get_one -from tests.fixtures.api_admin import SettingsControllerFixture +from api.discovery.opds_registration import OpdsRegistrationService +from api.integration.registry.discovery import DiscoveryRegistry +from core.integration.goals import Goals +from core.model import AdminRole, ExternalIntegration, IntegrationConfiguration, get_one +from core.util.problem_detail import ProblemDetail + +if TYPE_CHECKING: + from tests.fixtures.api_admin import SettingsControllerFixture + from tests.fixtures.database import IntegrationConfigurationFixture class TestDiscoveryServices: @@ -22,6 +33,11 @@ class TestDiscoveryServices: services. """ + @property + def protocol(self): + registry = DiscoveryRegistry() + return registry.get_protocol(OpdsRegistrationService) + def test_discovery_services_get_with_no_services_creates_default( self, settings_ctrl_fixture: SettingsControllerFixture ): @@ -29,17 +45,20 @@ def test_discovery_services_get_with_no_services_creates_default( response = ( settings_ctrl_fixture.manager.admin_discovery_services_controller.process_discovery_services() ) - [service] = response.get("discovery_services") - protocols = response.get("protocols") - assert ExternalIntegration.OPDS_REGISTRATION in [ - p.get("name") for p in protocols - ] + assert response.status_code == 200 + assert isinstance(response, Response) + json = response.get_json() + [service] = json.get("discovery_services") + protocols = json.get("protocols") + assert self.protocol in [p.get("name") for p in protocols] assert "settings" in protocols[0] - assert ExternalIntegration.OPDS_REGISTRATION == service.get("protocol") - assert RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL == service.get( + assert self.protocol == service.get("protocol") + assert OpdsRegistrationService.DEFAULT_LIBRARY_REGISTRY_URL == service.get( "settings" ).get(ExternalIntegration.URL) - assert RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_NAME == service.get("name") + assert OpdsRegistrationService.DEFAULT_LIBRARY_REGISTRY_NAME == service.get( + "name" + ) # Only system admins can see the discovery services. settings_ctrl_fixture.admin.remove_role(AdminRole.SYSTEM_ADMIN) @@ -50,30 +69,30 @@ def test_discovery_services_get_with_no_services_creates_default( ) def test_discovery_services_get_with_one_service( - self, settings_ctrl_fixture: SettingsControllerFixture + self, + settings_ctrl_fixture: SettingsControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, ): - discovery_service, ignore = create( - settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, + discovery_service = create_integration_configuration.discovery_service( + url=settings_ctrl_fixture.ctrl.db.fresh_str() ) - discovery_service.url = settings_ctrl_fixture.ctrl.db.fresh_str() - controller = settings_ctrl_fixture.manager.admin_discovery_services_controller with settings_ctrl_fixture.request_context_with_admin("/"): response = controller.process_discovery_services() - [service] = response.get("discovery_services") + assert isinstance(response, Response) + [service] = response.get_json().get("discovery_services") assert discovery_service.id == service.get("id") assert discovery_service.protocol == service.get("protocol") - assert discovery_service.url == service.get("settings").get( - ExternalIntegration.URL - ) + assert discovery_service.settings_dict["url"] == service.get( + "settings" + ).get(ExternalIntegration.URL) def test_discovery_services_post_errors( - self, settings_ctrl_fixture: SettingsControllerFixture + self, + settings_ctrl_fixture: SettingsControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, ): controller = settings_ctrl_fixture.manager.admin_discovery_services_controller with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): @@ -100,43 +119,35 @@ def test_discovery_services_post_errors( [ ("name", "Name"), ("id", "123"), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), + ("protocol", self.protocol), ] ) response = controller.process_discovery_services() assert response == MISSING_SERVICE - service, ignore = create( - settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, - name="name", + integration_url = settings_ctrl_fixture.ctrl.db.fresh_url() + existing_integration = create_integration_configuration.discovery_service( + url=integration_url ) - with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): - assert isinstance(service.name, str) + assert isinstance(existing_integration.name, str) flask.request.form = ImmutableMultiDict( [ - ("name", service.name), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), + ("name", existing_integration.name), + ("protocol", self.protocol), + ("url", "http://test.com"), ] ) response = controller.process_discovery_services() assert response == INTEGRATION_NAME_ALREADY_IN_USE - existing_integration = settings_ctrl_fixture.ctrl.db.external_integration( - ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, - url=settings_ctrl_fixture.ctrl.db.fresh_url(), - ) with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): assert isinstance(existing_integration.protocol, str) flask.request.form = ImmutableMultiDict( [ ("name", "new name"), ("protocol", existing_integration.protocol), - ("url", existing_integration.url), + ("url", integration_url), ] ) response = controller.process_discovery_services() @@ -145,18 +156,19 @@ def test_discovery_services_post_errors( with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): flask.request.form = ImmutableMultiDict( [ - ("id", str(service.id)), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), + ("id", str(existing_integration.id)), + ("protocol", self.protocol), ] ) response = controller.process_discovery_services() + assert isinstance(response, ProblemDetail) assert response.uri == INCOMPLETE_CONFIGURATION.uri settings_ctrl_fixture.admin.remove_role(AdminRole.SYSTEM_ADMIN) with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): flask.request.form = ImmutableMultiDict( [ - ("protocol", ExternalIntegration.OPDS_REGISTRATION), + ("protocol", self.protocol), (ExternalIntegration.URL, "registry url"), ] ) @@ -169,8 +181,8 @@ def test_discovery_services_post_create( flask.request.form = ImmutableMultiDict( [ ("name", "Name"), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - (ExternalIntegration.URL, "http://registry_url"), + ("protocol", self.protocol), + (ExternalIntegration.URL, "http://registry.url"), ] ) response = ( @@ -180,32 +192,34 @@ def test_discovery_services_post_create( service = get_one( settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - goal=ExternalIntegration.DISCOVERY_GOAL, + IntegrationConfiguration, + goal=Goals.DISCOVERY_GOAL, + ) + assert isinstance(service, IntegrationConfiguration) + assert isinstance(response, Response) + assert service.id == int(response.get_data(as_text=True)) + assert self.protocol == service.protocol + assert ( + OpdsRegistrationService.settings_class()(**service.settings_dict).url + == "http://registry.url" ) - assert isinstance(service, ExternalIntegration) - assert service.id == int(response.response[0]) - assert ExternalIntegration.OPDS_REGISTRATION == service.protocol - assert "http://registry_url" == service.url def test_discovery_services_post_edit( - self, settings_ctrl_fixture: SettingsControllerFixture + self, + settings_ctrl_fixture: SettingsControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, ): - discovery_service, ignore = create( - settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, + discovery_service = create_integration_configuration.discovery_service( + url="registry url" ) - discovery_service.url = "registry url" with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): flask.request.form = ImmutableMultiDict( [ ("name", "Name"), ("id", str(discovery_service.id)), - ("protocol", ExternalIntegration.OPDS_REGISTRATION), - (ExternalIntegration.URL, "http://new_registry_url"), + ("protocol", self.protocol), + (ExternalIntegration.URL, "http://new_registry_url.com"), ] ) response = ( @@ -213,54 +227,80 @@ def test_discovery_services_post_edit( ) assert response.status_code == 200 - assert discovery_service.id == int(response.response[0]) - assert ExternalIntegration.OPDS_REGISTRATION == discovery_service.protocol - assert "http://new_registry_url" == discovery_service.url - - def test_check_name_unique(self, settings_ctrl_fixture: SettingsControllerFixture): - kwargs = dict( - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, + assert isinstance(response, Response) + assert discovery_service.id == int(response.get_data(as_text=True)) + assert self.protocol == discovery_service.protocol + assert ( + "http://new_registry_url.com" + == OpdsRegistrationService.settings_class()( + **discovery_service.settings_dict + ).url ) - existing_service, ignore = create( - settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - name="existing service", - **kwargs - ) - new_service, ignore = create( - settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - name="new service", - **kwargs - ) - - m = ( - settings_ctrl_fixture.manager.admin_discovery_services_controller.check_name_unique - ) + def test_check_name_unique( + self, + settings_ctrl_fixture: SettingsControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, + ): + existing_service = create_integration_configuration.discovery_service() + new_service = create_integration_configuration.discovery_service() # Try to change new service so that it has the same name as existing service # -- this is not allowed. - result = m(new_service, existing_service.name) - assert result == INTEGRATION_NAME_ALREADY_IN_USE + with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( + [ + ("name", str(existing_service.name)), + ("id", str(new_service.id)), + ("protocol", self.protocol), + ("url", "http://test.com"), + ] + ) + response = ( + settings_ctrl_fixture.manager.admin_discovery_services_controller.process_discovery_services() + ) + assert response == INTEGRATION_NAME_ALREADY_IN_USE # Try to edit existing service without changing its name -- this is fine. - assert None == m(existing_service, existing_service.name) + with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( + [ + ("name", str(existing_service.name)), + ("id", str(existing_service.id)), + ("protocol", self.protocol), + ("url", "http://test.com"), + ] + ) + response = ( + settings_ctrl_fixture.manager.admin_discovery_services_controller.process_discovery_services() + ) + assert isinstance(response, Response) + assert response.status_code == 200 # Changing the existing service's name is also fine. - assert None == m(existing_service, "new name") + with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( + [ + ("name", "New name"), + ("id", str(existing_service.id)), + ("protocol", self.protocol), + ("url", "http://test.com"), + ] + ) + response = ( + settings_ctrl_fixture.manager.admin_discovery_services_controller.process_discovery_services() + ) + assert isinstance(response, Response) + assert response.status_code == 200 def test_discovery_service_delete( - self, settings_ctrl_fixture: SettingsControllerFixture + self, + settings_ctrl_fixture: SettingsControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, ): - discovery_service, ignore = create( - settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, + discovery_service = create_integration_configuration.discovery_service( + url="registry url" ) - discovery_service.url = "registry url" with settings_ctrl_fixture.request_context_with_admin("/", method="DELETE"): settings_ctrl_fixture.admin.remove_role(AdminRole.SYSTEM_ADMIN) @@ -272,7 +312,7 @@ def test_discovery_service_delete( settings_ctrl_fixture.admin.add_role(AdminRole.SYSTEM_ADMIN) response = settings_ctrl_fixture.manager.admin_discovery_services_controller.process_delete( - discovery_service.id + discovery_service.id # type: ignore[arg-type] ) assert response.status_code == 200 diff --git a/tests/api/admin/controller/test_individual_admins.py b/tests/api/admin/controller/test_individual_admins.py index ecb30a0a07..9e90f4a267 100644 --- a/tests/api/admin/controller/test_individual_admins.py +++ b/tests/api/admin/controller/test_individual_admins.py @@ -4,8 +4,13 @@ import pytest from werkzeug.datastructures import MultiDict -from api.admin.exceptions import * -from api.admin.problem_details import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + ADMIN_AUTH_NOT_CONFIGURED, + INCOMPLETE_CONFIGURATION, + UNKNOWN_ROLE, +) +from api.problem_details import LIBRARY_NOT_FOUND from core.model import Admin, AdminRole, create, get_one diff --git a/tests/api/admin/controller/test_library.py b/tests/api/admin/controller/test_library.py index 55349f19d9..736a4135b9 100644 --- a/tests/api/admin/controller/test_library.py +++ b/tests/api/admin/controller/test_library.py @@ -15,8 +15,15 @@ from werkzeug.datastructures import FileStorage, ImmutableMultiDict from api.admin.controller.library_settings import LibrarySettingsController -from api.admin.exceptions import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + INCOMPLETE_CONFIGURATION, + INVALID_CONFIGURATION_OPTION, + LIBRARY_SHORT_NAME_ALREADY_IN_USE, + UNKNOWN_LANGUAGE, +) from api.config import Configuration +from api.problem_details import LIBRARY_NOT_FOUND from core.facets import FacetConstants from core.model import AdminRole, Library, get_one from core.model.announcements import SETTING_NAME as ANNOUNCEMENTS_SETTING_NAME diff --git a/tests/api/admin/controller/test_library_registrations.py b/tests/api/admin/controller/test_library_registrations.py index cf25899d05..2445f58311 100644 --- a/tests/api/admin/controller/test_library_registrations.py +++ b/tests/api/admin/controller/test_library_registrations.py @@ -1,89 +1,97 @@ -import json +from unittest.mock import MagicMock import flask import pytest -from flask import url_for -from werkzeug.datastructures import MultiDict - -from api.admin.exceptions import * -from api.registration.registry import Registration, RemoteRegistry -from core.model import AdminRole, ConfigurationSetting, ExternalIntegration, create -from core.util.http import HTTP -from tests.core.mock import DummyHTTPClient +from flask import Response, url_for +from requests_mock import Mocker +from werkzeug.datastructures import ImmutableMultiDict + +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import MISSING_SERVICE, NO_SUCH_LIBRARY +from api.discovery.opds_registration import OpdsRegistrationService +from api.problem_details import REMOTE_INTEGRATION_FAILED +from core.model import AdminRole, create +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStage, + RegistrationStatus, +) +from core.problem_details import INVALID_INPUT +from core.util.problem_detail import ProblemDetail, ProblemError +from tests.fixtures.api_admin import AdminControllerFixture +from tests.fixtures.database import IntegrationConfigurationFixture +from tests.fixtures.library import LibraryFixture class TestLibraryRegistration: - """Test the process of registering a library with a RemoteRegistry.""" + """Test the process of registering a library with a OpdsRegistrationService.""" - def test_discovery_service_library_registrations_get(self, settings_ctrl_fixture): - db = settings_ctrl_fixture.ctrl.db + def test_discovery_service_library_registrations_get( + self, + admin_ctrl_fixture: AdminControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, + library_fixture: LibraryFixture, + requests_mock: Mocker, + ) -> None: + db = admin_ctrl_fixture.ctrl.db # Here's a discovery service. - discovery_service, ignore = create( - db.session, - ExternalIntegration, - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, + discovery_service = create_integration_configuration.discovery_service( + url="http://service-url/" ) - # We'll be making a mock request to this URL later. - discovery_service.setting(ExternalIntegration.URL).value = "http://service-url/" - # We successfully registered this library with the service. - succeeded = db.library( + succeeded = library_fixture.library( name="Library 1", short_name="L1", ) - config = ConfigurationSetting.for_library_and_externalintegration - config( - db.session, "library-registration-status", succeeded, discovery_service - ).value = "success" + registration, _ = create( + db.session, + DiscoveryServiceRegistration, + library=succeeded, + integration=discovery_service, + ) + registration.status = RegistrationStatus.SUCCESS + registration.stage = RegistrationStage.PRODUCTION # We tried to register this library with the service but were # unsuccessful. - config( - db.session, "library-registration-stage", succeeded, discovery_service - ).value = "production" - failed = db.library( + failed = library_fixture.library( name="Library 2", short_name="L2", ) - config( - db.session, - "library-registration-status", - failed, - discovery_service, - ).value = "failure" - config( + registration, _ = create( db.session, - "library-registration-stage", - failed, - discovery_service, - ).value = "testing" + DiscoveryServiceRegistration, + library=failed, + integration=discovery_service, + ) + registration.status = RegistrationStatus.FAILURE + registration.stage = RegistrationStage.TESTING # We've never tried to register this library with the service. - unregistered = db.library( + unregistered = library_fixture.library( name="Library 3", short_name="L3", ) - discovery_service.libraries = [succeeded, failed] # When a client sends a GET request to the controller, the # controller is going to call - # RemoteRegistry.fetch_registration_document() to try and find + # OpdsRegistrationService.fetch_registration_document() to try and find # the discovery services' terms of service. That's going to # make one or two HTTP requests. - # First, let's try the scenario where the discovery serivce is + # First, let's try the scenario where the discovery service is # working and has a terms-of-service. - client = DummyHTTPClient() # In this case we'll make two requests. The first request will # ask for the root catalog, where we'll look for a # registration link. root_catalog = dict(links=[dict(href="http://register-here/", rel="register")]) - client.queue_requests_response( - 200, RemoteRegistry.OPDS_2_TYPE, content=json.dumps(root_catalog) + requests_mock.get( + "http://service-url/", + json=root_catalog, + headers={"Content-Type": OpdsRegistrationService.OPDS_2_TYPE}, ) # The second request will fetch that registration link -- then @@ -98,19 +106,30 @@ def test_discovery_service_library_registrations_get(self, settings_ctrl_fixture ), ] ) - client.queue_requests_response( - 200, RemoteRegistry.OPDS_2_TYPE, content=json.dumps(registration_document) + requests_mock.get( + "http://register-here/", + json=registration_document, + headers={"Content-Type": OpdsRegistrationService.OPDS_2_TYPE}, ) controller = ( - settings_ctrl_fixture.manager.admin_discovery_service_library_registrations_controller + admin_ctrl_fixture.ctrl.manager.admin_discovery_service_library_registrations_controller ) m = controller.process_discovery_service_library_registrations - with settings_ctrl_fixture.request_context_with_admin("/", method="GET"): - response = m(do_get=client.do_get) + with admin_ctrl_fixture.request_context_with_admin("/", method="GET"): + # When the user lacks the SYSTEM_ADMIN role, the + # controller won't even start processing their GET + # request. + pytest.raises(AdminNotAuthorized, m) + + # Add the admin role and try again. + admin_ctrl_fixture.admin.add_role(AdminRole.SYSTEM_ADMIN) + + response = m() # The document we get back from the controller is a # dictionary with useful information on all known # discovery integrations -- just one, in this case. + assert isinstance(response, dict) [service] = response["library_registrations"] assert discovery_service.id == service["id"] @@ -118,7 +137,9 @@ def test_discovery_service_library_registrations_get(self, settings_ctrl_fixture # happened. The target of the first request is the URL to # the discovery service's main catalog. The second request # is to the "register" link found in that catalog. - assert ["http://service-url/", "http://register-here/"] == client.requests + assert ["service-url", "register-here"] == [ + r.hostname for r in requests_mock.request_history + ] # The TOS link and TOS HTML snippet were recovered from # the registration document served in response to the @@ -145,26 +166,29 @@ def test_discovery_service_library_registrations_get(self, settings_ctrl_fixture ) # Note that `unregistered`, the library that never tried - # to register with this discover service, is not included. + # to register with this discovery service, is not included. # Now let's try the controller method again, except this # time the discovery service's web server is down. The # first request will return a ProblemDetail document, and # there will be no second request. - client.requests = [] - client.queue_requests_response( - 502, - content=REMOTE_INTEGRATION_FAILED, + requests_mock.reset() + requests_mock.get( + "http://service-url/", + json=REMOTE_INTEGRATION_FAILED.response[0], + status_code=502, ) - response = m(do_get=client.do_get) + + response = m() # Everything looks good, except that there's no TOS data # available. + assert isinstance(response, dict) [service] = response["library_registrations"] assert discovery_service.id == service["id"] assert 2 == len(service["libraries"]) - assert None == service["terms_of_service_link"] - assert None == service["terms_of_service_html"] + assert service["terms_of_service_link"] is None + assert service["terms_of_service_html"] is None # The problem detail document that prevented the TOS data # from showing up has been converted to a dictionary and @@ -172,61 +196,68 @@ def test_discovery_service_library_registrations_get(self, settings_ctrl_fixture # discovery service. assert REMOTE_INTEGRATION_FAILED.uri == service["access_problem"]["type"] - # When the user lacks the SYSTEM_ADMIN role, the - # controller won't even start processing their GET - # request. - settings_ctrl_fixture.admin.remove_role(AdminRole.SYSTEM_ADMIN) - db.session.flush() - pytest.raises(AdminNotAuthorized, m) - - def test_discovery_service_library_registrations_post(self, settings_ctrl_fixture): + def test_discovery_service_library_registrations_post( + self, + admin_ctrl_fixture: AdminControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, + library_fixture: LibraryFixture, + ) -> None: """Test what might happen when you POST to discovery_service_library_registrations. """ controller = ( - settings_ctrl_fixture.manager.admin_discovery_service_library_registrations_controller + admin_ctrl_fixture.manager.admin_discovery_service_library_registrations_controller ) m = controller.process_discovery_service_library_registrations # Here, the user doesn't have permission to start the # registration process. - settings_ctrl_fixture.admin.remove_role(AdminRole.SYSTEM_ADMIN) - with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): - pytest.raises( - AdminNotAuthorized, - m, - do_get=settings_ctrl_fixture.do_request, - do_post=settings_ctrl_fixture.do_request, - ) - settings_ctrl_fixture.admin.add_role(AdminRole.SYSTEM_ADMIN) + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): + pytest.raises(AdminNotAuthorized, m) + + admin_ctrl_fixture.admin.add_role(AdminRole.SYSTEM_ADMIN) + + # We might not get an integration ID parameter. + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict() + response = m() + assert isinstance(response, ProblemDetail) + assert INVALID_INPUT.uri == response.uri # The integration ID might not correspond to a valid # ExternalIntegration. - with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict( + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( [ ("integration_id", "1234"), ] ) response = m() + assert isinstance(response, ProblemDetail) assert MISSING_SERVICE == response - # Create an ExternalIntegration to avoid that problem in future - # tests. - discovery_service, ignore = create( - settings_ctrl_fixture.ctrl.db.session, - ExternalIntegration, - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, + # Create an IntegrationConfiguration to avoid that problem in future tests. + discovery_service = create_integration_configuration.discovery_service( + url="http://register-here/" ) - discovery_service.url = "registry url" + + # We might not get a library short name. + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( + [ + ("integration_id", str(discovery_service.id)), + ] + ) + response = m() + assert isinstance(response, ProblemDetail) + assert INVALID_INPUT.uri == response.uri # The library name might not correspond to a real library. - with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): - flask.request.form = MultiDict( + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( [ - ("integration_id", discovery_service.id), + ("integration_id", str(discovery_service.id)), ("library_short_name", "not-a-library"), ] ) @@ -234,55 +265,65 @@ def test_discovery_service_library_registrations_post(self, settings_ctrl_fixtur assert NO_SUCH_LIBRARY == response # Take care of that problem. - library = settings_ctrl_fixture.ctrl.db.default_library() - form = MultiDict( + library = library_fixture.library() + + # We might not get a registration stage. + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( + [ + ("integration_id", str(discovery_service.id)), + ("library_short_name", str(library.short_name)), + ] + ) + response = m() + assert isinstance(response, ProblemDetail) + assert INVALID_INPUT.uri == response.uri + + # The registration stage might not be valid. + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): + flask.request.form = ImmutableMultiDict( + [ + ("integration_id", str(discovery_service.id)), + ("library_short_name", str(library.short_name)), + ("registration_stage", "not-a-stage"), + ] + ) + response = m() + assert isinstance(response, ProblemDetail) + assert INVALID_INPUT.uri == response.uri + + form = ImmutableMultiDict( [ - ("integration_id", discovery_service.id), - ("library_short_name", library.short_name), - ("registration_stage", Registration.TESTING_STAGE), + ("integration_id", str(discovery_service.id)), + ("library_short_name", str(library.short_name)), + ("registration_stage", RegistrationStage.TESTING.value), ] ) - # Registration.push might return a ProblemDetail for whatever - # reason. - class Mock(Registration): - # We reproduce the signature, even though it's not - # necessary for what we're testing, so that if the push() - # signature changes this test will fail. - def push(self, stage, url_for, catalog_url=None, do_get=None, do_post=None): - return REMOTE_INTEGRATION_FAILED + # The registration may fail for some reason. + mock_registry = MagicMock(spec=OpdsRegistrationService) + mock_registry.register_library.side_effect = ProblemError( + problem_detail=REMOTE_INTEGRATION_FAILED + ) + controller.look_up_registry = MagicMock(return_value=mock_registry) - with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): flask.request.form = form - response = m(registration_class=Mock) + response = m() assert REMOTE_INTEGRATION_FAILED == response # But if that doesn't happen, success! - class Mock(Registration): - """When asked to push a registration, do nothing and say it - worked. - """ + mock_registry = MagicMock(spec=OpdsRegistrationService) + mock_registry.register_library.return_value = True + controller.look_up_registry = MagicMock(return_value=mock_registry) - called_with = None - - def push(self, *args, **kwargs): - Mock.called_with = (args, kwargs) - return True - - with settings_ctrl_fixture.request_context_with_admin("/", method="POST"): + with admin_ctrl_fixture.request_context_with_admin("/", method="POST"): flask.request.form = form - response = controller.process_discovery_service_library_registrations( - registration_class=Mock - ) + response = controller.process_discovery_service_library_registrations() + assert isinstance(response, Response) assert 200 == response.status_code - # push() was called with the arguments we would expect. - args, kwargs = Mock.called_with - assert (Registration.TESTING_STAGE, url_for) == args - - # We would have made real HTTP requests. - assert HTTP.debuggable_post == kwargs.pop("do_post") - assert HTTP.debuggable_get == kwargs.pop("do_get") - - # No other keyword arguments were passed in. - assert {} == kwargs + # register_library() was called with the arguments we would expect. + mock_registry.register_library.assert_called_once_with( + library, RegistrationStage.TESTING, url_for + ) diff --git a/tests/api/admin/controller/test_metadata_services.py b/tests/api/admin/controller/test_metadata_services.py index 8ea9c06e09..ba62edcf47 100644 --- a/tests/api/admin/controller/test_metadata_services.py +++ b/tests/api/admin/controller/test_metadata_services.py @@ -5,7 +5,16 @@ from werkzeug.datastructures import MultiDict from api.admin.controller.metadata_services import MetadataServicesController -from api.admin.exceptions import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + CANNOT_CHANGE_PROTOCOL, + INCOMPLETE_CONFIGURATION, + INTEGRATION_NAME_ALREADY_IN_USE, + MISSING_SERVICE, + NO_PROTOCOL_FOR_NEW_SERVICE, + NO_SUCH_LIBRARY, + UNKNOWN_PROTOCOL, +) from api.novelist import NoveListAPI from api.nyt import NYTBestSellerAPI from core.model import AdminRole, ExternalIntegration, create, get_one diff --git a/tests/api/admin/controller/test_patron_auth_self_tests.py b/tests/api/admin/controller/test_patron_auth_self_tests.py index 54f615fb31..12816805f4 100644 --- a/tests/api/admin/controller/test_patron_auth_self_tests.py +++ b/tests/api/admin/controller/test_patron_auth_self_tests.py @@ -142,6 +142,7 @@ def test_patron_auth_self_tests_post_with_no_libraries( response = controller.process_patron_auth_service_self_tests(auth_service.id) assert isinstance(response, ProblemDetail) assert response.title == FAILED_TO_RUN_SELF_TESTS.title + assert response.detail is not None assert "Failed to run self tests" in response.detail assert response.status_code == 400 diff --git a/tests/api/admin/controller/test_search_services.py b/tests/api/admin/controller/test_search_services.py index ab57b1fae6..ee92ebec1f 100644 --- a/tests/api/admin/controller/test_search_services.py +++ b/tests/api/admin/controller/test_search_services.py @@ -2,7 +2,15 @@ import pytest from werkzeug.datastructures import MultiDict -from api.admin.exceptions import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + INCOMPLETE_CONFIGURATION, + INTEGRATION_NAME_ALREADY_IN_USE, + MISSING_SERVICE, + MULTIPLE_SITEWIDE_SERVICES, + NO_PROTOCOL_FOR_NEW_SERVICE, + UNKNOWN_PROTOCOL, +) from core.external_search import ExternalSearchIndex from core.model import AdminRole, ExternalIntegration, create, get_one @@ -105,8 +113,8 @@ def test_search_services_post_errors(self, settings_ctrl_fixture): service, ignore = create( settings_ctrl_fixture.ctrl.db.session, ExternalIntegration, - protocol=ExternalIntegration.OPDS_REGISTRATION, - goal=ExternalIntegration.DISCOVERY_GOAL, + protocol="test", + goal=ExternalIntegration.LICENSE_GOAL, name="name", ) diff --git a/tests/api/admin/controller/test_sitewide_settings.py b/tests/api/admin/controller/test_sitewide_settings.py index 49e7aba19b..5c2cfdd497 100644 --- a/tests/api/admin/controller/test_sitewide_settings.py +++ b/tests/api/admin/controller/test_sitewide_settings.py @@ -2,7 +2,11 @@ import pytest from werkzeug.datastructures import MultiDict -from api.admin.exceptions import * +from api.admin.exceptions import AdminNotAuthorized +from api.admin.problem_details import ( + MISSING_SITEWIDE_SETTING_KEY, + MISSING_SITEWIDE_SETTING_VALUE, +) from api.config import Configuration from core.model import AdminRole, ConfigurationSetting diff --git a/tests/api/discovery/test_opds_registration.py b/tests/api/discovery/test_opds_registration.py new file mode 100644 index 0000000000..284f3163dc --- /dev/null +++ b/tests/api/discovery/test_opds_registration.py @@ -0,0 +1,822 @@ +import base64 +import json +import os +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, List, Optional +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from Crypto.Cipher import PKCS1_OAEP +from Crypto.PublicKey import RSA +from requests_mock import Mocker + +from api.config import Configuration +from api.discovery.opds_registration import OpdsRegistrationService +from api.discovery.registration_script import LibraryRegistrationScript +from api.problem_details import * +from core.model import ConfigurationSetting, Library, create, get_one +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStage, + RegistrationStatus, +) +from core.util.problem_detail import JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE +from core.util.problem_detail import ProblemDetail, ProblemError +from tests.api.mockapi.circulation import MockCirculationManager +from tests.core.mock import MockRequestsResponse +from tests.fixtures.database import ( + DatabaseTransactionFixture, + IntegrationConfigurationFixture, +) +from tests.fixtures.library import LibraryFixture + + +class RemoteRegistryFixture: + def __init__( + self, + db: DatabaseTransactionFixture, + integration_configuration: IntegrationConfigurationFixture, + ): + self.db = db + # Create an ExternalIntegration that can be used as the basis for + # a OpdsRegistrationService. + self.registry_url = "http://registry.com/" + self.integration = integration_configuration.discovery_service( + url=self.registry_url + ) + assert self.integration.protocol is not None + self.protocol = self.integration.protocol + assert self.integration.goal is not None + self.goal = self.integration.goal + + self.registry = OpdsRegistrationService.for_integration( + db.session, self.integration + ) + + def create_registration( + self, library: Optional[Library] = None + ) -> DiscoveryServiceRegistration: + obj, _ = create( + self.db.session, + DiscoveryServiceRegistration, + library=library or self.db.default_library(), + integration=self.integration, + ) + return obj + + +@pytest.fixture(scope="function") +def remote_registry_fixture( + db: DatabaseTransactionFixture, + create_integration_configuration: IntegrationConfigurationFixture, +) -> RemoteRegistryFixture: + return RemoteRegistryFixture(db, create_integration_configuration) + + +class TestOpdsRegistrationService: + def test_constructor(self): + integration = MagicMock() + settings = MagicMock() + registry = OpdsRegistrationService(integration, settings) + assert integration == registry.integration + assert settings == registry.settings + + def test_for_integration(self, remote_registry_fixture: RemoteRegistryFixture): + """Test the ability to build a Registry for an ExternalIntegration + given its ID. + """ + db = remote_registry_fixture.db + m = OpdsRegistrationService.for_integration + assert remote_registry_fixture.integration.id is not None + registry = m( + db.session, + remote_registry_fixture.integration.id, + ) + assert isinstance(registry, OpdsRegistrationService) + assert remote_registry_fixture.integration == registry.integration + + # If the ID doesn't exist you get None. + assert m(db.session, -1) is None + + # You can also pass in the IntegrationConfiguration object itself. + registry = m(db.session, remote_registry_fixture.integration) + assert isinstance(registry, OpdsRegistrationService) + assert remote_registry_fixture.integration == registry.integration + + def test_for_protocol_goal_and_url( + self, remote_registry_fixture: RemoteRegistryFixture + ): + db = remote_registry_fixture.db + m = OpdsRegistrationService.for_protocol_goal_and_url + + registry = m( + db.session, + remote_registry_fixture.protocol, + remote_registry_fixture.goal, + remote_registry_fixture.registry_url, + ) + assert isinstance(registry, OpdsRegistrationService) + assert remote_registry_fixture.integration == registry.integration + + # If the ExternalIntegration doesn't exist, we get None. + registry = m( + db.session, + remote_registry_fixture.protocol, + remote_registry_fixture.goal, + "http://registry2.com", + ) + assert registry is None + + def test_registrations(self, remote_registry_fixture: RemoteRegistryFixture): + db = remote_registry_fixture.db + + # Associate the default library with the registry. + remote_registry_fixture.create_registration(db.default_library()) + + # Create another library not associated with the registry. + library2 = db.library() + + # registrations() finds a single Registration. + [registration] = list(remote_registry_fixture.registry.registrations) + assert isinstance(registration, DiscoveryServiceRegistration) + assert db.default_library() == registration.library + + def test_fetch_catalog( + self, remote_registry_fixture: RemoteRegistryFixture, requests_mock: Mocker + ): + # The behavior of fetch_catalog() depends on what comes back + # when we ask the remote registry for its root catalog. + requests_mock.get(remote_registry_fixture.registry_url, text="A root catalog") + + # Test our ability to retrieve essential information from a + # remote registry's root catalog. + func_mock = MagicMock(return_value="Essential information") + remote_registry_fixture.registry._extract_catalog_information = func_mock + + # If the response looks good, it's passed into + # _extract_catalog_information(), and the result of _that_ + # method is the return value of fetch_catalog. + assert ( + "Essential information" == remote_registry_fixture.registry.fetch_catalog() + ) + assert requests_mock.called_once + assert requests_mock.last_request is not None + assert requests_mock.last_request.url == remote_registry_fixture.registry_url + assert remote_registry_fixture.registry._extract_catalog_information.called + assert func_mock.call_args.args[0].text == "A root catalog" + + def test__extract_catalog_information( + self, remote_registry_fixture: RemoteRegistryFixture + ): + # Test our ability to extract a registration link and an + # Adobe Vendor ID from an OPDS 2 catalog. + def mock_request(document, type=OpdsRegistrationService.OPDS_2_TYPE) -> Any: + data = json.dumps(document) if isinstance(document, dict) else document + return MockRequestsResponse(200, {"Content-Type": type}, data) + + m = OpdsRegistrationService._extract_catalog_information + + # OPDS 2 feed with link and Adobe Vendor ID. + link = {"rel": "register", "href": "register url"} + metadata = {"adobe_vendor_id": "vendorid"} + request = mock_request(dict(links=[link], metadata=metadata)) + assert ("register url", "vendorid") == m(request) + + # OPDS 2 feed with link and no Adobe Vendor ID + request = mock_request(dict(links=[link])) + assert ("register url", None) == m(request) + + # OPDS 2 feed with no link. + with pytest.raises(ProblemError) as excinfo: + request = mock_request(dict(metadata=metadata)) + m(request) + + detail = excinfo.value.problem_detail + assert detail.detail is not None + assert ( + "The service at http://url/ did not provide a register link." + in detail.detail + ) + assert REMOTE_INTEGRATION_FAILED.uri == detail.uri + + # Non-OPDS document. + with pytest.raises(ProblemError) as excinfo: + request = mock_request("plain text here", "text/plain") + m(request) + + detail = excinfo.value.problem_detail + assert detail.detail is not None + assert "The service at http://url/ did not return OPDS." in detail.detail + assert REMOTE_INTEGRATION_FAILED.uri == detail.uri + + def test_fetch_registration_document_error_catalog( + self, remote_registry_fixture: RemoteRegistryFixture + ): + # Test our ability to retrieve terms-of-service information + # from a remote registry, assuming the registry makes that + # information available. + + # First, test the case where we can't even get the catalog + # document. + remote_registry_fixture.registry.fetch_catalog = MagicMock( + side_effect=ProblemError(problem_detail=REMOTE_INTEGRATION_FAILED) + ) + with pytest.raises(ProblemError) as excinfo: + remote_registry_fixture.registry.fetch_registration_document() + + # The fetch_catalog method raised a ProblemError, + # which propagated up to fetch_registration_document. + assert REMOTE_INTEGRATION_FAILED == excinfo.value.problem_detail + remote_registry_fixture.registry.fetch_catalog.assert_called_once() + + def test_fetch_registration_document_error_registration_document( + self, remote_registry_fixture: RemoteRegistryFixture, requests_mock: Mocker + ): + # Test the case where we get the catalog document, but we can't + # get the registration document. + requests_mock.get( + "http://register-here/", + status_code=REMOTE_INTEGRATION_FAILED.status_code, # type: ignore[arg-type] + headers={"Content-Type": PROBLEM_DETAIL_JSON_MEDIA_TYPE}, + text=REMOTE_INTEGRATION_FAILED.response[0], + ) + remote_registry_fixture.registry.fetch_catalog = MagicMock( + return_value=("http://register-here/", "vendor id") + ) + + with pytest.raises(ProblemError) as excinfo: + remote_registry_fixture.registry.fetch_registration_document() + + # A request was made to the registration URL mentioned in the catalog. + assert requests_mock.called_once + assert requests_mock.last_request is not None + assert "http://register-here/" == requests_mock.last_request.url + + # But the request returned a problem detail, which became a ProblemError + assert REMOTE_INTEGRATION_FAILED.uri == excinfo.value.problem_detail.uri + assert excinfo.value.problem_detail.detail is not None + assert ( + str(REMOTE_INTEGRATION_FAILED.detail) in excinfo.value.problem_detail.detail + ) + assert "Remote service returned" in excinfo.value.problem_detail.detail + + def test_fetch_registration_document( + self, remote_registry_fixture: RemoteRegistryFixture, requests_mock: Mocker + ): + # Finally, test the case where we can get both documents. + remote_registry_fixture.registry.fetch_catalog = MagicMock( + return_value=("http://register-here/", "vendor id") + ) + remote_registry_fixture.registry._extract_registration_information = MagicMock( + return_value=("TOS link", "TOS HTML data") + ) + + requests_mock.get("http://register-here/", text="a registration document") + result = remote_registry_fixture.registry.fetch_registration_document() + + # Another request was made to the registration URL. + assert requests_mock.called_once + assert requests_mock.last_request is not None + assert "http://register-here/" == requests_mock.last_request.url + remote_registry_fixture.registry.fetch_catalog.assert_called_once() + + # Our mock of _extract_registration_information was called + # with the mock response to that request. + remote_registry_fixture.registry._extract_registration_information.assert_called_once() + assert ( + remote_registry_fixture.registry._extract_registration_information.call_args.args[ + 0 + ].text + == "a registration document" + ) + + # The return value of _extract_registration_information was + # propagated as the return value of + # fetch_registration_document. + assert ("TOS link", "TOS HTML data") == result + + def test__extract_registration_information( + self, remote_registry_fixture: RemoteRegistryFixture + ): + # Test our ability to extract terms-of-service information -- + # a link and/or some HTML or textual instructions -- from a + # registration document. + + def data_link(data, type="text/html"): + encoded = base64.b64encode(data.encode("utf-8")).decode("utf-8") + return dict(rel="terms-of-service", href=f"data:{type};base64,{encoded}") + + class Mock(OpdsRegistrationService): + decoded: str + + @classmethod + def _decode_data_url(cls, url): + cls.decoded = url + return "Decoded: " + OpdsRegistrationService._decode_data_url(url) + + def extract(document, type=OpdsRegistrationService.OPDS_2_TYPE): + if type == OpdsRegistrationService.OPDS_2_TYPE: + document = json.dumps(dict(links=document)) + response = MockRequestsResponse(200, {"Content-Type": type}, document) + return Mock._extract_registration_information(response) + + # OPDS 2 feed with TOS in http: and data: links. + tos_link = dict(rel="terms-of-service", href="http://tos/") + tos_data = data_link("

Some HTML

") + assert ("http://tos/", "Decoded:

Some HTML

") == extract( + [tos_link, tos_data] + ) + + # At this point it's clear that the data: URL found in + # `tos_data` was run through `_decode_data()`. This gives us + # permission to test all the fiddly bits of `_decode_data` in + # isolation, below. + assert tos_data["href"] == Mock.decoded + + # OPDS 2 feed with http: link only. + assert ("http://tos/", None) == extract([tos_link]) + + # OPDS 2 feed with data: link only. + assert (None, "Decoded:

Some HTML

") == extract([tos_data]) + + # OPDS 2 feed with no links. + assert (None, None) == extract([]) + + # Non-OPDS document. + assert (None, None) == extract("plain text here", "text/plain") + + # Unrecognized URI schemes are ignored. + ftp_link = dict(rel="terms-of-service", href="ftp://tos/") + assert (None, None) == extract([ftp_link]) + + def test__decode_data_url(self, remote_registry_fixture: RemoteRegistryFixture): + # Test edge cases of decoding data: URLs. + m = OpdsRegistrationService._decode_data_url + + def data_url(data, type="text/html"): + encoded = base64.b64encode(data.encode("utf-8")).decode("utf-8") + return f"data:{type};base64,{encoded}" + + # HTML is okay. + html = data_url("some HTML", "text/html;charset=utf-8") + assert "some HTML" == m(html) + + # Plain text is okay. + text = data_url("some plain text", "text/plain") + assert "some plain text" == m(text) + + # No other media type is allowed. + image = data_url("an image!", "image/png") + with pytest.raises(ValueError) as excinfo: + m(image) + assert "Unsupported media type in data: URL: image/png" in str(excinfo.value) + + # Incoming HTML is sanitized. + dirty_html = data_url("

Some HTML

") + assert "

Some HTML

" == m(dirty_html) + + # Now test various malformed data: URLs. + no_header = "foobar" + with pytest.raises(ValueError) as excinfo: + m(no_header) + assert "Not a data: URL: foobar" in str(excinfo.value) + + no_comma = "data:blah" + with pytest.raises(ValueError) as excinfo: + m(no_comma) + assert "Invalid data: URL: data:blah" in str(excinfo.value) + + too_many_commas = "data:blah,blah,blah" + with pytest.raises(ValueError) as excinfo: + m(too_many_commas) + assert "Invalid data: URL: data:blah,blah,blah" in str(excinfo.value) + + # data: URLs don't have to be base64-encoded, but those are the + # only kind we support. + not_encoded = "data:blah,content" + with pytest.raises(ValueError) as excinfo: + m(not_encoded) + assert "data: URL not base64-encoded: data:blah,content" in str(excinfo.value) + + def test_register_library( + self, + remote_registry_fixture: RemoteRegistryFixture, + library_fixture: LibraryFixture, + ): + db = remote_registry_fixture.db + + # Test the other methods orchestrated by the register_library() method. + registry = remote_registry_fixture.registry + registry.fetch_catalog = MagicMock(return_value=("register_url", "vendor_id")) + registry._create_registration_payload = MagicMock( + return_value={"payload": "this is it"} + ) + registry._create_registration_headers = MagicMock( + return_value=dict(Header="Value") + ) + registry._send_registration_request = MagicMock( + return_value=MockRequestsResponse(200, content=json.dumps("you did it!")) + ) + registry._process_registration_result = MagicMock(return_value=True) + + library = library_fixture.library() + stage = RegistrationStage.TESTING + url_for = MagicMock() + + register_library = partial(registry.register_library, library, stage, url_for) + + # Kick off the registration process, and make sure we get expected return. + result = register_library() + assert result is True + + # But there were many steps towards this result. + + # First, fetch_catalog() was called, in an attempt + # to find the registration URL inside the root catalog. + registry.fetch_catalog.assert_called_once() + + # fetch_catalog() returned a registration URL and + # a vendor ID. The registration URL was used later on... + # + # The vendor ID was set on the registration in the database. + registration = get_one( + db.session, DiscoveryServiceRegistration, library=library + ) + assert registration is not None + assert "vendor_id" == registration.vendor_id + + # _create_registration_payload was called to create the body + # of the registration request. + registry._create_registration_payload.assert_called_once_with( + library, stage, url_for + ) + + # _create_registration_headers was called to create the headers + # sent along with the request. + registry._create_registration_headers.assert_called_once() + + # Then _send_registration_request was called, POSTing the + # payload to "register_url", the registration URL we got earlier. + registry._send_registration_request.assert_called_once_with( + "register_url", {"Header": "Value"}, dict(payload="this is it") + ) + + # Finally, the return value of that method was loaded as JSON + # and passed into _process_registration_result, along with + # a cipher created from the private key. (That cipher would be used + # to decrypt anything the foreign site signed using this site's + # public key.) + registry._process_registration_result.assert_called_once() + ( + actual_registration, + message, + cipher, + actual_stage, + ) = registry._process_registration_result.call_args.args + assert registration == actual_registration + assert "you did it!" == message + assert cipher._key.export_key("DER") == library.private_key + assert actual_stage == stage + + # Now in reverse order, let's replace the mocked methods so + # that they raise ProblemError exceptions. This tests that if + # there is a failure at any stage, the ProblemError is + # propagated. + def create_exception(message: str) -> ProblemError: + return ProblemError(problem_detail=INVALID_REGISTRATION.detailed(message)) + + registry._process_registration_result = MagicMock( + side_effect=create_exception("could not process registration result") + ) + with pytest.raises(ProblemError) as excinfo: + register_library() + assert ( + "could not process registration result" + == excinfo.value.problem_detail.detail + ) + + registry._send_registration_request = MagicMock( + side_effect=create_exception("could not send registration request") + ) + with pytest.raises(ProblemError) as excinfo: + register_library() + assert ( + "could not send registration request" == excinfo.value.problem_detail.detail + ) + + registry._create_registration_payload = MagicMock( + side_effect=create_exception("could not create registration payload") + ) + with pytest.raises(ProblemError) as excinfo: + register_library() + assert ( + "could not create registration payload" + == excinfo.value.problem_detail.detail + ) + + registry.fetch_catalog = MagicMock( + side_effect=create_exception("could not fetch catalog") + ) + with pytest.raises(ProblemError) as excinfo: + register_library() + assert "could not fetch catalog" == excinfo.value.problem_detail.detail + + def test__create_registration_payload( + self, + remote_registry_fixture: RemoteRegistryFixture, + library_fixture: LibraryFixture, + ): + m = remote_registry_fixture.registry._create_registration_payload + + # Mock url_for to create good-looking callback URLs. + def url_for(controller, library_short_name, **kwargs): + return f"http://server/{library_short_name}/{controller}" + + # First, test with no configuration contact configured for the + # library. + library = library_fixture.library() + stage = RegistrationStage.PRODUCTION + expect_url = url_for( + "authentication_document", + library.short_name, + ) + expect_payload = dict(url=expect_url, stage=stage.value) + assert expect_payload == m(library, stage, url_for) + + # If a contact is configured, it shows up in the payload. + contact = "mailto:ohno@library.org" + settings = library_fixture.settings(library) + settings.configuration_contact_email_address = contact # type: ignore[assignment] + expect_payload["contact"] = contact + assert expect_payload == m(library, stage, url_for) + + def test_create_registration_headers( + self, remote_registry_fixture: RemoteRegistryFixture + ): + db = remote_registry_fixture.db + m = remote_registry_fixture.registry._create_registration_headers + + # If no shared secret is configured, no custom headers are provided. + registration = remote_registry_fixture.create_registration() + assert {} == m(registration) + + # If a shared secret is configured, it shows up as part of + # the Authorization header. + registration.shared_secret = "a secret" + assert {"Authorization": "Bearer a secret"} == m(registration) + + def test__send_registration_request( + self, remote_registry_fixture: RemoteRegistryFixture, requests_mock: Mocker + ): + + # If everything goes well, the return value of do_post is + # passed through. + url = "http://url.com" + requests_mock.post(url, text="all good") + payload = {"payload": "payload"} + headers = {"headers": ""} + m = remote_registry_fixture.registry._send_registration_request + + result = m(url, headers, payload) + assert "all good" == result.text + + # Error handling is expected to be handled by post_request + # raising a ProblemError exception. + + # The remote sends a 401 response with a problem detail. + requests_mock.post( + url, + status_code=401, + headers={"Content-Type": PROBLEM_DETAIL_JSON_MEDIA_TYPE}, + text=json.dumps(dict(detail="this is a problem detail")), + ) + with pytest.raises(ProblemError) as excinfo: + m(url, headers, payload) + assert REMOTE_INTEGRATION_FAILED.uri == excinfo.value.problem_detail.uri + assert excinfo.value.problem_detail.detail is not None + assert ( + 'Remote service returned a problem detail document: \'{"detail": "this is a problem detail"}\'' + in excinfo.value.problem_detail.detail + ) + + # The remote sends some other kind of 401 response. + requests_mock.post( + url, + status_code=401, + headers={"Content-Type": "text/html"}, + text="log in why don't you", + ) + with pytest.raises(ProblemError) as excinfo: + m(url, headers, payload) + + assert REMOTE_INTEGRATION_FAILED.uri == excinfo.value.problem_detail.uri + assert ( + '401 response from integration server: "log in why don\'t you"' + == excinfo.value.problem_detail.detail + ) + + def test__decrypt_shared_secret( + self, remote_registry_fixture: RemoteRegistryFixture + ): + key = RSA.generate(2048) + encryptor = PKCS1_OAEP.new(key) + + key2 = RSA.generate(2048) + encryptor2 = PKCS1_OAEP.new(key2) + + shared_secret = os.urandom(24) + encrypted_secret = base64.b64encode(encryptor.encrypt(shared_secret)).decode( + "utf-8" + ) + + # Success. + m = remote_registry_fixture.registry._decrypt_shared_secret + assert shared_secret == m(encryptor, encrypted_secret) + + # If we try to decrypt using the wrong key, a ProblemError is + # raised explaining the problem. + with pytest.raises(ProblemError) as excinfo: + m(encryptor2, encrypted_secret) + + assert SHARED_SECRET_DECRYPTION_ERROR.uri == excinfo.value.problem_detail.uri + assert excinfo.value.problem_detail.detail is not None + assert encrypted_secret in excinfo.value.problem_detail.detail + + def test__process_registration_result( + self, remote_registry_fixture: RemoteRegistryFixture, monkeypatch: MonkeyPatch + ): + db = remote_registry_fixture.db + m = remote_registry_fixture.registry._process_registration_result + stage = RegistrationStage.TESTING + encryptor = MagicMock() + + reg = MagicMock(spec=DiscoveryServiceRegistration) + + # Result must be a dictionary. + with pytest.raises(ProblemError) as excinfo: + m(reg, "not a dictionary", encryptor, stage) + + problem = excinfo.value.problem_detail + assert INTEGRATION_ERROR.uri == problem.uri + assert ( + "Remote service served 'not a dictionary', which I can't make sense of as an OPDS document." + == problem.detail + ) + + # When the result is empty, the registration is marked as successful. + result = m(reg, dict(), encryptor, stage) + assert result is True + reg.status = RegistrationStatus.SUCCESS + + # The stage field has been set to the requested value. + reg.stage = stage + + # Now try with a result that includes a short name, + # a shared secret, and a web client URL. + mock = MagicMock(return_value="👉 cleartext 👈".encode()) + monkeypatch.setattr(OpdsRegistrationService, "_decrypt_shared_secret", mock) + + catalog = dict( + metadata=dict(short_name="SHORT", shared_secret="ciphertext", id="uuid"), + links=[dict(href="http://web/library", rel="self", type="text/html")], + ) + result = m(reg, catalog, encryptor, RegistrationStage.PRODUCTION) + assert result is True + + # Short name is set. + assert reg.short_name == "SHORT" + + # Shared secret was decrypted, decoded from UTF-8 and is set. + mock.assert_called_once_with(encryptor, "ciphertext") + assert reg.shared_secret == "👉 cleartext 👈" + + # Web client URL is set. + assert reg.web_client == "http://web/library" + + assert reg.stage == RegistrationStage.PRODUCTION + + # Now simulate a problem decrypting the shared secret. + mock.side_effect = ProblemError(problem_detail=SHARED_SECRET_DECRYPTION_ERROR) + with pytest.raises(ProblemError) as excinfo: + m(reg, catalog, encryptor, stage) + + assert SHARED_SECRET_DECRYPTION_ERROR == excinfo.value.problem_detail + + +class TestLibraryRegistrationScript: + def test_do_run( + self, + db: DatabaseTransactionFixture, + library_fixture: LibraryFixture, + remote_registry_fixture: RemoteRegistryFixture, + ): + @dataclass + class Processed: + registry: OpdsRegistrationService + library: Library + stage: RegistrationStage + url_for: Callable[..., str] + + class Mock(LibraryRegistrationScript): + processed: List[Processed] = [] + + def process_library( # type: ignore[override] + self, + registry: OpdsRegistrationService, + library: Library, + stage: RegistrationStage, + url_for: Callable[..., str], + ): + self.processed.append(Processed(registry, library, stage, url_for)) + + script = Mock(db.session) + + base_url_setting = ConfigurationSetting.sitewide( + db.session, Configuration.BASE_URL_KEY + ) + base_url_setting.value = "http://test-circulation-manager/" + + library = library_fixture.library() + library2 = library_fixture.library() + + cmd_args = [ + str(library.short_name), + "--stage=testing", + "--registry-url=http://registry.com/", + ] + manager = MockCirculationManager(db.session) + script.do_run(cmd_args=cmd_args, manager=manager) + + # One library was processed. + processed = script.processed.pop() + assert [] == script.processed + assert library == processed.library + assert RegistrationStage.TESTING == processed.stage + + # Let's say the other library was earlier registered in production. + registration = remote_registry_fixture.create_registration(library2) + registration.stage = RegistrationStage.PRODUCTION + + # Now run the script again without specifying a particular + # library or the --stage argument. + script.do_run(cmd_args=["--registry-url=http://registry.com/"], manager=manager) + + # Every library was processed. + assert {library, library2} == {x.library for x in script.processed} + + # Since no stage was provided, each library was registered + # using the stage already associated with it. + assert {RegistrationStage.TESTING, RegistrationStage.PRODUCTION} == { + x.stage for x in script.processed + } + + # Every library was registered with the specified registry. + assert {"http://registry.com/", "http://registry.com/"} == { + x.registry.settings.url for x in script.processed + } + + def test_process_library( + self, + db: DatabaseTransactionFixture, + remote_registry_fixture: RemoteRegistryFixture, + library_fixture: LibraryFixture, + ): + """Test the things that might happen when process_library is called.""" + script = LibraryRegistrationScript(db.session) + library = library_fixture.library() + registry = remote_registry_fixture.registry + + # First, simulate success. + registry.register_library = MagicMock(return_value=True) + stage = MagicMock() + url_for = MagicMock() + assert script.process_library(registry, library, stage, url_for) is True + + # The stage and url_for values were passed into register_library() + registry.register_library.assert_called_once_with(library, stage, url_for) + + # Next, simulate an exception raised during register_library() + # This can happen in real situations, though the next case + # we'll test is more common. + registry.register_library = MagicMock(side_effect=Exception("boo")) + + # We get False rather than the exception being propagated. + # Useful information about the exception is added to the logs, + # where someone actually running the script will see it. + assert script.process_library(registry, library, stage, url_for) is False + + # Next, simulate register_library() returning a problem detail document. + registry.register_library = MagicMock( + side_effect=ProblemError(problem_detail=INVALID_INPUT.detailed("oops")) + ) + + result = script.process_library(registry, library, stage, url_for) + + # The problem document is returned. Useful information about + # the exception is also added to the logs, where someone + # actually running the script will see it. + assert isinstance(result, ProblemDetail) + assert INVALID_INPUT.uri == result.uri + assert "oops" == result.detail diff --git a/tests/api/test_adobe_vendor_id.py b/tests/api/test_adobe_vendor_id.py index 5eccce5523..7f8714652d 100644 --- a/tests/api/test_adobe_vendor_id.py +++ b/tests/api/test_adobe_vendor_id.py @@ -1,14 +1,20 @@ +from __future__ import annotations + import base64 import datetime +from typing import Type from unittest.mock import MagicMock import pytest from jwt import DecodeError, ExpiredSignatureError, InvalidIssuedAtError +from sqlalchemy import select from api.adobe_vendor_id import AuthdataUtility -from api.registration.constants import RegistrationConstants from core.config import CannotLoadConfiguration -from core.model import ConfigurationSetting, ExternalIntegration +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStatus, +) from core.util.datetime_helpers import datetime_utc, utc_now from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.library import LibraryFixture @@ -29,15 +35,14 @@ class TestAuthdataUtility: @pytest.mark.parametrize( "registration_status, authdata_utility_type", [ - (RegistrationConstants.SUCCESS_STATUS, AuthdataUtility), - (RegistrationConstants.FAILURE_STATUS, type(None)), - (None, type(None)), + (RegistrationStatus.SUCCESS, AuthdataUtility), + (RegistrationStatus.FAILURE, type(None)), ], ) def test_eligible_authdata_vendor_id_integrations( self, - registration_status, - authdata_utility_type, + registration_status: RegistrationStatus, + authdata_utility_type: Type[AuthdataUtility] | Type[None], authdata: AuthdataUtility, vendor_id_fixture: VendorIDFixture, ): @@ -45,19 +50,7 @@ def test_eligible_authdata_vendor_id_integrations( # a given library is eligible to provide an AuthdataUtility. library = vendor_id_fixture.db.default_library() vendor_id_fixture.initialize_adobe(library) - registry = ExternalIntegration.lookup( - vendor_id_fixture.db.session, - ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, - library=library, - ) - ConfigurationSetting.for_library_and_externalintegration( - vendor_id_fixture.db.session, - RegistrationConstants.LIBRARY_REGISTRATION_STATUS, - library, - registry, - ).value = registration_status - + vendor_id_fixture.registration.status = registration_status utility = AuthdataUtility.from_config(library) assert isinstance(utility, authdata_utility_type) @@ -76,33 +69,19 @@ def test_from_config( assert utility is not None assert library.short_name is not None - registry = ExternalIntegration.lookup( - vendor_id_fixture.db.session, - ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, - library=library, - ) - assert ( - library.short_name + "token" - == ConfigurationSetting.for_library_and_externalintegration( - vendor_id_fixture.db.session, - ExternalIntegration.USERNAME, - library, - registry, - ).value - ) - assert ( - library.short_name + " token secret" - == ConfigurationSetting.for_library_and_externalintegration( - vendor_id_fixture.db.session, - ExternalIntegration.PASSWORD, - library, - registry, - ).value - ) + registration = vendor_id_fixture.db.session.scalars( + select(DiscoveryServiceRegistration).where( + DiscoveryServiceRegistration.library_id == library.id, + DiscoveryServiceRegistration.integration_id + == vendor_id_fixture.registry.id, + ) + ).first() + assert registration is not None + assert registration.short_name == library.short_name + "token" + assert registration.shared_secret == library.short_name + " token secret" - assert VendorIDFixture.TEST_VENDOR_ID == utility.vendor_id - assert library_url == utility.library_uri + assert utility.vendor_id == VendorIDFixture.TEST_VENDOR_ID + assert utility.library_uri == library_url # If the Library object is disconnected from its database # session, as may happen in production... @@ -125,16 +104,10 @@ def test_from_config( # If an integration is set up but incomplete, from_config # raises CannotLoadConfiguration. - setting = ConfigurationSetting.for_library_and_externalintegration( - vendor_id_fixture.db.session, - ExternalIntegration.USERNAME, - library, - registry, - ) - old_short_name = setting.value - setting.value = None + old_short_name = registration.short_name + registration.short_name = None pytest.raises(CannotLoadConfiguration, AuthdataUtility.from_config, library) - setting.value = old_short_name + registration.short_name = old_short_name library_settings = library_fixture.settings(library) old_website = library_settings.website @@ -142,20 +115,14 @@ def test_from_config( pytest.raises(CannotLoadConfiguration, AuthdataUtility.from_config, library) library_settings.website = old_website - setting = ConfigurationSetting.for_library_and_externalintegration( - vendor_id_fixture.db.session, - ExternalIntegration.PASSWORD, - library, - registry, - ) - old_secret = setting.value - setting.value = None + old_secret = registration.shared_secret + registration.shared_secret = None pytest.raises(CannotLoadConfiguration, AuthdataUtility.from_config, library) - setting.value = old_secret + registration.shared_secret = old_secret # If there is no Adobe Vendor ID integration set up, # from_config() returns None. - vendor_id_fixture.db.session.delete(registry) + vendor_id_fixture.db.session.delete(registration) assert AuthdataUtility.from_config(library) is None def test_short_client_token_for_patron( diff --git a/tests/api/test_controller_cm.py b/tests/api/test_controller_cm.py index 4166e0f96c..f97796b238 100644 --- a/tests/api/test_controller_cm.py +++ b/tests/api/test_controller_cm.py @@ -6,21 +6,26 @@ from api.custom_index import CustomIndexView from api.opds import CirculationManagerAnnotator, LibraryAnnotator from api.problem_details import * -from api.registration.registry import Registration from core.external_search import MockExternalSearchIndex from core.lane import Facets, WorkList -from core.model import Admin, CachedFeed, ConfigurationSetting, ExternalIntegration +from core.model import Admin, CachedFeed, ConfigurationSetting, create +from core.model.discovery_service_registration import DiscoveryServiceRegistration from core.problem_details import * from core.util.problem_detail import ProblemDetail # TODO: we can drop this when we drop support for Python 3.6 and 3.7 from tests.fixtures.api_controller import CirculationControllerFixture +from tests.fixtures.database import IntegrationConfigurationFixture class TestCirculationManager: """Test the CirculationManager object itself.""" - def test_load_settings(self, circulation_fixture: CirculationControllerFixture): + def test_load_settings( + self, + circulation_fixture: CirculationControllerFixture, + create_integration_configuration: IntegrationConfigurationFixture, + ): # Here's a CirculationManager which we've been using for a while. manager = circulation_fixture.manager @@ -62,15 +67,16 @@ def mock_for_library(incoming_library): ConfigurationSetting.sitewide( circulation_fixture.db.session, Configuration.PATRON_WEB_HOSTNAMES ).value = "http://sitewide/1234" - registry = circulation_fixture.db.external_integration( - protocol="some protocol", goal=ExternalIntegration.DISCOVERY_GOAL - ) - ConfigurationSetting.for_library_and_externalintegration( + + # And a discovery service registration, that sets a web client url. + registry = create_integration_configuration.discovery_service() + create( circulation_fixture.db.session, - Registration.LIBRARY_REGISTRATION_WEB_CLIENT, - library, - registry, - ).value = "http://registration" + DiscoveryServiceRegistration, + library=library, + integration=registry, + web_client="http://registration", + ) ConfigurationSetting.sitewide( circulation_fixture.db.session, diff --git a/tests/api/test_controller_playtime_entries.py b/tests/api/test_controller_playtime_entries.py index cd5d080b4c..86a7f7b87d 100644 --- a/tests/api/test_controller_playtime_entries.py +++ b/tests/api/test_controller_playtime_entries.py @@ -263,5 +263,6 @@ def test_api_validation(self, circulation_fixture: CirculationControllerFixture) ) assert isinstance(response, ProblemDetail) assert response.status_code == 400 + assert response.detail is not None assert "timeEntries" in response.detail assert "field required" in response.detail diff --git a/tests/api/test_marc.py b/tests/api/test_marc.py index 4ace5a537f..b11dbacf06 100644 --- a/tests/api/test_marc.py +++ b/tests/api/test_marc.py @@ -5,11 +5,14 @@ from pymarc import Record from api.marc import LibraryAnnotator -from api.registration.registry import Registration from core.config import Configuration from core.marc import MARCExporter -from core.model import ConfigurationSetting, ExternalIntegration -from tests.fixtures.database import DatabaseTransactionFixture +from core.model import ConfigurationSetting, ExternalIntegration, create +from core.model.discovery_service_registration import DiscoveryServiceRegistration +from tests.fixtures.database import ( + DatabaseTransactionFixture, + IntegrationConfigurationFixture, +) class TestLibraryAnnotator: @@ -142,7 +145,11 @@ def add_formats(self, record, pool): assert [record, pool] == annotator.called_with.get("add_distributor") assert [record, pool] == annotator.called_with.get("add_formats") - def test_add_web_client_urls(self, db: DatabaseTransactionFixture): + def test_add_web_client_urls( + self, + db: DatabaseTransactionFixture, + create_integration_configuration: IntegrationConfigurationFixture, + ): # Web client URLs can come from either the MARC export integration or # a library registry integration. @@ -195,17 +202,14 @@ def test_add_web_client_urls(self, db: DatabaseTransactionFixture): assert [] == record.get_fields("856") # Add a URL from a library registry. - registry = db.external_integration( - ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, - libraries=[db.default_library()], - ) - ConfigurationSetting.for_library_and_externalintegration( + registry = create_integration_configuration.discovery_service() + create( db.session, - Registration.LIBRARY_REGISTRATION_WEB_CLIENT, - db.default_library(), - registry, - ).value = client_base_1 + DiscoveryServiceRegistration, + library=db.default_library(), + integration=registry, + web_client=client_base_1, + ) record = Record() annotator.add_web_client_urls(record, db.default_library(), identifier) diff --git a/tests/api/test_opds.py b/tests/api/test_opds.py index 9eb25d3d78..2da5851593 100644 --- a/tests/api/test_opds.py +++ b/tests/api/test_opds.py @@ -33,7 +33,6 @@ from core.lcp.credential import LCPCredentialFactory, LCPHashedPassphrase from core.model import ( CirculationEvent, - ConfigurationSetting, Contributor, DataSource, DeliveryMechanism, @@ -649,15 +648,8 @@ def test_adobe_id_tags_when_vendor_id_configured( # If the Adobe Vendor ID configuration is present but # incomplete, adobe_id_tags does nothing. - # Delete one setting from the existing integration to check - # this. - setting = ConfigurationSetting.for_library_and_externalintegration( - annotator_fixture.db.session, - ExternalIntegration.USERNAME, - library, - vendor_id_fixture.registry, - ) - annotator_fixture.db.session.delete(setting) + # Delete one setting from the registration to check this. + vendor_id_fixture.registration.short_name = None assert [] == annotator_fixture.annotator.adobe_id_tags("new identifier") def test_lcp_acquisition_link_contains_hashed_passphrase( @@ -1241,12 +1233,8 @@ def test_active_loan_feed( == licensor.attrib["{http://librarysimplified.org/terms/drm}vendor"] ) [client_token] = licensor - expected = ConfigurationSetting.for_library_and_externalintegration( - annotator_fixture.db.session, - ExternalIntegration.USERNAME, - annotator_fixture.db.default_library(), - vendor_id_fixture.registry, - ).value.upper() + assert vendor_id_fixture.registration.short_name is not None + expected = vendor_id_fixture.registration.short_name.upper() assert client_token.text.startswith(expected) assert adobe_patron_identifier in client_token.text diff --git a/tests/api/test_registry.py b/tests/api/test_registry.py deleted file mode 100644 index 4024adf645..0000000000 --- a/tests/api/test_registry.py +++ /dev/null @@ -1,944 +0,0 @@ -import base64 -import json -import os - -import pytest -from Crypto.Cipher import PKCS1_OAEP -from Crypto.PublicKey import RSA - -from api.adobe_vendor_id import AuthdataUtility -from api.config import Configuration -from api.problem_details import * -from api.registration.registry import ( - LibraryRegistrationScript, - Registration, - RemoteRegistry, -) -from core.model import ConfigurationSetting, ExternalIntegration -from core.util.http import HTTP -from core.util.problem_detail import JSON_MEDIA_TYPE as PROBLEM_DETAIL_JSON_MEDIA_TYPE -from core.util.problem_detail import ProblemDetail -from tests.api.mockapi.circulation import MockCirculationManager -from tests.core.mock import DummyHTTPClient, MockRequestsResponse -from tests.fixtures.database import DatabaseTransactionFixture -from tests.fixtures.library import LibraryFixture - - -class RemoteRegistryFixture: - def __init__(self, db: DatabaseTransactionFixture): - self.db = db - # Create an ExternalIntegration that can be used as the basis for - # a RemoteRegistry. - self.integration = db.external_integration( - protocol="some protocol", goal=ExternalIntegration.DISCOVERY_GOAL - ) - - -@pytest.fixture(scope="function") -def remote_registry_fixture(db: DatabaseTransactionFixture) -> RemoteRegistryFixture: - return RemoteRegistryFixture(db) - - -class TestRemoteRegistry: - def test_constructor(self, remote_registry_fixture: RemoteRegistryFixture): - registry = RemoteRegistry(remote_registry_fixture.integration) - assert remote_registry_fixture.integration == registry.integration - - def test_for_integration_id(self, remote_registry_fixture: RemoteRegistryFixture): - """Test the ability to build a Registry for an ExternalIntegration - given its ID. - """ - db = remote_registry_fixture.db - m = RemoteRegistry.for_integration_id - - registry = m( - db.session, - remote_registry_fixture.integration.id, - ExternalIntegration.DISCOVERY_GOAL, - ) - assert isinstance(registry, RemoteRegistry) - assert remote_registry_fixture.integration == registry.integration - - # If the ID doesn't exist you get None. - assert None == m(db.session, -1, ExternalIntegration.DISCOVERY_GOAL) - - # If the integration's goal doesn't match what you provided, - # you get None. - assert None == m( - db.session, remote_registry_fixture.integration.id, "some other goal" - ) - - def test_for_protocol_and_goal( - self, remote_registry_fixture: RemoteRegistryFixture - ): - db = remote_registry_fixture.db - - # Create two ExternalIntegrations that have different protocols - # or goals from our original. - same_goal_different_protocol = db.external_integration( - protocol="some other protocol", - goal=remote_registry_fixture.integration.goal, - ) - - same_protocol_different_goal = db.external_integration( - protocol=remote_registry_fixture.integration.protocol, - goal="some other goal", - ) - - # Only the original ExternalIntegration has both the requested - # protocol and goal, so only it becomes a RemoteRegistry. - [registry] = list( - RemoteRegistry.for_protocol_and_goal( - db.session, - remote_registry_fixture.integration.protocol, - remote_registry_fixture.integration.goal, - ) - ) - assert isinstance(registry, RemoteRegistry) - assert remote_registry_fixture.integration == registry.integration - - def test_for_protocol_goal_and_url( - self, remote_registry_fixture: RemoteRegistryFixture - ): - db = remote_registry_fixture.db - protocol = db.fresh_str() - goal = db.fresh_str() - url = db.fresh_url() - m = RemoteRegistry.for_protocol_goal_and_url - - registry = m(db.session, protocol, goal, url) - assert isinstance(registry, RemoteRegistry) - - # A new ExternalIntegration was created. - integration = registry.integration - assert protocol == integration.protocol - assert goal == integration.goal - assert url == integration.url - - # Calling the method again doesn't create a second - # ExternalIntegration. - registry2 = m(db.session, protocol, goal, url) - assert registry2.integration == integration - - def test_registrations(self, remote_registry_fixture: RemoteRegistryFixture): - db = remote_registry_fixture.db - registry = RemoteRegistry(remote_registry_fixture.integration) - - # Associate the default library with the registry. - Registration(registry, db.default_library()) - - # Create another library not associated with the registry. - library2 = db.library() - - # registrations() finds a single Registration. - [registration] = list(registry.registrations) - assert isinstance(registration, Registration) - assert registry == registration.registry - assert db.default_library() == registration.library - - def test_fetch_catalog(self, remote_registry_fixture: RemoteRegistryFixture): - db = remote_registry_fixture.db - - # Test our ability to retrieve essential information from a - # remote registry's root catalog. - class Mock(RemoteRegistry): - def _extract_catalog_information(self, response): - self.extracted_from = response - return "Essential information" - - # The behavior of fetch_catalog() depends on what comes back - # when we ask the remote registry for its root catalog. - client = DummyHTTPClient() - - # If the result is a problem detail document, that document is - # the return value of fetch_catalog(). - problem = REMOTE_INTEGRATION_FAILED.detailed("oops") - client.responses.append(problem) - registry = Mock(remote_registry_fixture.integration) - result = registry.fetch_catalog(do_get=client.do_get) - assert remote_registry_fixture.integration.url == client.requests.pop() - assert problem == result - - # If the response looks good, it's passed into - # _extract_catalog_information(), and the result of _that_ - # method is the return value of fetch_catalog. - client.queue_requests_response(200, content="A root catalog") - [queued] = client.responses - assert "Essential information" == registry.fetch_catalog( - "custom catalog URL", do_get=client.do_get - ) - assert "custom catalog URL" == client.requests.pop() - - def test__extract_catalog_information( - self, remote_registry_fixture: RemoteRegistryFixture - ): - # Test our ability to extract a registration link and an - # Adobe Vendor ID from an OPDS 1 or OPDS 2 catalog. - def extract(document, type=RemoteRegistry.OPDS_2_TYPE): - response = MockRequestsResponse(200, {"Content-Type": type}, document) - return RemoteRegistry._extract_catalog_information(response) - - def assert_no_link(*args, **kwargs): - """Verify that calling _extract_catalog_information on the - given feed fails because there is no link with rel="register" - """ - result = extract(*args, **kwargs) - assert REMOTE_INTEGRATION_FAILED.uri == result.uri - assert ( - "The service at http://url/ did not provide a register link." - == result.detail - ) - - # OPDS 2 feed with link and Adobe Vendor ID. - link = {"rel": "register", "href": "register url"} - metadata = {"adobe_vendor_id": "vendorid"} - feed = json.dumps(dict(links=[link], metadata=metadata)) - assert ("register url", "vendorid") == extract(feed) - - # OPDS 2 feed with link and no Adobe Vendor ID - feed = json.dumps(dict(links=[link])) - assert ("register url", None) == extract(feed) - - # OPDS 2 feed with no link. - feed = json.dumps(dict(metadata=metadata)) - assert_no_link(feed) - - # OPDS 1 feed with link. - feed = '' - assert ("register url", None) == extract( - feed, RemoteRegistry.OPDS_1_PREFIX + ";foo" - ) - - # OPDS 1 feed with no link. - feed = "" - assert_no_link(feed, RemoteRegistry.OPDS_1_PREFIX + ";foo") - - # Non-OPDS document. - result = extract("plain text here", "text/plain") - assert REMOTE_INTEGRATION_FAILED.uri == result.uri - assert "The service at http://url/ did not return OPDS." == result.detail - - def test_fetch_registration_document( - self, remote_registry_fixture: RemoteRegistryFixture - ): - # Test our ability to retrieve terms-of-service information - # from a remote registry, assuming the registry makes that - # information available. - - # First, test the case where we can't even get the catalog - # document. - class Mock0(RemoteRegistry): - def fetch_catalog(self, do_get): - self.fetch_catalog_called_with = do_get - return REMOTE_INTEGRATION_FAILED - - registry = Mock0(object()) - result = registry.fetch_registration_document() - - # Our mock fetch_catalog was called with a method that would - # have made a real HTTP request. - assert HTTP.debuggable_get == registry.fetch_catalog_called_with - - # But the fetch_catalog method returned a problem detail, - # which became the return value of - # fetch_registration_document. - assert REMOTE_INTEGRATION_FAILED == result - - # Test the case where we get the catalog document but we can't - # get the registration document. - client = DummyHTTPClient() - client.responses.append(REMOTE_INTEGRATION_FAILED) - - class Mock1(RemoteRegistry): - def fetch_catalog(self, do_get): - return "http://register-here/", "vendor id" - - def _extract_registration_information(self, response): - self._extract_registration_information_called_with = response - return "TOS link", "TOS HTML data" - - registry1 = Mock1(object()) - result = registry1.fetch_registration_document(client.do_get) - # A request was made to the registration URL mentioned in the catalog. - assert "http://register-here/" == client.requests.pop() - assert [] == client.requests - - # But the request returned a problem detail, which became the - # return value of the method. - assert REMOTE_INTEGRATION_FAILED == result - - # Finally, test the case where we can get both documents. - - client.queue_requests_response(200, content="a registration document") - result = registry1.fetch_registration_document(client.do_get) - - # Another request was made to the registration URL. - assert "http://register-here/" == client.requests.pop() - assert [] == client.requests - - # Our mock of _extract_registration_information was called - # with the mock response to that request. - response = registry1._extract_registration_information_called_with - assert b"a registration document" == response.content - - # The return value of _extract_registration_information was - # propagated as the return value of - # fetch_registration_document. - assert ("TOS link", "TOS HTML data") == result - - def test__extract_registration_information( - self, remote_registry_fixture: RemoteRegistryFixture - ): - # Test our ability to extract terms-of-service information -- - # a link and/or some HTML or textual instructions -- from a - # registration document. - - def data_link(data, type="text/html"): - encoded = base64.b64encode(data.encode("utf-8")).decode("utf-8") - return dict(rel="terms-of-service", href=f"data:{type};base64,{encoded}") - - class Mock(RemoteRegistry): - decoded: str - - @classmethod - def _decode_data_url(cls, url): - cls.decoded = url - return "Decoded: " + RemoteRegistry._decode_data_url(url) - - def extract(document, type=RemoteRegistry.OPDS_2_TYPE): - if type == RemoteRegistry.OPDS_2_TYPE: - document = json.dumps(dict(links=document)) - response = MockRequestsResponse(200, {"Content-Type": type}, document) - return Mock._extract_registration_information(response) - - # OPDS 2 feed with TOS in http: and data: links. - tos_link = dict(rel="terms-of-service", href="http://tos/") - tos_data = data_link("

Some HTML

") - assert ("http://tos/", "Decoded:

Some HTML

") == extract( - [tos_link, tos_data] - ) - - # At this point it's clear that the data: URL found in - # `tos_data` was run through `_decode_data()`. This gives us - # permission to test all the fiddly bits of `_decode_data` in - # isolation, below. - assert tos_data["href"] == Mock.decoded - - # OPDS 2 feed with http: link only. - assert ("http://tos/", None) == extract([tos_link]) - - # OPDS 2 feed with data: link only. - assert (None, "Decoded:

Some HTML

") == extract([tos_data]) - - # OPDS 2 feed with no links. - assert (None, None) == extract([]) - - # OPDS 1 feed with link. - feed = '' - assert ("http://tos/", None) == extract( - feed, RemoteRegistry.OPDS_1_PREFIX + ";foo" - ) - - # OPDS 1 feed with no link. - feed = "" - assert (None, None) == extract(feed, RemoteRegistry.OPDS_1_PREFIX + ";foo") - - # Non-OPDS document. - assert (None, None) == extract("plain text here", "text/plain") - - # Unrecognized URI schemes are ignored. - ftp_link = dict(rel="terms-of-service", href="ftp://tos/") - assert (None, None) == extract([ftp_link]) - - def test__decode_data_url(self, remote_registry_fixture: RemoteRegistryFixture): - # Test edge cases of decoding data: URLs. - m = RemoteRegistry._decode_data_url - - def data_url(data, type="text/html"): - encoded = base64.b64encode(data.encode("utf-8")).decode("utf-8") - return f"data:{type};base64,{encoded}" - - # HTML is okay. - html = data_url("some HTML", "text/html;charset=utf-8") - assert "some HTML" == m(html) - - # Plain text is okay. - text = data_url("some plain text", "text/plain") - assert "some plain text" == m(text) - - # No other media type is allowed. - image = data_url("an image!", "image/png") - with pytest.raises(ValueError) as excinfo: - m(image) - assert "Unsupported media type in data: URL: image/png" in str(excinfo.value) - - # Incoming HTML is sanitized. - dirty_html = data_url("

Some HTML

") - assert "

Some HTML

" == m(dirty_html) - - # Now test various malformed data: URLs. - no_header = "foobar" - with pytest.raises(ValueError) as excinfo: - m(no_header) - assert "Not a data: URL: foobar" in str(excinfo.value) - - no_comma = "data:blah" - with pytest.raises(ValueError) as excinfo: - m(no_comma) - assert "Invalid data: URL: data:blah" in str(excinfo.value) - - too_many_commas = "data:blah,blah,blah" - with pytest.raises(ValueError) as excinfo: - m(too_many_commas) - assert "Invalid data: URL: data:blah,blah,blah" in str(excinfo.value) - - # data: URLs don't have to be base64-encoded, but those are the - # only kind we support. - not_encoded = "data:blah,content" - with pytest.raises(ValueError) as excinfo: - m(not_encoded) - assert "data: URL not base64-encoded: data:blah,content" in str(excinfo.value) - - -class RegistrationFixture: - def __init__(self, db: DatabaseTransactionFixture): - self.db = db - # Create a RemoteRegistry. - self.integration = db.external_integration( - protocol="some protocol", goal="some goal" - ) - self.registry = RemoteRegistry(self.integration) - self.registration = Registration(self.registry, db.default_library()) - - -@pytest.fixture(scope="function") -def registration_fixture(db: DatabaseTransactionFixture) -> RegistrationFixture: - return RegistrationFixture(db) - - -class TestRegistration: - def test_constructor(self, registration_fixture: RegistrationFixture): - db = registration_fixture.db - - # The Registration constructor was called during setup to create - # self.registration. - reg = registration_fixture.registration - assert registration_fixture.registry == reg.registry - assert db.default_library() == reg.library - - settings = [x for x in reg.integration.settings if x.library is not None] - assert {reg.status_field, reg.stage_field, reg.web_client_field} == set( - settings - ) - assert Registration.FAILURE_STATUS == reg.status_field.value - assert Registration.TESTING_STAGE == reg.stage_field.value - assert None == reg.web_client_field.value - - # The Library has been associated with the ExternalIntegration. - assert [db.default_library()] == registration_fixture.integration.libraries - - # Creating another Registration doesn't add the library to the - # ExternalIntegration again or override existing values for the - # settings. - reg.status_field.value = "new status" - reg.stage_field.value = "new stage" - reg2 = Registration(registration_fixture.registry, db.default_library()) - assert [db.default_library()] == registration_fixture.integration.libraries - assert "new status" == reg2.status_field.value - assert "new stage" == reg2.stage_field.value - - def test_setting(self, registration_fixture: RegistrationFixture): - db = registration_fixture.db - m = registration_fixture.registration.setting - - def _find(key): - """Find a ConfigurationSetting associated with the library. - - This is necessary because ConfigurationSetting.value - creates _two_ ConfigurationSettings, one associated with - the library and one not associated with any library, to - store the default value. - """ - values = [ - x - for x in registration_fixture.registration.integration.settings - if x.library and x.key == key - ] - if len(values) == 1: - return values[0] - return None - - # Calling setting() creates a ConfigurationSetting object - # associated with the library. - setting = m("key") - assert "key" == setting.key - assert None == setting.value - assert db.default_library() == setting.library - assert setting == _find("key") - - # You can specify a default value, which is used only if the - # current value is None. - setting2 = m("key", "default") - assert setting == setting2 - assert "default" == setting.value - - setting3 = m("key", "default2") - assert setting == setting3 - assert "default" == setting.value - - def test_push(self, registration_fixture: RegistrationFixture): - db = registration_fixture.db - # Test the other methods orchestrated by the push() method. - - class MockRegistry(RemoteRegistry): - def fetch_catalog(self, catalog_url, do_get): - # Pretend to fetch a root catalog and extract a - # registration URL from it. - self.fetch_catalog_called_with = (catalog_url, do_get) - return "register_url", "vendor_id" - - class MockRegistration(Registration): - def _create_registration_payload(self, url_for, stage): - self.payload_ingredients = (url_for, stage) - return dict(payload="this is it") - - def _create_registration_headers(self): - self._create_registration_headers_called = True - return dict(Header="Value") - - def _send_registration_request( - self, register_url, headers, payload, do_post - ): - self._send_registration_request_called_with = ( - register_url, - headers, - payload, - do_post, - ) - return MockRequestsResponse(200, content=json.dumps("you did it!")) - - def _process_registration_result(self, catalog, encryptor, stage): - self._process_registration_result_called_with = ( - catalog, - encryptor, - stage, - ) - return "all done!" - - library = db.default_library() - registry = MockRegistry(registration_fixture.integration) - registration = MockRegistration(registry, library) - stage = Registration.TESTING_STAGE - url_for = object() - catalog_url = "http://catalog/" - do_get = object() - do_post = object() - - def push(): - return registration.push(stage, url_for, catalog_url, do_get, do_post) - - # Kick off the registration process, and make sure we get expected return. - result = push() - assert "all done!" == result - - # But there were many steps towards this result. - - # First, MockRegistry.fetch_catalog() was called, in an attempt - # to find the registration URL inside the root catalog. - assert (catalog_url, do_get) == registry.fetch_catalog_called_with - - # fetch_catalog() returned a registration URL and - # a vendor ID. The registration URL was used later on... - # - # The vendor ID was set as a ConfigurationSetting on - # the ExternalIntegration associated with this registry. - assert ( - "vendor_id" - == ConfigurationSetting.for_externalintegration( - AuthdataUtility.VENDOR_ID_KEY, registration_fixture.integration - ).value - ) - - # _create_registration_payload was called to create the body - # of the registration request. - assert (url_for, stage) == registration.payload_ingredients - - # _create_registration_headers was called to create the headers - # sent along with the request. - assert True == registration._create_registration_headers_called - - # Then _send_registration_request was called, POSTing the - # payload to "register_url", the registration URL we got earlier. - results = registration._send_registration_request_called_with - assert ( - "register_url", - {"Header": "Value"}, - dict(payload="this is it"), - do_post, - ) == results - - # Finally, the return value of that method was loaded as JSON - # and passed into _process_registration_result, along with - # a cipher created from the private key. (That cipher would be used - # to decrypt anything the foreign site signed using this site's - # public key.) - results = registration._process_registration_result_called_with - message, cipher, actual_stage = results - assert "you did it!" == message - assert cipher._key.export_key("DER") == library.private_key - assert actual_stage == stage - - # If a nonexistent stage is provided a ProblemDetail is the result. - result = registration.push( - "no such stage", url_for, catalog_url, do_get, do_post - ) - assert INVALID_INPUT.uri == result.uri - assert "'no such stage' is not a valid registration stage" == result.detail - - # Now in reverse order, let's replace the mocked methods so - # that they return ProblemDetail documents. This tests that if - # there is a failure at any stage, the ProblemDetail is - # propagated. - - # The push() function will no longer push anything, so rename it. - cause_problem = push - - def fail0(*args, **kwargs): - return INVALID_REGISTRATION.detailed( - "could not process registration result" - ) - - registration._process_registration_result = fail0 # type: ignore - problem = cause_problem() - assert "could not process registration result" == problem.detail - - def fail1(*args, **kwargs): - return INVALID_REGISTRATION.detailed("could not send registration request") - - registration._send_registration_request = fail1 # type: ignore - problem = cause_problem() - assert "could not send registration request" == problem.detail - - def fail2(*args, **kwargs): - return INVALID_REGISTRATION.detailed( - "could not create registration payload" - ) - - registration._create_registration_payload = fail2 # type: ignore - problem = cause_problem() - assert "could not create registration payload" == problem.detail - - def fail3(*args, **kwargs): - return INVALID_REGISTRATION.detailed("could not fetch catalog") - - registry.fetch_catalog = fail3 # type: ignore - problem = cause_problem() - assert "could not fetch catalog" == problem.detail - - def test__create_registration_payload( - self, registration_fixture: RegistrationFixture, library_fixture: LibraryFixture - ): - m = registration_fixture.registration._create_registration_payload - - # Mock url_for to create good-looking callback URLs. - def url_for(controller, library_short_name, **kwargs): - return f"http://server/{library_short_name}/{controller}" - - # First, test with no configuration contact configured for the - # library. - stage = object() - expect_url = url_for( - "authentication_document", - registration_fixture.registration.library.short_name, - ) - expect_payload = dict(url=expect_url, stage=stage) - assert expect_payload == m(url_for, stage) - - # If a contact is configured, it shows up in the payload. - contact = "mailto:ohno@library.org" - settings = library_fixture.settings(registration_fixture.registration.library) - settings.configuration_contact_email_address = contact # type: ignore[assignment] - expect_payload["contact"] = contact - assert expect_payload == m(url_for, stage) - - def test_create_registration_headers( - self, registration_fixture: RegistrationFixture - ): - db = registration_fixture.db - m = registration_fixture.registration._create_registration_headers - # If no shared secret is configured, no custom headers are provided. - expect_headers = {} # type: ignore - assert expect_headers == m() - - # If a shared secret is configured, it shows up as part of - # the Authorization header. - setting = ConfigurationSetting.for_library_and_externalintegration( - db.session, - ExternalIntegration.PASSWORD, - registration_fixture.registration.library, - registration_fixture.registration.registry.integration, - ).value = "a secret" - expect_headers["Authorization"] = "Bearer a secret" - assert expect_headers == m() - - def test__send_registration_request( - self, registration_fixture: RegistrationFixture - ): - class Mock: - def __init__(self, response): - self.response = response - - def do_post(self, url, payload, **kwargs): - self.called_with = (url, payload, kwargs) - return self.response - - # If everything goes well, the return value of do_post is - # passed through. - mock = Mock(MockRequestsResponse(200, content="all good")) - url = "url" - payload = "payload" - headers = "headers" - m = Registration._send_registration_request - result = m(url, headers, payload, mock.do_post) - assert mock.response == result - called_with = mock.called_with - assert called_with == ( - url, - payload, - dict( - headers=headers, - timeout=60, - allowed_response_codes=["2xx", "3xx", "400", "401"], - ), - ) - - # Most error handling is expected to be handled by do_post - # raising an exception, but certain responses get special - # treatment: - - # The remote sends a 401 response with a problem detail. - mock = Mock( - MockRequestsResponse( - 401, - {"Content-Type": PROBLEM_DETAIL_JSON_MEDIA_TYPE}, - content=json.dumps(dict(detail="this is a problem detail")), - ) - ) - result = m(url, headers, payload, mock.do_post) - assert isinstance(result, ProblemDetail) - assert REMOTE_INTEGRATION_FAILED.uri == result.uri - assert 'Remote service returned: "this is a problem detail"' == result.detail - - # The remote sends some other kind of 401 response. - mock = Mock( - MockRequestsResponse( - 401, {"Content-Type": "text/html"}, content="log in why don't you" - ) - ) - result = m(url, headers, payload, mock.do_post) - assert isinstance(result, ProblemDetail) - assert REMOTE_INTEGRATION_FAILED.uri == result.uri - assert 'Remote service returned: "log in why don\'t you"' == result.detail - - def test__decrypt_shared_secret(self, registration_fixture: RegistrationFixture): - key = RSA.generate(2048) - encryptor = PKCS1_OAEP.new(key) - - key2 = RSA.generate(2048) - encryptor2 = PKCS1_OAEP.new(key2) - - # NOTE: In a real case, shared_secret must be UTF-8 string, - # but this particular method will work even if it's not. - shared_secret = os.urandom(24) - encrypted_secret = base64.b64encode(encryptor.encrypt(shared_secret)) - - # Success. - m = Registration._decrypt_shared_secret - assert shared_secret == m(encryptor, encrypted_secret) - - # If we try to decrypt using the wrong key, a ProblemDetail is - # returned explaining the problem. - problem = m(encryptor2, encrypted_secret) - assert isinstance(problem, ProblemDetail) - assert SHARED_SECRET_DECRYPTION_ERROR.uri == problem.uri - assert encrypted_secret.decode("utf-8") in problem.detail - - def test__process_registration_result( - self, registration_fixture: RegistrationFixture - ): - db = registration_fixture.db - reg = registration_fixture.registration - m = reg._process_registration_result - - # Result must be a dictionary. - result = m("not a dictionary", None, None) - assert INTEGRATION_ERROR.uri == result.uri - assert ( - "Remote service served 'not a dictionary', which I can't make sense of as an OPDS document." - == result.detail - ) - - # When the result is empty, the registration is marked as successful. - new_stage = "new stage" - encryptor = object() - result = m(dict(), encryptor, new_stage) - assert True == result - assert reg.SUCCESS_STATUS == reg.status_field.value - - # The stage field has been set to the requested value. - assert new_stage == reg.stage_field.value - - # Now try with a result that includes a short name, - # a shared secret, and a web client URL. - - class Mock0(Registration): - def _decrypt_shared_secret(self, encryptor, shared_secret): - self._decrypt_shared_secret_called_with = (encryptor, shared_secret) - return "👉 cleartext 👈".encode() - - reg = Mock0(registration_fixture.registry, db.default_library()) - catalog = dict( - metadata=dict(short_name="SHORT", shared_secret="ciphertext", id="uuid"), - links=[dict(href="http://web/library", rel="self", type="text/html")], - ) - result = reg._process_registration_result( - catalog, encryptor, "another new stage" - ) - assert True == result - - # Short name is set. - assert "SHORT" == reg.setting(ExternalIntegration.USERNAME).value - - # Shared secret was decrypted, decoded from UTF-8 and is set. - assert (encryptor, "ciphertext") == reg._decrypt_shared_secret_called_with - assert "👉 cleartext 👈" == reg.setting(ExternalIntegration.PASSWORD).value - - # Web client URL is set. - assert ( - "http://web/library" - == reg.setting(reg.LIBRARY_REGISTRATION_WEB_CLIENT).value - ) - - assert "another new stage" == reg.stage_field.value - - # Now simulate a problem decrypting the shared secret. - class Mock1(Registration): - def _decrypt_shared_secret(self, encryptor, shared_secret): - return SHARED_SECRET_DECRYPTION_ERROR - - reg = Mock1(registration_fixture.registry, db.default_library()) - result = reg._process_registration_result( - catalog, encryptor, "another new stage" - ) - assert SHARED_SECRET_DECRYPTION_ERROR == result - - -class TestLibraryRegistrationScript: - def test_do_run(self, db: DatabaseTransactionFixture): - class Mock(LibraryRegistrationScript): - processed = [] - - def process_library(self, *args): - self.processed.append(args) - - script = Mock(db.session) - - base_url_setting = ConfigurationSetting.sitewide( - db.session, Configuration.BASE_URL_KEY - ) - base_url_setting.value = "http://test-circulation-manager/" - - library = db.default_library() - library2 = db.library() - - cmd_args = [ - library.short_name, - "--stage=testing", - "--registry-url=http://registry/", - ] - manager = MockCirculationManager(db.session) - script.do_run(cmd_args=cmd_args, manager=manager) - - # One library was processed. - (registration, stage, url_for) = script.processed.pop() - assert [] == script.processed - assert library == registration.library - assert Registration.TESTING_STAGE == stage - - # A new ExternalIntegration was created for the newly defined - # registry at http://registry/. - assert "http://registry/" == registration.integration.url - - # Let's say the other library was earlier registered in production. - registration_2 = Registration(registration.registry, library2) - registration_2.stage_field.value = Registration.PRODUCTION_STAGE - - # Now run the script again without specifying a particular - # library or the --stage argument. - script.do_run(cmd_args=[], manager=manager) - - # Every library was processed. - assert {library, library2} == {x[0].library for x in script.processed} - - for i in script.processed: - # Since no stage was provided, each library was registered - # using the stage already associated with it. - assert i[0].stage_field.value == i[1] - - # Every library was registered with the default - # library registry. - assert RemoteRegistry.DEFAULT_LIBRARY_REGISTRY_URL == i[0].integration.url - - def test_process_library(self, db: DatabaseTransactionFixture): - """Test the things that might happen when process_library is called.""" - script = LibraryRegistrationScript(db.session) - library = db.default_library() - integration = db.external_integration( - protocol="some protocol", goal=ExternalIntegration.DISCOVERY_GOAL - ) - registry = RemoteRegistry(integration) - - # First, simulate success. - class Success(Registration): - def push(self, stage, url_for): - self.pushed = (stage, url_for) - return True - - registration = Success(registry, library) - - stage = object() - url_for = object() - assert True == script.process_library(registration, stage, url_for) - - # The stage and url_for values were passed into - # Registration.push() - assert (stage, url_for) == registration.pushed - - # Next, simulate an exception raised during push() - # This can happen in real situations, though the next case - # we'll test is more common. - class FailsWithException(Registration): - def push(self, stage, url_for): - raise Exception("boo") - - registration_fail_exception = FailsWithException(registry, library) - # We get False rather than the exception being propagated. - # Useful information about the exception is added to the logs, - # where someone actually running the script will see it. - assert False == script.process_library( - registration_fail_exception, stage, url_for - ) - - # Next, simulate push() returning a problem detail document. - class FailsWithProblemDetail(Registration): - def push(self, stage, url_for): - return INVALID_INPUT.detailed("oops") - - registration_fail_problem = FailsWithProblemDetail(registry, library) - result = script.process_library(registration_fail_problem, stage, url_for) - - # The problem document is returned. Useful information about - # the exception is also added to the logs, where someone - # actually running the script will see it. - assert INVALID_INPUT.uri == result.uri - assert "oops" == result.detail diff --git a/tests/core/models/test_admin.py b/tests/core/models/test_admin.py index bb6df885b9..6598a5cff6 100644 --- a/tests/core/models/test_admin.py +++ b/tests/core/models/test_admin.py @@ -216,6 +216,7 @@ def test_validate_reset_password_token_and_fetch_admin( ) assert isinstance(expired_token, ProblemDetail) assert expired_token.uri == INVALID_RESET_PASSWORD_TOKEN.uri + assert expired_token.detail is not None assert "expired" in expired_token.detail # Valid token but invalid admin id - unsuccessful validation diff --git a/tests/core/models/test_discovery_service_registration.py b/tests/core/models/test_discovery_service_registration.py new file mode 100644 index 0000000000..f953d01ac5 --- /dev/null +++ b/tests/core/models/test_discovery_service_registration.py @@ -0,0 +1,91 @@ +from typing import Optional + +import pytest +from sqlalchemy import select + +from core.integration.goals import Goals +from core.model import IntegrationConfiguration, Library, create +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStage, + RegistrationStatus, +) +from tests.fixtures.database import ( + DatabaseTransactionFixture, + IntegrationConfigurationFixture, +) +from tests.fixtures.library import LibraryFixture + + +class RegistrationFixture: + def __call__( + self, + library: Optional[Library] = None, + integration: Optional[IntegrationConfiguration] = None, + ) -> DiscoveryServiceRegistration: + library = library or self.library_fixture.library() + integration = integration or self.integration_fixture( + "test", Goals.DISCOVERY_GOAL + ) + registration, _ = create( + self.db.session, + DiscoveryServiceRegistration, + library=library, + integration=integration, + ) + return registration + + def __init__( + self, + db: DatabaseTransactionFixture, + library_fixture: LibraryFixture, + integration_fixture: IntegrationConfigurationFixture, + ) -> None: + self.db = db + self.library_fixture = library_fixture + self.integration_fixture = integration_fixture + + +@pytest.fixture +def registration_fixture( + db: DatabaseTransactionFixture, + library_fixture: LibraryFixture, + create_integration_configuration: IntegrationConfigurationFixture, +) -> RegistrationFixture: + return RegistrationFixture(db, library_fixture, create_integration_configuration) + + +class TestDiscoveryServiceRegistration: + def test_constructor(self, registration_fixture: RegistrationFixture): + registration = registration_fixture() + + # We get default values for status and stage. + assert registration.status == RegistrationStatus.FAILURE + assert registration.stage == RegistrationStage.TESTING + + assert registration.web_client is None + + @pytest.mark.parametrize( + "parent", + [ + "library", + "integration", + ], + ) + def test_registration_deleted_when_parent_deleted( + self, + db: DatabaseTransactionFixture, + registration_fixture: RegistrationFixture, + parent: str, + ): + registration = registration_fixture() + + registrations = db.session.execute(select(DiscoveryServiceRegistration)).all() + assert len(registrations) == 1 + + parent = getattr(registration, parent) + db.session.delete(parent) + db.session.flush() + + registrations = db.session.execute(select(DiscoveryServiceRegistration)).all() + assert len(registrations) == 0 diff --git a/tests/core/util/test_http.py b/tests/core/util/test_http.py index 45b8355333..b615e53364 100644 --- a/tests/core/util/test_http.py +++ b/tests/core/util/test_http.py @@ -13,7 +13,7 @@ RequestNetworkException, RequestTimedOut, ) -from core.util.problem_detail import ProblemDetail +from core.util.problem_detail import ProblemDetail, ProblemError from tests.core.mock import MockRequestsResponse @@ -264,16 +264,20 @@ def test_process_debuggable_response(self): success = MockRequestsResponse(302, content="Success!") assert success == m("url", success) - # An error is turned into a detailed ProblemDetail + # An error is turned into a ProblemError error = MockRequestsResponse(500, content="Error!") - problem = m("url", error) + with pytest.raises(ProblemError) as excinfo: + m("url", error) + problem = excinfo.value.problem_detail assert isinstance(problem, ProblemDetail) assert INTEGRATION_ERROR.uri == problem.uri - assert "500 response from integration server: 'Error!'" == problem.detail + assert '500 response from integration server: "Error!"' == problem.detail content, status_code, headers = INVALID_INPUT.response error = MockRequestsResponse(status_code, headers, content) - problem = m("url", error) + with pytest.raises(ProblemError) as excinfo: + m("url", error) + problem = excinfo.value.problem_detail assert isinstance(problem, ProblemDetail) assert INTEGRATION_ERROR.uri == problem.uri assert ( diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index b874ff578a..647caf9412 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -18,6 +18,8 @@ from sqlalchemy.orm import Session import core.lane +from api.discovery.opds_registration import OpdsRegistrationService +from api.integration.registry.discovery import DiscoveryRegistry from core.analytics import Analytics from core.classifier import Classifier from core.config import Configuration @@ -1031,6 +1033,24 @@ def __call__( ) return integration + def discovery_service( + self, protocol: Optional[str] = None, url: Optional[str] = None + ) -> IntegrationConfiguration: + registry = DiscoveryRegistry() + if protocol is None: + protocol = registry.get_protocol(OpdsRegistrationService) + assert protocol is not None + + if url is not None: + settings_obj = registry[protocol].settings_class().construct(url=url) # type: ignore[arg-type] + settings_dict = settings_obj.dict() + else: + settings_dict = {} + + return self( + protocol=protocol, goal=Goals.DISCOVERY_GOAL, settings_dict=settings_dict + ) + @pytest.fixture def create_integration_configuration( diff --git a/tests/fixtures/vendor_id.py b/tests/fixtures/vendor_id.py index 8dbb6c84e9..ae2c192d21 100644 --- a/tests/fixtures/vendor_id.py +++ b/tests/fixtures/vendor_id.py @@ -2,10 +2,15 @@ import pytest -from api.adobe_vendor_id import AuthdataUtility -from api.registration.constants import RegistrationConstants -from core.model import ConfigurationSetting, ExternalIntegration, Library -from tests.fixtures.database import DatabaseTransactionFixture +from core.model import IntegrationConfiguration, Library, create +from core.model.discovery_service_registration import ( + DiscoveryServiceRegistration, + RegistrationStatus, +) +from tests.fixtures.database import ( + DatabaseTransactionFixture, + IntegrationConfigurationFixture, +) class VendorIDFixture: @@ -18,51 +23,45 @@ class VendorIDFixture: TEST_VENDOR_ID = "vendor id" db: DatabaseTransactionFixture - registry: ExternalIntegration + registry: IntegrationConfiguration + registration: DiscoveryServiceRegistration def initialize_adobe( self, vendor_id_library: Library, ): - # The libraries will share a registry integration. - self.registry = self.db.external_integration( - ExternalIntegration.OPDS_REGISTRATION, - ExternalIntegration.DISCOVERY_GOAL, - libraries=[vendor_id_library], + self.registry = self.integration_configuration.discovery_service() + self.registration, _ = create( + self.db.session, + DiscoveryServiceRegistration, + library=vendor_id_library, + integration=self.registry, + # The integration knows which Adobe Vendor ID server it gets its Adobe IDs from. + vendor_id=self.TEST_VENDOR_ID, ) - # The integration knows which Adobe Vendor ID server it gets its Adobe IDs from. - self.registry.set_setting(AuthdataUtility.VENDOR_ID_KEY, self.TEST_VENDOR_ID) - # The library given to this fixture will be setup to be able to generate # Short Client Tokens. assert vendor_id_library.short_name is not None short_name = vendor_id_library.short_name + "token" secret = vendor_id_library.short_name + " token secret" - ConfigurationSetting.for_library_and_externalintegration( - self.db.session, - ExternalIntegration.USERNAME, - vendor_id_library, - self.registry, - ).value = short_name - ConfigurationSetting.for_library_and_externalintegration( - self.db.session, - ExternalIntegration.PASSWORD, - vendor_id_library, - self.registry, - ).value = secret - ConfigurationSetting.for_library_and_externalintegration( - self.db.session, - RegistrationConstants.LIBRARY_REGISTRATION_STATUS, - vendor_id_library, - self.registry, - ).value = RegistrationConstants.SUCCESS_STATUS + self.registration.short_name = short_name + self.registration.shared_secret = secret + self.registration.status = RegistrationStatus.SUCCESS - def __init__(self, db: DatabaseTransactionFixture): + def __init__( + self, + db: DatabaseTransactionFixture, + integration_configuration: IntegrationConfigurationFixture, + ) -> None: assert isinstance(db, DatabaseTransactionFixture) self.db = db + self.integration_configuration = integration_configuration @pytest.fixture(scope="function") -def vendor_id_fixture(db: DatabaseTransactionFixture) -> VendorIDFixture: - return VendorIDFixture(db) +def vendor_id_fixture( + db: DatabaseTransactionFixture, + create_integration_configuration: IntegrationConfigurationFixture, +) -> VendorIDFixture: + return VendorIDFixture(db, create_integration_configuration)