Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimizableBatch and stress relaxations #718

Merged
merged 126 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 118 commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
d8fd81e
remove r_edges, radius, max_neigh and add deprecation warning
lbluque Mar 21, 2024
81bf8a2
edit typing and dont use dicts as default
lbluque Mar 22, 2024
c2bb916
use super() and remove overkill deprecation warning
lbluque Mar 22, 2024
d11e72a
Merge branch 'main' of https://github.com/Open-Catalyst-Project/ocp i…
lbluque Apr 3, 2024
15610f8
Merge branch 'main' of https://github.com/Open-Catalyst-Project/ocp i…
lbluque Apr 19, 2024
479e7af
set implemented_properties from config
lbluque Apr 19, 2024
858a782
make determine step a method
lbluque Apr 23, 2024
06b232c
allow calculator to operate on batches
lbluque Apr 23, 2024
32030c3
only update if old config is used
lbluque Apr 24, 2024
1c1bea8
reshape properties
lbluque Apr 24, 2024
360656c
no test classes in ase calculator
lbluque Apr 24, 2024
b6c640e
yaml load fix
lbluque May 7, 2024
7f2746a
fix Subset of metadata
lbluque May 15, 2024
df7989e
use mappingproxy
lbluque May 15, 2024
9726df7
Merge branch 'main' into calculator_updates
lbluque May 16, 2024
b16d572
expressive import
lbluque May 16, 2024
6c636e5
Merge branch 'main' into calculator_updates
lbluque May 24, 2024
f5a358a
remove duplicated code
lbluque May 28, 2024
708efb0
optimizable batch class for ase compatible batch relaxations
lbluque May 28, 2024
9d44c3b
fix optimizable batch
lbluque May 28, 2024
5d89cf2
optimizable goodies
lbluque May 28, 2024
cae838c
apply force constraints
lbluque May 29, 2024
74a5347
use optimizable batch instead and remove torchcalc
lbluque May 29, 2024
55f9205
update ml relaxations to use optimizable batch correctly
lbluque May 29, 2024
163428b
force_consistent check for ASE compat
lbluque May 29, 2024
9d48f08
force_consistent check for ASE compat
lbluque May 29, 2024
7034d20
check force_consistent
lbluque May 29, 2024
100e768
init docs in lbfgs
lbluque May 29, 2024
8b919d3
unitcellfilter for batch relaxations
lbluque Jun 1, 2024
3a70b71
ruff
lbluque Jun 1, 2024
7d7686e
UnitCellOptimizable as child class instead of filter
lbluque Jun 2, 2024
e8fbb88
allow running unit cell relaxations
lbluque Jun 2, 2024
5a9b389
ruff
lbluque Jun 2, 2024
37ffe6a
no grad in run_relaxations
lbluque Jun 2, 2024
0b887fc
make batched_dot and determine_step methods
lbluque Jun 2, 2024
4a2c6c5
imports
lbluque Jun 3, 2024
e7ab81b
rename to optimizableunitcellbatch
lbluque Jun 3, 2024
dde4e58
allow passing energy and forces explicitly to batch to atoms
lbluque Jun 4, 2024
28f8155
check convergence in optimizable and allow passing general results to…
lbluque Jun 4, 2024
ff08208
relaxation test
lbluque Jun 4, 2024
609d0dc
unit tests
lbluque Jun 4, 2024
d09c041
move update mask to optimizable
lbluque Jun 5, 2024
c9f5632
use energy instead of y
lbluque Jun 5, 2024
02403aa
all setting/getting positions and convergence in optimizable
lbluque Jun 6, 2024
bcaa4ad
more (unfinished) tests
lbluque Jun 6, 2024
71a0ab6
Merge branch 'main' into stress-relaxations
lbluque Jun 6, 2024
008bb1f
backwards compatible test
lbluque Jun 6, 2024
155c515
minor fixes
lbluque Jun 7, 2024
34ca2be
code cleanup
lbluque Jun 8, 2024
20a2da8
add/fix tests
lbluque Jun 18, 2024
bc422d1
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into str…
lbluque Jun 20, 2024
693602d
fix lbfgs
lbluque Jun 21, 2024
06b3a6b
assert using norm
lbluque Jun 21, 2024
8ab940a
Merge branch 'main' into stress-relaxations
lbluque Jun 24, 2024
5636f36
Merge branch 'main' into stress-relaxations
lbluque Jun 25, 2024
3e6b39f
Merge branch 'main' into stress-relaxations
lbluque Jun 26, 2024
464f9b1
Merge pull request #749 from FAIR-Chem/main
zulissimeta Jul 7, 2024
aed6b31
Merge branch 'main' into stress-relaxations
lbluque Jul 9, 2024
9a81fff
add eps to masked batches if using ASE optimizers
lbluque Jul 10, 2024
a455d76
match iterations from previous implementation
lbluque Jul 11, 2024
95a0496
Merge branch 'stress-relaxations' of https://github.com/FAIR-Chem/fai…
lbluque Jul 11, 2024
1b71182
use float64 for forces
lbluque Jul 11, 2024
316398b
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into str…
lbluque Jul 12, 2024
879754a
float32
lbluque Jul 12, 2024
d7f926c
use energy_relaxed instead of y_relaxed
lbluque Jul 14, 2024
d8506ee
energy_relaxed and more explicit error msg
lbluque Jul 14, 2024
6753942
default to batch_size 1 if not set in config
lbluque Jul 14, 2024
9f9c4a0
merge with upstream
lbluque Jul 14, 2024
638404c
keep float64 training
lbluque Jul 15, 2024
ac78372
rename y_relaxed -> energy_relaxed
lbluque Jul 26, 2024
63470ba
Merge remote-tracking branch 'origin/main' into stress-relaxations
lbluque Jul 26, 2024
5cdd560
Merge branch 'main' into stress-relaxations
lbluque Aug 2, 2024
3578ceb
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into str…
lbluque Aug 7, 2024
63408a3
rm expcell batch
lbluque Aug 7, 2024
adf55f0
convenience commit from no_experimental_resolve
zulissimeta Aug 7, 2024
5e027d8
use numatoms tensor for cell factor
lbluque Aug 9, 2024
3ccb173
remove positions tests (wrapping atoms gives different results)
lbluque Aug 9, 2024
cab1fee
allow wrapping positions in batch to atoms
lbluque Aug 9, 2024
50465e9
Merge branch 'main' into stress-relaxations
lbluque Aug 9, 2024
bf619b7
fix test
lbluque Aug 9, 2024
ee6b021
wrap_positions in batch_to_atoms
lbluque Aug 15, 2024
23067c7
take a2g properties from model
lbluque Aug 15, 2024
8768dbb
test lbfgs traj writes
lbluque Aug 15, 2024
c7debdb
remove comments
lbluque Aug 15, 2024
00554c7
use model generate graph
lbluque Aug 15, 2024
a81d639
fix cell_factor
lbluque Aug 16, 2024
08f9f11
fix using model in ddp
lbluque Aug 16, 2024
649444d
Merge branch 'main' into stress-relaxations
lbluque Aug 19, 2024
8d028c8
Merge branch 'main' into stress-relaxations
lbluque Aug 20, 2024
c693829
fix r_edges in OCPcalculator
lbluque Aug 22, 2024
4a51006
write initial and final structure if save_full is false
lbluque Aug 22, 2024
2e118ea
Merge branch 'main' into stress-relaxations
lbluque Aug 22, 2024
88b5a4b
Merge branch 'stress-relaxations' of https://github.com/FAIR-Chem/fai…
lbluque Aug 26, 2024
9474860
check unique atoms saved in trajectory
lbluque Aug 26, 2024
55ddaab
tighter tol
lbluque Aug 26, 2024
054ef59
update ASE release comment
lbluque Aug 26, 2024
a539f72
remove cumulative mask option
lbluque Aug 26, 2024
6e5eda3
remove left over cumulative_mask
lbluque Aug 28, 2024
10b313a
fix batching when sids as str
lbluque Aug 29, 2024
88386bb
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into str…
lbluque Sep 5, 2024
25339aa
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into str…
lbluque Sep 9, 2024
91286f2
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into str…
lbluque Sep 17, 2024
b30c06e
do not try to fetch energy and forces if no explicit results
lbluque Sep 20, 2024
4ce57b0
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into str…
lbluque Sep 20, 2024
5fcbd42
accept Path objects
lbluque Sep 20, 2024
7c6244f
clean up setting defaults
lbluque Sep 20, 2024
efcb0af
expose ml_relax in relaxation
lbluque Sep 20, 2024
20a0319
force set r_pbc True
lbluque Sep 20, 2024
831e3b4
make relax_opt optional
lbluque Sep 21, 2024
134db29
Merge branch 'main' into stress-relaxations
lbluque Sep 30, 2024
d015f98
no ema on inference only
lbluque Oct 2, 2024
9696e86
define ema none to avoid issues
lbluque Oct 4, 2024
c97c4d0
lower force threshold to make sure test does not converge
lbluque Oct 16, 2024
c3b7b2b
clean up exception msg
lbluque Oct 16, 2024
e783bf6
Merge branch 'main' into stress-relaxations
lbluque Nov 14, 2024
209a1ed
allow strings in batch
lbluque Nov 18, 2024
2311511
remove device argument from lbfgs
lbluque Nov 18, 2024
23d38ff
minor cleanup
lbluque Nov 18, 2024
4329063
fix optimizable import
zulissimeta Nov 22, 2024
aea2694
Merge branch 'main' into stress-relaxations
lbluque Nov 26, 2024
25f2bba
do not pass device in ml_relax
lbluque Dec 2, 2024
b3fd1bd
simplify enforce max neighbors
lbluque Dec 2, 2024
547534d
fix tests (still not testing stress)
lbluque Dec 2, 2024
6842a0f
Merge branch 'main' into stress-relaxations
lbluque Dec 2, 2024
2f69937
pin sphinx autoapi
zulissimeta Dec 3, 2024
097757d
typo in version
zulissimeta Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/fairchem/core/common/relaxation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

from .ml_relaxation import ml_relax
from .optimizers.optimizable import OptimizableBatch, OptimizableUnitCellBatch
lbluque marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"]
108 changes: 77 additions & 31 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import copy
import logging
from typing import ClassVar
from types import MappingProxyType
from typing import TYPE_CHECKING

import torch
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.singlepoint import SinglePointCalculator as sp
from ase.calculators.singlepoint import SinglePointCalculator
from ase.constraints import FixAtoms
from ase.geometry import wrap_positions

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import (
Expand All @@ -33,51 +35,93 @@
from fairchem.core.models.model_registry import model_name_to_local_file
from fairchem.core.preprocessing import AtomsToGraphs

if TYPE_CHECKING:
from pathlib import Path

def batch_to_atoms(batch):
from torch_geometric.data import Batch


# system level model predictions have different shapes than expected by ASE
ASE_PROP_RESHAPE = MappingProxyType(
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little strange to have these hard-coded. What should a user do implementing new properties?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is strange, but it's the simplest solution I thought of. We can always generalized this to simply "tensor of rank X" but in that case we need to have a data structure that provides that information for each model output.

We could use the properties interface and properties defined in ASE to clean this up, but I would suggest doing that in a new PR:
https://gitlab.com/ase/ase/-/blob/master/ase/outputs.py

)


def batch_to_atoms(
batch: Batch,
results: dict[str, torch.Tensor] | None = None,
wrap_pos: bool = True,
eps: float = 1e-7,
) -> list[Atoms]:
"""Convert a data batch to ase Atoms

Args:
batch: data batch
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
are given no calculator will be added to the atoms objects.
wrap_pos: wrap positions back into the cell.
eps: Small number to prevent slightly negative coordinates from being wrapped.

Returns:
list of Atoms
"""
n_systems = batch.natoms.shape[0]
natoms = batch.natoms.tolist()
numbers = torch.split(batch.atomic_numbers, natoms)
fixed = torch.split(batch.fixed.to(torch.bool), natoms)
forces = torch.split(batch.force, natoms)
if results is not None:
results = {
key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist()
lbluque marked this conversation as resolved.
Show resolved Hide resolved
if len(val) == len(batch)
else [v.cpu().detach().numpy() for v in torch.split(val, natoms)]
for key, val in results.items()
}

positions = torch.split(batch.pos, natoms)
tags = torch.split(batch.tags, natoms)
cells = batch.cell
energies = batch.energy.view(-1).tolist()

atoms_objects = []
for idx in range(n_systems):
pos = positions[idx].cpu().detach().numpy()
cell = cells[idx].cpu().detach().numpy()

# TODO take pbc from data
if wrap_pos:
pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)
lbluque marked this conversation as resolved.
Show resolved Hide resolved

atoms = Atoms(
numbers=numbers[idx].tolist(),
positions=positions[idx].cpu().detach().numpy(),
cell=cell,
positions=pos,
tags=tags[idx].tolist(),
cell=cells[idx].cpu().detach().numpy(),
constraint=FixAtoms(mask=fixed[idx].tolist()),
pbc=[True, True, True],
lbluque marked this conversation as resolved.
Show resolved Hide resolved
)
calc = sp(
atoms=atoms,
energy=energies[idx],
forces=forces[idx].cpu().detach().numpy(),
)
atoms.set_calculator(calc)

if results is not None:
calc = SinglePointCalculator(
atoms=atoms, **{key: val[idx] for key, val in results.items()}
)
atoms.set_calculator(calc)

atoms_objects.append(atoms)

return atoms_objects


class OCPCalculator(Calculator):
implemented_properties: ClassVar[list[str]] = ["energy", "forces"]
"""ASE based calculator using an OCP model"""

_reshaped_props = ASE_PROP_RESHAPE
lbluque marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
lbluque marked this conversation as resolved.
Show resolved Hide resolved
self,
config_yml: str | None = None,
checkpoint_path: str | None = None,
checkpoint_path: str | Path | None = None,
model_name: str | None = None,
local_cache: str | None = None,
trainer: str | None = None,
cutoff: int = 6,
max_neighbors: int = 50,
cpu: bool = True,
seed: int | None = None,
) -> None:
Expand All @@ -96,16 +140,12 @@ def __init__(
Directory to save pretrained model checkpoints.
trainer (str):
OCP trainer to be used. "forces" for S2EF, "energy" for IS2RE.
cutoff (int):
Cutoff radius to be used for data preprocessing.
max_neighbors (int):
Maximum amount of neighbors to store for a given atom.
cpu (bool):
Whether to load and run the model on CPU. Set `False` for GPU.
"""
setup_imports()
setup_logging()
Calculator.__init__(self)
super().__init__()

if model_name is not None:
if checkpoint_path is not None:
Expand Down Expand Up @@ -165,9 +205,8 @@ def __init__(
### backwards compatability with OCP v<2.0
config = update_config(config)

# Save config so obj can be transported over network (pkl)
self.config = copy.deepcopy(config)
self.config["checkpoint"] = checkpoint_path
self.config["checkpoint"] = str(checkpoint_path)
del config["dataset"]["src"]

self.trainer = registry.get_trainer_class(config["trainer"])(
Expand Down Expand Up @@ -199,14 +238,13 @@ def __init__(
self.trainer.set_seed(seed)

self.a2g = AtomsToGraphs(
max_neigh=max_neighbors,
radius=cutoff,
r_energy=False,
r_forces=False,
r_distances=False,
r_edges=False,
lbluque marked this conversation as resolved.
Show resolved Hide resolved
r_pbc=True,
r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model...
)
self.implemented_properties = list(self.config["outputs"].keys())

def load_checkpoint(
self, checkpoint_path: str, checkpoint: dict | None = None
Expand All @@ -217,6 +255,8 @@ def load_checkpoint(
Args:
checkpoint_path: string
Path to trained model
checkpoint: dict
A pretrained checkpoint dict
"""
try:
self.trainer.load_checkpoint(
Expand All @@ -225,14 +265,20 @@ def load_checkpoint(
except NotImplementedError:
logging.warning("Unable to load checkpoint!")

def calculate(self, atoms: Atoms, properties, system_changes) -> None:
Calculator.calculate(self, atoms, properties, system_changes)
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
def calculate(self, atoms: Atoms | Batch, properties, system_changes) -> None:
"""Calculate implemented properties for a single Atoms object or a Batch of them."""
super().calculate(atoms, properties, system_changes)
if isinstance(atoms, Atoms):
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
else:
batch = atoms

predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True)

for key in predictions:
_pred = predictions[key]
_pred = _pred.item() if _pred.numel() == 1 else _pred.cpu().numpy()
if key in OCPCalculator._reshaped_props:
lbluque marked this conversation as resolved.
Show resolved Hide resolved
_pred = _pred.reshape(OCPCalculator._reshaped_props.get(key)).squeeze()
self.results[key] = _pred
117 changes: 79 additions & 38 deletions src/fairchem/core/common/relaxation/ml_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,85 +10,126 @@
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING

import torch
from torch_geometric.data import Batch

from fairchem.core.common.typing import assert_is_instance
from fairchem.core.datasets.lmdb_dataset import data_list_collater

from .optimizers.lbfgs_torch import LBFGS, TorchCalc
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch
from .optimizers.lbfgs_torch import LBFGS

if TYPE_CHECKING:
from fairchem.core.trainers import BaseTrainer


def ml_relax(
batch,
model,
batch: Batch,
model: BaseTrainer,
steps: int,
fmax: float,
relax_opt,
save_full_traj,
device: str = "cuda:0",
transform=None,
early_stop_batch: bool = False,
relax_opt: dict[str] | None = None,
relax_cell: bool = False,
relax_volume: bool = False,
save_full_traj: bool = True,
transform: torch.nn.Module | None = None,
mask_converged: bool = True,
):
"""
Runs ML-based relaxations.
"""Runs ML-based relaxations.

Args:
batch: object
model: object
steps: int
Max number of steps in the structure relaxation.
fmax: float
Structure relaxation terminates when the max force
of the system is no bigger than fmax.
relax_opt: str
Optimizer and corresponding parameters to be used for structure relaxations.
save_full_traj: bool
Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
batch: a data batch object.
model: a trainer object with model.
steps: Max number of steps in the structure relaxation.
fmax: Structure relaxation terminates when the max force of the system is no bigger than fmax.
relax_opt: Optimizer parameters to be used for structure relaxations.
relax_cell: if true will use stress predictions to relax crystallographic cell.
The model given must predict stress
relax_volume: if true will relax the cell isotropically. the given model must predict stress.
save_full_traj: Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
mask_converged: whether to mask batches where all atoms are below convergence threshold
cumulative_mask: if true, once system is masked then it remains masked even if new predictions give forces
above threshold, ie. once masked always masked. Note if this is used make sure to check convergence with
the same fmax always
"""
relax_opt = relax_opt or {}
# if not pbc is set, ignore it when comparing batches
if not hasattr(batch, "pbc"):
OptimizableBatch.ignored_changes = {"pbc"}

batches = deque([batch])
relaxed_batches = []
while batches:
batch = batches.popleft()
oom = False
ids = batch.sid
calc = TorchCalc(model, transform)

# clone the batch otherwise you can not run batch.to_data_list
# see https://github.com/pyg-team/pytorch_geometric/issues/8439#issuecomment-1826747915
if relax_cell or relax_volume:
optimizable = OptimizableUnitCellBatch(
batch.clone(),
trainer=model,
transform=transform,
mask_converged=mask_converged,
hydrostatic_strain=relax_volume,
)
else:
optimizable = OptimizableBatch(
batch.clone(),
trainer=model,
transform=transform,
mask_converged=mask_converged,
)

# Run ML-based relaxation
traj_dir = relax_opt.get("traj_dir", None)
traj_dir = relax_opt.get("traj_dir")
relax_opt.update({"traj_dir": Path(traj_dir) if traj_dir is not None else None})

optimizer = LBFGS(
batch,
calc,
maxstep=relax_opt.get("maxstep", 0.2),
memory=relax_opt["memory"],
damping=relax_opt.get("damping", 1.2),
alpha=relax_opt.get("alpha", 80.0),
device=device,
optimizable_batch=optimizable,
save_full_traj=save_full_traj,
traj_dir=Path(traj_dir) if traj_dir is not None else None,
traj_names=ids,
early_stop_batch=early_stop_batch,
**relax_opt,
)

e: RuntimeError | None = None
try:
relaxed_batch = optimizer.run(fmax=fmax, steps=steps)
relaxed_batches.append(relaxed_batch)
optimizer.run(fmax=fmax, steps=steps)
relaxed_batches.append(optimizable.batch)
except RuntimeError as err:
e = err
oom = True
torch.cuda.empty_cache()

if oom:
# move OOM recovery code outside of except clause to allow tensors to be freed.
# move OOM recovery code outside off except clause to allow tensors to be freed.
data_list = batch.to_data_list()
if len(data_list) == 1:
raise assert_is_instance(e, RuntimeError)
logging.info(
f"Failed to relax batch with size: {len(data_list)}, splitting into two..."
)
mid = len(data_list) // 2
batches.appendleft(data_list_collater(data_list[:mid]))
batches.appendleft(data_list_collater(data_list[mid:]))
batches.appendleft(
data_list_collater(data_list[:mid], otf_graph=optimizable.otf_graph)
)
batches.appendleft(
data_list_collater(data_list[mid:], otf_graph=optimizable.otf_graph)
)

# reset for good measure
OptimizableBatch.ignored_changes = {}

relaxed_batch = Batch.from_data_list(relaxed_batches)

# Batch.from_data_list is not intended to be used with a list of batches, so when sid is a list of str
# it will be incorrectly collated as a list of lists for each batch.
# but we can not use to_data_list in the relaxed batches (since they have been changed, see linked comment above).
# So instead just manually fix it for now. Remove this once pyg dependency is removed
if isinstance(relaxed_batch.sid, list):
relaxed_batch.sid = [sid for sid_list in relaxed_batch.sid for sid in sid_list]

return Batch.from_data_list(relaxed_batches)
return relaxed_batch
Loading
Loading