-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
checksum support , improve control with new cmd line switches #2134
base: main
Are you sure you want to change the base?
Changes from 19 commits
a608e57
3206092
a482213
ce4bffe
285ae7d
bfb95ab
d07c9eb
4769133
9a385a4
66b5506
f4a504c
4ac3780
283d6d2
797347c
66438f2
15f34a5
2375e70
9d2dbf1
4ef5622
7a1c07f
0b43812
d7f1c80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,11 +3,13 @@ | |
import logging | ||
import os | ||
import re | ||
from enum import Enum | ||
from typing import List, Optional, Union | ||
|
||
import fitz # type: ignore | ||
from azure.core.credentials_async import AsyncTokenCredential | ||
from azure.storage.blob import ( | ||
BlobClient, | ||
BlobSasPermissions, | ||
UserDelegationKey, | ||
generate_blob_sas, | ||
|
@@ -45,29 +47,65 @@ def __init__( | |
self.subscriptionId = subscriptionId | ||
self.user_delegation_key: Optional[UserDelegationKey] = None | ||
|
||
async def _create_new_blob(self, file: File, container_client: ContainerClient) -> BlobClient: | ||
with open(file.content.name, "rb") as reopened_file: | ||
blob_name = BlobManager.blob_name_from_file_name(file.content.name) | ||
logger.info("Uploading blob for whole file -> %s", blob_name) | ||
return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata) | ||
|
||
async def _file_blob_update_needed(self, blob_client: BlobClient, file: File) -> bool: | ||
# Get existing blob properties | ||
blob_properties = await blob_client.get_blob_properties() | ||
blob_metadata = blob_properties.metadata | ||
|
||
# Check if the md5 values are the same | ||
file_md5 = file.metadata.get("md5") | ||
blob_md5 = blob_metadata.get("md5") | ||
|
||
# If the file has an md5 value, check if it is different from the blob | ||
return file_md5 and file_md5 != blob_md5 | ||
|
||
async def upload_blob(self, file: File) -> Optional[List[str]]: | ||
async with BlobServiceClient( | ||
account_url=self.endpoint, credential=self.credential, max_single_put_size=4 * 1024 * 1024 | ||
) as service_client, service_client.get_container_client(self.container) as container_client: | ||
if not await container_client.exists(): | ||
await container_client.create_container() | ||
|
||
# Re-open and upload the original file | ||
# Re-open and upload the original file if the blob does not exist or the md5 values do not match | ||
class MD5Check(Enum): | ||
NOT_DONE = 0 | ||
MATCH = 1 | ||
NO_MATCH = 2 | ||
|
||
md5_check = MD5Check.NOT_DONE | ||
|
||
# Upload the file to Azure Storage | ||
# file.url is only None if files are not uploaded yet, for datalake it is set | ||
if file.url is None: | ||
with open(file.content.name, "rb") as reopened_file: | ||
blob_name = BlobManager.blob_name_from_file_name(file.content.name) | ||
logger.info("Uploading blob for whole file -> %s", blob_name) | ||
blob_client = await container_client.upload_blob(blob_name, reopened_file, overwrite=True) | ||
file.url = blob_client.url | ||
blob_client = container_client.get_blob_client(file.url) | ||
|
||
if self.store_page_images: | ||
if not await blob_client.exists(): | ||
logger.info("Blob %s does not exist, uploading", file.url) | ||
blob_client = await self._create_new_blob(file, container_client) | ||
else: | ||
if self._blob_update_needed(blob_client, file): | ||
logger.info("Blob %s exists but md5 values do not match, updating", file.url) | ||
md5_check = MD5Check.NO_MATCH | ||
# Upload the file with the updated metadata | ||
with open(file.content.name, "rb") as data: | ||
await blob_client.upload_blob(data, overwrite=True, metadata=file.metadata) | ||
else: | ||
logger.info("Blob %s exists and md5 values match, skipping upload", file.url) | ||
md5_check = MD5Check.MATCH | ||
file.url = blob_client.url | ||
|
||
if md5_check != MD5Check.MATCH and self.store_page_images: | ||
if os.path.splitext(file.content.name)[1].lower() == ".pdf": | ||
return await self.upload_pdf_blob_images(service_client, container_client, file) | ||
else: | ||
logger.info("File %s is not a PDF, skipping image upload", file.content.name) | ||
|
||
return None | ||
|
||
def get_managedidentity_connectionstring(self): | ||
return f"ResourceId=/subscriptions/{self.subscriptionId}/resourceGroups/{self.resourceGroup}/providers/Microsoft.Storage/storageAccounts/{self.account};" | ||
|
||
|
@@ -93,7 +131,20 @@ async def upload_pdf_blob_images( | |
|
||
for i in range(page_count): | ||
blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i) | ||
logger.info("Converting page %s to image and uploading -> %s", i, blob_name) | ||
|
||
blob_client = container_client.get_blob_client(blob_name) | ||
if await blob_client.exists(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this additional check here, given the check that happens in the function that calls this code? |
||
# Get existing blob properties | ||
blob_properties = await blob_client.get_blob_properties() | ||
blob_metadata = blob_properties.metadata | ||
|
||
# Check if the md5 values are the same | ||
file_md5 = file.metadata.get("md5") | ||
blob_md5 = blob_metadata.get("md5") | ||
if file_md5 == blob_md5: | ||
continue # documemt already uploaded | ||
|
||
logger.debug("Converting page %s to image and uploading -> %s", i, blob_name) | ||
|
||
doc = fitz.open(file.content.name) | ||
page = doc.load_page(i) | ||
|
@@ -120,15 +171,15 @@ async def upload_pdf_blob_images( | |
new_img.save(output, format="PNG") | ||
output.seek(0) | ||
|
||
blob_client = await container_client.upload_blob(blob_name, output, overwrite=True) | ||
await blob_client.upload_blob(data=output, overwrite=True, metadata=file.metadata) | ||
if not self.user_delegation_key: | ||
self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time) | ||
|
||
if blob_client.account_name is not None: | ||
if container_client.account_name is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am curious why the need to go from blob_client to container_client here? Does it matter? |
||
sas_token = generate_blob_sas( | ||
account_name=blob_client.account_name, | ||
container_name=blob_client.container_name, | ||
blob_name=blob_client.blob_name, | ||
account_name=container_client.account_name, | ||
container_name=container_client.container_name, | ||
blob_name=blob_name, | ||
user_delegation_key=self.user_delegation_key, | ||
permission=BlobSasPermissions(read=True), | ||
expiry=expiry_time, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,11 +22,11 @@ async def parse_file( | |
if processor is None: | ||
logger.info("Skipping '%s', no parser found.", file.filename()) | ||
return [] | ||
logger.info("Ingesting '%s'", file.filename()) | ||
logger.debug("Ingesting '%s'", file.filename()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like you changed the levels to debug, does the current output feel too verbose? I find it helpful. |
||
pages = [page async for page in processor.parser.parse(content=file.content)] | ||
logger.info("Splitting '%s' into sections", file.filename()) | ||
logger.debug("Splitting '%s' into sections", file.filename()) | ||
if image_embeddings: | ||
logger.warning("Each page will be split into smaller chunks of text, but images will be of the entire page.") | ||
logger.debug("Each page will be split into smaller chunks of text, but images will be of the entire page.") | ||
sections = [ | ||
Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages) | ||
] | ||
|
@@ -44,6 +44,7 @@ def __init__( | |
blob_manager: BlobManager, | ||
search_info: SearchInfo, | ||
file_processors: dict[str, FileProcessor], | ||
ignore_checksum: bool, | ||
document_action: DocumentAction = DocumentAction.Add, | ||
embeddings: Optional[OpenAIEmbeddings] = None, | ||
image_embeddings: Optional[ImageEmbeddings] = None, | ||
|
@@ -55,6 +56,7 @@ def __init__( | |
self.blob_manager = blob_manager | ||
self.file_processors = file_processors | ||
self.document_action = document_action | ||
self.ignore_checksum = ignore_checksum | ||
self.embeddings = embeddings | ||
self.image_embeddings = image_embeddings | ||
self.search_analyzer_name = search_analyzer_name | ||
|
@@ -77,29 +79,65 @@ async def run(self): | |
search_manager = SearchManager( | ||
self.search_info, self.search_analyzer_name, self.use_acls, False, self.embeddings | ||
) | ||
doccount = self.list_file_strategy.count_docs() | ||
logger.info(f"Processing {doccount} files") | ||
processed_count = 0 | ||
if self.document_action == DocumentAction.Add: | ||
files = self.list_file_strategy.list() | ||
async for file in files: | ||
try: | ||
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings) | ||
if sections: | ||
blob_sas_uris = await self.blob_manager.upload_blob(file) | ||
blob_image_embeddings: Optional[List[List[float]]] = None | ||
if self.image_embeddings and blob_sas_uris: | ||
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) | ||
await search_manager.update_content(sections, blob_image_embeddings, url=file.url) | ||
if self.ignore_checksum or not await search_manager.file_exists(file): | ||
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings) | ||
if sections: | ||
blob_sas_uris = await self.blob_manager.upload_blob(file) | ||
blob_image_embeddings: Optional[List[List[float]]] = None | ||
if self.image_embeddings and blob_sas_uris: | ||
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) | ||
await search_manager.update_content( | ||
sections=sections, file=file, image_embeddings=blob_image_embeddings | ||
) | ||
finally: | ||
if file: | ||
file.close() | ||
processed_count += 1 | ||
if processed_count % 10 == 0: | ||
remaining = max(doccount - processed_count, 1) | ||
logger.info(f"{processed_count} processed, {remaining} documents remaining") | ||
|
||
elif self.document_action == DocumentAction.Remove: | ||
doccount = self.list_file_strategy.count_docs() | ||
paths = self.list_file_strategy.list_paths() | ||
async for path in paths: | ||
await self.blob_manager.remove_blob(path) | ||
await search_manager.remove_content(path) | ||
processed_count += 1 | ||
if processed_count % 10 == 0: | ||
remaining = max(doccount - processed_count, 1) | ||
logger.info(f"{processed_count} removed, {remaining} documents remaining") | ||
|
||
elif self.document_action == DocumentAction.RemoveAll: | ||
await self.blob_manager.remove_blob() | ||
await search_manager.remove_content() | ||
|
||
async def process_file(self, file, search_manager): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two functions seem to be unused currently- perhaps your thought was to refactor the code above to call these two functions. I'll remove them for now to make the diff smaller. |
||
try: | ||
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings) | ||
if sections: | ||
blob_sas_uris = await self.blob_manager.upload_blob(file) | ||
blob_image_embeddings: Optional[List[List[float]]] = None | ||
if self.image_embeddings and blob_sas_uris: | ||
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) | ||
await search_manager.update_content( | ||
sections=sections, file=file, image_embeddings=blob_image_embeddings | ||
) | ||
finally: | ||
if file: | ||
file.close() | ||
|
||
async def remove_file(self, path, search_manager): | ||
await self.blob_manager.remove_blob(path) | ||
await search_manager.remove_content(path) | ||
|
||
|
||
class UploadUserFileStrategy: | ||
""" | ||
|
@@ -124,7 +162,7 @@ async def add_file(self, file: File): | |
logging.warning("Image embeddings are not currently supported for the user upload feature") | ||
sections = await parse_file(file, self.file_processors) | ||
if sections: | ||
await self.search_manager.update_content(sections, url=file.url) | ||
await self.search_manager.update_content(sections=sections, file=file) | ||
|
||
async def remove_file(self, filename: str, oid: str): | ||
if filename is None or filename == "": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,9 +9,9 @@ | |
from typing import IO, AsyncGenerator, Dict, List, Optional, Union | ||
|
||
from azure.core.credentials_async import AsyncTokenCredential | ||
from azure.storage.filedatalake.aio import ( | ||
DataLakeServiceClient, | ||
) | ||
from azure.identity import DefaultAzureCredential | ||
from azure.storage.blob import BlobServiceClient | ||
from azure.storage.filedatalake.aio import DataLakeServiceClient | ||
|
||
logger = logging.getLogger("scripts") | ||
|
||
|
@@ -22,10 +22,17 @@ class File: | |
This file might contain access control information about which users or groups can access it | ||
""" | ||
|
||
def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None): | ||
def __init__( | ||
self, | ||
content: IO, | ||
acls: Optional[dict[str, list]] = None, | ||
url: Optional[str] = None, | ||
metadata: Dict[str, str] = None, | ||
): | ||
self.content = content | ||
self.acls = acls or {} | ||
self.url = url | ||
self.metadata = metadata | ||
|
||
def filename(self): | ||
return os.path.basename(self.content.name) | ||
|
@@ -59,6 +66,10 @@ async def list_paths(self) -> AsyncGenerator[str, None]: | |
if False: # pragma: no cover - this is necessary for mypy to type check | ||
yield | ||
|
||
def count_docs(self) -> int: | ||
if False: # pragma: no cover - this is necessary for mypy to type check | ||
yield | ||
|
||
|
||
class LocalListFileStrategy(ListFileStrategy): | ||
""" | ||
|
@@ -110,6 +121,22 @@ def check_md5(self, path: str) -> bool: | |
|
||
return False | ||
|
||
def count_docs(self) -> int: | ||
""" | ||
Return the number of files that match the path pattern. | ||
""" | ||
return sum(1 for _ in self._list_paths_sync(self.path_pattern)) | ||
|
||
def _list_paths_sync(self, path_pattern: str): | ||
""" | ||
Synchronous version of _list_paths to be used for counting files. | ||
""" | ||
for path in glob(path_pattern): | ||
if os.path.isdir(path): | ||
yield from self._list_paths_sync(f"{path}/*") | ||
else: | ||
yield path | ||
|
||
|
||
class ADLSGen2ListFileStrategy(ListFileStrategy): | ||
""" | ||
|
@@ -168,10 +195,33 @@ async def list(self) -> AsyncGenerator[File, None]: | |
acls["oids"].append(acl_parts[1]) | ||
if acl_parts[0] == "group" and "r" in acl_parts[2]: | ||
acls["groups"].append(acl_parts[1]) | ||
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url) | ||
properties = await file_client.get_file_properties() | ||
yield File( | ||
content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata | ||
) | ||
except Exception as data_lake_exception: | ||
logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file") | ||
try: | ||
os.remove(temp_file_path) | ||
except Exception as file_delete_exception: | ||
logger.error(f"\tGot an error while deleting {temp_file_path} -> {file_delete_exception}") | ||
|
||
def count_docs(self) -> int: | ||
""" | ||
Return the number of blobs in the specified folder within the Azure Blob Storage container. | ||
""" | ||
|
||
# Create a BlobServiceClient using account URL and credentials | ||
service_client = BlobServiceClient( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use BlobServiceClient to interact with a DataLake Storage account? That's what you seem to be doing here, but I didn't realize that was possible. |
||
account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net", | ||
credential=DefaultAzureCredential(), | ||
) | ||
|
||
# Get the container client | ||
container_client = service_client.get_container_client(self.data_lake_filesystem) | ||
|
||
# Count blobs within the specified folder | ||
if self.data_lake_path != "/": | ||
return sum(1 for _ in container_client.list_blobs(name_starts_with=self.data_lake_path)) | ||
else: | ||
return sum(1 for _ in container_client.list_blobs()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see logic in the main prepdocs for setting the md5 of the local file metadata, I only see that in ADLS2. How does this work when not using the ADLS2 strategy?