Ferdig med oppgave 6: MNIST modell
This commit is contained in:
		
							parent
							
								
									29bbee4165
								
							
						
					
					
						commit
						fa9f0faf62
					
				
							
								
								
									
										512
									
								
								MNIST/W_1.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										512
									
								
								MNIST/W_1.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										256
									
								
								MNIST/W_2.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										256
									
								
								MNIST/W_2.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										10
									
								
								MNIST/W_3.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								MNIST/W_3.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										1
									
								
								MNIST/b_1.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								MNIST/b_1.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										1
									
								
								MNIST/b_2.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								MNIST/b_2.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										1
									
								
								MNIST/b_3.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								MNIST/b_3.txt
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					-0.3120380938053131	0.057060789316892624	0.5383008718490601	-0.4334142804145813	0.20545503497123718	0.8229865431785583	0.26446419954299927	1.3929160833358765	0.40466466546058655	-0.06923668831586838
 | 
				
			||||||
							
								
								
									
										108
									
								
								MNIST/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								MNIST/main.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,108 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3
 | 
				
			||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					Created on Thu Sep 28 08:23:56 2023
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@author: Mohamad Mohannad al Kawadri (mohamad.mohannad.al.kawadri@nmbu.no), Trygve Børte Nomeland (trygve.borte.nomeland@nmbu.no)
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from copy import deepcopy
 | 
				
			||||||
 | 
					from torchvision import datasets, transforms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Network:
 | 
				
			||||||
 | 
					    def __init__(self, layers, W_file_list, b_file_list):
 | 
				
			||||||
 | 
					        self.layers = layers
 | 
				
			||||||
 | 
					        self.W_file_list = W_file_list
 | 
				
			||||||
 | 
					        self.b_file_list = b_file_list
 | 
				
			||||||
 | 
					        self.x = input
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run(self, x):
 | 
				
			||||||
 | 
					        result = x
 | 
				
			||||||
 | 
					        for n, W_file, b_file in zip(self.layers, self.W_file_list, self.b_file_list):
 | 
				
			||||||
 | 
					            y = deepcopy(result)
 | 
				
			||||||
 | 
					            l = n(y, W_file = W_file, b_file = b_file)
 | 
				
			||||||
 | 
					            result = l.run()
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def evaluate(self, x, expected_value):
 | 
				
			||||||
 | 
					        result = list(self.run(x))
 | 
				
			||||||
 | 
					        max_value_index = result.index(max(result))
 | 
				
			||||||
 | 
					        return int(max_value_index) == expected_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Layer:
 | 
				
			||||||
 | 
					    def __init__(self, x, W_file, b_file):
 | 
				
			||||||
 | 
					        self.x = x
 | 
				
			||||||
 | 
					        files = read(W_file, b_file)
 | 
				
			||||||
 | 
					        self.W = files.get('W')
 | 
				
			||||||
 | 
					        self.b = files.get('b')
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def run(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SigmaLayer(Layer):
 | 
				
			||||||
 | 
					    def run(self):
 | 
				
			||||||
 | 
					        return layer(self.W,  self.x, self.b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ReluLayer(Layer):
 | 
				
			||||||
 | 
					    def run(self):
 | 
				
			||||||
 | 
					        return relu_layer(self.W,  self.x, self.b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read(W_file, b_file):
 | 
				
			||||||
 | 
					    return {'W': np.loadtxt(W_file), 'b': np.loadtxt(b_file)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# define activation function
 | 
				
			||||||
 | 
					def sigma(y):
 | 
				
			||||||
 | 
					    if y > 0:
 | 
				
			||||||
 | 
					        return y
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					sigma_vec = np.vectorize(sigma)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def relu_scalar(x):
 | 
				
			||||||
 | 
					    if x > 0:
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					relu = np.vectorize(relu_scalar)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# define layer function for given weight matrix, input and bias
 | 
				
			||||||
 | 
					def layer(W, x, b):
 | 
				
			||||||
 | 
					    return sigma_vec(W @ x + b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def relu_layer(W, x, b):
 | 
				
			||||||
 | 
					    return sigma_vec(W @ x + b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Function from example file "read.py"
 | 
				
			||||||
 | 
					def get_mnist():
 | 
				
			||||||
 | 
					    return datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Function from example file "read.py"
 | 
				
			||||||
 | 
					def return_image(image_index, mnist_dataset):
 | 
				
			||||||
 | 
					    image, label = mnist_dataset[image_index]
 | 
				
			||||||
 | 
					    image_matrix = image[0].detach().numpy()  # Grayscale image, so we select the first channel (index 0)
 | 
				
			||||||
 | 
					    return image_matrix.reshape(image_matrix.size), image_matrix, label
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def evalualte_on_mnist(image_index, expected_value):
 | 
				
			||||||
 | 
					    mnist_dataset = get_mnist()
 | 
				
			||||||
 | 
					    x, image, label = return_image(image_index, mnist_dataset)
 | 
				
			||||||
 | 
					    network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
 | 
				
			||||||
 | 
					    return network.evaluate(x, expected_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_on_mnist(image_index):
 | 
				
			||||||
 | 
					    mnist_dataset = get_mnist()
 | 
				
			||||||
 | 
					    x, image, label = return_image(image_index, mnist_dataset)
 | 
				
			||||||
 | 
					    network = Network([ReluLayer, ReluLayer, ReluLayer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
 | 
				
			||||||
 | 
					    return network.run(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    print(f'Check if network works on image 19961 (number 4): {evalualte_on_mnist(19961, 4)}')
 | 
				
			||||||
 | 
					    print(f'Check if network works on image 10003 (number 9): {evalualte_on_mnist(10003, 9)}')
 | 
				
			||||||
 | 
					    print(f'Check if network works on image 117 (number 2): {evalualte_on_mnist(117, 2)}')
 | 
				
			||||||
 | 
					    print(f'Check if network works on image 1145 (number 3): {evalualte_on_mnist(1145, 3)}')
 | 
				
			||||||
 | 
					    print(f'Values image 19961 (number 4): {run_on_mnist(19961)}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user