Skip to content

Instantly share code, notes, and snippets.

@AshNguyen
Created May 9, 2020 10:43
Show Gist options
  • Select an option

  • Save AshNguyen/fc3bbe881b65e7fc119683b97e58ac71 to your computer and use it in GitHub Desktop.

Select an option

Save AshNguyen/fc3bbe881b65e7fc119683b97e58ac71 to your computer and use it in GitHub Desktop.
simple VAE pytorch
class VAE(nn.Module):
def __init__(self, n_latent):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, n_latent)
self.fc22 = nn.Linear(400, n_latent)
self.fc3 = nn.Linear(n_latent, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment