Ferdig med oppgave 6: MNIST modell
This commit is contained in:
parent
29bbee4165
commit
fa9f0faf62
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
|||
-0.3120380938053131 0.057060789316892624 0.5383008718490601 -0.4334142804145813 0.20545503497123718 0.8229865431785583 0.26446419954299927 1.3929160833358765 0.40466466546058655 -0.06923668831586838
|
|
@ -0,0 +1,108 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Sep 28 08:23:56 2023
|
||||
|
||||
@author: Mohamad Mohannad al Kawadri (mohamad.mohannad.al.kawadri@nmbu.no), Trygve Børte Nomeland (trygve.borte.nomeland@nmbu.no)
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
class Network:
|
||||
def __init__(self, layers, W_file_list, b_file_list):
|
||||
self.layers = layers
|
||||
self.W_file_list = W_file_list
|
||||
self.b_file_list = b_file_list
|
||||
self.x = input
|
||||
|
||||
def run(self, x):
|
||||
result = x
|
||||
for n, W_file, b_file in zip(self.layers, self.W_file_list, self.b_file_list):
|
||||
y = deepcopy(result)
|
||||
l = n(y, W_file = W_file, b_file = b_file)
|
||||
result = l.run()
|
||||
return result
|
||||
|
||||
def evaluate(self, x, expected_value):
|
||||
result = list(self.run(x))
|
||||
max_value_index = result.index(max(result))
|
||||
return int(max_value_index) == expected_value
|
||||
|
||||
class Layer:
|
||||
def __init__(self, x, W_file, b_file):
|
||||
self.x = x
|
||||
files = read(W_file, b_file)
|
||||
self.W = files.get('W')
|
||||
self.b = files.get('b')
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
class SigmaLayer(Layer):
|
||||
def run(self):
|
||||
return layer(self.W, self.x, self.b)
|
||||
|
||||
class ReluLayer(Layer):
|
||||
def run(self):
|
||||
return relu_layer(self.W, self.x, self.b)
|
||||
|
||||
def read(W_file, b_file):
|
||||
return {'W': np.loadtxt(W_file), 'b': np.loadtxt(b_file)}
|
||||
|
||||
# define activation function
|
||||
def sigma(y):
|
||||
if y > 0:
|
||||
return y
|
||||
else:
|
||||
return 0
|
||||
sigma_vec = np.vectorize(sigma)
|
||||
|
||||
def relu_scalar(x):
|
||||
if x > 0:
|
||||
return x
|
||||
else:
|
||||
return 0
|
||||
relu = np.vectorize(relu_scalar)
|
||||
|
||||
|
||||
# define layer function for given weight matrix, input and bias
|
||||
def layer(W, x, b):
|
||||
return sigma_vec(W @ x + b)
|
||||
|
||||
def relu_layer(W, x, b):
|
||||
return sigma_vec(W @ x + b)
|
||||
|
||||
# Function from example file "read.py"
|
||||
def get_mnist():
|
||||
return datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
|
||||
|
||||
# Function from example file "read.py"
|
||||
def return_image(image_index, mnist_dataset):
|
||||
image, label = mnist_dataset[image_index]
|
||||
image_matrix = image[0].detach().numpy() # Grayscale image, so we select the first channel (index 0)
|
||||
return image_matrix.reshape(image_matrix.size), image_matrix, label
|
||||
|
||||
def evalualte_on_mnist(image_index, expected_value):
|
||||
mnist_dataset = get_mnist()
|
||||
x, image, label = return_image(image_index, mnist_dataset)
|
||||
network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
|
||||
return network.evaluate(x, expected_value)
|
||||
|
||||
def run_on_mnist(image_index):
|
||||
mnist_dataset = get_mnist()
|
||||
x, image, label = return_image(image_index, mnist_dataset)
|
||||
network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
|
||||
return network.run(x)
|
||||
|
||||
def main():
|
||||
print(f'Check if network works on image 19961 (number 4): {evalualte_on_mnist(19961, 4)}')
|
||||
print(f'Check if network works on image 10003 (number 9): {evalualte_on_mnist(10003, 9)}')
|
||||
print(f'Check if network works on image 117 (number 2): {evalualte_on_mnist(117, 2)}')
|
||||
print(f'Check if network works on image 1145 (number 3): {evalualte_on_mnist(1145, 3)}')
|
||||
print(f'Values image 19961 (number 4): {run_on_mnist(19961)}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue