Skip to content

Commit 06effd8

Browse files
committed
update avg_checkpoints.py to include existing averaged wights as well
1 parent 07c9b21 commit 06effd8

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

ImageNet/training_scripts/imagenet_training/avg_checkpoints.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import glob
1818
import hashlib
1919
from timm.models.helpers import load_state_dict
20+
from validate import validate
2021

2122
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
2223
parser.add_argument('--input', default='', type=str, metavar='PATH', help='path to base input folder containing checkpoints')
@@ -25,7 +26,7 @@
2526
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', help='Force not using ema version of weights (if present)')
2627
parser.add_argument('--no-sort', dest='no_sort', action='store_true', help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
2728
parser.add_argument('-n', type=int, default=10, metavar='N', help='Number of checkpoints to average')
28-
29+
parser.add_argument('--avg_weights', default='', type=str, metavar='PATH',help='avg fmodel filepath')
2930

3031
def checkpoint_metric(checkpoint_path):
3132
if not checkpoint_path or not os.path.isfile(checkpoint_path):
@@ -50,24 +51,38 @@ def main():
5051
args.sort = not args.no_sort
5152

5253
if os.path.exists(args.output):
53-
print("Error: Output filename ({}) already exists.".format(args.output))
54-
exit(1)
54+
with open(args.output, 'rb') as f:
55+
sha_hash = hashlib.sha256(f.read()).hexdigest()
56+
print(f'{args.output}')
57+
name,ext = os.path.splitext(args.output)
58+
new_name = f'{name}_{str(sha_hash)[-10:]}{ext}'
59+
os.rename(args.output, new_name)
60+
print(f'renamed "{args.output}" to "{new_name}"')
5561

5662
pattern = args.input
5763
if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
5864
pattern += os.path.sep
5965
pattern += args.filter
6066
checkpoints = glob.glob(pattern, recursive=True)
61-
67+
print(f'checkpoints: {checkpoints}')
68+
6269
if args.sort:
6370
checkpoint_metrics = []
6471
for c in checkpoints:
6572
metric = checkpoint_metric(c)
6673
if metric is not None:
6774
checkpoint_metrics.append((metric, c))
75+
76+
if args.avg_weights:
77+
if os.path.exists(args.avg_weights):
78+
checkpoint = torch.load(args.avg_weights, map_location='cpu')
79+
acc = float(args.avg_weights.split('_')[-1].split('.pth')[0])
80+
checkpoint_metrics.append((acc, args.avg_weights))
81+
else:
82+
print(f'FILE DOESNT EXIST!')
6883
checkpoint_metrics = list(sorted(checkpoint_metrics))
6984
checkpoint_metrics = checkpoint_metrics[-args.n:]
70-
print("Selected checkpoints:")
85+
print(f"Selected checkpoints:'({len(checkpoint_metrics)})'")
7186
[print(m, c) for m, c in checkpoint_metrics]
7287
avg_checkpoints = [c for m, c in checkpoint_metrics]
7388
else:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy