Skip to content

Commit

Permalink
TLDR-791 add handling of client disconnection (#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
NastyBoget authored Dec 20, 2024
1 parent 0d964e4 commit e4ec06b
Show file tree
Hide file tree
Showing 19 changed files with 233 additions and 115 deletions.
24 changes: 10 additions & 14 deletions dedoc/api/api_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Iterator, List, Optional, Set

from dedoc.api.schema import LineMetadata, ParsedDocument, Table, TreeNode
from dedoc.data_structures.concrete_annotations.attach_annotation import AttachAnnotation
from dedoc.data_structures.concrete_annotations.bold_annotation import BoldAnnotation
from dedoc.data_structures.concrete_annotations.italic_annotation import ItalicAnnotation
Expand All @@ -10,10 +11,6 @@
from dedoc.data_structures.concrete_annotations.table_annotation import TableAnnotation
from dedoc.data_structures.concrete_annotations.underlined_annotation import UnderlinedAnnotation
from dedoc.data_structures.hierarchy_level import HierarchyLevel
from dedoc.data_structures.line_metadata import LineMetadata
from dedoc.data_structures.parsed_document import ParsedDocument
from dedoc.data_structures.table import Table
from dedoc.data_structures.tree_node import TreeNode
from dedoc.extensions import converted_mimes, recognized_mimes


Expand All @@ -39,7 +36,7 @@ def _node2tree(paragraph: TreeNode, depth: int, depths: Set[int] = None) -> str:
space = "".join(space)
node_result = []

node_result.append(f" {space} {paragraph.metadata.hierarchy_level.line_type}&nbsp{paragraph.node_id} ")
node_result.append(f" {space} {paragraph.metadata.paragraph_type}&nbsp{paragraph.node_id} ")
for text in __prettify_text(paragraph.text):
space = [space_symbol] * 4 * (depth - 1) + 4 * [space_symbol]
space = "".join(space)
Expand Down Expand Up @@ -98,7 +95,7 @@ def json2tree(paragraph: TreeNode) -> str:
depths = {d for d in depths if d <= depth}
space = [space_symbol] * 4 * (depth - 1) + 4 * ["-"]
space = __add_vertical_line(depths, space)
node_result.append(f"<p> <tt> <em> {space} {node.metadata.hierarchy_level.line_type}&nbsp{node.node_id} </em> </tt> </p>")
node_result.append(f"<p> <tt> <em> {space} {node.metadata.paragraph_type}&nbsp{node.node_id} </em> </tt> </p>")
for text in __prettify_text(node.text):
space = [space_symbol] * 4 * (depth - 1) + 4 * [space_symbol]
space = __add_vertical_line(depths, space)
Expand Down Expand Up @@ -136,14 +133,14 @@ def json2html(text: str,

ptext = __annotations2html(paragraph=paragraph, table2id=table2id, attach2id=attach2id, tabs=tabs)

if paragraph.metadata.hierarchy_level.line_type in [HierarchyLevel.header, HierarchyLevel.root]:
if paragraph.metadata.paragraph_type in [HierarchyLevel.header, HierarchyLevel.root]:
ptext = f"<strong>{ptext.strip()}</strong>"
elif paragraph.metadata.hierarchy_level.line_type == HierarchyLevel.list_item:
elif paragraph.metadata.paragraph_type == HierarchyLevel.list_item:
ptext = f"<em>{ptext.strip()}</em>"
else:
ptext = ptext.strip()

ptext = f'<p> {"&nbsp;" * tabs} {ptext} <sub> id = {paragraph.node_id} ; type = {paragraph.metadata.hierarchy_level.line_type} </sub></p>'
ptext = f'<p> {"&nbsp;" * tabs} {ptext} <sub> id = {paragraph.node_id} ; type = {paragraph.metadata.paragraph_type} </sub></p>'
if hasattr(paragraph.metadata, "uid"):
ptext = f'<div id="{paragraph.metadata.uid}">{ptext}</div>'
text += ptext
Expand Down Expand Up @@ -259,11 +256,10 @@ def table2html(table: Table, table2id: Dict[str, int]) -> str:
text += ' style="display: none" '
cell_node = TreeNode(
node_id="0",
text=cell.get_text(),
annotations=cell.get_annotations(),
metadata=LineMetadata(page_id=table.metadata.page_id, line_id=0),
subparagraphs=[],
parent=None
text="\n".join([line.text for line in cell.lines]),
annotations=cell.lines[0].annotations if cell.lines else [],
metadata=LineMetadata(page_id=0, line_id=0, paragraph_type=HierarchyLevel.raw_text),
subparagraphs=[]
)
text += f' colspan="{cell.colspan}" rowspan="{cell.rowspan}">{__annotations2html(cell_node, {}, {})}</td>\n'

Expand Down
34 changes: 34 additions & 0 deletions dedoc/api/cancellation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from contextlib import asynccontextmanager

from anyio import create_task_group
from fastapi import Request


@asynccontextmanager
async def cancel_on_disconnect(request: Request, logger: logging.Logger) -> None:
"""
Async context manager for async code that needs to be cancelled if client disconnects prematurely.
The client disconnect is monitored through the Request object.
Source: https://github.com/dorinclisu/runner-with-api
See discussion: https://github.com/fastapi/fastapi/discussions/8805
"""
async with create_task_group() as task_group:
async def watch_disconnect() -> None:
while True:
message = await request.receive()

if message["type"] == "http.disconnect":
client = f"{request.client.host}:{request.client.port}" if request.client else "-:-"
logger.warning(f"{client} - `{request.method} {request.url.path}` 499 DISCONNECTED")

task_group.cancel_scope.cancel()
break

task_group.start_soon(watch_disconnect)

try:
yield
finally:
task_group.cancel_scope.cancel()
40 changes: 15 additions & 25 deletions dedoc/api/dedoc_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import base64
import dataclasses
import importlib
import json
import os
import tempfile
import traceback
from typing import Optional

from fastapi import Depends, FastAPI, File, Request, Response, UploadFile
Expand All @@ -15,24 +13,23 @@
import dedoc.version
from dedoc.api.api_args import QueryParameters
from dedoc.api.api_utils import json2collapsed_tree, json2html, json2tree, json2txt
from dedoc.api.process_handler import ProcessHandler
from dedoc.api.schema.parsed_document import ParsedDocument
from dedoc.common.exceptions.dedoc_error import DedocError
from dedoc.common.exceptions.missing_file_error import MissingFileError
from dedoc.config import get_config
from dedoc.dedoc_manager import DedocManager
from dedoc.utils.utils import save_upload_file

config = get_config()
logger = config["logger"]
PORT = config["api_port"]
static_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "web")
static_files_dirs = config.get("static_files_dirs")

app = FastAPI()
app.mount("/web", StaticFiles(directory=config.get("static_path", static_path)), name="web")

module_api_args = importlib.import_module(config["import_path_init_api_args"])
logger = config["logger"]
manager = DedocManager(config=config)
process_handler = ProcessHandler(logger=logger)


@app.get("/")
Expand Down Expand Up @@ -62,27 +59,20 @@ def _get_static_file_path(request: Request) -> str:
return os.path.abspath(os.path.join(directory, file))


def __add_base64_info_to_attachments(document_tree: ParsedDocument, attachments_dir: str) -> None:
for attachment in document_tree.attachments:
with open(os.path.join(attachments_dir, attachment.metadata.temporary_file_name), "rb") as attachment_file:
attachment.metadata.add_attribute("base64", base64.b64encode(attachment_file.read()).decode("utf-8"))


@app.post("/upload", response_model=ParsedDocument)
async def upload(file: UploadFile = File(...), query_params: QueryParameters = Depends()) -> Response:
async def upload(request: Request, file: UploadFile = File(...), query_params: QueryParameters = Depends()) -> Response:
parameters = dataclasses.asdict(query_params)
if not file or file.filename == "":
raise MissingFileError("Error: Missing content in request_post file parameter", version=dedoc.version.__version__)

return_format = str(parameters.get("return_format", "json")).lower()

with tempfile.TemporaryDirectory() as tmpdir:
file_path = save_upload_file(file, tmpdir)
document_tree = manager.parse(file_path, parameters={**dict(parameters), "attachments_dir": tmpdir})
document_tree = await process_handler.handle(request=request, parameters=parameters, file_path=file_path, tmpdir=tmpdir)

if return_format == "html":
__add_base64_info_to_attachments(document_tree, tmpdir)
if document_tree is None:
return JSONResponse(status_code=499, content={})

return_format = str(parameters.get("return_format", "json")).lower()
if return_format == "html":
html_content = json2html(
text="",
Expand All @@ -102,24 +92,25 @@ async def upload(file: UploadFile = File(...), query_params: QueryParameters = D
return HTMLResponse(content=html_content)

if return_format == "ujson":
return UJSONResponse(content=document_tree.to_api_schema().model_dump())
return UJSONResponse(content=document_tree.model_dump())

if return_format == "collapsed_tree":
html_content = json2collapsed_tree(paragraph=document_tree.content.structure)
return HTMLResponse(content=html_content)

if return_format == "pretty_json":
return PlainTextResponse(content=json.dumps(document_tree.to_api_schema().model_dump(), ensure_ascii=False, indent=2))
return PlainTextResponse(content=json.dumps(document_tree.model_dump(), ensure_ascii=False, indent=2))

logger.info(f"Send result. File {file.filename} with parameters {parameters}")
return ORJSONResponse(content=document_tree.to_api_schema().model_dump())
return ORJSONResponse(content=document_tree.model_dump())


@app.get("/upload_example")
async def upload_example(file_name: str, return_format: Optional[str] = None) -> Response:
async def upload_example(request: Request, file_name: str, return_format: Optional[str] = None) -> Response:
file_path = os.path.join(static_path, "examples", file_name)
parameters = {} if return_format is None else {"return_format": return_format}
document_tree = manager.parse(file_path, parameters=parameters)
with tempfile.TemporaryDirectory() as tmpdir:
document_tree = await process_handler.handle(request=request, parameters=parameters, file_path=file_path, tmpdir=tmpdir)

if return_format == "html":
html_page = json2html(
Expand All @@ -130,12 +121,11 @@ async def upload_example(file_name: str, return_format: Optional[str] = None) ->
tabs=0
)
return HTMLResponse(content=html_page)
return ORJSONResponse(content=document_tree.to_api_schema().model_dump(), status_code=200)
return ORJSONResponse(content=document_tree.model_dump(), status_code=200)


@app.exception_handler(DedocError)
async def exception_handler(request: Request, exc: DedocError) -> Response:
logger.error(f"Exception {exc}\n{traceback.format_exc()}")
result = {"message": exc.msg}
if exc.filename:
result["file_name"] = exc.filename
Expand Down
115 changes: 115 additions & 0 deletions dedoc/api/process_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import asyncio
import base64
import logging
import os
import pickle
import signal
import traceback
from multiprocessing import Process, Queue
from typing import Optional
from urllib.request import Request

from anyio import get_cancelled_exc_class

from dedoc.api.cancellation import cancel_on_disconnect
from dedoc.api.schema import ParsedDocument
from dedoc.common.exceptions.dedoc_error import DedocError
from dedoc.config import get_config
from dedoc.dedoc_manager import DedocManager


class ProcessHandler:
"""
Class for file parsing by DedocManager with support for client disconnection.
If client disconnects during file parsing, the process of parsing is fully terminated and API is available to receive new connections.
Handler uses the following algorithm:
1. Master process is used for checking current connection (client disconnect)
2. Child process is working on the background and waiting for the input file in the input_queue
3. Master process calls the child process for parsing and transfers data through the input_queue
4. Child process is parsing file using DedocManager
5. The result of parsing is transferred to the master process through the output_queue
6. If client disconnects, the child process is terminated. The new child process with queues will start with the new request
"""
def __init__(self, logger: logging.Logger) -> None:
self.input_queue = Queue()
self.output_queue = Queue()
self.logger = logger
self.process = Process(target=self.__parse_file, args=[self.input_queue, self.output_queue])
self.process.start()

async def handle(self, request: Request, parameters: dict, file_path: str, tmpdir: str) -> Optional[ParsedDocument]:
"""
Handle request in a separate process.
Checks for client disconnection and terminate the child process if client disconnected.
"""
if self.process is None:
self.logger.info("Initialization of a new parsing process")
self.__init__(logger=self.logger)

self.logger.info("Putting file to the input queue")
self.input_queue.put(pickle.dumps((parameters, file_path, tmpdir)), block=True)

loop = asyncio.get_running_loop()
async with cancel_on_disconnect(request, self.logger):
try:
future = loop.run_in_executor(None, self.output_queue.get)
result = await future
except get_cancelled_exc_class():
self.logger.warning("Terminating the parsing process")
if self.process is not None:
self.process.terminate()
self.process = None
future.cancel(DedocError)
return None

result = pickle.loads(result)
if isinstance(result, ParsedDocument):
self.logger.info("Got the result from the output queue")
return result

raise DedocError.from_dict(result)

def __parse_file(self, input_queue: Queue, output_queue: Queue) -> None:
"""
Function for file parsing in a separate (child) process.
It's a background process, i.e. it is waiting for a task in the input queue.
The result of parsing is returned in the output queue.
Operations with `signal` are used for saving master process while killing child process.
See the issue for more details: https://github.com/fastapi/fastapi/issues/1487
"""
signal.set_wakeup_fd(-1)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)

manager = DedocManager(config=get_config())
manager.logger.info("Parsing process is waiting for the task in the input queue")

while True:
file_path = None
try:
parameters, file_path, tmp_dir = pickle.loads(input_queue.get(block=True))
manager.logger.info("Parsing process got task from the input queue")
return_format = str(parameters.get("return_format", "json")).lower()
document_tree = manager.parse(file_path, parameters={**dict(parameters), "attachments_dir": tmp_dir})

if return_format == "html":
self.__add_base64_info_to_attachments(document_tree, tmp_dir)

output_queue.put(pickle.dumps(document_tree.to_api_schema()), block=True)
manager.logger.info("Parsing process put task to the output queue")
except DedocError as e:
tb = traceback.format_exc()
manager.logger.error(f"Exception {e}: {e.msg_api}\n{tb}")
output_queue.put(pickle.dumps(e.__dict__), block=True)
except Exception as e:
exc_message = f"Exception {e}\n{traceback.format_exc()}"
filename = "" if file_path is None else os.path.basename(file_path)
manager.logger.error(exc_message)
output_queue.put(pickle.dumps({"msg": exc_message, "filename": filename}), block=True)

def __add_base64_info_to_attachments(self, document_tree: ParsedDocument, attachments_dir: str) -> None:
for attachment in document_tree.attachments:
with open(os.path.join(attachments_dir, attachment.metadata.temporary_file_name), "rb") as attachment_file:
attachment.metadata.add_attribute("base64", base64.b64encode(attachment_file.read()).decode("utf-8"))
6 changes: 1 addition & 5 deletions dedoc/common/exceptions/bad_file_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ class BadFileFormatError(DedocError):
"""

def __init__(self, msg: str, msg_api: Optional[str] = None, filename: Optional[str] = None, version: Optional[str] = None) -> None:
super(BadFileFormatError, self).__init__(msg_api=msg_api, msg=msg, filename=filename, version=version)
super(BadFileFormatError, self).__init__(msg_api=msg_api, msg=msg, filename=filename, version=version, code=415)

def __str__(self) -> str:
return f"BadFileFormatError({self.msg})"

@property
def code(self) -> int:
return 415
4 changes: 0 additions & 4 deletions dedoc/common/exceptions/bad_parameters_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ def __init__(self, msg: str, msg_api: Optional[str] = None, filename: Optional[s

def __str__(self) -> str:
return f"BadParametersError({self.msg})"

@property
def code(self) -> int:
return 400
6 changes: 1 addition & 5 deletions dedoc/common/exceptions/conversion_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ class ConversionError(DedocError):
"""

def __init__(self, msg: str, msg_api: Optional[str] = None, filename: Optional[str] = None, version: Optional[str] = None) -> None:
super(ConversionError, self).__init__(msg_api=msg_api, msg=msg, filename=filename, version=version)
super(ConversionError, self).__init__(msg_api=msg_api, msg=msg, filename=filename, version=version, code=415)

def __str__(self) -> str:
return f"ConversionError({self.msg})"

@property
def code(self) -> int:
return 415
Loading

0 comments on commit e4ec06b

Please sign in to comment.