INF201/MNIST/main.py

108 lines
3.4 KiB
Python
Raw Normal View History

2023-10-30 13:11:48 +00:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 08:23:56 2023
@author: Mohamad Mohannad al Kawadri (mohamad.mohannad.al.kawadri@nmbu.no), Trygve Børte Nomeland (trygve.borte.nomeland@nmbu.no)
"""
from abc import ABC, abstractmethod
import numpy as np
from copy import deepcopy
from torchvision import datasets, transforms
class Network:
def __init__(self, layers, W_file_list, b_file_list):
self.layers = layers
self.W_file_list = W_file_list
self.b_file_list = b_file_list
self.x = input
def run(self, x):
result = x
for n, W_file, b_file in zip(self.layers, self.W_file_list, self.b_file_list):
y = deepcopy(result)
l = n(y, W_file = W_file, b_file = b_file)
result = l.run()
return result
def evaluate(self, x, expected_value):
result = list(self.run(x))
max_value_index = result.index(max(result))
return int(max_value_index) == expected_value
class Layer:
def __init__(self, x, W_file, b_file):
self.x = x
files = read(W_file, b_file)
self.W = files.get('W')
self.b = files.get('b')
@abstractmethod
def run(self):
pass
class SigmaLayer(Layer):
def run(self):
return layer(self.W, self.x, self.b)
class ReluLayer(Layer):
def run(self):
return relu_layer(self.W, self.x, self.b)
def read(W_file, b_file):
return {'W': np.loadtxt(W_file), 'b': np.loadtxt(b_file)}
# define activation function
def sigma(y):
if y > 0:
return y
else:
return 0
sigma_vec = np.vectorize(sigma)
def relu_scalar(x):
if x > 0:
return x
else:
return 0
relu = np.vectorize(relu_scalar)
# define layer function for given weight matrix, input and bias
def layer(W, x, b):
return sigma_vec(W @ x + b)
def relu_layer(W, x, b):
return sigma_vec(W @ x + b)
# Function from example file "read.py"
def get_mnist():
return datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
# Function from example file "read.py"
def return_image(image_index, mnist_dataset):
image, label = mnist_dataset[image_index]
image_matrix = image[0].detach().numpy() # Grayscale image, so we select the first channel (index 0)
return image_matrix.reshape(image_matrix.size), image_matrix, label
def evalualte_on_mnist(image_index, expected_value):
mnist_dataset = get_mnist()
x, image, label = return_image(image_index, mnist_dataset)
network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
return network.evaluate(x, expected_value)
def run_on_mnist(image_index):
mnist_dataset = get_mnist()
x, image, label = return_image(image_index, mnist_dataset)
network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
return network.run(x)
def main():
print(f'Check if network works on image 19961 (number 4): {evalualte_on_mnist(19961, 4)}')
print(f'Check if network works on image 10003 (number 9): {evalualte_on_mnist(10003, 9)}')
print(f'Check if network works on image 117 (number 2): {evalualte_on_mnist(117, 2)}')
print(f'Check if network works on image 1145 (number 3): {evalualte_on_mnist(1145, 3)}')
print(f'Values image 19961 (number 4): {run_on_mnist(19961)}')
if __name__ == '__main__':
main()