Skip to content

Commit 671a651

Browse files
antmarakisnorvig
authored andcommitted
Update MNIST Functions for Fashion (aimacode#646)
* Update notebook.py * Update notebook.py
1 parent aa1a31f commit 671a651

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

notebook.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,15 @@ def show_iris(i=0, j=1, k=2):
9595
# MNIST
9696

9797

98-
def load_MNIST(path="aima-data/MNIST"):
98+
def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
9999
import os, struct
100100
import array
101101
import numpy as np
102102
from collections import Counter
103103

104+
if fashion:
105+
path = "aima-data/MNIST/Fashion"
106+
104107
plt.rcParams.update(plt.rcParamsDefault)
105108
plt.rcParams['figure.figsize'] = (10.0, 8.0)
106109
plt.rcParams['image.interpolation'] = 'nearest'
@@ -143,8 +146,17 @@ def load_MNIST(path="aima-data/MNIST"):
143146
return(train_img, train_lbl, test_img, test_lbl)
144147

145148

146-
def show_MNIST(labels, images, samples=8):
147-
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
149+
digit_classes = [str(i) for i in range(10)]
150+
fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
151+
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
152+
153+
154+
def show_MNIST(labels, images, samples=8, fashion=False):
155+
if not fashion:
156+
classes = digit_classes
157+
else:
158+
classes = fashion_classes
159+
148160
num_classes = len(classes)
149161

150162
for y, cls in enumerate(classes):
@@ -161,13 +173,19 @@ def show_MNIST(labels, images, samples=8):
161173
plt.show()
162174

163175

164-
def show_ave_MNIST(labels, images):
165-
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
176+
def show_ave_MNIST(labels, images, fashion=False):
177+
if not fashion:
178+
item_type = "Digit"
179+
classes = digit_classes
180+
else:
181+
item_type = "Apparel"
182+
classes = fashion_classes
183+
166184
num_classes = len(classes)
167185

168186
for y, cls in enumerate(classes):
169187
idxs = np.nonzero([i == y for i in labels])
170-
print("Digit", y, ":", len(idxs[0]), "images.")
188+
print(item_type, y, ":", len(idxs[0]), "images.")
171189

172190
ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0)
173191
#print(ave_img.shape)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy