Skip to content

Commit

Permalink
feat(ocr): enhance documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
c0deplayer committed Dec 18, 2024
1 parent 1c8ef88 commit 9787fe4
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 15 deletions.
132 changes: 126 additions & 6 deletions src/ocr/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,47 @@ class OCRProcessor:
"""Handle OCR processing operations."""

def __init__(self, document_repository: DocumentRepository) -> None:
"""Initialize OCR processor with repository.
Args:
document_repository (DocumentRepository): The repository instance for document storage and retrieval.
Returns:
None
"""
"""Initialize OCR processor with repository."""
self.repository = document_repository

async def validate_file(self, file: UploadFile) -> str:
"""Validate uploaded file against constraints."""
"""Validate uploaded file against constraints.
This method performs several validation checks on the uploaded file:
1. Validates filename existence
2. Ensures filename uniqueness
3. Validates file type against accepted types
4. Checks file size constraints
5. For PDFs: validates page count
Args:
file (UploadFile): The file object to validate, containing attributes like
filename, content_type, size, and file stream.
Returns:
str: The validated filename (potentially modified for uniqueness)
Raises:
HTTPException: With appropriate status codes for validation failures:
- 400: Missing filename, empty file, or invalid PDF
- 413: File size exceeds limit
- 415: Unsupported media type
Example:
```
validated_filename = await ocr_instance.validate_file(uploaded_file)
```
"""
logger.info("Validating file: %s", file.filename)
if not file.filename:
raise HTTPException(
Expand Down Expand Up @@ -180,7 +216,19 @@ async def validate_file(self, file: UploadFile) -> str:

@staticmethod
def convert_file_to_images(file: UploadFile) -> list[Image.Image]:
"""Convert uploaded file to list of PIL Images."""
"""Convert uploaded file to list of PIL Images.
This function takes an uploaded file and converts it into a list of PIL Image objects.
For PDF files, it converts each page into a separate image. For other image files,
it returns a single image in a list.
Args:
file (UploadFile): The uploaded file object containing either a PDF or image file.
Returns:
list[Image.Image]: A list of PIL Image objects representing the file contents.
"""
logger.info("Converting file to images: %s", file.filename)

if file.content_type == "application/pdf":
Expand All @@ -190,7 +238,22 @@ def convert_file_to_images(file: UploadFile) -> list[Image.Image]:

@staticmethod
def convert_to_base64(images: list[Image.Image]) -> list[str]:
"""Convert PIL Images to base64 strings."""
"""Convert PIL Images to base64 strings.
This function takes a list of PIL Image objects and converts each image to a base64-encoded string.
Images in RGBA or P mode are converted to RGB before encoding. Each image is saved in JPEG format.
Args:
images (list[Image.Image]): A list of PIL Image objects to be converted
Returns:
list[str]: A list of base64-encoded strings representing the images
Example:
>>> images = [Image.open('image1.png'), Image.open('image2.jpg')]
>>> base64_strings = convert_to_base64(images)
"""
encoded_images = []

for img in images:
Expand All @@ -208,7 +271,18 @@ def convert_to_base64(images: list[Image.Image]) -> list[str]:
return encoded_images

async def save_document(self, file_name: str) -> Document:
"""Save document metadata to database."""
"""Save document metadata to database.
Args:
file_name (str): Name of the document file to be saved
Returns:
Document: Created document instance with metadata
Raises:
HTTPException: If document creation fails with 500 status code
"""
try:
return await self.repository.create(
DocumentCreate(
Expand All @@ -232,7 +306,25 @@ async def get_docs(
get_repository,
),
) -> list[dict]:
"""Get all documents."""
"""Asynchronously retrieves all documents from the repository.
Args:
repository (AsyncGenerator[DocumentRepository, None]): An async generator that yields a DocumentRepository instance.
Defaults to the repository provided by get_repository dependency.
Returns:
list[dict]: A list of documents, where each document is represented as a dictionary.
Raises:
HTTPException: If there's an error retrieving documents from the repository.
Returns a 500 Internal Server Error status code.
Example:
>>> async def example():
... docs = await get_docs()
... print(docs) # [{doc1_data}, {doc2_data}, ...]
"""
try:
documents = await repository.get_all()
return [doc.to_dict() for doc in documents]
Expand All @@ -251,7 +343,35 @@ async def process_document(
get_repository,
),
) -> dict[str, str]:
"""Process document for OCR and forward results."""
"""Process and analyze an uploaded document through OCR and subsequent processing.
This asynchronous function handles document upload, encryption, OCR processing, and forwards
the results to a downstream processing service.
Args:
file (UploadFile): The uploaded file to be processed
repository (AsyncGenerator[DocumentRepository, None]): Repository dependency for document storage
Returns:
dict[str, str]: JSON response from the downstream processing service
Raises:
HTTPException: With status code 500 if:
- Document processing fails in downstream service
- Any other processing error occurs
Flow:
1. Validates and saves uploaded file
2. Encrypts the saved file
3. Performs OCR on file contents
4. Saves initial document record
5. Forwards OCR results to processor service
6. Returns processor service response
Note:
The function includes built-in cleanup in case of failures at any stage
"""
logger.info("Processing document: %s", file.filename)

processor = OCRProcessor(repository)
Expand Down
120 changes: 111 additions & 9 deletions src/summarization/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,22 @@


class OllamaOptions(TypedDict):
"""Type definition for Ollama API options."""
"""A TypedDict class representing configuration options for Ollama API.
Attributes:
temperature (float): Controls randomness in the response. Higher values (e.g., 0.8) make the output more random,
while lower values (e.g., 0.2) make it more focused and deterministic.
top_p (float): Controls diversity via nucleus sampling. Restricts cumulative probability of tokens considered
for sampling. Range 0.0-1.0.
frequency_penalty (float): Reduces likelihood of repeated tokens. Positive values discourage repetition while
negative values encourage it.
presence_penalty (float): Influences likelihood of discussing new topics. Positive values encourage covering
new topics while negative values focus on existing ones.
num_ctx (int): Sets the size of the context window (in tokens) for the model's input/output.
num_gpu (int): Specifies the number of GPUs to utilize for model inference.
num_thread (int): Defines the number of CPU threads to use for processing.
"""

temperature: float
top_p: float
Expand All @@ -46,7 +61,19 @@ class OllamaOptions(TypedDict):


class OllamaRequest(TypedDict, total=False):
"""Type definition for Ollama API request."""
"""A TypedDict class representing a request to the Ollama API.
Attributes:
model (str): The name of the model to use for the request. Required.
prompt (str): The prompt text to send to the model. Required.
format (str, optional): The desired output format.
options (OllamaOptions, optional): Additional options for the model.
system (str, optional): System prompt to modify model behavior.
stream (bool, optional): Whether to stream the response.
raw (bool, optional): Whether to return raw response.
keep_alive (Union[str, int], optional): Duration to keep the model loaded.
"""

model: str # required
prompt: str # required
Expand All @@ -59,7 +86,38 @@ class OllamaRequest(TypedDict, total=False):


class Summarizer:
"""Text summarization using direct Ollama integration."""
"""A class for generating text summaries using the Ollama API.
This class handles the preparation, validation, and processing of text summarization
requests through the Ollama API. It includes functionality for JSON response parsing,
error handling, and content validation.
Attributes:
SUMMARY_SYSTEM_PROMPT (str): Template for system-level instructions to the model.
SUMMARY_PROMPT (str): Template for the main summarization prompt.
config (SummarizationConfig): Configuration settings for the summarizer.
base_url (str): Base URL for the Ollama API endpoint.
Example:
```python
config = SummarizationConfig()
summarizer = Summarizer(config)
summary = await summarizer.generate_summary(
text="Long text to summarize",
min_length=100,
max_length=300,
classification="technical"
```
Note:
The class expects the Ollama API to be available and properly configured.
It handles responses in JSON format and includes fallback mechanisms for
non-JSON responses.
HTTPException: When API communication fails or response validation fails.
ValueError: When response format or content validation fails.
"""

SUMMARY_SYSTEM_PROMPT = """You are an expert summarization analyst specializing in precise content distillation.
Expand Down Expand Up @@ -118,7 +176,20 @@ class Summarizer:
Generate summary:"""

def __init__(self, config: SummarizationConfig) -> None:
"""Initialize summarizer with configuration."""
"""Initialize the summarizer with configuration settings.
Args:
config (SummarizationConfig): Configuration object containing settings for summarization,
including OLLAMA_BASE_URL for API endpoint.
Returns:
None
Example:
>>> config = SummarizationConfig(OLLAMA_BASE_URL="http://localhost:11434")
>>> summarizer = Summarizer(config)
"""
self.config = config
self.base_url = f"{config.OLLAMA_BASE_URL}/api/generate"

Expand All @@ -129,7 +200,29 @@ def _prepare_request(
max_length: int,
classification: str,
) -> OllamaRequest:
"""Prepare the request payload for Ollama API."""
"""Prepare a request for the Ollama API to generate a summary of the given text.
This method constructs a request dictionary with the necessary parameters for text summarization,
including system and user prompts, model configuration, and various generation parameters.
Args:
text (str): The input text to be summarized.
min_length (int): The minimum length of the generated summary.
max_length (int): The maximum length of the generated summary.
classification (str): The classification category of the text.
Returns:
OllamaRequest: A dictionary containing all necessary parameters for the Ollama API request,
including:
- model: The name of the language model to use
- prompt: The formatted user prompt
- system: The formatted system prompt
- format: The expected response format (json)
- stream: Whether to stream the response
- options: Model-specific parameters (temperature, top_p, etc.)
- keep_alive: Duration to keep the model loaded
"""
system_prompt = self.SUMMARY_SYSTEM_PROMPT.format(
min_length=min_length,
max_length=max_length,
Expand Down Expand Up @@ -162,16 +255,25 @@ def _prepare_request(
}

def _extract_json_from_text(self, text: str) -> dict[str, Any]:
"""Extract JSON from text that might contain additional content.
r"""Extract a JSON object from a text string.
This method processes the input text by removing any markdown JSON code block markers
and attempts to parse the first JSON object found in the text.
Args:
text: Text that might contain JSON
text (str): The input text containing a JSON object, potentially within markdown code blocks.
Returns:
Extracted JSON as dict
dict[str, Any]: The parsed JSON object as a dictionary.
Raises:
ValueError: If no valid JSON found
ValueError: If no JSON object is found in the text or if the JSON parsing fails.
Example:
>>> text = "```json\\n{\\\"key\\\": \\\"value\\\"}\\n```"
>>> result = _extract_json_from_text(text)
>>> print(result)
{'key': 'value'}
"""
# Remove any markdown code blocks
Expand Down

0 comments on commit 9787fe4

Please sign in to comment.