Skip to content

Commit

Permalink
Use selection in slice/slice_block
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Aug 15, 2024
1 parent 269af49 commit de42f94
Showing 1 changed file with 23 additions and 38 deletions.
61 changes: 23 additions & 38 deletions python/metatensor-operations/metatensor/operations/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,15 @@

def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock:
if axis == "samples":
# only keep the same names as `labels`
all_samples = block.samples.view(labels.names)
# create an arrays of bools indicating which samples indices to keep
samples_mask = _dispatch.bool_array_like(
[all_samples.entry(i) in labels for i in range(len(all_samples))],
block.values,
)
new_values = _dispatch.mask(block.values, 0, samples_mask)
new_samples = Labels(
block.samples.names,
_dispatch.mask(block.samples.values, 0, samples_mask),
)
selected = block.samples.select(labels)

bool_array = _dispatch.bool_array_like([], block.properties.values)
mask = _dispatch.zeros_like(bool_array, [len(block.samples)])
mask[selected] = True

new_block = TensorBlock(
values=new_values,
samples=new_samples,
values=block.values[selected],
samples=Labels(block.samples.names, block.samples.values[selected]),
components=block.components,
properties=block.properties,
)
Expand All @@ -38,11 +31,11 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock:
# sample_map contains at position old_sample the index of the
# corresponding new sample
sample_map = _dispatch.int_array_like(
int_list=[-1] * len(samples_mask),
like=samples_mask,
int_list=[-1] * len(block.samples),
like=block.samples.values,
)
last = 0
for i, picked in enumerate(samples_mask):
for i, picked in enumerate(mask):
if picked:
sample_map[i] = last
last += 1
Expand All @@ -53,15 +46,15 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock:

sample_column = gradient.samples.column("sample")
if not isinstance(gradient.samples.values, TorchTensor) and isinstance(
samples_mask, TorchTensor
mask, TorchTensor
):
# Torch complains if `sample_column` is numpy since it tries to convert
# it to a Tensor, but the numpy array is read-only. Making a copy
# removes the read-only marker
sample_column = sample_column.copy()

# Create a samples filter for the Gradient TensorBlock
grad_samples_mask = samples_mask[_dispatch.to_index_array(sample_column)]
grad_samples_mask = mask[_dispatch.to_index_array(sample_column)]

new_grad_samples_values = _dispatch.mask(
gradient.samples.values, 0, grad_samples_mask
Expand Down Expand Up @@ -100,20 +93,14 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock:

else:
assert axis == "properties"
# only keep the same names as `labels`
all_properties = block.properties.view(list(labels.names))
# create an arrays of bools indicating which samples indices to keep
properties_mask = _dispatch.bool_array_like(
[all_properties.entry(i) in labels for i in range(len(all_properties))],
block.values,
)
new_values = _dispatch.mask(
block.values, len(block.values.shape) - 1, properties_mask
)
new_properties = Labels(
block.properties.names,
_dispatch.mask(block.properties.values, 0, properties_mask),
)

selected = block.properties.select(labels)
bool_array = _dispatch.bool_array_like([], block.properties.values)
mask = _dispatch.zeros_like(bool_array, [len(block.properties)])
mask[selected] = True

new_values = _dispatch.mask(block.values, len(block.values.shape) - 1, mask)
new_properties = Labels(block.properties.names, block.properties.values[mask])

new_block = TensorBlock(
values=new_values,
Expand All @@ -129,18 +116,16 @@ def _slice_block(block: TensorBlock, axis: str, labels: Labels) -> TensorBlock:

assert axis == "properties"
new_grad_values = _dispatch.mask(
gradient.values, len(gradient.values.shape) - 1, properties_mask
gradient.values, len(gradient.values.shape) - 1, mask
)
new_grad_samples = gradient.samples

# Add sliced gradient to the TensorBlock
new_block.add_gradient(
parameter=parameter,
gradient=TensorBlock(
values=new_grad_values,
samples=new_grad_samples,
samples=gradient.samples,
components=gradient.components,
properties=new_block.properties,
properties=new_properties,
),
)

Expand Down

0 comments on commit de42f94

Please sign in to comment.