diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index e547b5cb..8268ca3f 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -66,8 +66,8 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: mod._parameters[attr] = value # type: ignore[assignment] elif hasattr(mod, '_buffers') and attr in mod._buffers: mod._buffers[attr] = value - elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator] - mod._meta_parameters[attr] = value # type: ignore[operator,index] + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: + mod._meta_parameters[attr] = value else: setattr(mod, attr, value) # pylint: enable=protected-access