diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 07793440..4bbad3d5 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -18,8 +18,21 @@ @dataclass class Args: model_path: str + # eval data in the format: + # { + # "": { + # "sentence": { + # "": { + # "meta": { + # "train_data": ["train sentence 1", "train sentence 2"] + # }, + # "data": ["test sentence 1", "test sentence 2"] + # } + # } + # } + # } eval_data_path: str = "data/eval_new.pth" - valid_text_path: str = "data/sentence/valid.parquet" + valid_text_path: str = None#"data/sentence/valid.parquet" device: str = "cuda" block_size: int = 512 stride: int = 64 @@ -27,7 +40,7 @@ class Args: include_langs: List[str] = None -def load_or_compute_logits(args, model, eval_data, valid_data, max_n_train_sentences=10_000): +def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000): logits_path = Constants.CACHE_DIR / (model.config.mixture_name + "_logits.h5") with h5py.File(logits_path, "a") as f, torch.no_grad(): @@ -41,7 +54,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data, max_n_train_sente lang_group = f[lang_code] # valid data - if "valid" not in lang_group: + if valid_data is not None and "valid" not in lang_group: sentences = [sample["text"].strip() for sample in valid_data if sample["lang"] == lang_code] assert len(sentences) > 0 @@ -110,8 +123,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data, max_n_train_sente (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses() eval_data = torch.load(args.eval_data_path) - valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train") - + if args.valid_text_path is not None: + valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train") + else: + valid_data = None + model = AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device) # first, logits for everything.