-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_embeddings.py
86 lines (74 loc) · 4.15 KB
/
train_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import gensim
import csv
import io
import sys
import numpy as np
import gzip
import os
import argparse
import logging
from sklearn.manifold import TSNE
from gensim.models.poincare import PoincareModel, PoincareRelations
from gensim.test.utils import datapath
from gensim.models.word2vec import LineSentence
from data_loader import read_all_data, compound_operator, preprocess_com, preprocess_wordnet
import pickle
import itertools
from nltk.corpus import wordnet as wn
import pandas
def main():
parser = argparse.ArgumentParser(description="Embeddings for Taxonomy")
parser.add_argument('--mode', type=str, default='preload', choices=["train_poincare_wordnet", "train_poincare_custom", "train_word2vec"], help="Mode of the system.")
parser.add_argument('--language', type=str, default='EN', choices=["EN", "FR", "NL", "IT"], help="Mode of the system.")
args = parser.parse_args()
print("Mode: ", args.mode)
run(args.mode, args.language)
def run(mode, language):
if mode == "train_poincare_custom":
gold_s,_ = read_all_data(domain = "science", language = language)
gold_e,_ = read_all_data(domain = "environment", language = language)
gold_f,_ = read_all_data(domain = "food", language = language)
vocabulary = set([relation[0].lower() for relation in gold_s] + [relation[1].lower() for relation in gold_s])
vocabulary = vocabulary | set([relation[0].lower() for relation in gold_f] + [relation[1].lower() for relation in gold_f])
vocabulary = vocabulary | set([relation[0].lower() for relation in gold_e] + [relation[1].lower() for relation in gold_e])
relations ="data/" + language + "/poincare_common_and_domains_" + language + ".tsv"
assert len(open(relations, 'r').readlines()) > 10, "Not enough relations to train embeddings. Aborting ..."
poincare_rel = PoincareRelations(relations)
dim = 50
model = PoincareModel(poincare_rel, size = dim)
print("Starting Training...")
model.train(epochs=400)
model.save("embeddings/poincare_common_and_domains_5_3_" + language + "_" + str(dim))
if mode == 'train_poincare_wordnet':
assert language == 'EN', "Wordnet consists only of English nouns"
gold_s,_ = read_all_data(domain = "science")
gold_e,_ = read_all_data(domain = "environment")
gold_f,_ = read_all_data(domain = "food")
vocabulary = set([relation[0].lower() for relation in gold_s] + [relation[1].lower() for relation in gold_s])
vocabulary = vocabulary | set([relation[0].lower() for relation in gold_f] + [relation[1].lower() for relation in gold_f])
vocabulary = vocabulary | set([relation[0].lower() for relation in gold_e] + [relation[1].lower() for relation in gold_e])
preprocess_wordnet('data/EN/noun_closure.tsv', vocabulary)
poincare_rel = PoincareRelations('data/EN/noun_closure_filtered.tsv')
dim = 50
model = PoincareModel(poincare_rel, size = dim)
print("Starting Training...")
model.train(epochs=400)
model.save("embeddings/wordnet_filtered" + "_" + str(dim))
if mode == "train_word2vec":
gold_s,relations_s = read_all_data("science")
gold_e,relations_e = read_all_data("environment")
gold_f,relations_f = read_all_data("food")
vocabulary = set([relation[0].lower() for relation in gold_s] + [relation[1].lower() for relation in gold_s])
vocabulary = vocabulary | set([relation[0].lower() for relation in gold_f] + [relation[1].lower() for relation in gold_f])
vocabulary = vocabulary | set([relation[0].lower() for relation in gold_e] + [relation[1].lower() for relation in gold_e])
documents = []
documents = list(read_input("/data/EN/wikipedia_utf8_filtered_20pageviews.csv",vocabulary))
model = gensim.models.Word2Vec(documents, size= 300, window = 10, min_count = 2, workers = 10)
model.train(documents, total_examples=len(documents), epochs=30)
print("Finished building word2vec model")
model.save("embeddings/own_embeddings_w2v")
if __name__ == '__main__':
main()