diff --git a/litgpt/data/text_files.py b/litgpt/data/text_files.py index eca65bf535..0333dd0ae3 100644 --- a/litgpt/data/text_files.py +++ b/litgpt/data/text_files.py @@ -36,11 +36,6 @@ class TextFiles(DataModule): batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) - def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: - self.tokenizer = tokenizer - self.batch_size = batch_size - self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well - def __post_init__(self) -> None: self.out_path_train = self.train_data_path / "train" if self.val_data_path is None: @@ -48,6 +43,11 @@ def __post_init__(self) -> None: else: self.out_path_val = Path(self.val_data_path) / "val" + def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + self.tokenizer = tokenizer + self.batch_size = batch_size + self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + def prepare_data(self) -> None: from litdata import optimize