Skip to content

Commit

Permalink
Preserve points feature_key in queries (#794)
Browse files Browse the repository at this point in the history
* Preserve points feature_key during queries

* add PR number to changelog

* fix docs; add sphinx-autobuild dep

---------

Co-authored-by: Luca Marconato <[email protected]>
  • Loading branch information
quentinblampey and LucaMarconato authored Nov 25, 2024
1 parent 94f0a31 commit 93615b2
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning][].
### Fixed

- Updated deprecated default stages of `pre-commit` #771
- Preserve points `feature_key` during queries #794

## [0.2.5] - 2024-06-11

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dev = [
]
docs = [
"sphinx>=4.5",
"sphinx-autobuild",
"sphinx-book-theme>=1.0.0",
"myst-nb",
"sphinxcontrib-bibtex>=1.0.0",
Expand Down
12 changes: 9 additions & 3 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
points_geopandas_to_dask_dataframe,
)
from spatialdata.models._utils import ValidAxis_t, get_spatial_axes
from spatialdata.models.models import ATTRS_KEY
from spatialdata.transformations.operations import set_transformation
from spatialdata.transformations.transformations import (
Affine,
Expand Down Expand Up @@ -712,9 +713,13 @@ def _(
points_df = p.compute().iloc[bounding_box_indices]
old_transformations = get_transformation(p, get_all=True)
assert isinstance(old_transformations, dict)
feature_key = p.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)

output.append(
PointsModel.parse(
dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy()
dd.from_pandas(points_df, npartitions=1),
transformations=old_transformations.copy(),
feature_key=feature_key,
)
)
if len(output) == 0:
Expand Down Expand Up @@ -925,10 +930,11 @@ def _(
queried_points = points_gdf.loc[joined["index_right"]]
ddf = points_geopandas_to_dask_dataframe(queried_points, suppress_z_warning=True)
transformation = get_transformation(points, target_coordinate_system)
feature_key = points.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)
if "z" in ddf.columns:
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"}, feature_key=feature_key)
else:
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"})
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"}, feature_key=feature_key)
set_transformation(ddf, transformation, target_coordinate_system)
t = get_transformation(ddf, get_all=True)
assert isinstance(t, dict)
Expand Down
3 changes: 1 addition & 2 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,7 @@ def set_channel_names(element: DataArray | DataTree, channel_names: str | list[s
Returns
-------
element
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
"""
from spatialdata.models import Image2DModel, Image3DModel, get_model

Expand Down
7 changes: 5 additions & 2 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,8 +736,11 @@ def _(
elif isinstance(data, dd.DataFrame): # type: ignore[attr-defined]
table = data[[coordinates[ax] for ax in axes]]
table.columns = axes
if feature_key is not None and data[feature_key].dtype.name != "category":
table[feature_key] = data[feature_key].astype(str).astype("category")
if feature_key is not None:
if data[feature_key].dtype.name == "category":
table[feature_key] = data[feature_key]
else:
table[feature_key] = data[feature_key].astype(str).astype("category")
if instance_key is not None:
table[instance_key] = data[instance_key]
for c in [X, Y, Z]:
Expand Down
16 changes: 16 additions & 0 deletions tests/core/query/test_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ShapesModel,
TableModel,
)
from spatialdata.models.models import ATTRS_KEY
from spatialdata.testing import assert_spatial_data_objects_are_identical
from spatialdata.transformations import Identity, MapAxis, set_transformation
from tests.conftest import _make_points, _make_squares
Expand Down Expand Up @@ -205,6 +206,21 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, mul
if is_3d:
np.testing.assert_allclose(points_element["z"].compute(), original_z)

# the feature_key should be preserved
if not multiple_boxes:
assert (
points_result.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
== points_element.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
)
else:
for result in points_result:
if result is None:
continue
assert (
result.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
== points_element.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
)


def test_query_points_no_points():
"""Points bounding box query with no points in range should
Expand Down

0 comments on commit 93615b2

Please sign in to comment.