Skip to content

Commit bbdbc20

Browse files
committed
cleaned classification_test.py a bit!
1 parent d0b5217 commit bbdbc20

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ImageNet/training_scripts/imagenet_training/classification_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
parser.add_argument('--model', '-m', metavar='MODEL', default='simpnet', help='model architecture (default: simpnet)')
1414
parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset')
1515
parser.add_argument('--weights', default='', type=str, metavar='PATH', help='path to model weights (default: none)')
16+
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
1617
parser.add_argument('--jit', action='store_true', default=False, help='convert the model to jit before doing classification!')
1718
parser.add_argument('--netscale', type=float, default=1.0, help='scale of the net (default 1.0)')
1819
parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
@@ -25,23 +26,22 @@
2526
model = create_model(
2627
args.model,
2728
num_classes=args.num_classes,
29+
pretrained=args.pretrained,
2830
checkpoint_path=args.weights,
2931
scale=args.netscale,
3032
network_idx = args.netidx,
3133
mode = args.netmode,
3234
)
35+
model.eval()
3336

34-
# print('Restoring model state from checkpoint...')
35-
# model_weights = torch.load(args.weights, map_location='cpu')
36-
# model.load_state_dict(model_weights)
37-
# model.eval()
37+
if not args.pretrained and not args.weights:
38+
print(f'WARNING: No pretrained weights specified! (pretrained is False and there is no checkpoint specified!)')
3839

3940
if args.jit:
4041
dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
4142
model = torch.jit.trace(model, dummy_input)
4243

4344
config = resolve_data_config({}, model=model)
44-
print(f'config: {config}')
4545
transform = create_transform(**config)
4646

4747
filename = "./misc_files/dog.jpg"
@@ -53,13 +53,14 @@
5353
with torch.no_grad():
5454
out = model(tensor)
5555
probabilities = torch.nn.functional.softmax(out[0], dim=0)
56-
print(probabilities.shape) # prints: torch.Size([1000])
56+
print(f'{probabilities.shape}') # prints: torch.Size([1000])
5757

5858
filename="./misc_files/imagenet_classes.txt"
5959
with open(filename, "r") as f:
6060
categories = [s.strip() for s in f.readlines()]
6161

6262
# Print top categories per image
63+
print(f'Top categories:')
6364
top5_prob, top5_catid = torch.topk(probabilities, 5)
6465
for i in range(top5_prob.size(0)):
6566
print(categories[top5_catid[i]], top5_prob[i].item())

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy