diff --git a/gpt2_pretrain.py b/gpt2_pretrain.py index 5bc61a4..1a45f95 100644 --- a/gpt2_pretrain.py +++ b/gpt2_pretrain.py @@ -47,7 +47,7 @@ def main(): parser = HfArgumentParser((MMModelArguments, MMDataTrainingArguments, MMTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if training_args.local_rank in [-1, 0]: + if training_args.process_index in [-1, 0]: from transformers4ime.utils.logger import TensorboardLogger TB_LOGGER = TensorboardLogger() TB_LOGGER.create(training_args.logging_dir) diff --git a/ime_logo.png b/ime_logo.png index 454b2d1..bae8c9b 100755 Binary files a/ime_logo.png and b/ime_logo.png differ diff --git a/src/transformers4ime/data/loaders/text_only.py b/src/transformers4ime/data/loaders/text_only.py index 14d4834..be2cc2c 100644 --- a/src/transformers4ime/data/loaders/text_only.py +++ b/src/transformers4ime/data/loaders/text_only.py @@ -9,6 +9,7 @@ import webdataset as wds from more_itertools import unzip from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader from transformers import PretrainedConfig from transformers4ime.data.arguments import MMModelArguments, MMTrainingArguments, MMDataTrainingArguments @@ -25,10 +26,10 @@ def __init__(self, tokenizer, model_args: MMModelArguments, training_args: MMTra super().__init__(tokenizer, model_args, training_args, data_args, config) self.shards = self.get_shards(self.data_args.train_text_only_files) self.batch_size = self.training_args.text_only_per_device_train_batch_size - self.max_len = self.data_args.text_only_block_size - self.n_ctx = self.config.n_ctx + self.n_ctx = self.data_args.text_only_block_size def build_sample(self, example): + example = example['json'] context_ids = self.tokenizer(example['content'], add_special_tokens=False)['input_ids'] if len(context_ids) + 2 > self.n_ctx: context_ids = context_ids[:self.n_ctx - 2] @@ -68,20 +69,17 @@ def collate_fn(inputs): 'labels': label_ids} return batch - def convert_to_features(self, data): - d_lists = collections.defaultdict(list) - # truth_ids, input_ids, segment_ids, attention_masks = [], [], [], [] - inputs = [self.build_sample(item) for item in data] - return self.collate_fn(inputs) + def wrap_build_sample(self, example): + try: + return self.build_sample(example) + except Exception as e: + logger.warning([e, example]) + return [None] * 4 def __iter__(self): assert len(self.shards) >= self.training_args.world_size # guarantee at least one shard for each device logging.info(f"Constructing data loader for image text: {len(self.shards)}") - dataset = ( - wds.WebDataset(self.shards) - .shuffle(1000) - .decode() - .to_tuple("json", ) - ) - for d, in dataset.batched(self.batch_size): - yield self.convert_to_features(d) + dataset = wds.WebDataset(self.shards).shuffle(1000).decode().map(self.wrap_build_sample) + for batch in DataLoader(dataset, num_workers=8, batch_size=self.batch_size, + collate_fn=self.collate_fn): + yield batch