Skip to content

PyTorch implementation of batched bi-RNN encoder and attention-decoder.

Notifications You must be signed in to change notification settings

zqudm/PyTorch-Batch-Attention-Seq2seq

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 

Repository files navigation

A fast, batched Bi-RNN(GRU) encoder & attention decoder implementation in PyTorch

This code is written in PyTorch 0.2. By the time the PyTorch has released their 1.0 version, there are plenty of outstanding seq2seq learning packages built on PyTorch, such as OpenNMT, AllenNLP and etc. You can learn from their source code.

Usage: Please refer to offical pytorch tutorial on attention-RNN machine translation, except that this implementation handles batched inputs, and that it implements a slightly different attention mechanism.
To find out the formula-level difference of implementation, illustrations below will help a lot.

PyTorch version mechanism illustration, see here:
http://pytorch.org/tutorials/_images/decoder-network.png
PyTorch offical Seq2seq machine translation tutorial:
http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
Bahdanau attention illustration, see here:
http://images2015.cnblogs.com/blog/670089/201610/670089-20161012111504671-910168246.png

PyTorch version attention decoder fed "word_embedding" to compute attention weights, while in the origin paper it is supposed to be "encoder_outputs". In this repository, we implemented the origin attention decoder according to the paper

Update: dynamic encoder added and does not require inputs to be sorted by length in a batch.


Speed up with batched tensor manipulation

PyTorch supports element-wise fetching and assigning tensor values during procedure, but actually it is slow especially when running on GPU. In a tutorial(https://github.com/spro/practical-pytorch), attention values are assigned element-wise; it's absolutely correct(and intuitive from formulas in paper), but slow on our GPU. Thus, we re-implemented a real batched tensor manipulating version, and it achieves more than 10X speed improvement.

This code works well on personal projects.

About

PyTorch implementation of batched bi-RNN encoder and attention-decoder.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%