Skip to content

Commit

Permalink
Merge branch 'ilongin/266-refactor-from-storage' into ilongin/329-ref…
Browse files Browse the repository at this point in the history
…actor-storages
  • Loading branch information
ilongin committed Sep 2, 2024
2 parents 50be970 + 62045cd commit 2e125ae
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 102 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ dependencies = [
"datamodel-code-generator>=0.25",
"Pillow>=10.0.0,<11",
"msgpack>=1.0.4,<2",
"psutil"
"psutil",
"huggingface_hub"
]

[project.optional-dependencies]
Expand Down
3 changes: 3 additions & 0 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
def get_implementation(url: str) -> type["Client"]:
from .azure import AzureClient
from .gcs import GCSClient
from .hf import HfClient
from .local import FileClient
from .s3 import ClientS3

Expand All @@ -104,6 +105,8 @@ def get_implementation(url: str) -> type["Client"]:
return AzureClient
if protocol == FileClient.protocol:
return FileClient
if protocol == HfClient.protocol:
return HfClient

raise NotImplementedError(f"Unsupported protocol: {protocol}")

Expand Down
47 changes: 47 additions & 0 deletions src/datachain/client/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import posixpath
from typing import Any, cast

from huggingface_hub import HfFileSystem

from datachain.lib.file import File
from datachain.node import Entry

from .fsspec import Client


class HfClient(Client):
FS_CLASS = HfFileSystem
PREFIX = "hf://"
protocol = "hf"

@classmethod
def create_fs(cls, **kwargs) -> HfFileSystem:
if os.environ.get("HF_TOKEN"):
kwargs["token"] = os.environ["HF_TOKEN"]

return cast(HfFileSystem, super().create_fs(**kwargs))

def convert_info(self, v: dict[str, Any], path: str) -> Entry:
return Entry.from_file(
path=path,
size=v["size"],
version=v["last_commit"].oid,
etag=v.get("blob_id", ""),
last_modified=v["last_commit"].date,
)

def info_to_file(self, v: dict[str, Any], path: str) -> File:
return File(
path=path,
size=v["size"],
version=v["last_commit"].oid,
etag=v.get("blob_id", ""),
last_modified=v["last_commit"].date,
)

async def ls_dir(self, path):
return self.fs.ls(path, detail=True)

def rel_path(self, path):
return posixpath.relpath(path, self.name)
26 changes: 21 additions & 5 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from pyarrow.dataset import dataset
from tqdm import tqdm

from datachain.lib.data_model import dict_to_data_model
from datachain.lib.file import File, IndexedFile
from datachain.lib.model_store import ModelStore
from datachain.lib.udf import Generator

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,7 +61,13 @@ def process(self, file: File):
vals = list(record.values())
if self.output_schema:
fields = self.output_schema.model_fields
vals = [self.output_schema(**dict(zip(fields, vals)))]
vals_dict = {}
for (field, field_info), val in zip(fields.items(), vals):
if ModelStore.is_pydantic(field_info.annotation):
vals_dict[field] = field_info.annotation(**val) # type: ignore[misc]
else:
vals_dict[field] = val
vals = [self.output_schema(**vals_dict)]
if self.source:
yield [IndexedFile(file=file, index=index), *vals]
else:
Expand Down Expand Up @@ -95,15 +103,15 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
if not column:
column = f"c{default_column}"
default_column += 1
dtype = arrow_type_mapper(field.type) # type: ignore[assignment]
if field.nullable:
dtype = arrow_type_mapper(field.type, column) # type: ignore[assignment]
if field.nullable and not ModelStore.is_pydantic(dtype):
dtype = Optional[dtype] # type: ignore[assignment]
output[column] = dtype

return output


def arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
"""Convert pyarrow types to basic types."""
from datetime import datetime

Expand All @@ -123,7 +131,15 @@ def arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
return str
if pa.types.is_list(col_type):
return list[arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
if pa.types.is_struct(col_type):
type_dict = {}
for field in col_type:
dtype = arrow_type_mapper(field.type, field.name)
if field.nullable and not ModelStore.is_pydantic(dtype):
dtype = Optional[dtype] # type: ignore[assignment]
type_dict[field.name] = dtype
return dict_to_data_model(column, type_dict)
if pa.types.is_map(col_type):
return dict
if isinstance(col_type, pa.lib.DictionaryType):
return arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def from_storage(
in_memory=in_memory,
)
.gen(
list_bucket(list_uri, **session.catalog.client_config),
list_bucket(list_uri, client_config=session.catalog.client_config),
output={f"{object_name}": File},
)
.save(list_dataset_name, listing=True)
Expand Down
100 changes: 7 additions & 93 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import asyncio
import posixpath
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import Iterator
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Callable, Optional

from botocore.exceptions import ClientError
from fsspec.asyn import get_loop
from sqlalchemy.sql.expression import true

from datachain.asyn import iter_over_async
from datachain.client import Client
from datachain.error import ClientError as DataChainClientError
from datachain.lib.file import File
from datachain.query.schema import Column
from datachain.sql.functions import path as pathfunc
Expand All @@ -19,105 +16,22 @@
if TYPE_CHECKING:
from datachain.lib.dc import DataChain


ResultQueue = asyncio.Queue[Optional[Sequence[File]]]

DELIMITER = "/" # Path delimiter
FETCH_WORKERS = 100
LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
LISTING_PREFIX = "lst__" # listing datasets start with this name


async def _fetch_dir(client, prefix, result_queue) -> set[str]:
path = f"{client.name}/{prefix}"
infos = await client.ls_dir(path)
files = []
subdirs = set()
for info in infos:
full_path = info["name"]
subprefix = client.rel_path(full_path)
if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
continue
if info["type"] == "directory":
subdirs.add(subprefix)
else:
files.append(client.info_to_file(info, subprefix))
if files:
await result_queue.put(files)
return subdirs


async def _fetch(
client, start_prefix: str, result_queue: ResultQueue, fetch_workers
) -> None:
loop = get_loop()

queue: asyncio.Queue[str] = asyncio.Queue()
queue.put_nowait(start_prefix)

async def process(queue) -> None:
while True:
prefix = await queue.get()
try:
subdirs = await _fetch_dir(client, prefix, result_queue)
for subdir in subdirs:
queue.put_nowait(subdir)
except Exception:
while not queue.empty():
queue.get_nowait()
queue.task_done()
raise

finally:
queue.task_done()

try:
workers: list[asyncio.Task] = [
loop.create_task(process(queue)) for _ in range(fetch_workers)
]

# Wait for all fetch tasks to complete
await queue.join()
# Stop the workers
excs = []
for worker in workers:
if worker.done() and (exc := worker.exception()):
excs.append(exc)
else:
worker.cancel()
if excs:
raise excs[0]
except ClientError as exc:
raise DataChainClientError(
exc.response.get("Error", {}).get("Message") or exc,
exc.response.get("Error", {}).get("Code"),
) from exc
finally:
# This ensures the progress bar is closed before any exceptions are raised
result_queue.put_nowait(None)


async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
"""Recursively goes through dir tree and yields files"""
result_queue: ResultQueue = asyncio.Queue()
loop = get_loop()
main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
while (files := await result_queue.get()) is not None:
for f in files:
yield f

await main_task


def list_bucket(uri: str, fetch_workers=FETCH_WORKERS, **kwargs) -> Callable:
def list_bucket(uri: str, client_config=None) -> Callable:
"""
Function that returns another generator function that yields File objects
from bucket where each File represents one bucket entry.
"""

def list_func() -> Iterator[File]:
client, path = Client.parse_url(uri, None, **kwargs) # type: ignore[arg-type]
yield from iter_over_async(_scandir(client, path, fetch_workers), get_loop())
config = client_config or {}
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
for entry in entries:
yield entry.to_file(client.uri)

return list_func

Expand Down
13 changes: 13 additions & 0 deletions src/datachain/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import attrs

from datachain.cache import UniqueId
from datachain.lib.file import File
from datachain.storage import StorageURI
from datachain.utils import TIME_ZERO, time_to_str

Expand Down Expand Up @@ -189,6 +190,18 @@ def parent(self):
return ""
return split[0]

def to_file(self, source: str) -> File:
return File(
source=source,
path=self.path,
size=self.size,
version=self.version,
etag=self.etag,
is_latest=self.is_latest,
last_modified=self.last_modified,
location=self.location,
)


def get_path(parent: str, name: str):
return f"{parent}/{name}" if parent else name
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_listing_generator(cloud_test_catalog, cloud_type):
uri = f"{ctc.src_uri}/cats"

dc = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD).gen(
file=list_bucket(uri, **ctc.catalog.client_config)
file=list_bucket(uri, client_config=ctc.catalog.client_config)
)
assert dc.count() == 2

Expand Down
1 change: 1 addition & 0 deletions tests/test_query_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
dogs-and-cats/cat.1001.jpg
"""
),
"listing": True,
},
{
"command": (
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from datachain.lib.arrow import (
ArrowGenerator,
arrow_type_mapper,
schema_to_output,
)
from datachain.lib.data_model import dict_to_data_model
from datachain.lib.file import File, IndexedFile


Expand Down Expand Up @@ -55,6 +57,34 @@ def test_arrow_generator_no_source(tmp_path, catalog):
assert o[1] == text


def test_arrow_generator_output_schema(tmp_path, catalog):
ids = [12345, 67890, 34, 0xF0123]
texts = ["28", "22", "we", "hello world"]
dicts = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}, {"a": 7, "b": 8}]
df = pd.DataFrame({"id": ids, "text": texts, "dict": dicts})
table = pa.Table.from_pandas(df)

name = "111.parquet"
pq_path = tmp_path / name
pq.write_table(table, pq_path)
stream = File(path=pq_path.as_posix(), source="file:///")
stream._set_stream(catalog, caching_enabled=False)

output_schema = dict_to_data_model("", schema_to_output(table.schema))
func = ArrowGenerator(output_schema=output_schema)
objs = list(func.process(stream))

assert len(objs) == len(ids)
for index, (o, id, text, dict) in enumerate(zip(objs, ids, texts, dicts)):
assert isinstance(o[0], IndexedFile)
assert isinstance(o[0].file, File)
assert o[0].index == index
assert o[1].id == id
assert o[1].text == text
assert o[1].dict.a == dict["a"]
assert o[1].dict.b == dict["b"]


@pytest.mark.parametrize(
"col_type,expected",
(
Expand All @@ -72,7 +102,6 @@ def test_arrow_generator_no_source(tmp_path, catalog):
(pa.date32(), datetime),
(pa.string(), str),
(pa.large_string(), str),
(pa.struct({"x": pa.int32(), "y": pa.string()}), dict),
(pa.map_(pa.string(), pa.int32()), dict),
(pa.dictionary(pa.int64(), pa.string()), str),
(pa.list_(pa.string()), list[str]),
Expand All @@ -82,6 +111,14 @@ def test_arrow_type_mapper(col_type, expected):
assert arrow_type_mapper(col_type) == expected


def test_arrow_type_mapper_struct():
col_type = pa.struct({"x": pa.int32(), "y": pa.string()})
fields = arrow_type_mapper(col_type).model_fields
assert list(fields.keys()) == ["x", "y"]
dtypes = [field.annotation for field in fields.values()]
assert dtypes == [Optional[int], Optional[str]]


def test_arrow_type_error():
col_type = pa.union(
[pa.field("a", pa.binary(10)), pa.field("b", pa.string())],
Expand Down

0 comments on commit 2e125ae

Please sign in to comment.