-
-
Save W3SS/a35e4bf45e5b2a8b6a8ceed9b9357955 to your computer and use it in GitHub Desktop.
Efficient LSTM cell in Torch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --[[ | |
| Efficient LSTM in Torch using nngraph library. This code was optimized | |
| by Justin Johnson (@jcjohnson) based on the trick of batching up the | |
| LSTM GEMMs, as also seen in my efficient Python LSTM gist. | |
| --]] | |
| function LSTM.fast_lstm(input_size, rnn_size) | |
| local x = nn.Identity()() | |
| local prev_c = nn.Identity()() | |
| local prev_h = nn.Identity()() | |
| local i2h = nn.Linear(input_size, 4 * rnn_size)(x) | |
| local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) | |
| local all_input_sums = nn.CAddTable()({i2h, h2h}) | |
| local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) | |
| sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) | |
| local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) | |
| local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) | |
| local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) | |
| local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) | |
| in_transform = nn.Tanh()(in_transform) | |
| local next_c = nn.CAddTable()({ | |
| nn.CMulTable()({forget_gate, prev_c}), | |
| nn.CMulTable()({in_gate, in_transform}) | |
| }) | |
| local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) | |
| return nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) | |
| end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment