Skip to content

Commit

Permalink
fix data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
Vimos committed Mar 18, 2022
1 parent 2666df5 commit 5470932
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 1 addition & 2 deletions pinyingpt_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from transformers4ime.utils.misc import NoOp

logger = logging.getLogger(__name__)
BUFSIZE = 40960000


def update_lr(optimizer, lr):
Expand Down Expand Up @@ -49,7 +48,7 @@ def main():
parser = HfArgumentParser((MMModelArguments, MMDataTrainingArguments, MMTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if training_args.local_rank in [-1, 0]:
if training_args.process_index in [-1, 0]:
from transformers4ime.utils.logger import TensorboardLogger
TB_LOGGER = TensorboardLogger()
TB_LOGGER.create(training_args.logging_dir)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers4ime/data/loaders/text_pinyin.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def build_sample(self, example):
post_context_ids = in_context_ids[start_idx:end_idx]
pinyin_ids = in_pinyin_ids[start_idx:end_idx]

if len(pinyin_ids) * 2 > self.max_len:
raise ValueError(f"Pinyin too long: {(start_idx, end_idx)}")

if len(pre_context_ids) > self.max_len - len(pinyin_ids) * 2:
pre_context_start = len(pre_context_ids) - (self.max_len - len(pinyin_ids) * 2)
pre_context_ids = pre_context_ids[pre_context_start:]
Expand Down

0 comments on commit 5470932

Please sign in to comment.