Skip to content

Commit

Permalink
modify Neva to work with video neva inference
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanzic committed Apr 19, 2024
1 parent b3a3b30 commit 4873956
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

def get_forward_output_only_func(self):
def fwd_output_only_func(dataloader_iter, model):
batch, _, _ = next(dataloader_iter)
batch = next(dataloader_iter)
extra_arg = {}
(
tokens,
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def dummy():


def create_neva_model_and_processor(cfg):
from nemo.collections.multimodal.models.neva.neva_model import MegatronNevaModel
from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel

plugins = []
if cfg.get('cluster_type', None) == 'BCP':
Expand Down Expand Up @@ -388,6 +388,7 @@ def create_neva_model_and_processor(cfg):
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.expert_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
Expand All @@ -402,6 +403,9 @@ def create_neva_model_and_processor(cfg):
checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name))
# TODO: This wont work properly (We need to set model.llm.from_pretrained model.vision.from_pretrained to nul)
model = MegatronNevaModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer)
neva_cfg = OmegaConf.load(cfg.hparams_file)
neva_cfg = neva_cfg.cfg

else:
raise ValueError("need at least a nemo file or checkpoint dir")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def __init__(self, model):
image_processor=None,
add_extra_token=add_extra_token,
context_length=self.cfg.encoder_seq_length,
media_type=self.data_cfg.media_type
)

def clip_max_len(self, maxlen: int) -> int:
Expand Down

0 comments on commit 4873956

Please sign in to comment.