Created
September 22, 2021 10:13
-
-
Save sudo-panda/44f61d8db98ecc5f3c1d8ac82c881a62 to your computer and use it in GitHub Desktop.
This gist has a small neural net written in python with JAX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| from math import exp | |
| import jax.numpy as jnp | |
| import jax.random as jrn | |
| import jax.profiler as jpr | |
| import random | |
| from jax import grad, vmap | |
| from typing import List | |
| def activation_func(x: jnp.ndarray) -> jnp.ndarray: | |
| return 1.0 / (1.0 + jnp.exp(-x)) | |
| def neuron_output(weight: jnp.ndarray, inp: jnp.ndarray) -> float: | |
| neuron_out: float = 0.0 | |
| for w, i in zip(weight, inp): | |
| neuron_out = neuron_out + w * i | |
| neuron_out = neuron_out + weight[len(weight) - 1] | |
| return neuron_out | |
| def loss_func(expected_output: float, weight: jnp.ndarray, inp: jnp.ndarray) -> float: | |
| activated_neuron_output: float = activation_func(neuron_output(weight, inp)) | |
| loss: float = pow((expected_output - activated_neuron_output), 2) | |
| return loss | |
| class Network: | |
| def __init__(self, input_neurons: int, output_neurons: int, learning_rate: float): | |
| assert (input_neurons > 0) | |
| assert (output_neurons > 0) | |
| self.input_dim: int = input_neurons | |
| self.output_dim: int = output_neurons | |
| self.lr: float = learning_rate | |
| self.output_values: jnp.ndarray = jnp.array([]) | |
| self.input_values: jnp.ndarray = jnp.array([]) | |
| self.error_backpropagate = grad(loss_func, argnums=1) | |
| self.af_backpropagate = grad(activation_func, argnums=0) | |
| self.weights: jnp.ndarray = jrn.uniform(jrn.PRNGKey(0), shape=(self.output_dim, self.input_dim + 1)) | |
| def feed_forward(self, inp: jnp.ndarray) -> jnp.ndarray: | |
| assert (self.input_dim == len(inp)) | |
| self.output_values = [] | |
| self.input_values = inp | |
| for weight in self.weights: | |
| self.output_values.append(activation_func(neuron_output(weight, inp))) | |
| return self.output_values | |
| def calculate_weights(self, exp_value, local_weights): | |
| activation_backprop = self.af_backpropagate(exp_value) | |
| network_backprop = self.error_backpropagate(exp_value, local_weights, self.input_values) | |
| return local_weights - self.lr * network_backprop * activation_backprop | |
| def back_propagate(self, exp_values: jnp.ndarray): | |
| assert (len(exp_values) == len(self.output_values)) | |
| self.weights = vmap(self.calculate_weights)(exp_values, self.weights) | |
| def train(self, train_data: jnp.ndarray, nr_epochs: int = 10000): | |
| for i in range(nr_epochs): | |
| ind: int = random.randint(0, len(train_data)) | |
| output: jnp.ndarray = self.feed_forward([train_data[ind][0], train_data[ind][1]]) | |
| self.back_propagate(jnp.asarray([train_data[ind][2]])) | |
| if i % 10 == 0: | |
| print(f"{i+1} iteration: weight update: {self.weights[0][0]}") | |
| def test(self, train_data: jnp.ndarray): | |
| acc: float = 0.0 | |
| print(f"after weights: {self.weights[0][0]}") | |
| for data in train_data: | |
| output: jnp.ndarray = self.feed_forward([data[0], data[1]]) | |
| if output[0] > 0.5 and data[2] == 1.0: | |
| acc += 1 | |
| elif output[0] < 0.5 and data[2] == 0.0: | |
| acc += 1 | |
| print(f"Accuracy: {acc / len(train_data)}") | |
| def main(): | |
| jpr.start_trace("/tmp/tensorboard") | |
| network: Network = Network(2, 1, 0.15) | |
| x1 = jrn.uniform(jrn.PRNGKey(0), shape=(50,1)) | |
| x2 = (-(5 / 7) * x1) + jrn.uniform(jrn.PRNGKey(0), shape=(50,1)) + 0.0001 | |
| x3 = jnp.ones((50,1)) | |
| train_data = jnp.concatenate([x1, x2, x3], axis=1) | |
| x1 = jrn.uniform(jrn.PRNGKey(0), shape=(50,1)) | |
| x2 = (-(5 / 7) * x1) - jrn.uniform(jrn.PRNGKey(0), shape=(50,1)) - 0.0001 | |
| x3 = jnp.zeros((50,1)) | |
| x = jnp.concatenate([x1, x2, x3], axis=1) | |
| train_data = jnp.concatenate([train_data, x], axis=0) | |
| #for data in train_data: | |
| #print(f"{data[0]}, {data[1]}, {data[2]}") | |
| network.train(train_data, 100) | |
| network.test(train_data) | |
| jpr.stop_trace() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment