XXXBetter Plain ViT Baselines For ImageNet-1k
XXXBetter Plain ViT Baselines For ImageNet-1k
https://github.com/google-research/big_vision
arXiv:2205.01580v1 [cs.CV] 3 May 2022
80 S/16 80
It is commonly accepted that the Vision Transformer
Orig.
model requires sophisticated regularization techniques to B/16
ViT (improved)
excel at ImageNet-1k scale data. Surprisingly, we find this ViT (original)
AugReg
is not the case and standard data augmentation is sufficient. S/16 ViT (AugReg)
This note presents a few minor modifications to the origi- ViT (DeiT)
ViT (DeiT III)
nal Vision Transformer (ViT) vanilla training setting that 75 75 ResNet50 (orig)
dramatically improve the performance of plain ViT models. Orig. ResNet50 (BiT)
B/32 R50 (strikes back)
Notably, 90 epochs of training surpass 76% top-1 accuracy
in under seven hours on a TPUv3-8, similar to the classic 90 150 300 6h30m 10h50m 21h40m
Epochs of training TPUv3-8 wallclock training time
ResNet50 baseline, and 300 epochs of training reach 80% in
less than one day. Figure 1. Comparison of ViT model for this note to state-of-the-art
ViT and ResNet models. Left plot demonstrates how performance
depends on the total number of epochs, while the right plot uses
1. Introduction TPUv3-8 wallclock time to measure compute. We observe that
our simple setting is highly competitive, even to the canonical
The ViT paper [4] focused solely on the aspect of large- ResNet-50 setups.
scale pre-training, where ViT models outshine well tuned
ResNet [6] (BiT [8]) models. The addition of results
when pre-training only on ImageNet-1k was an afterthought, by [15] as we believe it provides a good tradeoff between
mostly to ablate the effect of data scale. Nevertheless, iteration velocity with commonly available hardware and
ImageNet-1k remains a key testbed in the computer vision final accuracy. However, when more compute and data is
research and it is highly beneficial to have as simple and available, we highly recommend iterating with ViT-B/32 or
effective a baseline as possible. ViT-B/16 instead [12,19], and note that increasing patch-size
Thus, coupled with the release of the big vision code- is almost equivalent to reducing image resolution.
base used to develop ViT [4], MLP-Mixer [14], ViT-G [19], All experiments use “inception crop” [13] at 224px² res-
LiT [20], and a variety of other research projects, we now olution, random horizontal flips, RandAugment [3], and
provide a new baseline that stays true to the original ViT’s Mixup augmentations. We train on the first 99% of the
simplicity while reaching results competitive with similar training data, and keep 1% for minival to encourage the com-
approaches [15, 17] and concurrent [16], which also strives munity to stop selecting design choices on the validation
for simplification. (de-facto test) set. The full setup is shown in Appendix A.
1
Table 1. Ablation of our trivial modifications.
report, as well as the Google Brain team for a supportive
90ep 150ep 300ep research environment.
2
[14] Ilya O Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lu- 25 config.log_training_steps = 50
cas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, 26 config.log_eval_steps = 1000
Andreas Steiner, Daniel Keysers, Jakob Uszkoreit, et al. Mlp- 27 config.checkpoint_steps = 1000
28
mixer: An all-mlp architecture for vision. Advances in Neural
29 # Model section
Information Processing Systems, 34, 2021. 1 30 config.model_name = 'vit'
[15] Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco 31 config.model = dict(
Massa, Alexandre Sablayrolles, and Hervé Jégou. Training 32 variant='S/16',
data-efficient image transformers & distillation through at- 33 rep_size=True,
tention. In International Conference on Machine Learing 34 pool_type='gap',
35 posemb='sincos2d',
(ICML), 2021. 1, 2
36 )
[16] Hugo Touvron, Matthieu Cord, and Hervé Jégou. DeiT III: 37
revenge of the ViT. CoRR, abs/2204.07118, 2022. 1 38 # Optimizer section
[17] Ross Wightman, Hugo Touvron, and Hervé Jégou. ResNet 39 config.grad_clip_norm = 1.0
strikes back: An improved training procedure in timm. CoRR, 40 config.optax_name = 'scale_by_adam'
41 config.optax = dict(mu_dtype='bfloat16')
abs/2110.00476, 2021. 1
42 config.lr = 0.001
[18] Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Seong Joon 43 config.wd = 0.0001
Oh, Youngjoon Yoo, and Junsuk Choe. CutMix: Regular- 44 config.schedule = dict(warmup_steps=10_000,
ization strategy to train strong classifiers with localizable decay_type='cosine')
features. In International Conference on Computer Vision 45 config.mixup = dict(p=0.2, fold_in=None)
46
(ICCV), 2019. 2
47 # Eval section
[19] Xiaohua Zhai, Alexander Kolesnikov, Neil Houlsby, and Lu- 48 config.evals = [
cas Beyer. Scaling vision transformers. In Conference on 49 ('minival', 'classification'),
Computer Vision and Pattern Recognition (CVPR), 2022. 1 50 ('val', 'classification'),
[20] Xiaohua Zhai, Xiao Wang, Basil Mustafa, Andreas Steiner, 51 ('real', 'classification'),
Daniel Keysers, Alexander Kolesnikov, and Lucas Beyer. LiT: 52 ('v2', 'classification'),
53 ]
Zero-shot transfer with locked-image text tuning. In Confer- 54 eval_common = dict(
ence on Computer Vision and Pattern Recognition (CVPR), 55 pp_fn=pp_eval.format(lbl='label'),
2022. 1 56 loss_name=config.loss,
[21] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, and 57 log_steps=1000,
David Lopez-Paz. mixup: Beyond empirical risk minimiza- 58 )
59
tion. In International Conference on Learning Representa-
60 config.minival = dict(**eval_common)
tions (ICLR), 2018. 2 61 config.minival.dataset = 'imagenet2012'
62 config.minival.split = 'train[99%:]'
A. big vision experiment configuration 63 config.minival.prefix = 'minival_'
64
65 config.val = dict(**eval_common)
1 def get_config():
66 config.val.dataset = 'imagenet2012'
2 config = mlc.ConfigDict()
67 config.val.split = 'validation'
3
68 config.val.prefix = 'val_'
4 config.dataset = 'imagenet2012'
69
5 config.train_split = 'train[:99%]'
70 config.real = dict(**eval_common)
6 config.cache_raw = True
71 config.real.dataset = 'imagenet2012_real'
7 config.shuffle_buffer_size = 250_000
72 config.real.split = 'validation'
8 config.num_classes = 1000
73 config.real.pp_fn = pp_eval.format(lbl='
9 config.loss = 'softmax_xent'
real_label')
10 config.batch_size = 1024
74 config.real.prefix = 'real_'
11 config.num_epochs = 90
75
12
76 config.v2 = dict(**eval_common)
13 pp_common = (
77 config.v2.dataset = 'imagenet_v2'
14 '|value_range(-1, 1)'
78 config.v2.split = 'test'
15 '|onehot(1000, key="{lbl}", key_result="
79 config.v2.prefix = 'v2_'
labels")'
80
16 '|keep("image", "labels")'
81 return config
17 )
18 config.pp_train = ( Listing 1. Full recommended config
19 'decode_jpeg_and_inception_crop(224)' +
20 '|flip_lr|randaug(2,10)' +
21 pp_common.format(lbl='label')
22 )
23 pp_eval = 'decode|resize_small(256)|
central_crop(224)' + pp_common
24