From 05dda4cbc0d4df8fe1cf01e9ccb9734634f99717 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Apr 2024 22:10:47 +0200 Subject: [PATCH] fix bug --- litgpt/data/json_data.py | 5 +++++ tests/data/test_json.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/litgpt/data/json_data.py b/litgpt/data/json_data.py index 541678b93f..a40096486d 100644 --- a/litgpt/data/json_data.py +++ b/litgpt/data/json_data.py @@ -42,6 +42,11 @@ class JSON(DataModule): val_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self): + if self.json_path.is_file() and self.val_split_fraction is None: + raise ValueError( + "If `json_path` is a file, you must set `val_split_fraction` to a value between 0 and 1 to split the" + " data into train and validation sets." + ) if self.json_path.is_dir() and self.val_split_fraction is not None: raise ValueError( "If `json_path` is a directory, it must contain 'train.json' and 'val.json' files and" diff --git a/tests/data/test_json.py b/tests/data/test_json.py index 10d246e892..100c74302b 100644 --- a/tests/data/test_json.py +++ b/tests/data/test_json.py @@ -81,6 +81,9 @@ def test_json_input_validation(tmp_path): with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"): data.setup() + with pytest.raises(ValueError, match="you must set `val_split_fraction` to a value between 0 and 1"): + JSON(tmp_path / "train.json", val_split_fraction=None) + @pytest.mark.parametrize("as_jsonl", [False, True]) def test_json_with_splits(as_jsonl, tmp_path, mock_tokenizer):