Ferdig med oppgave 6: MNIST modell
This commit is contained in:
parent
29bbee4165
commit
fa9f0faf62
512
MNIST/W_1.txt
Normal file
512
MNIST/W_1.txt
Normal file
File diff suppressed because one or more lines are too long
256
MNIST/W_2.txt
Normal file
256
MNIST/W_2.txt
Normal file
File diff suppressed because one or more lines are too long
10
MNIST/W_3.txt
Normal file
10
MNIST/W_3.txt
Normal file
File diff suppressed because one or more lines are too long
1
MNIST/b_1.txt
Normal file
1
MNIST/b_1.txt
Normal file
File diff suppressed because one or more lines are too long
1
MNIST/b_2.txt
Normal file
1
MNIST/b_2.txt
Normal file
File diff suppressed because one or more lines are too long
1
MNIST/b_3.txt
Normal file
1
MNIST/b_3.txt
Normal file
@ -0,0 +1 @@
|
||||
-0.3120380938053131 0.057060789316892624 0.5383008718490601 -0.4334142804145813 0.20545503497123718 0.8229865431785583 0.26446419954299927 1.3929160833358765 0.40466466546058655 -0.06923668831586838
|
108
MNIST/main.py
Normal file
108
MNIST/main.py
Normal file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Sep 28 08:23:56 2023
|
||||
|
||||
@author: Mohamad Mohannad al Kawadri (mohamad.mohannad.al.kawadri@nmbu.no), Trygve Børte Nomeland (trygve.borte.nomeland@nmbu.no)
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
class Network:
|
||||
def __init__(self, layers, W_file_list, b_file_list):
|
||||
self.layers = layers
|
||||
self.W_file_list = W_file_list
|
||||
self.b_file_list = b_file_list
|
||||
self.x = input
|
||||
|
||||
def run(self, x):
|
||||
result = x
|
||||
for n, W_file, b_file in zip(self.layers, self.W_file_list, self.b_file_list):
|
||||
y = deepcopy(result)
|
||||
l = n(y, W_file = W_file, b_file = b_file)
|
||||
result = l.run()
|
||||
return result
|
||||
|
||||
def evaluate(self, x, expected_value):
|
||||
result = list(self.run(x))
|
||||
max_value_index = result.index(max(result))
|
||||
return int(max_value_index) == expected_value
|
||||
|
||||
class Layer:
|
||||
def __init__(self, x, W_file, b_file):
|
||||
self.x = x
|
||||
files = read(W_file, b_file)
|
||||
self.W = files.get('W')
|
||||
self.b = files.get('b')
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
class SigmaLayer(Layer):
|
||||
def run(self):
|
||||
return layer(self.W, self.x, self.b)
|
||||
|
||||
class ReluLayer(Layer):
|
||||
def run(self):
|
||||
return relu_layer(self.W, self.x, self.b)
|
||||
|
||||
def read(W_file, b_file):
|
||||
return {'W': np.loadtxt(W_file), 'b': np.loadtxt(b_file)}
|
||||
|
||||
# define activation function
|
||||
def sigma(y):
|
||||
if y > 0:
|
||||
return y
|
||||
else:
|
||||
return 0
|
||||
sigma_vec = np.vectorize(sigma)
|
||||
|
||||
def relu_scalar(x):
|
||||
if x > 0:
|
||||
return x
|
||||
else:
|
||||
return 0
|
||||
relu = np.vectorize(relu_scalar)
|
||||
|
||||
|
||||
# define layer function for given weight matrix, input and bias
|
||||
def layer(W, x, b):
|
||||
return sigma_vec(W @ x + b)
|
||||
|
||||
def relu_layer(W, x, b):
|
||||
return sigma_vec(W @ x + b)
|
||||
|
||||
# Function from example file "read.py"
|
||||
def get_mnist():
|
||||
return datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
|
||||
|
||||
# Function from example file "read.py"
|
||||
def return_image(image_index, mnist_dataset):
|
||||
image, label = mnist_dataset[image_index]
|
||||
image_matrix = image[0].detach().numpy() # Grayscale image, so we select the first channel (index 0)
|
||||
return image_matrix.reshape(image_matrix.size), image_matrix, label
|
||||
|
||||
def evalualte_on_mnist(image_index, expected_value):
|
||||
mnist_dataset = get_mnist()
|
||||
x, image, label = return_image(image_index, mnist_dataset)
|
||||
network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
|
||||
return network.evaluate(x, expected_value)
|
||||
|
||||
def run_on_mnist(image_index):
|
||||
mnist_dataset = get_mnist()
|
||||
x, image, label = return_image(image_index, mnist_dataset)
|
||||
network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
|
||||
return network.run(x)
|
||||
|
||||
def main():
|
||||
print(f'Check if network works on image 19961 (number 4): {evalualte_on_mnist(19961, 4)}')
|
||||
print(f'Check if network works on image 10003 (number 9): {evalualte_on_mnist(10003, 9)}')
|
||||
print(f'Check if network works on image 117 (number 2): {evalualte_on_mnist(117, 2)}')
|
||||
print(f'Check if network works on image 1145 (number 3): {evalualte_on_mnist(1145, 3)}')
|
||||
print(f'Values image 19961 (number 4): {run_on_mnist(19961)}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user