Skip to content

Commit

Permalink
Update run_text2text_infer.py (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric8932 authored Aug 24, 2023
1 parent 805ab4c commit ba8f95a
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions inference/run_text2text_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_dataset(args, path):
for i, column_name in enumerate(line.rstrip("\r\n").split("\t")):
columns[column_name] = i
continue
line = line.rstrip("\r\n").split('\t')
line = line.rstrip("\r\n").split("\t")

if "text_b" in columns:
text = line[columns["text_a"]] + SEP_TOKEN + line[columns["text_b"]]
Expand Down Expand Up @@ -91,28 +91,27 @@ def main():
print("The number of prediction instances: ", instances_num)

model.eval()

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
f.write("label")
f.write("\n")
for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
src_batch = src_batch.to(args.device)
seg_batch = seg_batch.to(args.device)
tgt_in_batch = torch.zeros(src_batch.size()[0], 1, dtype = torch.long, device = args.device)
for j in range(tgt_in_batch.size()[0]):
tgt_seg_batch = torch.ones(tgt_in_batch.size()[0], 1, dtype = torch.long, device = args.device)
current_batch_size = tgt_in_batch.size()[0]
for j in range(current_batch_size):
tgt_in_batch[j][-1] = args.tokenizer.vocab.get(CLS_TOKEN)

with torch.no_grad():
memory_bank = model(src_batch, None, seg_batch, only_use_encoder=True)

memory_bank = model(src_batch, None, seg_batch, tgt_seg_batch, only_use_encoder=True)
for _ in range(args.tgt_seq_length):
with torch.no_grad():
outputs = model(src_batch, (tgt_in_batch, None, src_batch), None, memory_bank=memory_bank)

outputs = model(src_batch, (tgt_in_batch, None, src_batch), None, tgt_seg_batch, memory_bank=memory_bank)
next_token_logits = outputs[:, -1]
next_tokens = torch.argmax(next_token_logits, dim=1).unsqueeze(1)
tgt_in_batch = torch.cat([tgt_in_batch, next_tokens], dim=1)

tgt_seg_batch = torch.ones(tgt_in_batch.size()[0], tgt_in_batch.size()[1], dtype=torch.long, device=args.device)
for j in range(len(outputs)):
f.write("".join([args.tokenizer.inv_vocab[token_id.item()] for token_id in tgt_in_batch[j][1:]])
.split(SEP_TOKEN)[0])
Expand Down

0 comments on commit ba8f95a

Please sign in to comment.