Skip to content

Commit

Permalink
add eval data description, make valid data optional
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer authored Nov 30, 2023
1 parent 54909cf commit 48ab11a
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,29 @@
@dataclass
class Args:
model_path: str
# eval data in the format:
# {
# "<lang_code>": {
# "sentence": {
# "<dataset_name>": {
# "meta": {
# "train_data": ["train sentence 1", "train sentence 2"]
# },
# "data": ["test sentence 1", "test sentence 2"]
# }
# }
# }
# }
eval_data_path: str = "data/eval_new.pth"
valid_text_path: str = "data/sentence/valid.parquet"
valid_text_path: str = None#"data/sentence/valid.parquet"
device: str = "cuda"
block_size: int = 512
stride: int = 64
batch_size: int = 32
include_langs: List[str] = None


def load_or_compute_logits(args, model, eval_data, valid_data, max_n_train_sentences=10_000):
def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000):
logits_path = Constants.CACHE_DIR / (model.config.mixture_name + "_logits.h5")

with h5py.File(logits_path, "a") as f, torch.no_grad():
Expand All @@ -41,7 +54,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data, max_n_train_sente
lang_group = f[lang_code]

# valid data
if "valid" not in lang_group:
if valid_data is not None and "valid" not in lang_group:
sentences = [sample["text"].strip() for sample in valid_data if sample["lang"] == lang_code]
assert len(sentences) > 0

Expand Down Expand Up @@ -110,8 +123,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data, max_n_train_sente
(args,) = HfArgumentParser([Args]).parse_args_into_dataclasses()

eval_data = torch.load(args.eval_data_path)
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")

if args.valid_text_path is not None:
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
else:
valid_data = None

model = AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device)

# first, logits for everything.
Expand Down

0 comments on commit 48ab11a

Please sign in to comment.