Skip to content

Commit

Permalink
bug: uncorrectly use get_batch func, change:not use this function now.
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-jin committed Jul 17, 2021
1 parent 48ba358 commit 9f98db0
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions train_embeddings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,14 @@ def get_batch(self, er_vocab, er_vocab_pairs, idx):

def evaluate(self, model, data):
model.eval()
hits = []
hits = [[] for _ in range(10)]
ranks = []
for i in range(10):
hits.append([])

test_data_idxs = self.get_data_idxs(data)
er_vocab = self.get_er_vocab(self.get_data_idxs(d.data))
er_vocab = self.get_er_vocab(test_data_idxs)

print("Number of data points: %d" % len(test_data_idxs))
for i in tqdm(range(0, len(test_data_idxs), self.batch_size)):
data_batch, _ = self.get_batch(er_vocab, test_data_idxs, i)
data_batch = np.array(test_data_idxs[i: i+self.batch_size])
e1_idx = torch.tensor(data_batch[:,0])
r_idx = torch.tensor(data_batch[:,1])
e2_idx = torch.tensor(data_batch[:,2])
Expand Down

0 comments on commit 9f98db0

Please sign in to comment.