diff --git a/haystack/components/converters/docx.py b/haystack/components/converters/docx.py index 8ffc3888a8..dc0a51f485 100644 --- a/haystack/components/converters/docx.py +++ b/haystack/components/converters/docx.py @@ -2,12 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 +import csv import io from dataclasses import dataclass +from enum import Enum +from io import StringIO from pathlib import Path from typing import Any, Dict, List, Optional, Union -from haystack import Document, component, logging +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata from haystack.dataclasses import ByteStream from haystack.lazy_imports import LazyImport @@ -17,6 +20,7 @@ with LazyImport("Run 'pip install python-docx'") as docx_import: import docx from docx.document import Document as DocxDocument + from docx.table import Table from docx.text.paragraph import Paragraph @@ -59,6 +63,30 @@ class DOCXMetadata: version: str +class DOCXTableFormat(Enum): + """ + Supported formats for storing DOCX tabular data in a Document. + """ + + MARKDOWN = "markdown" + CSV = "csv" + + def __str__(self): + return self.value + + @staticmethod + def from_str(string: str) -> "DOCXTableFormat": + """ + Convert a string to a DOCXTableFormat enum. + """ + enum_map = {e.value: e for e in DOCXTableFormat} + table_format = enum_map.get(string.lower()) + if table_format is None: + msg = f"Unknown table format '{string}'. Supported formats are: {list(enum_map.keys())}" + raise ValueError(msg) + return table_format + + @component class DOCXToDocument: """ @@ -69,9 +97,9 @@ class DOCXToDocument: Usage example: ```python - from haystack.components.converters.docx import DOCXToDocument + from haystack.components.converters.docx import DOCXToDocument, DOCXTableFormat - converter = DOCXToDocument() + converter = DOCXToDocument(table_format=DOCXTableFormat.CSV) results = converter.run(sources=["sample.docx"], meta={"date_added": datetime.now().isoformat()}) documents = results["documents"] print(documents[0].content) @@ -79,11 +107,38 @@ class DOCXToDocument: ``` """ - def __init__(self): + def __init__(self, table_format: Union[str, DOCXTableFormat] = DOCXTableFormat.CSV): """ Create a DOCXToDocument component. + + :param table_format: The format for table output. Can be either DOCXTableFormat.MARKDOWN, + DOCXTableFormat.CSV, "markdown", or "csv". Defaults to DOCXTableFormat.CSV. """ docx_import.check() + self.table_format = DOCXTableFormat.from_str(table_format) if isinstance(table_format, str) else table_format + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict(self, table_format=str(self.table_format)) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DOCXToDocument": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + if "table_format" in data["init_parameters"]: + data["init_parameters"]["table_format"] = DOCXTableFormat.from_str(data["init_parameters"]["table_format"]) + return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run( @@ -118,9 +173,9 @@ def run( logger.warning("Could not read {source}. Skipping it. Error: {error}", source=source, error=e) continue try: - file = docx.Document(io.BytesIO(bytestream.data)) - paragraphs = self._extract_paragraphs_with_page_breaks(file.paragraphs) - text = "\n".join(paragraphs) + docx_document = docx.Document(io.BytesIO(bytestream.data)) + elements = self._extract_elements(docx_document) + text = "\n".join(elements) except Exception as e: logger.warning( "Could not read {source} and convert it to a DOCX Document, skipping. Error: {error}", @@ -129,52 +184,116 @@ def run( ) continue - docx_metadata = self._get_docx_metadata(document=file) + docx_metadata = self._get_docx_metadata(document=docx_document) merged_metadata = {**bytestream.meta, **metadata, "docx": docx_metadata} document = Document(content=text, meta=merged_metadata) documents.append(document) return {"documents": documents} - def _extract_paragraphs_with_page_breaks(self, paragraphs: List["Paragraph"]) -> List[str]: + def _extract_elements(self, document: "DocxDocument") -> List[str]: """ - Extracts paragraphs from a DOCX file, including page breaks. + Extracts elements from a DOCX file. - Page breaks (both soft and hard page breaks) are not automatically extracted by python-docx as '\f' chars. - This means we need to add them in ourselves, as done here. This allows the correct page number - to be associated with each document if the file contents are split, e.g. by DocumentSplitter. + :param document: The DOCX Document object. + :returns: List of strings (paragraph texts and table representations) with page breaks added as '\f' characters. + """ + elements = [] + for element in document.element.body: + if element.tag.endswith("p"): + paragraph = Paragraph(element, document) + if paragraph.contains_page_break: + para_text = self._process_paragraph_with_page_breaks(paragraph) + else: + para_text = paragraph.text + elements.append(para_text) + elif element.tag.endswith("tbl"): + table = docx.table.Table(element, document) + table_str = ( + self._table_to_markdown(table) + if self.table_format == DOCXTableFormat.MARKDOWN + else self._table_to_csv(table) + ) + elements.append(table_str) - :param paragraphs: - List of paragraphs from a DOCX file. + return elements - :returns: - List of strings (paragraph text fields) with all page breaks added in as '\f' characters. - """ - paragraph_texts = [] - for para in paragraphs: - if para.contains_page_break: - para_text_w_page_breaks = "" - # Usually, just 1 page break exists, but could be more if paragraph is really long, so we loop over them - for pb_index, page_break in enumerate(para.rendered_page_breaks): - # Can only extract text from first paragraph page break, unfortunately - if pb_index == 0: - if page_break.preceding_paragraph_fragment: - para_text_w_page_breaks += page_break.preceding_paragraph_fragment.text - para_text_w_page_breaks += "\f" - if page_break.following_paragraph_fragment: - # following_paragraph_fragment contains all text for remainder of paragraph. - # However, if the remainder of the paragraph spans multiple page breaks, it won't include - # those later page breaks so we have to add them at end of text in the `else` block below. - # This is not ideal, but this case should be very rare and this is likely good enough. - para_text_w_page_breaks += page_break.following_paragraph_fragment.text - else: - para_text_w_page_breaks += "\f" - - paragraph_texts.append(para_text_w_page_breaks) + def _process_paragraph_with_page_breaks(self, paragraph: "Paragraph") -> str: + """ + Processes a paragraph with page breaks. + + :param paragraph: The DOCX paragraph to process. + :returns: A string with page breaks added as '\f' characters. + """ + para_text = "" + # Usually, just 1 page break exists, but could be more if paragraph is really long, so we loop over them + for pb_index, page_break in enumerate(paragraph.rendered_page_breaks): + # Can only extract text from first paragraph page break, unfortunately + if pb_index == 0: + if page_break.preceding_paragraph_fragment: + para_text += page_break.preceding_paragraph_fragment.text + para_text += "\f" + if page_break.following_paragraph_fragment: + # following_paragraph_fragment contains all text for remainder of paragraph. + # However, if the remainder of the paragraph spans multiple page breaks, it won't include + # those later page breaks so we have to add them at end of text in the `else` block below. + # This is not ideal, but this case should be very rare and this is likely good enough. + para_text += page_break.following_paragraph_fragment.text else: - paragraph_texts.append(para.text) + para_text += "\f" + return para_text + + def _table_to_markdown(self, table: "Table") -> str: + """ + Converts a DOCX table to a Markdown string. + + :param table: The DOCX table to convert. + :returns: A Markdown string representation of the table. + """ + markdown: List[str] = [] + max_col_widths: List[int] = [] + + # Calculate max width for each column + for row in table.rows: + for i, cell in enumerate(row.cells): + cell_text = cell.text.strip() + if i >= len(max_col_widths): + max_col_widths.append(len(cell_text)) + else: + max_col_widths[i] = max(max_col_widths[i], len(cell_text)) + + # Process rows + for i, row in enumerate(table.rows): + md_row = [cell.text.strip().ljust(max_col_widths[j]) for j, cell in enumerate(row.cells)] + markdown.append("| " + " | ".join(md_row) + " |") + + # Add separator after header row + if i == 0: + separator = ["-" * max_col_widths[j] for j in range(len(row.cells))] + markdown.append("| " + " | ".join(separator) + " |") + + return "\n".join(markdown) + + def _table_to_csv(self, table: "Table") -> str: + """ + Converts a DOCX table to a CSV string. + + :param table: The DOCX table to convert. + :returns: A CSV string representation of the table. + """ + csv_output = StringIO() + csv_writer = csv.writer(csv_output, quoting=csv.QUOTE_MINIMAL) + + # Process rows + for row in table.rows: + csv_row = [cell.text.strip() for cell in row.cells] + csv_writer.writerow(csv_row) + + # Get the CSV as a string and strip any trailing newlines + csv_string = csv_output.getvalue().strip() + csv_output.close() - return paragraph_texts + return csv_string def _get_docx_metadata(self, document: "DocxDocument") -> DOCXMetadata: """ @@ -191,15 +310,15 @@ def _get_docx_metadata(self, document: "DocxDocument") -> DOCXMetadata: category=document.core_properties.category, comments=document.core_properties.comments, content_status=document.core_properties.content_status, - created=document.core_properties.created.isoformat() if document.core_properties.created else None, + created=(document.core_properties.created.isoformat() if document.core_properties.created else None), identifier=document.core_properties.identifier, keywords=document.core_properties.keywords, language=document.core_properties.language, last_modified_by=document.core_properties.last_modified_by, - last_printed=document.core_properties.last_printed.isoformat() - if document.core_properties.last_printed - else None, - modified=document.core_properties.modified.isoformat() if document.core_properties.modified else None, + last_printed=( + document.core_properties.last_printed.isoformat() if document.core_properties.last_printed else None + ), + modified=(document.core_properties.modified.isoformat() if document.core_properties.modified else None), revision=document.core_properties.revision, subject=document.core_properties.subject, title=document.core_properties.title, diff --git a/releasenotes/notes/enhance-docx-table-extraction-3232d3059d220550.yaml b/releasenotes/notes/enhance-docx-table-extraction-3232d3059d220550.yaml new file mode 100644 index 0000000000..3c7c52e0dd --- /dev/null +++ b/releasenotes/notes/enhance-docx-table-extraction-3232d3059d220550.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Enhanced DOCX converter to support table extraction in addition to paragraph content. The converter supports both CSV and Markdown table formats, providing flexible options for representing tabular data extracted from DOCX documents. diff --git a/test/components/converters/test_docx_file_to_document.py b/test/components/converters/test_docx_file_to_document.py index d529f2a06d..64dfaa79f4 100644 --- a/test/components/converters/test_docx_file_to_document.py +++ b/test/components/converters/test_docx_file_to_document.py @@ -1,11 +1,12 @@ -import logging import json - +import logging import pytest +import csv +from io import StringIO +from haystack import Document, Pipeline +from haystack.components.converters.docx import DOCXMetadata, DOCXToDocument, DOCXTableFormat from haystack.dataclasses import ByteStream -from haystack import Document -from haystack.components.converters.docx import DOCXToDocument, DOCXMetadata @pytest.fixture @@ -17,6 +18,96 @@ class TestDOCXToDocument: def test_init(self, docx_converter): assert isinstance(docx_converter, DOCXToDocument) + def test_init_with_string(self): + converter = DOCXToDocument(table_format="markdown") + assert isinstance(converter, DOCXToDocument) + assert converter.table_format == DOCXTableFormat.MARKDOWN + + def test_init_with_invalid_string(self): + with pytest.raises(ValueError, match="Unknown table format 'invalid_format'"): + DOCXToDocument(table_format="invalid_format") + + def test_to_dict(self): + converter = DOCXToDocument() + data = converter.to_dict() + assert data == { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "csv"}, + } + + def test_to_dict_custom_parameters(self): + converter = DOCXToDocument(table_format="markdown") + data = converter.to_dict() + assert data == { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "markdown"}, + } + + converter = DOCXToDocument(table_format="csv") + data = converter.to_dict() + assert data == { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "csv"}, + } + + converter = DOCXToDocument(table_format=DOCXTableFormat.MARKDOWN) + data = converter.to_dict() + assert data == { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "markdown"}, + } + + converter = DOCXToDocument(table_format=DOCXTableFormat.CSV) + data = converter.to_dict() + assert data == { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "csv"}, + } + + def test_from_dict(self): + data = { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "csv"}, + } + converter = DOCXToDocument.from_dict(data) + assert converter.table_format == DOCXTableFormat.CSV + + def test_from_dict_custom_parameters(self): + data = { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "markdown"}, + } + converter = DOCXToDocument.from_dict(data) + assert converter.table_format == DOCXTableFormat.MARKDOWN + + def test_from_dict_invalid_table_format(self): + data = { + "type": "haystack.components.converters.docx.DOCXToDocument", + "init_parameters": {"table_format": "invalid_format"}, + } + with pytest.raises(ValueError, match="Unknown table format 'invalid_format'"): + DOCXToDocument.from_dict(data) + + def test_from_dict_empty_init_parameters(self): + data = {"type": "haystack.components.converters.docx.DOCXToDocument", "init_parameters": {}} + converter = DOCXToDocument.from_dict(data) + assert converter.table_format == DOCXTableFormat.CSV + + def test_pipeline_serde(self): + pipeline = Pipeline() + converter = DOCXToDocument(table_format=DOCXTableFormat.MARKDOWN) + pipeline.add_component("converter", converter) + + pipeline_str = pipeline.dumps() + assert "haystack.components.converters.docx.DOCXToDocument" in pipeline_str + assert "table_format" in pipeline_str + assert "markdown" in pipeline_str + + new_pipeline = Pipeline.loads(pipeline_str) + new_converter = new_pipeline.get_component("converter") + assert isinstance(new_converter, DOCXToDocument) + assert new_converter.table_format == DOCXTableFormat.MARKDOWN + def test_run(self, test_files_path, docx_converter): """ Test if the component runs correctly @@ -48,6 +139,113 @@ def test_run(self, test_files_path, docx_converter): ), } + def test_run_with_table(self, test_files_path): + """ + Test if the component runs correctly + """ + docx_converter = DOCXToDocument(table_format=DOCXTableFormat.MARKDOWN) + paths = [test_files_path / "docx" / "sample_docx.docx"] + output = docx_converter.run(sources=paths) + docs = output["documents"] + assert len(docs) == 1 + assert "Donald Trump" in docs[0].content ## :-) + assert docs[0].meta.keys() == {"file_path", "docx"} + assert docs[0].meta == { + "file_path": str(paths[0]), + "docx": DOCXMetadata( + author="Saha, Anirban", + category="", + comments="", + content_status="", + created="2020-07-14T08:14:00+00:00", + identifier="", + keywords="", + language="", + last_modified_by="Saha, Anirban", + last_printed=None, + modified="2020-07-14T08:16:00+00:00", + revision=1, + subject="", + title="", + version="", + ), + } + # let's now detect that the table markdown is correctly added and that order of elements is correct + content_parts = docs[0].content.split("\n\n") + table_index = next(i for i, part in enumerate(content_parts) if "| This | Is | Just a |" in part) + # check that natural order of the document is preserved + assert any("Donald Trump" in part for part in content_parts[:table_index]), "Text before table not found" + assert any( + "Now we are in Page 2" in part for part in content_parts[table_index + 1 :] + ), "Text after table not found" + + @pytest.mark.parametrize("table_format", ["markdown", "csv"]) + def test_table_between_two_paragraphs(self, test_files_path, table_format): + docx_converter = DOCXToDocument(table_format=table_format) + paths = [test_files_path / "docx" / "sample_docx_3.docx"] + output = docx_converter.run(sources=paths) + + content = output["documents"][0].content + + paragraphs_one = content.find("Table: AI Use Cases in Different Industries") + paragraphs_two = content.find("Paragraph 2:") + table = content[ + paragraphs_one + len("Table: AI Use Cases in Different Industries") + 1 : paragraphs_two + ].strip() + + if table_format == "markdown": + split = list(filter(None, table.split("\n"))) + expected_table_header = "| Industry | AI Use Case | Impact |" + expected_last_row = "| Finance | Fraud detection and prevention | Reduced financial losses |" + + assert split[0] == expected_table_header + assert split[-1] == expected_last_row + if table_format == "csv": # CSV format + csv_reader = csv.reader(StringIO(table)) + rows = list(csv_reader) + assert len(rows) == 3 # Header + 2 data rows + assert rows[0] == ["Industry", "AI Use Case", "Impact"] + assert rows[-1] == ["Finance", "Fraud detection and prevention", "Reduced financial losses"] + + @pytest.mark.parametrize("table_format", ["markdown", "csv"]) + def test_table_content_correct_parsing(self, test_files_path, table_format): + docx_converter = DOCXToDocument(table_format=table_format) + paths = [test_files_path / "docx" / "sample_docx_3.docx"] + output = docx_converter.run(sources=paths) + content = output["documents"][0].content + + paragraphs_one = content.find("Table: AI Use Cases in Different Industries") + paragraphs_two = content.find("Paragraph 2:") + table = content[ + paragraphs_one + len("Table: AI Use Cases in Different Industries") + 1 : paragraphs_two + ].strip() + + if table_format == "markdown": + split = list(filter(None, table.split("\n"))) + assert len(split) == 4 + + expected_table_header = "| Industry | AI Use Case | Impact |" + expected_table_top_border = "| ---------- | ------------------------------ | ------------------------- |" + expected_table_row_one = "| Healthcare | Predictive diagnostics | Improved patient outcomes |" + expected_table_row_two = "| Finance | Fraud detection and prevention | Reduced financial losses |" + + assert split[0] == expected_table_header + assert split[1] == expected_table_top_border + assert split[2] == expected_table_row_one + assert split[3] == expected_table_row_two + if table_format == "csv": # CSV format + csv_reader = csv.reader(StringIO(table)) + rows = list(csv_reader) + assert len(rows) == 3 # Header + 2 data rows + + expected_header = ["Industry", "AI Use Case", "Impact"] + expected_row_one = ["Healthcare", "Predictive diagnostics", "Improved patient outcomes"] + expected_row_two = ["Finance", "Fraud detection and prevention", "Reduced financial losses"] + + assert rows[0] == expected_header + assert rows[1] == expected_row_one + assert rows[2] == expected_row_two + def test_run_with_additional_meta(self, test_files_path, docx_converter): paths = [test_files_path / "docx" / "sample_docx_1.docx"] output = docx_converter.run(sources=paths, meta={"language": "it", "author": "test_author"}) diff --git a/test/test_files/docx/sample_docx_3.docx b/test/test_files/docx/sample_docx_3.docx new file mode 100644 index 0000000000..f3100fa9ab Binary files /dev/null and b/test/test_files/docx/sample_docx_3.docx differ