2023-10-25 12:12:43 +00:00
|
|
|
import numpy as np
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
class Network:
|
2023-10-26 07:28:10 +00:00
|
|
|
def __init__(self, layers, W_file_list, b_file_list):
|
2023-10-26 07:04:31 +00:00
|
|
|
self.layers = layers
|
2023-10-26 07:28:10 +00:00
|
|
|
self.W_file_list = W_file_list
|
|
|
|
self.b_file_list = b_file_list
|
2023-10-26 07:04:31 +00:00
|
|
|
self.n_layers = 4
|
|
|
|
self.n_inputs = 784
|
|
|
|
self.n_outputs = 10
|
|
|
|
self.n = [self.n_inputs, 512, 256, self.n_outputs]
|
|
|
|
self.x = np.random.rand(self.n_inputs)
|
|
|
|
|
|
|
|
def run(self):
|
2023-10-26 07:28:10 +00:00
|
|
|
result = self.x
|
|
|
|
for n, W_file, b_file in zip(self.layers, self.W_file_list, self.b_file_list):
|
|
|
|
y = deepcopy(result)
|
|
|
|
l = n(y, W_file = W_file, b_file = b_file)
|
2023-10-26 07:04:31 +00:00
|
|
|
result = l.run()
|
|
|
|
return result
|
2023-10-25 12:12:43 +00:00
|
|
|
class Layer:
|
2023-10-26 07:28:10 +00:00
|
|
|
def __init__(self, x, W_file, b_file):
|
2023-10-26 07:04:31 +00:00
|
|
|
self.x = x
|
|
|
|
files = read(W_file, b_file)
|
2023-10-26 07:28:10 +00:00
|
|
|
self.W = files.get('W')
|
|
|
|
self.b = files.get('b')
|
|
|
|
|
2023-10-25 12:12:43 +00:00
|
|
|
def run(self):
|
2023-10-26 07:28:10 +00:00
|
|
|
return layer(self.W, self.x, self.b)
|
2023-10-26 07:04:31 +00:00
|
|
|
|
|
|
|
def read(W_file, b_file):
|
2023-10-26 07:28:10 +00:00
|
|
|
return {'W': np.loadtxt(W_file), 'b': np.loadtxt(b_file)}
|
2023-10-25 12:12:43 +00:00
|
|
|
|
|
|
|
# define activation function
|
|
|
|
def sigma(y):
|
|
|
|
if y > 0:
|
|
|
|
return y
|
|
|
|
else:
|
|
|
|
return 0
|
|
|
|
sigma_vec = np.vectorize(sigma)
|
|
|
|
|
|
|
|
# define layer function for given weight matrix, input and bias
|
|
|
|
def layer(W, x, b):
|
|
|
|
return sigma_vec(W @ x + b)
|
|
|
|
|
|
|
|
def main():
|
2023-10-26 07:28:10 +00:00
|
|
|
network = Network([Layer, Layer, Layer], ['W_1.txt', 'W_2.txt', 'W_3.txt'], ['b_1.txt', 'b_2.txt', 'b_3.txt'])
|
|
|
|
print(network.run())
|
2023-10-25 12:12:43 +00:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2023-10-26 07:04:31 +00:00
|
|
|
main()
|