Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

query: remove support for saving dataset query with a given name #389

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,19 +1297,6 @@ def create_dataset_from_sources(

return self.get_dataset(name)

def register_new_dataset(
Copy link
Member Author

@skshetry skshetry Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is no longer used after removing save_as logic from query_wrapper below, so I have removed it.

self,
source_dataset: DatasetRecord,
source_version: int,
target_name: str,
) -> DatasetRecord:
target_dataset = self.metastore.create_dataset(
target_name,
query_script=source_dataset.query_script,
schema=source_dataset.serialized_schema,
)
return self.register_dataset(source_dataset, source_version, target_dataset, 1)

def register_dataset(
self,
dataset: DatasetRecord,
Expand Down Expand Up @@ -1875,7 +1862,6 @@ def query(
envs: Optional[Mapping[str, str]] = None,
python_executable: Optional[str] = None,
save: bool = False,
save_as: Optional[str] = None,
preview_limit: int = 10,
preview_offset: int = 0,
preview_columns: Optional[list[str]] = None,
Expand Down Expand Up @@ -1927,7 +1913,6 @@ def query(
preview_limit,
preview_offset,
save,
save_as,
job_id,
)
finally:
Expand Down Expand Up @@ -1963,7 +1948,7 @@ def query(

dataset: Optional[DatasetRecord] = None
version: Optional[int] = None
if save or save_as:
if save:
dataset, version = self.save_result(
query_script, exec_result, output, version, job_id
)
Expand All @@ -1990,7 +1975,6 @@ def run_query(
preview_limit: int,
preview_offset: int,
save: bool,
save_as: Optional[str],
job_id: Optional[str],
) -> tuple[list[str], subprocess.Popen, str]:
try:
Expand All @@ -2005,10 +1989,6 @@ def run_query(
raise QueryScriptCompileError(
f"Query script failed to compile, reason: {exc}"
) from exc
if save_as and save_as.startswith(QUERY_DATASET_PREFIX):
raise ValueError(
f"Cannot use {QUERY_DATASET_PREFIX} prefix for dataset name"
)
r, w = os.pipe()
if os.name == "nt":
import msvcrt
Expand Down Expand Up @@ -2039,7 +2019,6 @@ def run_query(
}
),
"DATACHAIN_QUERY_SAVE": "1" if save else "",
"DATACHAIN_QUERY_SAVE_AS": save_as or "",
"PYTHONUNBUFFERED": "1",
"DATACHAIN_OUTPUT_FD": str(handle),
"DATACHAIN_JOB_ID": job_id or "",
Expand Down
6 changes: 0 additions & 6 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
query_parser.add_argument(
"script", metavar="<script.py>", type=str, help="Filepath for script"
)
query_parser.add_argument(
"dataset_name", nargs="?", type=str, help="Save result dataset as"
)
query_parser.add_argument(
"--parallel",
nargs="?",
Expand Down Expand Up @@ -813,7 +810,6 @@ def show(
def query(
catalog: "Catalog",
script: str,
dataset_name: Optional[str] = None,
parallel: Optional[int] = None,
limit: int = 10,
offset: int = 0,
Expand Down Expand Up @@ -846,7 +842,6 @@ def query(
result = catalog.query(
script_content,
python_executable=python_executable,
save_as=dataset_name,
preview_limit=limit,
preview_offset=offset,
preview_columns=columns,
Expand Down Expand Up @@ -1042,7 +1037,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
query(
catalog,
args.script,
dataset_name=args.dataset_name,
parallel=args.parallel,
limit=args.limit,
offset=args.offset,
Expand Down
29 changes: 1 addition & 28 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,39 +1784,12 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:

catalog = dataset_query.catalog
save = bool(os.getenv("DATACHAIN_QUERY_SAVE"))
save_as = os.getenv("DATACHAIN_QUERY_SAVE_AS")

is_session_temp_dataset = dataset_query.name and dataset_query.name.startswith(
dataset_query.session.get_temp_prefix()
)

if save_as:
if dataset_query.attached:
dataset_name = dataset_query.name
version = dataset_query.version
assert dataset_name, "Dataset name should be provided in attached mode"
assert version, "Dataset version should be provided in attached mode"

dataset = catalog.get_dataset(dataset_name)

try:
target_dataset = catalog.get_dataset(save_as)
except DatasetNotFoundError:
target_dataset = None

if target_dataset:
dataset = catalog.register_dataset(dataset, version, target_dataset)
else:
dataset = catalog.register_new_dataset(dataset, version, save_as)

dataset_query = DatasetQuery(
name=dataset.name,
version=dataset.latest_version,
catalog=catalog,
)
else:
dataset_query = dataset_query.save(save_as)
elif save and (is_session_temp_dataset or not dataset_query.attached):
if save and (is_session_temp_dataset or not dataset_query.attached):
name = catalog.generate_query_dataset_name()
dataset_query = dataset_query.save(name)

Expand Down
15 changes: 0 additions & 15 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,21 +963,6 @@ def test_query_fail_to_compile(cloud_test_catalog):
catalog.query(query_script)


def test_query_fail_wrong_dataset_name(cloud_test_catalog):
catalog = cloud_test_catalog.catalog

query_script = """\
from datachain.query import DatasetQuery
DatasetQuery("s3://bucket-name")
"""
query_script = dedent(query_script)

with pytest.raises(
ValueError, match="Cannot use ds_query_ prefix for dataset name"
):
catalog.query(query_script, save_as="ds_query_dataset")


def test_query_subprocess_wrong_return_code(mock_popen, cloud_test_catalog):
mock_popen.configure_mock(returncode=1)
catalog = cloud_test_catalog.catalog
Expand Down
29 changes: 12 additions & 17 deletions tests/func/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,22 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath,
from datachain import C
from datachain.sql.functions.path import name

DatasetQuery({src_uri!r}, catalog=catalog).mutate(name=name(C.path))
DatasetQuery({src_uri!r}, catalog=catalog).mutate(name=name(C.path)).save("my-ds")
"""
query_script = setup_catalog(query_script, catalog_info_filepath)

filepath = tmp_path / "query_script.py"
filepath.write_text(query_script)

ds_name = "my-dataset"
query(catalog, str(filepath), ds_name, columns=["name"])
query(catalog, str(filepath), columns=["name"])
captured = capsys.readouterr()

header, *rows = captured.out.splitlines()
assert header.strip() == "name"
name_rows = {row.split()[1] for row in rows}
assert name_rows == {"cat1", "cat2", "description", "dog1", "dog2", "dog3", "dog4"}

dataset = catalog.get_dataset(ds_name)
dataset = catalog.get_dataset("my-ds")
assert dataset
result_job_id = dataset.get_version(dataset.latest_version).job_id
assert result_job_id
Expand Down Expand Up @@ -153,7 +152,7 @@ def test_query_cli_no_dataset_returned(
QueryScriptRunError,
match="Last line in a script was not an instance of DataChain",
):
query(catalog, str(filepath), "my-dataset", columns=["name"])
query(catalog, str(filepath), columns=["name"])

latest_job = get_latest_job(catalog.metastore)
assert latest_job
Expand All @@ -166,17 +165,12 @@ def test_query_cli_no_dataset_returned(


@pytest.mark.parametrize(
"save,save_as",
(
(True, None),
(None, "my-dataset"),
(True, "my-dataset"),
),
"save",
(True, False),
)
@pytest.mark.parametrize("save_dataset", (None, "new-dataset"))
def test_query(
save,
save_as,
save_dataset,
cloud_test_catalog_tmpfile,
tmp_path,
Expand All @@ -194,11 +188,12 @@ def test_query(
"""
query_script = setup_catalog(query_script, catalog_info_filepath)

result = catalog.query(query_script, save=save, save_as=save_as)
if save_as:
assert result.dataset.name == save_as
assert catalog.get_dataset(save_as)
elif save_dataset:
result = catalog.query(query_script, save=save)
if not save:
assert result.dataset is None
return
Comment on lines +192 to +194
Copy link
Member Author

@skshetry skshetry Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous test did not test with save=False and when save_as=None. So this scenario was never hit.

I am only testing this just to make parametrization easier. #360 will require adjusting these anyway.


if save_dataset:
assert result.dataset.name == save_dataset
assert catalog.get_dataset(save_dataset)
else:
Expand Down