Skip to content

Commit

Permalink
Fix the serialization of partial functions in nemo 2.0 (#9668)
Browse files Browse the repository at this point in the history
* fix serialization of partial function

* update serialization to handle value.args

Signed-off-by: srabhi <[email protected]>

* add unit test

Signed-off-by: srabhi <[email protected]>

* remove redundant code from unit-test

Signed-off-by: srabhi <[email protected]>

---------

Signed-off-by: srabhi <[email protected]>
  • Loading branch information
sararb authored Jul 17, 2024
1 parent af4f0ed commit f65fea2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
4 changes: 4 additions & 0 deletions nemo/lightning/io/fdl_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"""

import types
from functools import partial

import fiddle as fdl
import libcst as cst
import torch
import torch.nn as nn
Expand Down Expand Up @@ -110,6 +112,8 @@ def enable():
def _modified_serialize(self, value, current_path, all_paths=None):
if isinstance(value, types.BuiltinFunctionType):
return self._pyref(value, current_path)
if isinstance(value, partial):
value = fdl.Partial(value.func, *value.args, **value.keywords)
return self._original_serialize(value, current_path, all_paths)

serialization.Serialization._original_serialize = serialization.Serialization._serialize
Expand Down
19 changes: 17 additions & 2 deletions tests/lightning/io/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from functools import partial

import pytest
import transformer_engine as te
from pytorch_lightning.loggers import TensorBoardLogger

Expand All @@ -7,8 +10,17 @@
from nemo.lightning import io


def dummy_extra(a, b, c=5):
return a + b + c


@pytest.fixture
def partial_function_with_pos_and_key_args():
return partial(dummy_extra, 10, c=15)


class TestLoad:
def test_reload_ckpt(self, tmpdir):
def test_reload_ckpt(self, tmpdir, partial_function_with_pos_and_key_args):
trainer = nl.Trainer(
devices=1,
accelerator="cpu",
Expand All @@ -26,10 +38,13 @@ def test_reload_ckpt(self, tmpdir):
tokenizer=tokenizer,
)

ckpt = io.TrainerContext(model, trainer)
ckpt = io.TrainerContext(model, trainer, extra={"dummy": partial_function_with_pos_and_key_args})
ckpt.io_dump(tmpdir)
loaded = io.load_context(tmpdir)

assert loaded.model.config.seq_length == ckpt.model.config.seq_length
assert loaded.model.__io__.tokenizer.vocab_file.startswith(str(tmpdir))
assert loaded.model.__io__.tokenizer.merges_file.startswith(str(tmpdir))

loaded_func = loaded.extra["dummy"]
assert loaded_func(b=2) == partial_function_with_pos_and_key_args(b=2)

0 comments on commit f65fea2

Please sign in to comment.