Skip to content

Commit 6e3954a

Browse files
committed
Evaluation function
1 parent 38c3b44 commit 6e3954a

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
.vscode/
22
__pycache__/
3+
4+
data/mnist_train.csv

main.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66

77

8-
def make_neural_network(layer_sizes, layer_activations, learning_rate=0.2, low=-2, high=2):
8+
def make_neural_network(layer_sizes, layer_activations, learning_rate=0.05, low=-2, high=2):
99

1010
# Initialize typed layer sizes list.
1111
typed_layer_sizes = typed.List()
@@ -118,6 +118,18 @@ def train(self, test_input_data, test_desired_output_data, validate_input_data,
118118
current_mse = self.calculate_MSE(validate_input_data, validate_output_data)
119119
return epochs, current_mse
120120

121+
def evaluate(self, input_data, desired_output_data):
122+
corrects, wrongs = 0, 0
123+
for i in range(len(input_data)):
124+
output = self.calculate_output(input_data[i])
125+
output_max = output.argmax()
126+
desired_output_max = desired_output_data[i].argmax()
127+
if output_max == desired_output_max:
128+
corrects += 1
129+
else:
130+
wrongs += 1
131+
return corrects / (corrects + wrongs)
132+
121133
def print_weights_and_biases(self):
122134
print(self.weights)
123135
print(self.biases)
@@ -135,7 +147,7 @@ def print_weights_and_biases(self):
135147
end_time = time.time_ns()
136148
print("Compile time:", (end_time-begin_time) / 1e9)
137149

138-
for i in range(3):
150+
for i in range(10):
139151

140152
random_seed = np.random.randint(10, 1010)
141153
np.random.seed(random_seed)
@@ -151,4 +163,6 @@ def print_weights_and_biases(self):
151163

152164
train_mse = nn.calculate_MSE(train_input, train_output)
153165
test_mse = nn.calculate_MSE(test_input, test_output)
154-
print("Seed:", random_seed, "Epochs:", epochs, "Time:", (end_time-begin_time)/1e9, "Tr:", train_mse, "V:", current_mse, "T:", test_mse)
166+
167+
accuracy_test = nn.evaluate(test_input, test_output)
168+
print("Seed:", random_seed, "Epochs:", epochs, "Time:", (end_time-begin_time)/1e9, "Accuracy:", accuracy_test, "Tr:", train_mse, "V:", current_mse, "T:", test_mse)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy