Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 27, 2023
1 parent c5d9f66 commit e5251c2
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 29 deletions.
4 changes: 3 additions & 1 deletion configs/xlmr_stratify_0.1_3layers.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,7 @@
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "info"
}
3 changes: 2 additions & 1 deletion configs/xlmr_stratify_0.1_3layers_nounks.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr.txt"
"custom_punctuation_file": "punctuation_xlmr.txt",
"log_level": "info"
}
17 changes: 15 additions & 2 deletions utils/remove_unks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def get_subword_label_dict(label_args, tokenizer):

n_unks = 0
# Map auxiliary characters to token IDs with labels
for i, c in enumerate(label_args.auxiliary_chars):
for i, c in enumerate(Constants.PUNCTUATION_CHARS):
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.AUX_OFFSET + i
# TODO: remove UNKs?
Expand All @@ -33,7 +33,7 @@ def get_subword_label_dict(label_args, tokenizer):
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})

label_dict = get_subword_label_dict(LabelArgs(custom_punctuation_file='punctuation_xlmr.txt'), tokenizer)
label_dict = get_subword_label_dict(LabelArgs(), tokenizer)
print(len(label_dict))

def write_punctuation_file():
Expand All @@ -42,8 +42,21 @@ def write_punctuation_file():
token_id = tokenizer.convert_tokens_to_ids(char)
if token_id != tokenizer.unk_token_id:
file.write(char + '\n')

def write_punctuation_file_unk():
added_unk = False
with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr_unk.txt"), 'w', encoding='utf-8') as file:
for char in Constants.PUNCTUATION_CHARS:
token_id = tokenizer.convert_tokens_to_ids(char)
if token_id != tokenizer.unk_token_id:
file.write(char + '\n')
elif not added_unk:
print("added unk")
file.write('<unk>\n')
added_unk = True

write_punctuation_file()
write_punctuation_file_unk()

label_args_default = LabelArgs()
print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS))
Expand Down
109 changes: 109 additions & 0 deletions wtpsplit/data/punctuation_xlmr_unk.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~
¡
£
¤
§
<unk>
©
«
¬
®
°
±
·
»
¿
÷
՛
՝
՞
։
־
׳
،
؛
؟
۔
💘
27 changes: 7 additions & 20 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,6 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
tokenizer=tokenizer if args.use_subwords else None,
)

if input_ids[0] != tokenizer.cls_token_id:
logger.warn(input_ids)
logger.warn(len(input_ids))
logger.warn(tokenizer.cls_token_id)
# raise ValueError("CLS token not first token")
if input_ids[-1] != tokenizer.sep_token_id:
logger.warn(input_ids)
logger.warn(len(input_ids))
logger.warn(tokenizer.sep_token_id)
# raise ValueError("SEP token not last token")

if len(input_ids) > args.block_size:
if tokenizer:
# always include CLS
Expand All @@ -264,7 +253,7 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
# always include SEP
if input_ids[-1] != tokenizer.sep_token_id:
# also insert PAD token as long as len < block_size
while len(input_ids) <= args.block_size - 1:
while len(input_ids) < args.block_size - 1:
input_ids = input_ids + [tokenizer.pad_token_id]
labels = labels + [0]
input_ids = input_ids + [tokenizer.sep_token_id]
Expand All @@ -273,34 +262,33 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
start = np.random.randint(0, len(input_ids) - args.block_size)
input_ids = input_ids[start : start + args.block_size]
labels = labels[start : start + args.block_size]
elif len(input_ids) != args.block_size and args.use_subwords:
if len(input_ids) != args.block_size and tokenizer:
del input_ids[-1]
del labels[-1]
while len(input_ids) <= args.block_size - 1:
while len(input_ids) < args.block_size - 1:
# insert pad token at second-to-last position
logger.debug("second", len(input_ids))
logger.warn("second", len(input_ids))
input_ids = input_ids + [tokenizer.pad_token_id]
labels = labels + [0]
input_ids = input_ids + [tokenizer.sep_token_id]
labels = labels + [0]

if len(input_ids) != args.block_size:
logger.warn(len(input_ids))
input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long)
labels = torch.tensor(labels[: args.block_size], dtype=torch.long)
if input_ids[-1] != tokenizer.sep_token_id:
logger.warn(input_ids)
logger.warn(tokenizer.sep_token_id)
logger.warn(labels)
# raise ValueError("SEP token not last token")
if input_ids[0] != tokenizer.cls_token_id:
logger.warn(input_ids)
logger.warn(tokenizer.cls_token_id)
logger.warn(labels)
# raise ValueError("CLS token not first token")
# FIXME: check this - why does it occur in train split?
if (input_ids == tokenizer.cls_token_id).sum() != 1:
logger.warn(input_ids)
logger.warn(tokenizer.cls_token_id)
logger.warn(labels)
# raise ValueError("CLS token not unique")

position_ids = torch.arange(len(input_ids), dtype=torch.long)
Expand Down Expand Up @@ -657,7 +645,7 @@ def maybe_pad(text):
num_workers=args.preprocessing_num_workers,
include_languages=args.include_languages,
shuffle=args.shuffle,
split="train",
split="valid",
)
logger.info(f"Train dataset has {len(train_dataset)} examples.")

Expand All @@ -671,7 +659,6 @@ def maybe_pad(text):
logger.info(f"Sample {index} of the training set: {sample}.")
if tokenizer:
logger.info(tokenizer.decode(sample["input_ids"]))
logger.info()
count += 1

# dataset we use is in cached now
Expand Down
11 changes: 6 additions & 5 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __post_init__(self):
def get_label_dict(label_args):
label_dict = {}

for i, c in enumerate(label_args.auxiliary_chars):
for i, c in enumerate(Constants.PUNCTUATION_CHARS):
label_dict[ord(c)] = 1 + Constants.AUX_OFFSET + i

for c in label_args.newline_chars:
Expand All @@ -99,10 +99,11 @@ def get_subword_label_dict(label_args, tokenizer):

n_unks = 0
# Map auxiliary characters to token IDs with labels
for i, c in enumerate(label_args.auxiliary_chars):
logger.warn(f"Using {Constants.PUNCTUATION_CHARS} auxiliary characters.")
for i, c in enumerate(Constants.PUNCTUATION_CHARS):
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.AUX_OFFSET + i
logger.info(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}")
logger.warn(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}")
if token_id == tokenizer.unk_token_id:
n_unks += 1

Expand All @@ -112,8 +113,8 @@ def get_subword_label_dict(label_args, tokenizer):
for c in label_args.newline_chars:
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.NEWLINE_INDEX
logger.info(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:")
logger.info(r"{}".format(tokenizer.decode([token_id])))
logger.warn(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:")
logger.warn(r"{}".format(tokenizer.decode([token_id])))

return label_dict

Expand Down

0 comments on commit e5251c2

Please sign in to comment.