Skip to content

Commit

Permalink
initial push
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 19, 2023
1 parent 7fccec7 commit b637cf3
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ external
*.egg-info
notebooks
final_outputs
.cache*
.cache*
data_subset/**
24 changes: 24 additions & 0 deletions commands.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# .bashrc
export PATH=$PATH:~/.local/bin

export XRT_TPU_CONFIG="localservice;0;localhost:51011"

export XLA_USE_BF16=0

export TPU_NUM_DEVICES=8

export HF_DATASETS_CACHE=/dev/shm/cache

# data
gcloud auth login
gsutil -m cp -r gs://trc-transfer-data/sentence/data/eval.pth data/

# cleanup
pkill -e python3
(until no more)
or
watch -n1 pkill -e python3

# for debugging:

os.environ["PJRT_DEVICE"] = "None"
11 changes: 11 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
transformers==4.29.2
accelerate==0.19.0
datasets
pysbd
wandb
h5py
nltk
spacy
ersatz
iso-639
scikit-learn==1.2.2
1 change: 1 addition & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 $HOME/transformers/examples/pytorch/xla_spawn.py --num_cores ${TPU_NUM_DEVICES} wtpsplit/train/train.py $1
11 changes: 10 additions & 1 deletion wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import subprocess
import unicodedata
import os

import numpy as np
import regex as re
Expand Down Expand Up @@ -69,7 +70,7 @@ def train_mixture(lang_code, original_train_x, train_y, n_subsample=None, featur

train_x = train_x.float()

clf = linear_model.LogisticRegression(max_iter=10_000)
clf = linear_model.LogisticRegression(max_iter=10_000, random_state=0)
clf.fit(train_x, train_y)
preds = clf.predict_proba(train_x)[:, 1]

Expand Down Expand Up @@ -229,6 +230,14 @@ def ersatz_sentencize(
):
if lang_code not in ERSATZ_LANGUAGES:
raise LanguageError(f"ersatz does not support {lang_code}")

# check if infile parent dir exists, if not, create it
if not os.path.exists(os.path.dirname(infile)):
os.makedirs(os.path.dirname(infile))
# check if outfile parent dir exists, if not, create it
if not os.path.exists(os.path.dirname(outfile)):
os.makedirs(os.path.dirname(outfile))

open(infile, "w").write(text)

subprocess.check_output(
Expand Down
6 changes: 4 additions & 2 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class Args:
# }
# }
# }
eval_data_path: str = "data/eval_new.pth"
eval_data_path: str = "data/eval.pth"
valid_text_path: str = None#"data/sentence/valid.parquet"
device: str = "cuda"
device: str = "xla:1"
block_size: int = 512
stride: int = 64
batch_size: int = 32
Expand Down Expand Up @@ -70,6 +70,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
block_size=args.block_size,
batch_size=args.batch_size,
pad_last_batch=True,
verbose=True,
)[0]
lang_group.create_dataset("valid", data=valid_logits)

Expand All @@ -92,6 +93,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
block_size=args.block_size,
batch_size=args.batch_size,
pad_last_batch=True,
verbose=True,
)[0]
test_labels = get_labels(lang_code, test_sentences, after_space=False)

Expand Down
1 change: 1 addition & 0 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def evaluate_sentence(
stride=stride,
block_size=block_size,
batch_size=batch_size,
verbose=True,
)[0]

true_end_indices = np.cumsum(np.array([len(s) for s in sentences])) + np.arange(len(sentences)) * len(separator)
Expand Down

0 comments on commit b637cf3

Please sign in to comment.