diff --git a/python/metatensor-operations/metatensor/operations/slice.py b/python/metatensor-operations/metatensor/operations/slice.py index c402fc6c2..75930ecb8 100644 --- a/python/metatensor-operations/metatensor/operations/slice.py +++ b/python/metatensor-operations/metatensor/operations/slice.py @@ -12,22 +12,15 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock: if axis == "samples": - # only keep the same names as `labels` - all_samples = block.samples.view(labels.names) - # create an arrays of bools indicating which samples indices to keep - samples_mask = _dispatch.bool_array_like( - [all_samples.entry(i) in labels for i in range(len(all_samples))], - block.values, - ) - new_values = _dispatch.mask(block.values, 0, samples_mask) - new_samples = Labels( - block.samples.names, - _dispatch.mask(block.samples.values, 0, samples_mask), - ) + selected = block.samples.select(labels) + + bool_array = _dispatch.bool_array_like([], block.properties.values) + mask = _dispatch.zeros_like(bool_array, [len(block.samples)]) + mask[selected] = True new_block = TensorBlock( - values=new_values, - samples=new_samples, + values=block.values[selected], + samples=Labels(block.samples.names, block.samples.values[selected]), components=block.components, properties=block.properties, ) @@ -38,11 +31,11 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock: # sample_map contains at position old_sample the index of the # corresponding new sample sample_map = _dispatch.int_array_like( - int_list=[-1] * len(samples_mask), - like=samples_mask, + int_list=[-1] * len(block.samples), + like=block.samples.values, ) last = 0 - for i, picked in enumerate(samples_mask): + for i, picked in enumerate(mask): if picked: sample_map[i] = last last += 1 @@ -53,7 +46,7 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock: sample_column = gradient.samples.column("sample") if not isinstance(gradient.samples.values, TorchTensor) and isinstance( - samples_mask, TorchTensor + mask, TorchTensor ): # Torch complains if `sample_column` is numpy since it tries to convert # it to a Tensor, but the numpy array is read-only. Making a copy @@ -61,7 +54,7 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock: sample_column = sample_column.copy() # Create a samples filter for the Gradient TensorBlock - grad_samples_mask = samples_mask[_dispatch.to_index_array(sample_column)] + grad_samples_mask = mask[_dispatch.to_index_array(sample_column)] new_grad_samples_values = _dispatch.mask( gradient.samples.values, 0, grad_samples_mask @@ -100,20 +93,14 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock: else: assert axis == "properties" - # only keep the same names as `labels` - all_properties = block.properties.view(list(labels.names)) - # create an arrays of bools indicating which samples indices to keep - properties_mask = _dispatch.bool_array_like( - [all_properties.entry(i) in labels for i in range(len(all_properties))], - block.values, - ) - new_values = _dispatch.mask( - block.values, len(block.values.shape) - 1, properties_mask - ) - new_properties = Labels( - block.properties.names, - _dispatch.mask(block.properties.values, 0, properties_mask), - ) + + selected = block.properties.select(labels) + bool_array = _dispatch.bool_array_like([], block.properties.values) + mask = _dispatch.zeros_like(bool_array, [len(block.properties)]) + mask[selected] = True + + new_values = _dispatch.mask(block.values, len(block.values.shape) - 1, mask) + new_properties = Labels(block.properties.names, block.properties.values[mask]) new_block = TensorBlock( values=new_values, @@ -129,18 +116,16 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock: assert axis == "properties" new_grad_values = _dispatch.mask( - gradient.values, len(gradient.values.shape) - 1, properties_mask + gradient.values, len(gradient.values.shape) - 1, mask ) - new_grad_samples = gradient.samples - # Add sliced gradient to the TensorBlock new_block.add_gradient( parameter=parameter, gradient=TensorBlock( values=new_grad_values, - samples=new_grad_samples, + samples=gradient.samples, components=gradient.components, - properties=new_block.properties, + properties=new_properties, ), )