Skip to content

Commit

Permalink
add logo
Browse files Browse the repository at this point in the history
  • Loading branch information
Vimos committed Mar 19, 2022
1 parent 5470932 commit 9250afe
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
2 changes: 1 addition & 1 deletion gpt2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main():
parser = HfArgumentParser((MMModelArguments, MMDataTrainingArguments, MMTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if training_args.local_rank in [-1, 0]:
if training_args.process_index in [-1, 0]:
from transformers4ime.utils.logger import TensorboardLogger
TB_LOGGER = TensorboardLogger()
TB_LOGGER.create(training_args.logging_dir)
Expand Down
Binary file modified ime_logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 13 additions & 15 deletions src/transformers4ime/data/loaders/text_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import webdataset as wds
from more_itertools import unzip
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import PretrainedConfig

from transformers4ime.data.arguments import MMModelArguments, MMTrainingArguments, MMDataTrainingArguments
Expand All @@ -25,10 +26,10 @@ def __init__(self, tokenizer, model_args: MMModelArguments, training_args: MMTra
super().__init__(tokenizer, model_args, training_args, data_args, config)
self.shards = self.get_shards(self.data_args.train_text_only_files)
self.batch_size = self.training_args.text_only_per_device_train_batch_size
self.max_len = self.data_args.text_only_block_size
self.n_ctx = self.config.n_ctx
self.n_ctx = self.data_args.text_only_block_size

def build_sample(self, example):
example = example['json']
context_ids = self.tokenizer(example['content'], add_special_tokens=False)['input_ids']
if len(context_ids) + 2 > self.n_ctx:
context_ids = context_ids[:self.n_ctx - 2]
Expand Down Expand Up @@ -68,20 +69,17 @@ def collate_fn(inputs):
'labels': label_ids}
return batch

def convert_to_features(self, data):
d_lists = collections.defaultdict(list)
# truth_ids, input_ids, segment_ids, attention_masks = [], [], [], []
inputs = [self.build_sample(item) for item in data]
return self.collate_fn(inputs)
def wrap_build_sample(self, example):
try:
return self.build_sample(example)
except Exception as e:
logger.warning([e, example])
return [None] * 4

def __iter__(self):
assert len(self.shards) >= self.training_args.world_size # guarantee at least one shard for each device
logging.info(f"Constructing data loader for image text: {len(self.shards)}")
dataset = (
wds.WebDataset(self.shards)
.shuffle(1000)
.decode()
.to_tuple("json", )
)
for d, in dataset.batched(self.batch_size):
yield self.convert_to_features(d)
dataset = wds.WebDataset(self.shards).shuffle(1000).decode().map(self.wrap_build_sample)
for batch in DataLoader(dataset, num_workers=8, batch_size=self.batch_size,
collate_fn=self.collate_fn):
yield batch

0 comments on commit 9250afe

Please sign in to comment.