From 9787fe49016aff4d19765b518878ddb30494168f Mon Sep 17 00:00:00 2001 From: Jakub Zenon Kujawa Date: Wed, 18 Dec 2024 18:57:22 +0100 Subject: [PATCH] feat(ocr): enhance documentation --- src/ocr/ocr.py | 132 ++++++++++++++++++++++++++++++-- src/summarization/summarizer.py | 120 ++++++++++++++++++++++++++--- 2 files changed, 237 insertions(+), 15 deletions(-) diff --git a/src/ocr/ocr.py b/src/ocr/ocr.py index bc5707f..d9f3fda 100644 --- a/src/ocr/ocr.py +++ b/src/ocr/ocr.py @@ -111,11 +111,47 @@ class OCRProcessor: """Handle OCR processing operations.""" def __init__(self, document_repository: DocumentRepository) -> None: + """Initialize OCR processor with repository. + + Args: + document_repository (DocumentRepository): The repository instance for document storage and retrieval. + + Returns: + None + + """ """Initialize OCR processor with repository.""" self.repository = document_repository async def validate_file(self, file: UploadFile) -> str: - """Validate uploaded file against constraints.""" + """Validate uploaded file against constraints. + + This method performs several validation checks on the uploaded file: + 1. Validates filename existence + 2. Ensures filename uniqueness + 3. Validates file type against accepted types + 4. Checks file size constraints + 5. For PDFs: validates page count + + Args: + file (UploadFile): The file object to validate, containing attributes like + filename, content_type, size, and file stream. + + Returns: + str: The validated filename (potentially modified for uniqueness) + + Raises: + HTTPException: With appropriate status codes for validation failures: + - 400: Missing filename, empty file, or invalid PDF + - 413: File size exceeds limit + - 415: Unsupported media type + + Example: + ``` + validated_filename = await ocr_instance.validate_file(uploaded_file) + ``` + + """ logger.info("Validating file: %s", file.filename) if not file.filename: raise HTTPException( @@ -180,7 +216,19 @@ async def validate_file(self, file: UploadFile) -> str: @staticmethod def convert_file_to_images(file: UploadFile) -> list[Image.Image]: - """Convert uploaded file to list of PIL Images.""" + """Convert uploaded file to list of PIL Images. + + This function takes an uploaded file and converts it into a list of PIL Image objects. + For PDF files, it converts each page into a separate image. For other image files, + it returns a single image in a list. + + Args: + file (UploadFile): The uploaded file object containing either a PDF or image file. + + Returns: + list[Image.Image]: A list of PIL Image objects representing the file contents. + + """ logger.info("Converting file to images: %s", file.filename) if file.content_type == "application/pdf": @@ -190,7 +238,22 @@ def convert_file_to_images(file: UploadFile) -> list[Image.Image]: @staticmethod def convert_to_base64(images: list[Image.Image]) -> list[str]: - """Convert PIL Images to base64 strings.""" + """Convert PIL Images to base64 strings. + + This function takes a list of PIL Image objects and converts each image to a base64-encoded string. + Images in RGBA or P mode are converted to RGB before encoding. Each image is saved in JPEG format. + + Args: + images (list[Image.Image]): A list of PIL Image objects to be converted + + Returns: + list[str]: A list of base64-encoded strings representing the images + + Example: + >>> images = [Image.open('image1.png'), Image.open('image2.jpg')] + >>> base64_strings = convert_to_base64(images) + + """ encoded_images = [] for img in images: @@ -208,7 +271,18 @@ def convert_to_base64(images: list[Image.Image]) -> list[str]: return encoded_images async def save_document(self, file_name: str) -> Document: - """Save document metadata to database.""" + """Save document metadata to database. + + Args: + file_name (str): Name of the document file to be saved + + Returns: + Document: Created document instance with metadata + + Raises: + HTTPException: If document creation fails with 500 status code + + """ try: return await self.repository.create( DocumentCreate( @@ -232,7 +306,25 @@ async def get_docs( get_repository, ), ) -> list[dict]: - """Get all documents.""" + """Asynchronously retrieves all documents from the repository. + + Args: + repository (AsyncGenerator[DocumentRepository, None]): An async generator that yields a DocumentRepository instance. + Defaults to the repository provided by get_repository dependency. + + Returns: + list[dict]: A list of documents, where each document is represented as a dictionary. + + Raises: + HTTPException: If there's an error retrieving documents from the repository. + Returns a 500 Internal Server Error status code. + + Example: + >>> async def example(): + ... docs = await get_docs() + ... print(docs) # [{doc1_data}, {doc2_data}, ...] + + """ try: documents = await repository.get_all() return [doc.to_dict() for doc in documents] @@ -251,7 +343,35 @@ async def process_document( get_repository, ), ) -> dict[str, str]: - """Process document for OCR and forward results.""" + """Process and analyze an uploaded document through OCR and subsequent processing. + + This asynchronous function handles document upload, encryption, OCR processing, and forwards + the results to a downstream processing service. + + Args: + file (UploadFile): The uploaded file to be processed + repository (AsyncGenerator[DocumentRepository, None]): Repository dependency for document storage + + Returns: + dict[str, str]: JSON response from the downstream processing service + + Raises: + HTTPException: With status code 500 if: + - Document processing fails in downstream service + - Any other processing error occurs + + Flow: + 1. Validates and saves uploaded file + 2. Encrypts the saved file + 3. Performs OCR on file contents + 4. Saves initial document record + 5. Forwards OCR results to processor service + 6. Returns processor service response + + Note: + The function includes built-in cleanup in case of failures at any stage + + """ logger.info("Processing document: %s", file.filename) processor = OCRProcessor(repository) diff --git a/src/summarization/summarizer.py b/src/summarization/summarizer.py index 1a7849b..dca5566 100644 --- a/src/summarization/summarizer.py +++ b/src/summarization/summarizer.py @@ -34,7 +34,22 @@ class OllamaOptions(TypedDict): - """Type definition for Ollama API options.""" + """A TypedDict class representing configuration options for Ollama API. + + Attributes: + temperature (float): Controls randomness in the response. Higher values (e.g., 0.8) make the output more random, + while lower values (e.g., 0.2) make it more focused and deterministic. + top_p (float): Controls diversity via nucleus sampling. Restricts cumulative probability of tokens considered + for sampling. Range 0.0-1.0. + frequency_penalty (float): Reduces likelihood of repeated tokens. Positive values discourage repetition while + negative values encourage it. + presence_penalty (float): Influences likelihood of discussing new topics. Positive values encourage covering + new topics while negative values focus on existing ones. + num_ctx (int): Sets the size of the context window (in tokens) for the model's input/output. + num_gpu (int): Specifies the number of GPUs to utilize for model inference. + num_thread (int): Defines the number of CPU threads to use for processing. + + """ temperature: float top_p: float @@ -46,7 +61,19 @@ class OllamaOptions(TypedDict): class OllamaRequest(TypedDict, total=False): - """Type definition for Ollama API request.""" + """A TypedDict class representing a request to the Ollama API. + + Attributes: + model (str): The name of the model to use for the request. Required. + prompt (str): The prompt text to send to the model. Required. + format (str, optional): The desired output format. + options (OllamaOptions, optional): Additional options for the model. + system (str, optional): System prompt to modify model behavior. + stream (bool, optional): Whether to stream the response. + raw (bool, optional): Whether to return raw response. + keep_alive (Union[str, int], optional): Duration to keep the model loaded. + + """ model: str # required prompt: str # required @@ -59,7 +86,38 @@ class OllamaRequest(TypedDict, total=False): class Summarizer: - """Text summarization using direct Ollama integration.""" + """A class for generating text summaries using the Ollama API. + + This class handles the preparation, validation, and processing of text summarization + requests through the Ollama API. It includes functionality for JSON response parsing, + error handling, and content validation. + + Attributes: + SUMMARY_SYSTEM_PROMPT (str): Template for system-level instructions to the model. + SUMMARY_PROMPT (str): Template for the main summarization prompt. + config (SummarizationConfig): Configuration settings for the summarizer. + base_url (str): Base URL for the Ollama API endpoint. + + Example: + ```python + config = SummarizationConfig() + summarizer = Summarizer(config) + summary = await summarizer.generate_summary( + text="Long text to summarize", + min_length=100, + max_length=300, + classification="technical" + ``` + + Note: + The class expects the Ollama API to be available and properly configured. + It handles responses in JSON format and includes fallback mechanisms for + non-JSON responses. + + HTTPException: When API communication fails or response validation fails. + ValueError: When response format or content validation fails. + + """ SUMMARY_SYSTEM_PROMPT = """You are an expert summarization analyst specializing in precise content distillation. @@ -118,7 +176,20 @@ class Summarizer: Generate summary:""" def __init__(self, config: SummarizationConfig) -> None: - """Initialize summarizer with configuration.""" + """Initialize the summarizer with configuration settings. + + Args: + config (SummarizationConfig): Configuration object containing settings for summarization, + including OLLAMA_BASE_URL for API endpoint. + + Returns: + None + + Example: + >>> config = SummarizationConfig(OLLAMA_BASE_URL="http://localhost:11434") + >>> summarizer = Summarizer(config) + + """ self.config = config self.base_url = f"{config.OLLAMA_BASE_URL}/api/generate" @@ -129,7 +200,29 @@ def _prepare_request( max_length: int, classification: str, ) -> OllamaRequest: - """Prepare the request payload for Ollama API.""" + """Prepare a request for the Ollama API to generate a summary of the given text. + + This method constructs a request dictionary with the necessary parameters for text summarization, + including system and user prompts, model configuration, and various generation parameters. + + Args: + text (str): The input text to be summarized. + min_length (int): The minimum length of the generated summary. + max_length (int): The maximum length of the generated summary. + classification (str): The classification category of the text. + + Returns: + OllamaRequest: A dictionary containing all necessary parameters for the Ollama API request, + including: + - model: The name of the language model to use + - prompt: The formatted user prompt + - system: The formatted system prompt + - format: The expected response format (json) + - stream: Whether to stream the response + - options: Model-specific parameters (temperature, top_p, etc.) + - keep_alive: Duration to keep the model loaded + + """ system_prompt = self.SUMMARY_SYSTEM_PROMPT.format( min_length=min_length, max_length=max_length, @@ -162,16 +255,25 @@ def _prepare_request( } def _extract_json_from_text(self, text: str) -> dict[str, Any]: - """Extract JSON from text that might contain additional content. + r"""Extract a JSON object from a text string. + + This method processes the input text by removing any markdown JSON code block markers + and attempts to parse the first JSON object found in the text. Args: - text: Text that might contain JSON + text (str): The input text containing a JSON object, potentially within markdown code blocks. Returns: - Extracted JSON as dict + dict[str, Any]: The parsed JSON object as a dictionary. Raises: - ValueError: If no valid JSON found + ValueError: If no JSON object is found in the text or if the JSON parsing fails. + + Example: + >>> text = "```json\\n{\\\"key\\\": \\\"value\\\"}\\n```" + >>> result = _extract_json_from_text(text) + >>> print(result) + {'key': 'value'} """ # Remove any markdown code blocks