Skip to content

Commit e218c9f

Browse files
committed
fix seemless fintetuning for timm simplenet
1 parent 26ec25d commit e218c9f

File tree

1 file changed

+12
-9
lines changed
  • ImageNet/training_scripts/imagenet_training/timm/models

1 file changed

+12
-9
lines changed

ImageNet/training_scripts/imagenet_training/timm/models/simplenet.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def get_classifier(self):
338338

339339
def reset_classifier(self, num_classes: int):
340340
self.num_classes = num_classes
341-
self.classifier = nn.Linear(round(self.cfg[self.networks[self.network_idx]][-1][1] * self.scale), num_classes)
341+
self.classifier = nn.Linear(round(self.cfg[self.networks[self.network_idx]][-1][0] * self.scale), num_classes)
342342

343343
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
344344
return self.features(x)
@@ -367,15 +367,18 @@ def _gen_simplenet(
367367
) -> SimpleNet:
368368

369369
model_args = dict(
370-
num_classes=num_classes,
371-
in_chans=in_chans,
372-
scale=scale,
373-
network_idx=network_idx,
374-
mode=mode,
375-
drop_rates=drop_rates,
376-
**kwargs,
370+
in_chans=in_chans, scale=scale, network_idx=network_idx, mode=mode, drop_rates=drop_rates, **kwargs,
377371
)
372+
# to allow for seemless finetuning, remove the num_classes
373+
# and load the model intact, we apply the changes afterward!
374+
if "num_classes" in kwargs:
375+
kwargs.pop("num_classes")
378376
model = build_model_with_cfg(SimpleNet, model_variant, pretrained, **model_args)
377+
# if the num_classes is different than imagenet's, it
378+
# means its going to be finetuned, so only create a
379+
# new classifier after the whole model is loaded!
380+
if num_classes != 1000:
381+
model.reset_classifier(num_classes)
379382
return model
380383

381384

@@ -436,7 +439,7 @@ def remove_network_settings(kwargs: Dict[str, Any]) -> Dict[str, Any]:
436439
Returns:
437440
Dict[str,Any]: cleaned kwargs
438441
"""
439-
model_args = {k: v for k, v in kwargs.items() if k not in ["scale", "network_idx", "mode","drop_rate"]}
442+
model_args = {k: v for k, v in kwargs.items() if k not in ["scale", "network_idx", "mode", "drop_rate"]}
440443
return model_args
441444

442445

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy