Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: xuanzic <[email protected]>
  • Loading branch information
xuanzic committed Aug 2, 2024
1 parent 002d8f9 commit 1ca19ac
Showing 1 changed file with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1327,19 +1327,23 @@ def build_train_valid_test_datasets_blend(self):

def build_train_valid_test_datasets(self):
logging.info('Building Neva datasets.')

if isinstance(self.cfg.data.data_path, (list, ListConfig)):
if len(self.cfg.data.data_path) > 1:
# Only consider data blending if there are multiple dataset paths
if self.cfg.data.get('concat_sampling_probabilities') is None:
logging.warning("No sampling probabilities provided. Defaulting to uniform sampling.")
self.cfg.data.concat_sampling_probabilities = [1 / len(self.cfg.data.data_path)] * len(self.cfg.data.data_path)
self.cfg.data.concat_sampling_probabilities = [1 / len(self.cfg.data.data_path)] * len(
self.cfg.data.data_path
)
elif sum(self.cfg.data.concat_sampling_probabilities) != 1:
raise ValueError("Concat_sampling_probabilities must sum up to 1.")
return self.build_train_valid_test_datasets_blend()
elif len(self.cfg.data.data_path) == 1:
elif len(self.cfg.data.data_path) == 1:
if self.cfg.data.concat_sampling_probabilities is not None:
logging.warning("Using sampling probabilities with a single dataset has no effect. Defaulting to None and not using blend dataset.")
logging.warning(
"Using sampling probabilities with a single dataset has no effect. Defaulting to None and not using blend dataset."
)
self.cfg.data.concat_sampling_probabilities = None
self.cfg.data.data_path = self.cfg.data.data_path[0]
else:
Expand All @@ -1351,7 +1355,7 @@ def build_train_valid_test_datasets(self):

if self.cfg.data.get("packed_sequence", False):
assert self.cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence"

self._train_ds = NevaPackedSeqDatatset(
self.cfg.data.data_path, self.cfg.mm_cfg.vision_encoder.get("crop_size")
)
Expand Down

0 comments on commit 1ca19ac

Please sign in to comment.