Skip to content

Instantly share code, notes, and snippets.

@flaport
Last active September 13, 2018 15:59
Show Gist options
  • Select an option

  • Save flaport/437728ca8e373e8f47563b573d92cbfc to your computer and use it in GitHub Desktop.

Select an option

Save flaport/437728ca8e373e8f47563b573d92cbfc to your computer and use it in GitHub Desktop.
simple pytorch rnn
class RNN(torch.nn.Module):
''' Simple pure-pytorch RNN implementation '''
def __init__(self, input_size, hidden_size, batch_first=False):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.batch_first = batch_first
self.input_layer = torch.nn.Linear(input_size, hidden_size, bias=False)
self.hidden_layer = torch.nn.Linear(hidden_size, hidden_size, bias=True)
self.activation = torch.nn.Tanh()
def forward(self, input, state=None):
input_shape = input.shape
input = self.input_layer(input.view(-1, input_shape[-1])).view(*(input_shape[:-1] + (-1,)))
if self.batch_first:
input = torch.unbind(input, 1)
else:
input = torch.unbind(input, 0)
if state is None:
with torch.no_grad():
state = torch.zeros_like(input[0])
output = []
for i, inp in enumerate(input):
state = self.activation(inp + self.hidden_layer(state))
output.append(state)
if self.batch_first:
output = torch.stack(output, 1)
else:
output = torch.stack(output, 0)
return output, state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment