Skip to content

Instantly share code, notes, and snippets.

@sudo-panda
Created September 22, 2021 10:13
Show Gist options
  • Select an option

  • Save sudo-panda/44f61d8db98ecc5f3c1d8ac82c881a62 to your computer and use it in GitHub Desktop.

Select an option

Save sudo-panda/44f61d8db98ecc5f3c1d8ac82c881a62 to your computer and use it in GitHub Desktop.
This gist has a small neural net written in python with JAX
#!/usr/bin/env python3
from math import exp
import jax.numpy as jnp
import jax.random as jrn
import jax.profiler as jpr
import random
from jax import grad, vmap
from typing import List
def activation_func(x: jnp.ndarray) -> jnp.ndarray:
return 1.0 / (1.0 + jnp.exp(-x))
def neuron_output(weight: jnp.ndarray, inp: jnp.ndarray) -> float:
neuron_out: float = 0.0
for w, i in zip(weight, inp):
neuron_out = neuron_out + w * i
neuron_out = neuron_out + weight[len(weight) - 1]
return neuron_out
def loss_func(expected_output: float, weight: jnp.ndarray, inp: jnp.ndarray) -> float:
activated_neuron_output: float = activation_func(neuron_output(weight, inp))
loss: float = pow((expected_output - activated_neuron_output), 2)
return loss
class Network:
def __init__(self, input_neurons: int, output_neurons: int, learning_rate: float):
assert (input_neurons > 0)
assert (output_neurons > 0)
self.input_dim: int = input_neurons
self.output_dim: int = output_neurons
self.lr: float = learning_rate
self.output_values: jnp.ndarray = jnp.array([])
self.input_values: jnp.ndarray = jnp.array([])
self.error_backpropagate = grad(loss_func, argnums=1)
self.af_backpropagate = grad(activation_func, argnums=0)
self.weights: jnp.ndarray = jrn.uniform(jrn.PRNGKey(0), shape=(self.output_dim, self.input_dim + 1))
def feed_forward(self, inp: jnp.ndarray) -> jnp.ndarray:
assert (self.input_dim == len(inp))
self.output_values = []
self.input_values = inp
for weight in self.weights:
self.output_values.append(activation_func(neuron_output(weight, inp)))
return self.output_values
def calculate_weights(self, exp_value, local_weights):
activation_backprop = self.af_backpropagate(exp_value)
network_backprop = self.error_backpropagate(exp_value, local_weights, self.input_values)
return local_weights - self.lr * network_backprop * activation_backprop
def back_propagate(self, exp_values: jnp.ndarray):
assert (len(exp_values) == len(self.output_values))
self.weights = vmap(self.calculate_weights)(exp_values, self.weights)
def train(self, train_data: jnp.ndarray, nr_epochs: int = 10000):
for i in range(nr_epochs):
ind: int = random.randint(0, len(train_data))
output: jnp.ndarray = self.feed_forward([train_data[ind][0], train_data[ind][1]])
self.back_propagate(jnp.asarray([train_data[ind][2]]))
if i % 10 == 0:
print(f"{i+1} iteration: weight update: {self.weights[0][0]}")
def test(self, train_data: jnp.ndarray):
acc: float = 0.0
print(f"after weights: {self.weights[0][0]}")
for data in train_data:
output: jnp.ndarray = self.feed_forward([data[0], data[1]])
if output[0] > 0.5 and data[2] == 1.0:
acc += 1
elif output[0] < 0.5 and data[2] == 0.0:
acc += 1
print(f"Accuracy: {acc / len(train_data)}")
def main():
jpr.start_trace("/tmp/tensorboard")
network: Network = Network(2, 1, 0.15)
x1 = jrn.uniform(jrn.PRNGKey(0), shape=(50,1))
x2 = (-(5 / 7) * x1) + jrn.uniform(jrn.PRNGKey(0), shape=(50,1)) + 0.0001
x3 = jnp.ones((50,1))
train_data = jnp.concatenate([x1, x2, x3], axis=1)
x1 = jrn.uniform(jrn.PRNGKey(0), shape=(50,1))
x2 = (-(5 / 7) * x1) - jrn.uniform(jrn.PRNGKey(0), shape=(50,1)) - 0.0001
x3 = jnp.zeros((50,1))
x = jnp.concatenate([x1, x2, x3], axis=1)
train_data = jnp.concatenate([train_data, x], axis=0)
#for data in train_data:
#print(f"{data[0]}, {data[1]}, {data[2]}")
network.train(train_data, 100)
network.test(train_data)
jpr.stop_trace()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment