diff --git a/pinyingpt_pretrain.py b/pinyingpt_pretrain.py index 153001d..51fd174 100644 --- a/pinyingpt_pretrain.py +++ b/pinyingpt_pretrain.py @@ -18,7 +18,6 @@ from transformers4ime.utils.misc import NoOp logger = logging.getLogger(__name__) -BUFSIZE = 40960000 def update_lr(optimizer, lr): @@ -49,7 +48,7 @@ def main(): parser = HfArgumentParser((MMModelArguments, MMDataTrainingArguments, MMTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if training_args.local_rank in [-1, 0]: + if training_args.process_index in [-1, 0]: from transformers4ime.utils.logger import TensorboardLogger TB_LOGGER = TensorboardLogger() TB_LOGGER.create(training_args.logging_dir) diff --git a/src/transformers4ime/data/loaders/text_pinyin.py b/src/transformers4ime/data/loaders/text_pinyin.py index 8800aba..3db9e4e 100644 --- a/src/transformers4ime/data/loaders/text_pinyin.py +++ b/src/transformers4ime/data/loaders/text_pinyin.py @@ -186,6 +186,9 @@ def build_sample(self, example): post_context_ids = in_context_ids[start_idx:end_idx] pinyin_ids = in_pinyin_ids[start_idx:end_idx] + if len(pinyin_ids) * 2 > self.max_len: + raise ValueError(f"Pinyin too long: {(start_idx, end_idx)}") + if len(pre_context_ids) > self.max_len - len(pinyin_ids) * 2: pre_context_start = len(pre_context_ids) - (self.max_len - len(pinyin_ids) * 2) pre_context_ids = pre_context_ids[pre_context_start:]