Skip to content

Commit

Permalink
inital commit on adding blenddataset for neva
Browse files Browse the repository at this point in the history
Signed-off-by: Vivian Chen <[email protected]>
  • Loading branch information
Vivian Chen committed Aug 1, 2024
1 parent 425d5dd commit 6150584
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 10 deletions.
15 changes: 10 additions & 5 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,8 @@ def __len__(self):
return len(self.list_data_dict)

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if isinstance(i, np.integer):
i = int(i)
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
Expand Down Expand Up @@ -1186,7 +1188,6 @@ class NevaDataset(LazySupervisedDataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, data_path: str, tokenizer, multimodal_cfg: dict, data_cfg: dict):

if data_path.endswith(".json"):
super(NevaDataset, self).__init__(data_path, tokenizer, multimodal_cfg, data_cfg)

Expand Down Expand Up @@ -1309,18 +1310,22 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
return batch


def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict:
def make_supervised_data_module(tokenizer, image_processor, model_cfg, data_file) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
data_cfg = model_cfg.data
mm_cfg = model_cfg.mm_cfg
add_extra_token = 1
if getattr(model_cfg, 'no_seqlen_plus_one_input_tokens', False):
add_extra_token = 0
crop_size = mm_cfg.vision_encoder.get("crop_size", (224, 224))

if not data_cfg.get("data_path"):
data_path = data_file
else:
data_path = data_cfg.data_path
# use blend
train_dataset = NevaDataset(
tokenizer=tokenizer,
data_path=data_cfg.data_path,
data_path=data_path,
multimodal_cfg=dict(
is_multimodal=data_cfg.is_multimodal,
sep_image_conv_front=data_cfg.sep_image_conv_front,
Expand Down Expand Up @@ -1349,7 +1354,7 @@ def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict:
)

return dict(train_dataset=train_dataset, eval_dataset=train_dataset)


class NevaPackedSeqDatatset(Dataset):
def __init__(self, data_path: str, crop_size: Tuple[int, int] = (224, 224)):
Expand Down
103 changes: 98 additions & 5 deletions nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
from einops import rearrange, reduce, repeat
from omegaconf.dictconfig import DictConfig
from omegaconf import DictConfig, ListConfig
from pkg_resources import packaging
from pytorch_lightning.trainer.trainer import Trainer
from transformers import CLIPVisionModel, SiglipVisionModel
Expand All @@ -33,6 +34,10 @@
NevaPackedSeqDatatset,
make_supervised_data_module,
)
from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
get_datasets_weights_and_num_samples,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import (
CLIPVisionTransformer,
MegatronCLIPModel,
Expand Down Expand Up @@ -473,7 +478,7 @@ def create_vision_encoder_and_processor(self, mm_cfg):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
if config.architectures[0] == "CLIPVisionModel" or config.architectures[0] == "CLIPModel":
vision_encoder = CLIPVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand All @@ -483,7 +488,7 @@ def create_vision_encoder_and_processor(self, mm_cfg):
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
elif config.architectures[0] == "SiglipVisionModel":
elif config.architectures[0] == "SiglipVisionModel" or config.architectures[0] == "SiglipModel":
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -1205,7 +1210,8 @@ def setup(self, stage=None):
else:
# TODO: consider adding a ModelPT guard to check if model is being restored.
# allowing restored models to optionally setup datasets
self.build_train_valid_test_datasets()
#self.build_train_valid_test_datasets()
self.build_train_valid_test_datasets_blend()
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)
Expand All @@ -1224,6 +1230,93 @@ def setup(self, stage=None):

if self.cfg.get('transformer_engine', False):
self.setup_transformer_engine_tp_groups()

def build_train_valid_test_datasets_blend(self):
logging.info('Building Blending Neva datasets.')

train_datasets = []
valid_datasets = []

data_cfg = self.cfg.data
is_packed_sequence = data_cfg.get("packed_sequence", False)

if is_packed_sequence:
assert self.cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence"

# Check if concat_sampling_probabilities is properly set
if data_cfg.get('concat_sampling_probabilities') is None or not isinstance(
data_cfg.concat_sampling_probabilities, ListConfig
):
raise ValueError("concat_sampling_probabilities must be a ListConfig with the same number of entries as data_file_names.")

if len(data_cfg.concat_sampling_probabilities) != len(data_cfg.data_file_names):
raise ValueError(
f"concat_sampling_probabilities must be of the same size as data_file_names. "
f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.data_file_names)}"
)

for data_file in data_cfg.data_file_names:
if is_packed_sequence:
train_dataset = NevaPackedSeqDatatset(
data_file, self.cfg.mm_cfg.vision_encoder.get("crop_size")
)
valid_dataset = NevaPackedSeqDatatset(
data_file, self.cfg.mm_cfg.vision_encoder.get("crop_size")
)
else:
ds_dict = make_supervised_data_module(
tokenizer=self.tokenizer,
image_processor=(
self.model.module.image_processor if hasattr(self.model, "module") else self.model.image_processor
),
model_cfg=self.cfg,
data_file=data_file,
)
train_dataset = ds_dict["train_dataset"]
valid_dataset = ds_dict["eval_dataset"]

train_datasets.append(train_dataset)
valid_datasets.append(valid_dataset)

# Create BlendableDataset for training
if self.trainer.max_steps is None or self.trainer.max_steps <= 0:
raise ValueError(f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}')

num_train_samples = self.trainer.max_steps * data_cfg.global_batch_size
_, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples(
data_prefix=[weight for pair in zip(data_cfg.concat_sampling_probabilities, data_cfg.data_file_names) for weight in pair],
num_samples=[num_train_samples]
)
num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset])

logging.info(f"Number of train datasets: {len(train_datasets)}")
logging.info(f"Lengths of train datasets: {[len(ds) for ds in train_datasets]}")
logging.info(f"concat_sampling_probabilities: {data_cfg.concat_sampling_probabilities}")
logging.info(f"num_train_samples_after_blend: {num_train_samples_after_blend}")

self._train_ds = BlendableDataset(
datasets=train_datasets,
weights=data_cfg.concat_sampling_probabilities,
size=num_train_samples_after_blend
)

self._validation_ds = BlendableDataset(
datasets=valid_datasets,
weights=data_cfg.concat_sampling_probabilities,
size=num_train_samples_after_blend
)


logging.info(f'Length of train dataset: {len(self._train_ds)}')
logging.info(f'Length of validation dataset: {len(self._validation_ds)}')


######### Use ConcatDataset instead of BlendableDataset##########
# self._train_ds = ConcatDataset(train_datasets)
# self._validation_ds = ConcatDataset(valid_datasets)
##################################################################

return self._train_ds, self._validation_ds

def build_train_valid_test_datasets(self):
logging.info('Building Neva datasets.')
Expand Down Expand Up @@ -1286,7 +1379,7 @@ def build_pretraining_data_loader(
raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"')
else:
raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"')

collate_func = DataCollatorForSupervisedDataset(self.cfg, self.tokenizer)
return torch.utils.data.DataLoader(
dataset,
Expand Down

0 comments on commit 6150584

Please sign in to comment.