Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update adaptive tests #81

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions tests/_self-adaptive/contextpart.costs.expected
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
22.19
2.75
6.56
0.29
47.07
15.31
4.52
0.11
24.44
2.20
28.06
6.56
2.21881247
2.21806955
0.46859169
0.46825710
2.76864004
2.76791930
0.56501710
0.56451356
2.71565413
2.71488881
3.11743879
3.11609936
10 changes: 5 additions & 5 deletions tests/_self-adaptive/contextpart.expected
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
klicken und ziehen
die linke Maustaste für einen Rechtsklick gedrückt halten .
Sie können rechts@@ klicken , indem Sie die linke Maustaste gedrückt halten .
drücken Sie die linke Maustaste und klicken Sie mit der rechten Maustaste .
Sie können mit Rechtsklick die linke Maustaste gedrückt halten .
wechseln Sie &@@ lt@@ ; gu@@ i > Sim@@ ulated Secondary Click &@@ lt@@ ; / gu@@ i > weiter .
warum sollte ich meine E-Mail-@@ Konten oder sozialen Medien zu Ihrem Desktop hinzufügen ?
warum fügen Sie Ihre E-Mail oder Social Media auf Ihren Desktop ?
warum sollte ich ein Konto hinzufügen ?
Anwendungen entfernen , die Sie nicht mehr benötigen
entfernen Sie die Software , die Sie nicht mehr verwenden .
prüfen Sie Ihre Sicherung
die von Ihnen verwendete Software kann in der Regel recht schnell wieder hergestellt werden .
zurück eure wichtigen Akten
was ist die " Super " -Taste ?
was ist der " Super " -@@ Schlüssel ?
44 changes: 22 additions & 22 deletions tests/_self-adaptive/costs.expected
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
1.26
0.01
22.19
2.75
6.56
0.29
77.35
52.49
47.07
15.31
4.52
0.11
24.44
2.20
7.48
0.13
38.95
8.38
13.98
0.36
28.06
6.56
0.31558633
0.31486320
2.21881247
2.21806955
0.46859169
0.46825710
4.07114267
4.07016468
2.76864004
2.76791930
0.56501710
0.56451356
2.71565413
2.71488881
1.87033439
1.86898220
1.77039230
1.76994097
2.79638267
2.79453158
3.11743879
3.11609936
188 changes: 188 additions & 0 deletions tests/_self-adaptive/gen-costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#!/usr/bin/env python3

import argparse as ag
import os
import sys
import glob
import subprocess as sp
import re
import tempfile
import shutil
import dataclasses
from dataclasses import dataclass

@dataclass
class MarianTrainConfig:
marian_dir: str
model: str
model_type: str
vocab1: str
vocab2: str
dim_vocab1: int
dim_vocab2: int
dim_emb: int
epochs: int


def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)

def main():
parser = ag.ArgumentParser()
parser.add_argument('-t', '--train-sets',
type=ag.FileType('r', encoding='utf-8'), nargs=2, required=True)
parser.add_argument('-i', '--input', type=ag.FileType('r', encoding='utf-8'), required=True)
parser.add_argument('-m', '--model', required=True)
parser.add_argument('-v', '--vocabs', nargs=2, required=True)
parser.add_argument('--dim-vocabs', nargs=2, type=int, required=False)
parser.add_argument('--dim-emb', type=int, required=False)
parser.add_argument('-e', '--epochs', type=int, required=True)
parser.add_argument('--type', required=True)
parser.add_argument('--marian-dir', required=True)
parser.add_argument('--output-costs', type=ag.FileType('w', encoding='utf-8'))
parser.add_argument('--output-transl', type=ag.FileType('w', encoding='utf-8'))

args = parser.parse_args()

[sfile, tfile] = args.train_sets
[vocab1, vocab2] = args.vocabs
[dvocab1, dvocab2] = args.dim_vocabs if args.dim_vocabs is not None else [None, None]
config = MarianTrainConfig(args.marian_dir, args.model, args.type, vocab1,
vocab2, dvocab1, dvocab1, args.dim_emb, args.epochs)
eprint(config)
costs_and_translations = iterate_over_inputs(config, sfile, tfile, args.input)
output_costs_and_translations(costs_and_translations, args.output_costs, args.output_transl)


def training_file_generator(source, target):
begin_sentences = True
contains_sentences = False
for sline, tline in zip(source, target):
if begin_sentences:
sfile = tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False)
tfile = tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False)
eprint(f"Created temp files for training data: {sfile.name}, {tfile.name}")
begin_sentences = False
contains_sentences = False

if sline != "\n" or tline != "\n":
eprint(sline.rstrip())
eprint(tline.rstrip())
sfile.write(sline)
tfile.write(tline)
contains_sentences = True
else:
sfile.close()
tfile.close()
yield (contains_sentences, sfile, tfile)
begin_sentences = True

# The last non-empty set of sentences can not be delimited with an empty line
if contains_sentences:
sfile.close()
tfile.close()
yield (contains_sentences, sfile, tfile)

def iterate_over_inputs(config, source, target, inputs):
all_costs_and_translations = []
try:
for input_line, (contains_sentences, sfile, tfile) in zip(inputs, training_file_generator(source, target)):
if contains_sentences:
costs_and_translations = run_marian(sfile.name, tfile.name, input_line, config)
else:
eprint("No context provided, skipping training")
translations = translate_marian(input_line, config)
costs_and_translations = ([], translations)
all_costs_and_translations.append(costs_and_translations)
os.remove(sfile.name)
os.remove(tfile.name)
finally:
for f in [sfile, tfile]:
if f is not None and os.path.exists(f.name):
if not f.closed:
f.close()
os.remove(f.name)
eprint(f"ERROR: Needed cleanup for {f.name}")
return all_costs_and_translations

def run_marian(sfile, tfile, input_line, config):
c = config
temp_model_path = create_temp_model_copy(c.model)
new_config = dataclasses.replace(c, model=temp_model_path)
costs = train_marian(sfile, tfile, new_config)
for path in glob.glob(f"{temp_model_path}*"):
eprint(f"Removing model file: {path}")
os.remove(path)
translations = translate_marian(input_line, config)
return (costs, translations)

def create_temp_model_copy(model):
fd, path = tempfile.mkstemp(suffix='.npz')
eprint(f"Created temp file for model: {path}")
os.close(fd)
shutil.copyfile(model, path)
return path

def train_marian(sfile, tfile, config):
c = config

args = [f"{c.marian_dir}/marian", '-m', c.model, '--disp-freq', '1', '--optimizer', 'sgd', '--type', c.model_type, '-v',
c.vocab1, '-v', c.vocab2, '--after-epochs', str(c.epochs), '--mini-batch', '1', '-t', sfile, tfile]
if c.dim_emb is not None:
args += ['--dim-emb', str(c.dim_emb)]
if c.dim_vocab1 is not None and c.dim_vocab2 is not None:
args += ['--dim-vocabs', str(c.dim_vocab1), str(c.dim_vocab2)]
process = sp.run(args, capture_output=True, text=True)

eprint("STDOUT:")
eprint(process.stdout)
eprint("STDERR:")
eprint(process.stderr)
costs = extract_costs(process.stderr)
eprint("COSTS:")
eprint(costs)
return costs

def extract_costs(output_log):
p = re.compile('Ep\..* Cost ([-e0-9.]+) .*: Time')
costs = []
for line in output_log.splitlines():
m = p.search(line)
if m is not None:
costs.append(m.group(1))
return costs

def translate_marian(input_line, config):
c = config

args = [f"{c.marian_dir}/marian-decoder", '-m', c.model,
'--type', c.model_type, '-v', c.vocab1, '-v', c.vocab2]
if c.dim_emb is not None:
args += ['--dim-emb', str(c.dim_emb)]
if c.dim_vocab1 is not None and c.dim_vocab2 is not None:
args += ['--dim-vocabs', str(c.dim_vocab1), str(c.dim_vocab2)]
process = sp.run(args, input=input_line, capture_output=True, text=True)

eprint(f"Translate input: {input_line}")
eprint("STDOUT:")
eprint(process.stdout)
eprint("STDERR:")
eprint(process.stderr)
translations = process.stdout.splitlines()
eprint(translations)
return translations

def output_costs_and_translations(costs_and_translations, output_costs, output_transl):
eprint("COSTS AND TRANSLATIONS:")
eprint(costs_and_translations)

if output_costs is not None:
all_costs = [cost for costs, _ in costs_and_translations for cost in costs]
output_costs.writelines(map(lambda c: c + '\n', all_costs))
if output_transl is not None:
all_translations = [translation for _, translations in costs_and_translations for translation in translations]
output_transl.writelines(map(lambda t: t + '\n', all_translations))


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion tests/_self-adaptive/oracle.bleu.expected
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BLEU = 79.14, 96.0/90.9/85.7/81.8 (BP=0.895, ratio=0.900, hyp_len=99, ref_len=110)
BLEU = 28.81, 51.4/33.7/24.1/17.1 (BP=0.991, ratio=0.991, hyp_len=109, ref_len=110)
18 changes: 9 additions & 9 deletions tests/_self-adaptive/oracle.expected
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
klicken und ziehen
die linke Maustaste für einen Rechtsklick gedrückt halten .
Sie können rechts@@ klicken , indem Sie die linke Maustaste gedrückt halten .
aktivieren Sie &@@ lt@@ ; gu@@ i > .
warum sollte ich meine E-Mail-@@ Konten oder sozialen Medien zu Ihrem Desktop hinzufügen ?
drücken Sie die linke Maustaste und klicken Sie mit der rechten Maustaste .
Sie können mit Rechtsklick die linke Maustaste gedrückt halten .
wechseln Sie &@@ lt@@ ; gu@@ i > Sim@@ ulated Secondary Click &@@ lt@@ ; / gu@@ i > weiter .
warum fügen Sie Ihre E-Mail oder Social Media auf Ihren Desktop ?
warum sollte ich ein Konto hinzufügen ?
Anwendungen entfernen , die Sie nicht mehr benötigen
ihre Sicherung überprüfen
die Anwendungen , die Sie nutzen , können durch Neu@@ installation nach einem schwerwiegenden Rechner@@ problem meist schnell wiederhergestellt werden .
sichern Ihrer wichtigen Dateien
was ist die " Super " -Taste ?
entfernen Sie die Software , die Sie nicht mehr verwenden .
prüfen Sie Ihre Sicherung
die von Ihnen verwendete Software kann in der Regel recht schnell wieder hergestellt werden .
zurück eure wichtigen Akten
was ist der " Super " -@@ Schlüssel ?
65 changes: 65 additions & 0 deletions tests/_self-adaptive/regenerate-expected-outputs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/bin/bash
set -euo pipefail

MRT_MODELS=../../models
MRT_TOOLS=../../tools

MODELS=$MRT_MODELS/wmt16_systems/en-de

echo "### Generating files for the oracle tests"
./gen-costs.py \
-t ubuntu.oracle_2s1e.{src,ref} \
-m $MODELS/model.npz \
--type amun \
--dim-vocabs 85000 85000 \
--dim-emb 500 \
-v $MODELS/vocab.{en,de}.json \
-e 1 \
--marian-dir ~/prog/cpp/marian-adaptive/build/ \
-i ubuntu.src \
--output-costs costs.expected \
--output-transl oracle.expected

# Generate BLEU
$MRT_TOOLS/moses-scripts/scripts/generic/multi-bleu.perl -lc ubuntu.ref < oracle.expected > oracle.bleu.expected

echo -e "\n\n### Generating files for the partial context tests"
./gen-costs.py \
-t ubuntu.contextpart.{src,ref} \
-m $MODELS/model.npz \
--type amun \
--dim-vocabs 85000 85000 \
--dim-emb 500 \
-v $MODELS/vocab.{en,de}.json \
-e 1 \
--marian-dir ~/prog/cpp/marian-adaptive/build/ \
-i ubuntu.src \
--output-costs contextpart.costs.expected \
--output-transl contextpart.expected


echo -e "\n\n### Generating files for the no context tests"
./gen-costs.py \
-t ubuntu.nocontext.{src,ref} \
-m $MODELS/model.npz \
--type amun \
--dim-vocabs 85000 85000 \
--dim-emb 500 \
-v $MODELS/vocab.{en,de}.json \
-e 1 \
--marian-dir ~/prog/cpp/marian-adaptive/build/ \
-i ubuntu.src \
--output-transl nocontext.expected


echo -e "\n\n### Generating files for the transformer partial context tests"
./gen-costs.py \
-t ubuntu.contextpart.{src,ref} \
-m $MRT_MODELS/transformer/model.npz \
--type transformer \
-v $MRT_MODELS/transformer/vocab.ende.yml{,} \
-e 1 \
--marian-dir ~/prog/cpp/marian-adaptive/build/ \
-i ubuntu.src \
--output-costs transformer.contextpart.costs.expected \
--output-transl transformer.contextpart.expected
24 changes: 24 additions & 0 deletions tests/_self-adaptive/test_context_partial_transformer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

# Exit on error
set -e

# Test code goes here
rm -f contextpart.log

# Run Marian
$MRT_MARIAN/marian-adaptive \
-m $MRT_MODELS/transformer/model.npz \
-v $MRT_MODELS/transformer/vocab.ende.yml -v $MRT_MODELS/transformer/vocab.ende.yml \
--after-epochs 1 \
-t ubuntu.contextpart.src ubuntu.contextpart.ref --log contextpart.transformer.log < ubuntu.src > contextpart.transformer.out

# Check outputs
$MRT_TOOLS/diff.sh contextpart.out contextpart.expected > contextpart.transformer.diff

# Check costs
cat contextpart.log | $MRT_TOOLS/extract-costs.sh > contextpart.costs.transformer.out
$MRT_TOOLS/diff-nums.py -p 0.01 contextpart.costs.out contextpart.costs.expected -o contextpart.costs.transformer.diff

# Exit with success code
exit 0
Loading