Skip to content

Commit

Permalink
add example script smoke tests (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Aug 8, 2024
1 parent 6a16e6f commit a6b1411
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ env:
FORCE_COLOR: "1"

jobs:
build:
run:
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-benchmarks') }}
runs-on: ubuntu-latest

Expand Down
29 changes: 29 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,32 @@ jobs:
--splits=6 --group=${{ matrix.group }} --durations-path=../../.github/.test_durations
tests ../datachain/tests
working-directory: backend/datachain_server


examples:
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest-16-cores, macos-latest, windows-latest-8-cores]
pyv: ['3.9', '3.12']
group: ['get_started', 'llm_and_nlp or computer_vision', 'multimodal']
steps:

- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.pyv }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.pyv }}
cache: 'pip'

- name: Upgrade nox and uv
run: |
python -m pip install --upgrade 'nox[uv]'
nox --version
uv --version
- name: Run examples
run: nox -s examples -p ${{ matrix.pyv }} -- -m "${{ matrix.group }}"
24 changes: 13 additions & 11 deletions examples/get_started/common_sql_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def num_chars_udf(file):
return ([],)


ds = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/")
ds.map(num_chars_udf, params=["file"], output={"num_chars": list[str]}).select(
dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/")
dc.map(num_chars_udf, params=["file"], output={"num_chars": list[str]}).select(
"file.path", "num_chars"
).show(5)

(
ds.mutate(
dc.mutate(
length=string.length(path.name(C("file.path"))),
parts=string.split(path.name(C("file.path")), literal(".")),
)
Expand All @@ -25,22 +25,24 @@ def num_chars_udf(file):
)

(
ds.mutate(
dc.mutate(
stem=path.file_stem(path.name(C("file.path"))),
ext=path.file_ext(path.name(C("file.path"))),
)
.select("file.path", "stem", "ext")
.show(5)
)


chain = dc.mutate(
a=array.length(string.split(C("file.path"), literal("/"))),
b=array.length(string.split(path.name(C("file.path")), literal("0"))),
)

(
ds.mutate(
a=array.length(string.split(C("file.path"), literal("/"))),
b=array.length(string.split(path.name(C("file.path")), literal("0"))),
)
.mutate(
greatest=greatest(C("a"), C("b")),
least=least(C("a"), C("b")),
chain.mutate(
greatest=greatest(chain.column("a"), C("b")),
least=least(chain.column("a"), C("b")),
)
.select("a", "b", "greatest", "least")
.show(10)
Expand Down
5 changes: 3 additions & 2 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pip install Pillow torchvision

import os
from posixpath import basename

import torch
Expand All @@ -11,6 +12,7 @@
from datachain.torch import label_to_int

STORAGE = "gs://datachain-demo/dogs-and-cats/"
NUM_EPOCHS = os.getenv("NUM_EPOCHS", "3")

# Define transformation for data preprocessing
transform = v2.Compose(
Expand Down Expand Up @@ -66,8 +68,7 @@ def forward(self, x):
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 3
for epoch in range(num_epochs):
for epoch in range(int(NUM_EPOCHS)):
for i, data in enumerate(train_loader):
inputs, labels = data
optimizer.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_and_nlp/unstructured-text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx"

from transformers import pipeline
from unstructured.partition.auto import partition
from unstructured.partition.pdf import partition_pdf as partition
from unstructured.staging.base import convert_to_dataframe

from datachain import C, DataChain
Expand Down
23 changes: 17 additions & 6 deletions examples/multimodal/wds.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
import os

from datachain import C, DataChain
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta
from datachain.sql.functions import path

IMAGE_TARS = os.getenv(
"IMAGE_TARS", "gs://datachain-demo/datacomp-small/shards/000000[0-5]*.tar"
)
PARQUET_METADATA = os.getenv(
"PARQUET_METADATA", "gs://datachain-demo/datacomp-small/metadata/0020f*.parquet"
)
NPZ_METADATA = os.getenv(
"NPZ_METADATA", "gs://datachain-demo/datacomp-small/metadata/0020f*.npz"
)

wds_images = (
DataChain.from_storage("gs://datachain-demo/datacomp-small/shards/")
.filter(C("file.path").glob("*000000[0-5]*.tar")) # from *00.tar to *59.tar
DataChain.from_storage(IMAGE_TARS)
.settings(cache=True)
.gen(laion=process_webdataset(spec=WDSLaion), params="file")
)

wds_with_pq = (
DataChain.from_parquet("gs://datachain-demo/datacomp-small/metadata/0020f*.parquet")
DataChain.from_parquet(PARQUET_METADATA)
.settings(cache=True)
.merge(wds_images, on="uid", right_on="laion.json.uid", inner=True)
.mutate(stem=path.file_stem(path.name(C("source.file.path"))))
.mutate(stem=path.file_stem(C("source.file.path")))
)

res = (
DataChain.from_storage("gs://datachain-demo/datacomp-small/metadata/0020f*.npz")
DataChain.from_storage(NPZ_METADATA)
.settings(cache=True)
.gen(emd=process_laion_meta)
.mutate(stem=path.file_stem(path.name(C("emd.file.path"))))
.mutate(stem=path.file_stem(C("emd.file.path")))
.merge(
wds_with_pq,
on=["stem", "emd.index"],
Expand Down
6 changes: 4 additions & 2 deletions examples/multimodal/wds_filtered.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import datachain.error
from datachain import C, DataChain
from datachain.lib.model_store import ModelStore
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion
from datachain.lib.webdataset_laion import LaionMeta, WDSLaion
from datachain.sql import literal
from datachain.sql.functions import array, greatest, least, string

name = "wds"
ModelStore.register(LaionMeta)
try:
wds = DataChain.from_dataset(name=name)
except datachain.error.DatasetNotFoundError:
Expand All @@ -18,7 +20,6 @@
)

wds.print_schema()
wds.show(3)

filtered = (
wds.filter(string.length(C("laion.txt")) > 5)
Expand All @@ -33,6 +34,7 @@
)
.save()
)

filtered.show(3)

print(f"wds count: {wds.count():>6}")
Expand Down
11 changes: 11 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,14 @@ def dev(session: nox.Session) -> None:

python = os.path.join(venv_dir, "bin/python")
session.run(python, "-m", "pip", "install", "-e", ".[dev]", external=True)


@nox.session(python=["3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"])
def examples(session: nox.Session) -> None:
session.install(".[examples]")
session.run(
"pytest",
"-m",
"examples",
*session.posargs,
)
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ dev = [
"types-PyYAML",
"types-requests"
]
examples = [
"datachain[tests]",
"defusedxml",
"accelerate",
"unstructured[pdf]",
"pdfplumber==0.11.1"
]

[project.urls]
Documentation = "https://datachain.dvc.ai"
Expand All @@ -110,10 +117,15 @@ namespaces = false
[tool.setuptools_scm]

[tool.pytest.ini_options]
addopts = "-rfEs -m 'not benchmark'"
addopts = "-rfEs -m 'not benchmark and not examples'"
markers = [
"benchmark: benchmarks.",
"e2e: End-to-end tests"
"e2e: End-to-end tests",
"examples: All examples",
"computer_vision: Computer vision examples",
"get_started: Get started examples",
"llm_and_nlp: LLM and NLP examples",
"multimodal: Multimodal examples"
]
asyncio_mode = "auto"
filterwarnings = [
Expand Down
97 changes: 97 additions & 0 deletions tests/examples/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import glob
import os
import subprocess
import sys
from typing import Optional

import pytest

get_started_examples = [
filename
for filename in glob.glob("examples/get_started/**/*.py", recursive=True)
# torch-loader will not finish within an hour on Linux runner
if "torch" not in filename or os.environ.get("RUNNER_OS") != "Linux"
]

llm_and_nlp_examples = [
filename
for filename in glob.glob("examples/llm_and_nlp/**/*.py", recursive=True)
# no anthropic token
if "claude" not in filename
]

multimodal_examples = [
filename
for filename in glob.glob("examples/multimodal/**/*.py", recursive=True)
# no OpenAI token
# and hf download painfully slow
if "openai" not in filename and "hf" not in filename
]

computer_vision_examples = [
filename
for filename in glob.glob("examples/computer_vision/**/*.py", recursive=True)
# fashion product images tutorial out of scope
# and hf download painfully slow
if "image_desc" not in filename and "fashion_product_images" not in filename
]


def smoke_test(example: str, env: Optional[dict] = None):
try:
completed_process = subprocess.run( # noqa: S603
[sys.executable, example],
env={**os.environ, **(env or {})},
capture_output=True,
cwd=os.path.abspath(os.path.join(__file__, "..", "..", "..")),
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Example failed: {example}")
print()
print()
print("stdout:")
print(e.stdout.decode("utf-8"))
print()
print()
print("stderr:")
print(e.stderr.decode("utf-8"))
pytest.fail("subprocess returned a non-zero exit code")

assert completed_process.stdout
assert completed_process.stderr


@pytest.mark.examples
@pytest.mark.get_started
@pytest.mark.parametrize("example", get_started_examples)
def test_get_started_examples(example):
smoke_test(example, {"NUM_EPOCHS": "1"})


@pytest.mark.examples
@pytest.mark.llm_and_nlp
@pytest.mark.parametrize("example", llm_and_nlp_examples)
def test_llm_and_nlp_examples(example):
smoke_test(example)


@pytest.mark.examples
@pytest.mark.multimodal
@pytest.mark.parametrize("example", multimodal_examples)
def test_multimodal(example):
smoke_test(
example,
{
"IMAGE_TARS": "gs://datachain-demo/datacomp-small/shards/00001286.tar",
"PARQUET_METADATA": "gs://datachain-demo/datacomp-small/metadata/036d6b9ae87a00e738f8fc554130b65b.parquet",
"NPZ_METADATA": "gs://datachain-demo/datacomp-small/metadata/036d6b9ae87a00e738f8fc554130b65b.npz",
},
)


@pytest.mark.examples
@pytest.mark.computer_vision
@pytest.mark.parametrize("example", computer_vision_examples)
def test_computer_vision_examples(example):
smoke_test(example)

0 comments on commit a6b1411

Please sign in to comment.