Skip to content

Commit 2285ba0

Browse files
rossbarbsipocz
andauthored
Fix bug and improve computation / display of metrics for MNIST tutorial (numpy#189)
* BUG: Fix incorrect variable in computing eval metrics. * ENH: Replace list comps with vectorization. * ENH: Use dicts and condense plotting. * Update content/tutorial-deep-learning-on-mnist.md Co-authored-by: Brigitta Sipőcz <b.sipocz@gmail.com>
1 parent 73ebebf commit 2285ba0

File tree

1 file changed

+19
-27
lines changed

1 file changed

+19
-27
lines changed

content/tutorial-deep-learning-on-mnist.md

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -561,39 +561,31 @@ The training process may take many minutes, depending on a number of factors, su
561561
After executing the cell above, you can visualize the training and test set errors and accuracy for an instance of this training process.
562562

563563
```{code-cell}
564+
epoch_range = np.arange(epochs) + 1 # Starting from 1
565+
564566
# The training set metrics.
565-
y_training_error = [
566-
store_training_loss[i] / float(len(training_images))
567-
for i in range(len(store_training_loss))
568-
]
569-
x_training_error = range(1, len(store_training_loss) + 1)
570-
y_training_accuracy = [
571-
store_training_accurate_pred[i] / float(len(training_images))
572-
for i in range(len(store_training_accurate_pred))
573-
]
574-
x_training_accuracy = range(1, len(store_training_accurate_pred) + 1)
567+
training_metrics = {
568+
"accuracy": np.asarray(store_training_accurate_pred) / len(training_images),
569+
"error": np.asarray(store_training_loss) / len(training_images),
570+
}
575571
576572
# The test set metrics.
577-
y_test_error = [
578-
store_test_loss[i] / float(len(test_images)) for i in range(len(store_test_loss))
579-
]
580-
x_test_error = range(1, len(store_test_loss) + 1)
581-
y_test_accuracy = [
582-
store_training_accurate_pred[i] / float(len(training_images))
583-
for i in range(len(store_training_accurate_pred))
584-
]
585-
x_test_accuracy = range(1, len(store_test_accurate_pred) + 1)
573+
test_metrics = {
574+
"accuracy": np.asarray(store_test_accurate_pred) / len(test_images),
575+
"error": np.asarray(store_test_loss) / len(test_images),
576+
}
586577
587578
# Display the plots.
588579
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
589-
axes[0].set_title("Training set error, accuracy")
590-
axes[0].plot(x_training_accuracy, y_training_accuracy, label="Training set accuracy")
591-
axes[0].plot(x_training_error, y_training_error, label="Training set error")
592-
axes[0].set_xlabel("Epochs")
593-
axes[1].set_title("Test set error, accuracy")
594-
axes[1].plot(x_test_accuracy, y_test_accuracy, label="Test set accuracy")
595-
axes[1].plot(x_test_error, y_test_error, label="Test set error")
596-
axes[1].set_xlabel("Epochs")
580+
for ax, metrics, title in zip(
581+
axes, (training_metrics, test_metrics), ("Training set", "Test set")
582+
):
583+
# Plot the metrics
584+
for metric, values in metrics.items():
585+
ax.plot(epoch_range, values, label=metric.capitalize())
586+
ax.set_title(title)
587+
ax.set_xlabel("Epochs")
588+
ax.legend()
597589
plt.show()
598590
```
599591

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy