From 2219c2d926aeb468a089292a721dde62d25af78f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 8 Apr 2024 08:46:15 +0100 Subject: [PATCH] move post_init to top --- litgpt/data/text_files.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/litgpt/data/text_files.py b/litgpt/data/text_files.py index eca65bf535..0333dd0ae3 100644 --- a/litgpt/data/text_files.py +++ b/litgpt/data/text_files.py @@ -36,11 +36,6 @@ class TextFiles(DataModule): batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) - def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: - self.tokenizer = tokenizer - self.batch_size = batch_size - self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well - def __post_init__(self) -> None: self.out_path_train = self.train_data_path / "train" if self.val_data_path is None: @@ -48,6 +43,11 @@ def __post_init__(self) -> None: else: self.out_path_val = Path(self.val_data_path) / "val" + def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + self.tokenizer = tokenizer + self.batch_size = batch_size + self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + def prepare_data(self) -> None: from litdata import optimize