Combined Scaling For Zero-Shot Transfer Learning
Combined Scaling For Zero-Shot Transfer Learning
Editor: To be assigned.
Abstract
We present a combined scaling method – named BASIC – that achieves 85.7% top-1 accuracy on the
ImageNet ILSVRC-2012 validation set without learning from any labeled ImageNet example. This accuracy
surpasses best-published similar models – CLIP and ALIGN – by 9.3%. Our BASIC model also shows
significant improvements in robustness benchmarks. For instance, on 5 test sets with natural distribution
shifts such as ImageNet-{A,R,V2,Sketch} and ObjectNet, our model achieves 84.3% top-1 average accuracy,
only a small drop from its original ImageNet accuracy.
To achieve these results, we scale up the contrastive learning framework of CLIP and ALIGN in three
dimensions: data size, model size, and batch size. Our dataset has 6.6B noisy image-text pairs, which is 4x
larger than ALIGN, and 16x larger than CLIP. Our largest model has 3B weights, which is 3.75x larger in
parameters and 8x larger in FLOPs than ALIGN and CLIP. Finally, our batch size is 65536 which is 2x more
than CLIP and 4x more than ALIGN.
We encountered two main challenges with the scaling rules of BASIC. First, the main challenge with
implementing the combined scaling rules of BASIC is the limited memory of accelerators, such as GPUs
and TPUs. To overcome the memory limit, we propose two simple methods which make use of gradient
checkpointing and model parallelism. Second, while increasing the dataset size and the model size has been
the defacto method to improve the performance of deep learning models like BASIC, the effect of a large
contrastive batch size on such contrastive-trained image-text models is not well-understood. To shed light
on the benefits of large contrastive batch sizes, we develop a theoretical framework which shows that larger
contrastive batch sizes lead to smaller generalization gaps for image-text models such as BASIC.
1
Contents
1 Introduction 4
2 Related Work 5
9 Experiments 18
9.1 Training details . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 18
9.2 Results on Image Classification Benchmarks . . . . . . . . . . . . . . . . . . . . . . . . . . 19
9.3 Results on Robustness Benchmarks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 19
10 Ablation Study 21
10.1 The Importance of Batch Size Scaling . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21
10.2 Data Scaling, Model Scaling, and Pretraining . . . . . . . . . . . . . . . . . . . . . . . . . 22
11 Limitations 23
12 Conclusion 23
A Model sizes 29
E Computational Cost 31
2
G Failure Analysis 32
G.1 The benchmarks where BASIC fails . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32
G.2 Example failure cases . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 33
H Proofs 36
H.1 General case . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36
H.1.1 Analyzing supx∈X RB (HF x
,G,e ) and Rm (HF ,G,`ˆB ) . . . . . . . . . . . . . . . . . . 40
H.1.2 Combining h all together for the generali case . . . . . . . . . . . . . . . . . . . . . . 43
1 PB
and D
P
H.2 Bounding Ey,σ supG∈G B i=1 σi G(yi ) k=1 (Rm (Fk ) + Rm (Gk )) for the spe-
2
cial case with deep neural networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 44
H.3 Combining all together for the special case with deep neural networks . . . . . . . . . . . . 46
3
1. Introduction
The recent advances in multimodal training approaches such as CLIP (Radford et al., 2021) and ALIGN (Jia
et al., 2021) have the potential to eliminate the need for collecting labeled training data for every new
application. Using natural language as a weak supervision signal, CLIP and ALIGN achieve the impressive
top-1 accuracy of 76.2% and 76.4% on ImageNet ILSVRC-2012 without learning from any labeled ImageNet
data. In addition to the promising accuracy on ImageNet, the so-called “zero-shot” models in CLIP and
ALIGN demonstrate two important properties. First, these models are versatile, as they can be directly
deployed on many downstream tasks without task-specific data for finetuning. Second, CLIP and ALIGN
models are more robust than traditional classifiers. Robustness evaluations on benchmarks with natural
distribution shifts (Hendrycks et al., 2021b,a; Recht et al., 2019; Barbu et al., 2019; Wang et al., 2019) show
that the accuracy of models like CLIP and ALIGN typically drops less than 10%, while the accuracy of
supervised and semi-supervised models might drop as much as 40% (Taori et al., 2020; Szegedy et al., 2013).
Despite their versatility and robustness, the best models from CLIP and ALIGN are still not as competitive
as supervised and semi-supervised models when enough labeled data is available, which can limit their
potential applications. For example, the best CLIP and ALIGN models have an accuracy around 76% on
ImageNet, which is only comparable with a supervised ResNet-50 (He et al., 2015), and significantly worse
than the state-of-the-art supervised training on ImageNet (without extra data: 87.1% (Yuan et al., 2021), and
with extra data: 90.88% (Dai et al., 2021)). Therefore, narrowing the gap from these models to supervised
and semi-supervised models would make the image-text contrastive learning approach in CLIP and ALIGN a
viable alternative for image classification.
In this paper, we develop significantly better image classifiers that leverage the image-text contrastive
learning approaches like CLIP and ALIGN at a much larger scale. In particular, we scale up the contrastive
learning framework of CLIP (Radford et al., 2021) and ALIGN (Jia et al., 2021) in 3 dimensions: dataset size,
model size, and batch size. For the data, we expand the ALIGN dataset (Jia et al., 2021) from 1.7B noisy
image-text pairs to 6.6B pairs, i.e., almost 4x larger. For the models, we choose CoAtNet, an architecture
with higher learning capacity (Dai et al., 2021), and scale it to 3B parameters, i.e., 3.75x more weights and
8x more FLOPs than the largest models in CLIP and ALIGN. For the batch size, we use 65536 contrastive
learning examples per minibatch, i.e., 2x more than CLIP and 4x more than ALIGN.
Overview of our implementation. The fundamental bottleneck of training large models at larger batch
sizes is the limited memory of deep learning accelerators such as GPUs and TPUs. We propose two
approaches that allow practitioners to overcome such memory limits.
Our first approach (Section 4) makes use of micro-batch pipelining (Huang et al., 2019) and gradient
accumulation (GradAccum) (Ott et al., 2018; Zhai et al., 2021). Our second approach (Section 5) utilizes the
model parallelism scheme of Single-Program Multi-Data (SPMD) (Lepikhin et al., 2020; Xu et al., 2021) to
distribute the weights of certain layers in our networks onto different devices. While our SPMD approach
is faster than our pipelining approach, and can deliver exact computations, the SPMD approach requires
more manual designs to scale to arbitrarily large contrastive batch sizes, and hence, is less general than the
pipelining approach.
Both our pipelining approach and our SPMD approach make use of gradient checkpointing (Chen et al.,
2016), which is also called rematerialization in certain literature (Kumar et al., 2019; Jain et al., 2020).
The idea behind rematerialization is to discard certain intermediate values in the forward pass of a neural
network to save memory, and then recompute – i.e., rematerialize – these values only when they are needed
for gradient computation in the network’s backward pass.
4
Overview of our theoretical insights. While the benefits of large datasets and large models for deep
learning models have become established knowledge, the benefits of large batch size are less well-understood
in the context of relatively new image-text contrastive models. To understand such benefits, we develop a
theoretical analysis of the image-text contrastive learning framework of CLIP and ALIGN. Our analysis
establishes that using a larger contrastive batch size in CLIP and ALIGN’s framework leads to a smaller
generalization gap of the resulting models.
ALIGN (Jia et al., 2021) CLIP (Radford et al., 2021) BASIC (ours)
ImageNet 76.4 76.2 85.7 (+9.3)
ImageNet-A 75.8 77.2 85.6 (+8.4)
ImageNet-R 92.2 88.9 95.7 (+3.5)
ImageNet-V2 70.1 70.1 80.6 (+10.5)
ImageNet-Sketch 64.8 60.2 76.1 (+11.3)
ObjectNet 72.2 72.3 82.3 (+10.1)
Average 74.5 74.2 84.3 (+10.1)
Table 1: Highlights of our key results. Shown are the top-1 accuracy of our method, BASIC, and similar baselines –
CLIP and ALIGN – on ImageNet and other robustness test sets. None of these models have learned from any labeled
training example in ImageNet. On average, BASIC surpasses the baselines by the significant 10.1 percentage points.
Overview of our empirical results. Our proposed method, called BASIC, for Batch, Data and Model
SIze Combined Scaling, achieves drastic improvements over CLIP and ALIGN models. For instance, on
ImageNet, the largest BASIC model achieves 85.7% top-1 accuracy, without learning from any labeled
example in the ImageNet training set. This result surpasses similar models in CLIP and ALIGN 9.3%. This
BASIC model also shows significant improvements on robustness benchmarks. For instance, on 5 test sets
with natural distribution shifts such as ImageNet-{A,R,V2,Sketch} and ObjectNet, the model achieves an
average of 83.7% top-1 accuracy, only a small drop from its original ImageNet accuracy (see Table 1). When
tested against CLIP on the other 17 image classification benchmarks, e.g., CIFAR, Caltech101, Flowers, etc.
BASIC outperforms CLIP on 13 out of these 17 benchmarks.
2. Related Work
Large-scale pretraining and the contrastive loss. As computer vision models grow in their size and
capacity, many weakly-supervised and self-supervised pretraining methods have been proposed to learn
good visual representations. On one hand, pretraining with a classification loss on large weakly-labeled
datasets such as Instagram hashtags or JFT can produce significant gains on downstream tasks such as
ImageNet (Joulin et al., 2016; Mahajan et al., 2018; Kolesnikov et al., 2020; Dosovitskiy et al., 2021; Sun
et al., 2017; Zhai et al., 2021). On the other hand, self-supervised methods which leverage existing structures
in unlabeled data to train models have been developed. A promising development in self-supervised learning
is the contrastive loss, with representative works like CPC (van den Oord et al., 2018), SimCLR (Chen et al.,
2020a,b) and MoCo (He et al., 2020; Chen et al., 2020c). In this paper, we scale up the contrastive learning
framework, which we will revisit in detail in Section 3.
Contrastive-learned image-text models. Unlike the single-modal contrastive approaches mentioned in
the previous paragraph, our work leverages data from two modalities: image and text. Using images with
accompanying text is related to the literature on image-captioning models, such as (Vinyals et al., 2015;
Karpathy and Fei-Fei, 2015; Xu et al., 2015; Joulin et al., 2016; Li et al., 2017; Sariyildiz et al., 2020; Zhang
5
et al., 2020; Desai and Johnson, 2021). While learning to generate captions from images can induce good
visual representations, it is not the goal of this paper. Instead, this paper focuses on establishing the ability of
models to classify images based on textual descriptions. This focus makes our work closely related to the
recent work of image-text models such as CLIP (Radford et al., 2021) and ALIGN (Jia et al., 2021). Similar
to CLIP and ALIGN, our work also learns the mapping between images and texts, which is related to many
previous works, such as (Hironobu et al., 1999; Weston et al., 2010; Socher and Fei-Fei, 2010; Socher et al.,
2013; Hodosh et al., 2013; Frome et al., 2013; Norouzi et al., 2013; Kiros et al., 2014; Socher et al., 2014;
Akata et al., 2015b,a; Nam et al., 2017; Faghri et al., 2017; Li et al., 2019; Liu et al., 2019; Lu et al., 2019;
Messina et al., 2020; Chen et al., 2020d; Huang et al., 2020; Chen et al., 2021).
Differences between our work and zero-shot learning. Early works on zero-shot vision models date
back to the 2000s, e.g., (Larochelle et al., 2008; Zhang et al., 2017; Xian et al., 2016, 2017; Schönfeld et al.,
2019). In these works, the term “zero-shot” refers to the ability of models to “generalize to classes or tasks
for which no training data are available and only a description of the classes or tasks are provided”. Under
such definition, BASIC models – as well as the recent work that BASIC is based on such as CLIP (Radford
et al., 2021) and ALIGN (Jia et al., 2021) – are not “zero-shot learned” models. This is because the data
curating procedures of BASIC, CLIP, and ALIGN can exposes certain class names to their models, albeit not
intentionally. For instance, when an image of a golden retriever dog is crawled from the internet, the image
could come from a file named my_golden_retriever.jpg which was uploaded by a user. If a model
in BASIC, CLIP, or ALIGN learns to associate the content of such an image with the text sequence “my
golden retriever” as parsed from the image’s file name, and then the model uses the knowledge from such
association at test time, then the model is not zero-shot. Despite being not zero-shot, models from BASIC,
CLIP, and ALIGN retain their claimed benefits on versatility and robustness.
Zero-shot transfer learning Instead of zero-shot learning, CLIP and ALIGN are known to conduct zero-
shot transfer learning (Radford et al., 2021; Jia et al., 2021; Zhai et al., 2022). Zero-shot transfer learning
differs significantly from zero-shot learning. Unlike zero-shot learning, it permits relevant supervised
information during pretraining, while it allows no supervised examples during the transfer protocol; i.e.,
zero-shot transfer learning skips the finetuning stage completely and performs the downstream task based
only on a text description of the target classes. For example, see (Radford et al., 2021; Jia et al., 2021; Zhai
et al., 2022) for more details on this terminology.
Data, model and batch scaling. Scaling has proven to be a powerful tool to boost the efficacy of vision
model pretraining. There are three dimensions one can scale on. The simplest dimension is data. Indeed,
recent efforts have shown that the more data we train on, the better the models become (Joulin et al., 2016;
Mahajan et al., 2018; Kolesnikov et al., 2020; Dosovitskiy et al., 2021; Sun et al., 2017). The second
dimension is the model size, with representative works such as EfficientNet, VITs and related works (Tan
and Le, 2019, 2021; Tan et al., 2020; Dosovitskiy et al., 2021; Zhai et al., 2021; Bello et al., 2021). Lastly,
scaling up batch sizes is also the key for improving the model effectiveness (Goyal et al., 2017), especially
for the contrastive loss (Chen et al., 2020a; Tian et al., 2020; Jia et al., 2021; Radford et al., 2021). Our work
is inspired by the power of scaling, and pushes the limits in all the dimensions.
6
3. Background on Image-text Contrastive Learning and Zero-shot Transfer Learning
In this section, we revisit the contrastive training framework for parallel image-text data, as introduced by
CLIP (Radford et al., 2021) and ALIGN (Jia et al., 2021). In doing so, we define the notations that will be
used throughout the remaining of this paper.
Let x ∈ X be an arbitrary image and y ∈ Y be an arbitrary text sequence. The image-text contrastive
training framework (Radford et al., 2021; Jia et al., 2021) trains an image encoder F and a text encoder G to
map x and y into a D-dimensional unit sphere, i.e., F (x), G(y) ∈ SD . The desiderata of these encoders is
that images and text sequences of similar semantics should be mapped to nearby points in the latent space,
while those with different semantics should be mapped to distant points in the space. To train F and G
to achieve such desiderata, a minibatch gradient training procedure is used. At each step in this training
procedure, F and G receives B image-text pairs, i.e. (xi , yi ) for i = 1, 2, ..., B. Based on the
embeddings
computed by F and G, a similarity matrix A is computed, where Ai,j = F (xi )> G(yi ) /τ . Here, τ is
called the softmax temperature which serves to steepen or dampen the softmax distributions in the rows and
columns of A. From this similarity matrix A, two softmax-like losses are computed based on the rows and
the columns of A:
B
1 X Ai,j
RowLossB = − log PB (1)
B k=1 Ai,k
i=1
B
1 X Ai,j
ColumnLossB = − log PB (2)
B k=1 Ak,j
j=1
RowLossB + ColumnLossB
ContrastiveLossB = (3)
2
Minimizing ContrastiveLossB encourages the entries on the diagonal of A to be large while the entries
elsewhere to be small. Equivalently, images and text sequences from the same pair in the minibatch, i.e. xi
and yi , will be embedded into nearby points, while those from different pairs, i.e. xi and yj6=i , will be
embedded into distant points. The resulting encoders F and G thus achieve the desiderata of the contrastive
learning framework.
7
burden are gradient accumulation (GradAccum) (Ott et al., 2018; Zhai et al., 2021), re-materialization (or
gradient checkpointing) (Griewank and Walther, 2000; Chen et al., 2016) and model parallelism (Shazeer
et al., 2018; Huang et al., 2019; Lepikhin et al., 2020). Note that all three techniques are orthogonal and
complementary to each other. Next in section 4.2, we present an approach based on pipelining model
parallelism and gradient accumulation.
Vanilla GradAccum. Consider training a model weight vector θ to minimize a loss function L. For a
batch of B examples {e1 , e2 , ..., eB }, let gi be the gradient of L with respect to θ computed on example
ei , i.e.,Pgi = ∇ θ L(θ; ei ). In the standard minibatch setting, we update θ with the average batch gradient
B
ḡ = i=1 gi /B. When our accelerator memory can only hold M B examples, GradAccum splits the
batch of B examples into smaller batches with at most M examples, called microbatches, then computes the
gradients of the microbatches, and averages them.
We now analyze the steps of GradAccum. For simplicity, assume that M evenly divides B, and that
microbatch i-th consists of examples ej ’s with (i − 1)M + 1 ≤ j ≤ iM . With this assumption, the
GradAccum procedure first initializes a zero vector PiMḡ of the same size with θ. Then, sequentially for each
microbatch i-th, the microbatch gradient ci = j=(i−1)M +1 j /M is added to ḡ. In the end, ḡ holds the
g
correct minibatch gradient, up to a normalization constant K = B/M .
GradAccum cannot be naively applied to contrastive learning. There are two properties that make
GradAccum not applicable to contrastive learning. First, in order to evaluate the loss ContrastiveLossB in
Equation 3, we need all entries of the similarity matrix A. Hence, we cannot rely only on examples in every
microbatch i-th to compute the microbatch gradients ci ’s. Second, GradAccum must allocate memory for
the cumulative gradient ḡ.1 As ḡ has as many elements as θ, its memory grows as we scale up the networks
F and G. This growth becomes a more pronounced issue as we scale up our models. For reference, our
largest model has 3B weights, occupying roughly 11GB of accelerator memory. Spending another 11GB on
ḡ, while possible, defeats the purpose of saving memory in GradAccum. In the remaining of this subsection,
we discuss how to modify GradAccum so that we can use it to scale up contrastive learning.
1. It is worth noting that this is a common issue with GradAccum and is not specific to contrastive learning.
8
Inputs • Networks F , G with a weight vector θ = [θF , θG ],
• A minibatch of B (image, text) pairs {(xi , yi )}B
i=1 , . B is the contrastive batch size.
Memory
• Microbatch size M . Assuming M evenly devices B. . M is the largest in-memory batch size. Analysis
Yields • Gradients ∇θ ContrastiveLoss for B/M . The loss is computed as in Equation 3
• microbatches of the minibatch.
1 Allocate embedding matrices X, Y ∈ RD×B . D is the embedding size Θ(BD)
2 For i = 1 to B/M do: . Sequentially compute the embeddings for
3 Let J ← {j : (i − 1)M + 1 ≤ j ≤ iM } microbatches of images and text sequences,
4 X:,J ← F (xJ ) not saving the activations of F and G. Θ(M · Mem(F ))
5 Y:,J ← G(yJ) Θ(M · Mem(G))
6 A ← X > · Y /τ . A ∈ RB×B is the similarity matrix
Ai,j
RowLossB ← − B1 B Θ(B 2 )
P
7 i=1 log PB Ai,k . The contrastive loss in Equation 3
k=1
Ai,j
ColumnLossB ← − B1 B
P
8 j=1 log PB Ak,j
k=1
9 ContrastiveLossB ← RowLossB +ColumnLoss
2
B
Algorithm 1: Pseudo code of our gradient accumulation process for the contrastive loss. Here Mem(F ), Mem(G)
denote the memory required for a pass for the networks F , G. As shown in our memory analysis, at the cost
of repeating one forward pass for F , G (lines 13-16), our procedure’s peak memory footprint is dominated by
Θ(M · max {Mem(F ), Mem(G)}).
that our algorithm can be flexibly modified to work different microbatch-sizes, i.e., M , for the image network
F and the text network G. This flexibility allows for more efficient computations, e.g., when one network is
smaller than another and thus, can operate with larger microbatches.
Accumulating the microbatch gradients. Algorithm 1 yields a stream of microbatch gradients c1 , ..., cB/M ,
which need to be accumulated, i.e., averaged, into ḡ to perform the batch weight update. As discussed, we want
to avoid allocating extra memory for ḡ. To do this, we need two assumptions about our training implementa-
tion. Our first assumption is that we use an optimizer which involves gradient moments (Nesterov, 1983; Tiele-
man and Hinton, 2012; Kingma and Ba, 2015; Loshchilov and Hutter, 2019; Shazeer and Stern, 2018). This
assumption motivates our idea to avoid allocating ḡ: since the optimizer already allocates the memory for gra-
dient moments, typically called slots, we will directly accumulate the microbatch gradients ci ’s into these slots.
We illustrate this idea with Adam (Kingma and Ba, 2015), a popular optimizer that involves two gradient
moments. At training step t, Adam receives the averaged minibatch gradient ḡ and makes the following
updates to its gradient moments v1 and v2 :
XB XB/M
ḡ = 1/B · gi = 1/ (B/M ) · ci
i=1 | {z } i=1
K
(t) (t−1)
v1 = β1 v1 + (1 − β1 )ḡ
(t) (t−1)
v2 = β2 v2 + (1 − β2 )ḡ 2
9
Accumulating the microbatch gradients ci ’s to v1 is straightforward. We can simply modify v1 ’s single update
with ḡ into K = B/M updates as follows:
(
β1 if i = 1
v1 ← ki v1 + (1 − β1 )ci , where ki =
1/K otherwise
which we can estimate. Indeed, since each ci ’s is the mean of M per-example gradients gj ’s in the i-th
microbatch, we can treat ci ’s as the population mean of M observed examples drawn from a random variable
g ∼ Uniform{g1 , ..., gB }. This treatment allows us to use the familiar identity:
1 XiM Var[g]
Var[ci ] = Var gj = (4)
M j=(i−1)M +1 M
Therefore, to estimate Var[ci ], we only need to estimate Var[g]. For this, we make the second assumption
about our training: that we use a data parallelism setting with R replicas. Under this assumption, each
microbatch gradient ci is obtained from an all-reduce operation on R replicas, each of which processes
M/R examples. Once again, treating these per-device gradients d1 , ..., dR as the population mean of M/R
observed examples for g, we can apply Identity 4 to obtain: Var[d] = Var[g] /(M/R). This treatment
allows us to perform GradAccum while avoiding to allocate ḡ.
10
GradAccum, but our SPMD method also has a better latency per training step. However, as we shall see in
Section 5.1 and Section 5.2, our SPMD method requires several manual designs, which make it less generic
than pipelining and GradAccum.
As model sizes grow, model weights occupy a significant part of accelerator memory. In modern optimizers
for deep learning models, such as Adam (Kingma and Ba, 2015), RMSprop (Tieleman and Hinton, 2012),
and AdamW (Loshchilov and Hutter, 2019), every weight tensor is additionally accompanied by the first and
second gradient moments, hence tripling its memory footprint. Furthermore, in the vanilla data parallelism
training, all these weights are replicated to all accelerators. In our experiments with a relatively large model
size, roughly 4GB of accelerator memory is occupied by these weights and their gradient moments, which is
significant for the typical 16GB memory in an accelerator in 2022, such as a Google TPU core or an Nvidia
RTX 3080 GPU.
Here, we split the weight tensors in our encoder networks, i.e. F and G in Section 3, into multiple
accelerator cores, and only combine these tensors together when the whole tensor is needed to perform certain
computations. Note that upon splitting a weight tensor to multiple cores, we also split its first and second
gradient moments in the similar way. Figure 1 illustrates our weight sharding strategy on the 2D convolution
operation which is prevalent in image encoder models.
Figure 1: An illustrative example for our model parallelism design. Shown is 2D convolution operation with a 3x3
kernel sharded to 4 cores. Gray cells represent the tensor values that are mirrored across all cores, while cells of other
colors represent per-core independent tensor values. The convolution’s input is a tensor of shape [N, H, W, i] which is
sharded along its first dimension so that each core has processes a tensor of shape [N/4, H, W, i]. The convolution’s
kernel is a tensor of shape [3, 3, i, o], but each core only stores one shard of the kernel which has size [3, 3, i/4, o].
Before convolving, every core receives the kernel shards from all other cores and concatenate the shares, forming the
complete kernel of size [3, 3, i, o]. After convolving, the complete kernel is discarded from all cores’ memory.
Our approach is based on the Single-Program Multiple-Data (SPMD) technique, which has been suc-
cessfully applied to train large language models in previous works such as in Xu et al. (2021); Lepikhin
et al. (2020). In the SPMD technique, we define a computational graph which represents our entire training
program. This computational graph is compiled once, and then is replicated identically to all computational
cores to run the training program. While all of our computational cores run an identical program, they are
allowed to receive different inputs and hence can produce different outputs. These inputs and outputs can be
organized in certain ways to define arbitrarily complex model parallelism strategies. Next, we describe how
we apply the SPMD technique on our model weights only.
11
attention feedforward block
qkv_attention
layer_norm
layer_norm
[N, H, W, C]
[N, T, Cx4]
[N, T, Cx4]
[N, T, C]
[N, T, C]
[N, T, C]
[N, T, C]
expand block depthwise block output block
linear
linear
input
gelu
+
+
depthwise_3x3
[N, H, W, Cx4]
[N, H, W, Cx4]
[N, H, W, Cx4]
[N, H, W, Cx4]
[N, H, W, Cx4]
[N, H, W, Cx4]
batch_norm
batch_norm
batch_norm
[N, H, W, C]
[N, H, W, C]
[N, H, W, C]
[N, H, W, C]
conv_1x1
conv_1x1
input
gelu
gelu
gelu
+
x
squeeze and excitation
GlobalAveragePool
[N, H, W, Cx4]
[N, H, W, Cx4]
block_name
[N, H, W, C]
[N, H, W, C]
[N, H, W, C]
Rematerialized operation
[N, 1, 1, C]
conv_1x1
conv_1x1
sigmoid
[output shape]
swish
gelu
block_name
Normal operation
[output shape]
Figure 2: Generic rematerialization map for the blocks in our CoAtNet models. Left: in the Mobile Inverse Convolution
blocks (MBConv), all batch normalization layers and activation layers, as well as all layers in the squeeze and excitation
steps, are rematerialized. Right: in Transformer blocks, only the layer normalization layers and activation layers are
rematerialized.
Our training program runs typically on a cluster of 2048 TPUv3 cores. We partition these 2048 cores into
R replicas, each of which uses 2048/R cores. The value of R governs how the weights of our image encoder
F and our text encoder G are stored in the memory of our 2048 cores. In particular, all weight tensors in
the networks F and G are split into R equal parts, each lives in one of the R cores in a replica. Note that
since we have 2048/R replicas, the weights of our image and text encoders are still replicated for 2048/R
times. For instance, our cores 1st , 2nd , ... Rth can each store 1/R of the weight tensors, and then the cores
R + 1st , R + 2nd , ..., 2Rth store an identical copy of these tensors. Thus, using fewer replicas and more cores
per replica leads to a better memory utilization, at a higher overhead for cross-cores communications. We
empirically find that using 512 replicas and 4 cores per replica offers a good balance.
It is important to note that we only apply SPMD on our model weights, and not on any other steps of our
computations. This means that if our training program receives an input batch of B examples, then these
B examples are distributed equally to all our 2048 cores. In other words, each of our 2048 cores processes
B/2048 examples, regardless of the value of R. We find that this design choice disentangles our weight
sharding strategy from the rematerialization strategy, as described next in Section 5.2.
5.2 Rematerialization
The technique of rematerialization, also widely known as gradient checkpointing (Chen et al., 2016), preserves
the accelerator memory while training neural networks. It works by not saving certain values from a network’s
forward pass, and recompute them in the backward pass only when their values are needed for a particular
calculation. For instance, if our image encoder F , as discussed in Section 3 has 100 layers, a rematerialization
program can decide that after the forward pass, only the values of layers 10th , 20th , ..., 90th are kept in an
accelerator’s memory, while the values of other layers are removed. If all layers of F consumes similar
memory, this rematerialization program has reduced the memory cost by 10 times, at the trade off that the
values of the unsaved layers in the forward pass, such as layer 21st or layer 72nd , have to be recomputed in
the backward pass.
We select which layer to rematerialize in our image encoder F and our text encoder G based on a
simple heuristic. Ideally, we want to rematerialize the layers that are fast to recompute but consumes the
more memory. Since we utilize weight sharding, as described in Section 5.1, the computations that involve
weights are slower than normal because of their overhead time for cross-core communications. As such,
we keep almost all layers that involve weights, such as convolution, attention, and dense feed-forwards,
12
in our accelerator’s memory. In contrast, layers that do not involve weights, such as activation functions,
batch normalization, and layer normalization, are all rematerialized. Figure 2 illustrates our rematerialization
strategy for all three block types in our image and text encoders: the mobile-inverse convolutional block (Tan
and Le, 2019; Dai et al., 2021), the attention block, and the feed-forward blocks (Vaswani et al., 2017).
We find this design choice beneficial, because in modern encoder architectures, every layer that involves
weights is typically followed by a normalization layer or an activation layer. As such, our design allows more
than half of the encoder’s activation values to be removed from the accelerator’s memory after each forward
pass, while leaving only the light computational steps to be repeated in each backward pass. We empirically
find that weight sharding and rematerialization, each of our forward-backward pass is 1.4 times slower than
the vanilla implementation of the same batch size.
Exceptions to our heuristics. Certain parts of our encoders do not follow these general heuristics, as we
find that doing so saves a certain amount of time at the cost of using a little extra memory. Here, we describe
these exceptions:
1. All weights in batch normalization and layer normalization in our models, including the β’s and γ’s
and the moving average statistics of batch normalization, are not sharded. Instead, these weights
are replicated to all computational cores to avoid cross-cores communications, because they are one
dimensional vectors which do not occupy much memory.
2. All computations the Squeeze-and-Excitation blocks (SE; (Hu et al., 2018)) of our models are remateri-
alized, including the convolutions. This is because these SE blocks only involve 1x1 convolution with
reduced internal channels, making them less costly to recompute. In addition, all the weights of these
1x1 convolutions are replicated to all of our cores, because they have a small memory footprint but are
reused in the backward pass for rematerialization.
13
Model Batch Step time (millisecs) Memory
Methods
Size (B) FWD BWD Total (GB)
Data parallelism 216 28.6 81.2 128.5 4.2
Data parallelism 217 65.8 144.5 230.3 6.4
Data parallelism 218 131.2 282.9 428.0 9.9
Small Data parallelism 219 261.0 559.4 842.4 14.8
Data parallelism 220 − − − OOM
Pipeline & GradAccum 220 301.7 1497.4 1983.1 11.2
SPMD 220 271.8 1311.8 1631.6 15.2
Data parallelism 216 129.5 270.8 437.6 10.7
Data parallelism 217 237.7 536.4 806.1 13.8
Data parallelism 218 − − − OOM
Pipeline & GradAccum 218 597.7 1677.1 2677.3 12.6
Medium SPMD 218 599.3 1407.2 2018.6 12.1
Pipeline & GradAccum 219 1393.1 4087.8 5601.2 12.8
SPMD 219 1402.7 3141.1 4912.3 14.1
Pipeline & GradAccum 220 3014.3 6129.9 9781.4 12.6
SPMD 220 2909.7 4912.3 8361.1 15.4
Table 2: Comparison between our SPMD programs and our Pipelining & GradAccum programs. Shown are the step
times and memory footprints of our small-sized and medium-sized models at different batch sizes. With the same model
size and batch size, our SPMD design results in a larger device memory footprint and Pipelining & GradAccum, but
SPMD is faster.
the run time of pipelining’s forward pass, but our backward time is often a lot faster. For instance, in our
largest setting, with the medium-sized model and the contrastive batch size B = 220 , our backward time is
more than 1.2 seconds faster than that of the pipelining approach, amounting to about 10% of the total step
time. Additionally, our strategy also has a faster total step time, perhaps because do not need to spend extra
time to accumulate the microbatch gradients like the pipelining approach.
Finally, we note that as B grows larger, the SPMD approach typically occupies more accelerator memory
than does the pipelining approach. This is because in the pipelining approach, increasing the contrastive batch
size B only leads to more microbatches, but does not change the micro batch size, and so the accelerator’s
memory remains constant. As such, the pipelining approach is still applicable if B grows larger than 220 , but
the SPMD strategy has to be redesigned, e.g. by deciding to rematerialize a larger portion of our image and
text encoder.
14
where we multiply B to scale the loss correctly in the regime of B → ∞. That is, since the unnormalized loss
goes to zero (`ˆB (x, y)/B → 0) as the batch size approach infinity (B → ∞), analyzing the unnormalized
version of the loss `ˆB (x, y)/B can mistakenly predict benefits of the large contrastive batch size. We avoid
this with the normalization. Similarly, we define the normalized testing loss by
Therefore, `ˆB (x̂i , ŷi ) is minimized during training for training points (x̂i , ŷi ) while we want to minimize
`¯M (x, y) to make a prediction at a new point x. This leads to the question of the generalization from training
to unseen-data, which can be studied by analyzing the upper bound on the following quantity:
where ÊS [`ˆB (x̂, ŷ)] is the empirical training loss with a dataset, S = ((x̂i , ŷi ))m
i=1 , of size m. To analyze
PB
exp(F (x)> G(ŷk ))G(ŷk )i
this in a statistical setting, we define a vector v ∈ RD by vi = F (x)i − k=1
PB >
for
k=1 exp(F (x) G(ŷk ))
iid iid
all i ∈ {1, . . . , D}, and assume that ŷ1 , ŷ2 , . . . , ŷB ∼ py , ȳ1 , ȳ2 , . . . , ȳM ∼ py , exp(F (x)> G(y)) ≤ c1 ,
>
`ˆB (x, y) ≤ c2 , kF (x)k2 ≤ c3 , 1 PBexp(F (x) G(y)) >
≤ c4 , kF (xi )k2 ≤ c5 , kvk2 ≤ c6 , kyk2 ≤ c7 , and
B k=1 exp(F (x) G(ŷk ))
kxk2 ≤ c8 , with probability one. Moreover, by defining γ(x) = Eȳ [exp(F (x)> G(y))]− B1 B >
P
k=1 exp(F (x)
0 0
G(ŷk )), we assume that γ is c9 -Lipschitz; i.e., |γ(x) − γ(x )| ≤ c9 kx − x k2 for all x ∈ X ⊆ R . To provide κ
where ωl (q) = Wl q represents the linear transformation and σl is an element-wise nonlinear activation
function. Similarly, ωl0 (q) = Wl0 q and σl0 is an element-wise activation function.
The following theorem provides an insight on the role of the contrastive batch size to close the accuracy
gap from contrastive models to their supervised counterparts:
Theorem 1 Suppose that the activation functions σ and σl0 are 1-Lipschitz and positive homogeneous for
all l ∈ [L − 1]. Let G = {y 7→ G(y) : (∀l ∈ [L − 1])[kWl kF ≤ Ml ] ∧ k(WL )k kF ≤ ML,k } and
F = {x 7→ F (x) : (∀l ∈ [L0 − 1])[kWl0 kF ≤ Ml0 ] ∧ k(WL0 0 )k kF ≤ ML0 0 ,k } where (Wl )k is the k-th row of
Wl . Then, for any δ > 0, with probability at least 1 − δ, the following holds for all F ∈ F and G ∈ G:
r
Q1 Q2 ln(2/δ)
Ex,y [`¯M (x, y)] − ÊS [`ˆB (x, y)] ≤ √ + √ + c2 ,
m 2B 2m
15
√ p
where Q1 = 2 2c4 c25 + c26 (Q̃1,1 + Q̃1,2 ), Q2 = c1 Ex,y [A(x, y)](Q̃2,1 + Q̃2,2 ),
L−1
! D !
p Y X
Q̃1,1 = c7 ( 2 log(2)L + 1) Ml ML,k ,
l=1 k=1
0 −1
LY
! D
!
X
Ml0 ML0 0 ,k
p
Q̃1,2 = c8 ( 2 log(2)L0 + 1) ,
l=1 k=1
√ q √
Q̃2,1 = 2 2c8 c9 + c1 κ ln( κB/δ),
L−1
!v
uD
√ p Y uX
Q̃2,2 = 2 2c3 c7 ( 2 log(2)L + 1) Ml t 2 ,
ML,k
l=1 k=1
Theorem 2 Let F be a set of maps x 7→ F (x) and G be a set of maps y 7→ G(y). Then, for any δ > 0, with
probability at least 1 − δ, the following holds for all F ∈ F and G ∈ G:
D
r
¯ ˆ C1 ln(2/δ) X
Ex,y [`M (x, y)] − ÊS [`B (x, y)] ≤ √ + c2 + C2 (Rm (Fk ) + Rm (Gk )) + C3 R̃B (G),
2B 2m
k=1
Pm 1
where Rm (H) := ES,ξ [suph∈H i )], Fk = {x →
hi=1 ξi h(xi , yP
m 7 i F (x)k : F ∈ F}, Gk = {y 7→
1 B
G(y)k : G ∈ G}, R̃B (G) = Ey,ξ supG∈G B i=1 ξi G(yi ) , C1 = c1 Ex,y [A(x, y)]Q̃2,1 , C2 =
√ p 2 2
2 2c4 c5 + c26 , and C3 = 2c1 c3 Ex,y [A(x, y)]. Here, ξ1 , . . . , ξm are independent uniform random variables
taking values in {−1, 1}.
16
Rm (Fk ) + Rm (Gk ) = O( √1m ) and R̃B (G) = O( √1B ) in terms of m and B as illustrated in the proof of
Theorem 1 for standard deep neural networks. Therefore, it is desirable to use a large contrastive batch size,
which motivates our new algorithm in the previous section.
17
training labeled dataset, which mostly consists of natural images, has very few digit images. Meanwhile, our
noisy image-text dataset has plenty instances that can teach a model certain optical character recognition skills.
As will be shown in Section 9, our best experimental results are achieved using a hybrid procedure. First,
we pretrain the image encoder on a large labeled dataset, then fix its weights and train the text encoder using
the contrastive loss on our image-text dataset. Finally, we finetune both image and text encoders, using our
GradAccum technique when needed. In Section 10, we present ablation studies to analyze the effects of
pretraining, finetuning, and other alternative training procedures.
9. Experiments
9.1 Training details
Labeled data for pretraining. For pretraining (Section 8), we use the JFT dataset. This dataset has been
used in previous publications (Zhai et al., 2021; Dosovitskiy et al., 2021; Kolesnikov et al., 2020), but it has
been constantly expanded. The JFT version used in our experiments has 5B images, each of which can be
associated to one or multiple labels out of 29K possible classes.
Data filtering. A problem with training on large auto-curated datasets like ALIGN and JFT is that these
datasets might unintentionally contain examples from our test sets. To avoid such contaminations, we filter
all instances in our training data that has a structural similarity index (SSIM (Wang et al., 2004)) of at least
0.5 with any image from our evaluation benchmarks.
Optimizer. We train our models with our own optimizer called AdaFactorW, adapted from two existing
ones: AdaFactor (Shazeer and Stern, 2018) and AdamW (Loshchilov and Hutter, 2019). Specifically, we
factorize our second gradient moments like AdaFactor, and decouple the weight decay from all moments
like AdamW. To further save memory, we follow Zhai et al. (2021) and store the first gradient moments
in bfloat16. We observe, however, that while we can store these moments in bfloat16, we need to
convert them into float32 prior to computing our weight updates to avoid numerical instability.
Caltech101
RESISC45
CIFAR100
VOC2007
CIFAR10
ImageNet
EuroSAT
Datasets
Food101
Birdsnap
IIIT-Pets
SUN397
UCF101
Flowers
MNIST
STL10
PCam
DTD
ResNet-50 32.6 82.1 75.6 41.6 41.7 41.1 65.9 81.1 59.6 66.6 85.4 57.6 54.2 94.3 59.6 63.6 82.1
38.6 91.6 86.4 57.8 54.3 29.1 76.8 86.0 71.9 32.4 93.2 54.3 53.5 96.7 67.3 65.5 83.4
BASIC-S
(+6.0) (+9.5) (+10.8) (+16.2) (+12.6) (-12.0) (+10.9) (+4.9) (+12.3) (-34.2) (+7.8) (-3.3) (-0.7) (+2.4) (+7.7) (+1.9) (+1.3)
ViT-B/16 39.1 89.3 91.6 68.7 46.0 54.1 70.4 89.2 68.6 56.0 88.9 48.1 65.5 98.2 65.2 69.8 83.9
49.4 94.2 94.8 72.2 60.2 39.5 86.0 92.3 81.5 33.6 95.3 58.3 65.4 99.3 72.9 77.4 84.2
BASIC-M
(+10.3) (+4.9) (+3.2) (+3.5) (+14.2) (-14.6) (+15.6) (+3.1) (+12.9) (-22.4) (+6.4) (+10.2) (-0.1) (+1.1) (+7.7) (+7.6) (+0.3)
ViT-L/14-336 49.5 92.8 95.7 77.5 55.7 59.6 78.3 93.8 76.2 88.3 93.5 63.0 71.7 99.4 68.4 76.9 84.3
59.2 94.7 97.5 82.3 64.6 51.0 91.2 95.1 85.7 40.3 97.9 59.6 72.7 99.6 76.2 84.8 84.6
BASIC-L
(+9.7) (+1.9) (+1.8) (+4.8) (+8.9) (-8.6) (+13.1) (+1.3) (+9.5) (-48.0) (+4.4) (-3.4) (+1.0) (+0.2) (+7.8) (+7.9) (+0.3)
Table 3: Performances of BASIC and CLIP models (Radford et al., 2021) on 17 image classification benchmarks. The
first two blocks compare models of similar numbers of weights and FLOPs. The last block compares the largest CLIP
and BASIC models.
Other hyperparameters. For all experiments, we train and evaluate with the image resolution of 224x224.
While we can increase this resolution to gain performance (Tan and Le, 2019, 2021; Touvron et al., 2019;
Radford et al., 2021; Jia et al., 2021), we choose not to do this and instead, reserve our computational resources
for scaling up our model and our batch size. All of our other hyper-parameters can be found in Appendix B.
18
9.2 Results on Image Classification Benchmarks
We first present the zero-shot transfer performance of our BASIC models. We compare our models BASIC-
{S,M,L} to CLIP models with similar computational budgets (Radford et al., 2021) on 17 natural image
classification datasets. Details about these datasets can be found in Appendix C.
Zero-shot transfer models require textual prompts, which we take from CLIP (Radford et al., 2021) for
consistent comparison. We suspect that using prompts which are tuned for our models can further improve
our results as shown in (Lester et al., 2021), because the text sequences in our training data have a different
distribution from the text sequences in CLIP.
Table 3 shows the comparison. From the table, it can be seen that BASIC models conclusively outperform
CLIP models of the same computational budgets. Specifically, BASIC models demonstrate higher accuracy
than CLIP models on 13 out of 17 datasets. On the Oxford IIIT Pets dataset, BASIC-L achieves 97.9% mean
per-class recall which sets a new state-of-the-art, despite having never seen any training images from the
dataset. On ther other hand, BASIC models have low accuracy on EuroSAT, MNIST, and PCam. MNIST is
where BASIC models perform worst, where the highest accuracy is only 40.3%. We discuss these failure
cases further in Section 11.
More accurate zero-shot transfer models are also more robust. We evaluate BASIC-{S,M,L} models
from Section 9.2 on 5 robustness benchmarks derived from ImageNet: ImageNet-A (Hendrycks et al., 2021b),
ImageNet-R (Hendrycks et al., 2021a), ImageNet-V2 (Recht et al., 2019), ImageNet-Sketch (Wang et al.,
2019), and ObjectNet (Barbu et al., 2019). These benchmarks have images in all or a subset of the 1000
ImageNet classes, but their inputs are selected from certain natural distribution shifts, which can cause
ImageNet-trained models to make many more mistakes. Our numerical results are highlighted in Table 1
from Section 1. To visualize the data trend, in Figure 3, we plot the accuracy of zero-shot models – BASIC,
CLIP (Radford et al., 2021), and ALIGN (Jia et al., 2021) – and of 200 ImageNet-trained models collected
by Taori et al. (2020).
The data points from our BASIC models extend the prediction from CLIP: zero-shot transfer models have
a higher effective robustness (Radford et al., 2021; Taori et al., 2020), i.e. they have higher robustness than
ImageNet-trained models with the same ImageNet accuracy. To extrapolate from this trend, we fit a logistic
curve (red dashes) to the zero-shot accuracy and robustness of zero-shot transfer models. The plot shows that
this line meets the ideal robustness line at about 91% on the x-coordinate. In other words, our plot predicts
19
Ideal robustness (y=x)
90 ZS logistic fit
70
60
50
40
30
20
60 65 70 75 80 85 90
ImageNet (top-1, %)
Figure 3: Top-1 accuracy on ImageNet vs. average top-1 accuracy on 5 robustness benchmarks. Zero-shot models
(red stars and yellow rhombuses) have significantly higher effective robustness (Taori et al., 2020) compared to
ImageNet-trained models (blue dots).
ImageNet-R (top-1)
ObjectNet (top-1)
82.4 92.6 85.1 82.7 82.6 84.4
Average (top-1)
75.6 88.5 81.1 76.1 76.0 79.6
Figure 4: Top-1 accuracy of BASIC models on ImageNet and on 5 robustness benchmarks. In all cases, as the BASIC
models are trained on more ImageNet labeled data (1%, 10%, 20%, and 50%), their ImageNet accuracy significantly
increase, but their accuracy on the robustness benchmarks increase much less, or decrease.
that a model which achieves about 91% zero-shot accuracy on ImageNet, i.e., just slightly better than the
state-of-the-art ImageNet-trained model (Dai et al., 2021), will also achieve the ideal robustness.
ImageNet-finetuned models are less robust. We now study the effect of ImageNet’s labeled data on
our models. We take the converged BASIC-{S,M,L} checkpoints from Section 9.2 and continue to train
them on 1%, 10%, 20%, and 50% of ImageNet’s labeled examples. Note that we continue training these
checkpoints using the contrastive loss, where the names of ImageNet classes are utilized as text sequences
accompanying their images. This is different from CLIP’s linear probing approach, which we do not perform
to avoid potential confounding factors from our study, e.g. linear classifiers might behave differently from
our zero-shot transfer classifiers. We then compare the accuracy of these finetuned models on ImageNet and
on the 5 robustness benchmarks. The results are visualized in Figure 4.
The figure shows a clear trend: as our model learns from more labeled ImageNet data, they become more
accurate on ImageNet, but these gains do not carry over to the robustness benchmarks. Specifically, with
the exception of ImageNet-V2, for which the accuracy of finetuned models stay the same (for BASIC-L)
or slightly increase (for BASIC-M), for all other robustness benchmarks, the finetuned models suffer from
20
significant performance drops. In the extreme case, 3% accuracy gain on ImageNet leads to 8.3% accuracy
drop for ImageNet-R.
What makes our finetuned models less robust? A quick glance at our results might lead to the superficial
conclusion that our models have overfit, as our finetuning sets are a lot smaller than ALIGN and JFT. However,
this overfitting theory does not explain the trend observed in Figure 4: training on more labeled ImageNet
data makes our models less robust. We hope our observation invites further causal analysis on the effects of
ImageNet’s labeled data.
To demonstrate the role of large batch sizes, we conduct several controlled experiments for BASIC-S and
BASIC-M on ALIGN. For both BASIC-S and BASIC-M, we fix all hyperparameters as shown in Table 6,
but vary the batch size and the number of training steps. Models that are trained with larger batch sizes are
trained with fewer steps to guarantee that they “see” the same number of examples. Table 4 presents the
ImageNet top-1 zero-shot accuracy of all models at the end of their training, and Figure 5 visualizes their
entire validation accuracy curves.
BASIC-S BASIC-M
70
ImageNet top-1 (%)
60
Batch size Steps BASIC-S BASIC-M
50
4096 800K 55.6 64.8
40 8192 400K 57.6 67.7
16384 200K 58.8 69.4
30 32768 100K 59.3 70.1
20
0 0.8 1.6 2.4 3.2 0 0.8 1.6 2.4 3.2
Table 4: Top-1 ImageNet accuracy at the end
# examples seen (×109 ) # examples seen (×109 )
of the training for our BASIC-{S,M} models
4k 8k 16k 32k trained with different batch sizes and num-
bers of training steps. All models are trained
Figure 5: ImageNet held-out validation accuracy curves with different for the same number of epochs, but models
batch sizes. Models with smaller batch sizes are trained for more steps trained with larger batch sizes has a higher
to ensure a fair comparison. The comparison shows that despite seeing accuracy.
the same number of training examples, models with larger batch sizes
reach higher performances than models with more training steps. Image
best viewed in color.
Table 4 and Figure 5 both suggest that training for more steps cannot equalize the benefit of large batch
sizes. This phenomenon is consistent with the observation from SimCLR (Chen et al., 2020a,b): large batch
sizes help contrastive learning. SimCLR observes that the benefit of large batch sizes saturate at 8192. In
contrast, our results in Table 4 and Figure 5 show that lager batch sizes continue to benefit our models until
32768, and even until 65536 as in Section 9.2. We suspect that the benefits for large batch sizes do not
saturate because our dataset size and model size are both larger than those of SimCLR, e.g. ALIGN with 1.7B
examples compared to ImageNet with 1M examples, and BASIC-{S, M} compared to ResNet-{50,101,152}.
This comparison suggests the benefits of our method – combined scaling.
21
10.2 Data Scaling, Model Scaling, and Pretraining
85
Figure 6: Break-down contributions of data scaling and model scaling for BASIC-S and BASIC-M. Shown are the
ImageNet top-1 accuracy of our BASIC-{S,M} models under different training settings. Models trained from scratch
on ALIGN+JFT has almost the same performance with models pretrained on JFT and then finetuned on ALIGN or on
ALIGN+JFT. Models that are pretrained and then have both their image and text encoders finetuned reach the highest
accuracy. Figure best viewed in color.
We now study the benefits of other scaling dimensions, data and model scaling, on the quality of our
models. We also study pretraining as an alternate training procedure to contrastive learning. We train
BASIC-{S,M} models in 6 different settings and plot their final top-1 ImageNet accuracy in Figure 6. Below,
we compare and analyze the settings.
First, BASIC-S and BASIC-M respectively gain 5.3% and 5.8% accuracy when we expand the contrastive
training dataset from ALIGN to ALIGN+JFT. These gains, albeit large, are smaller than the gain by enlarging
the model size, e.g., 11.7% when going from BASIC-S to BASIC-M.
Next, we study the effects of pretraining image encoders on JFT. As can be seen from Figure 6, models
whose image encoders are pretrained on JFT and whose text encoders are subsequently trained on ALIGN,
i.e., the red bars, have similar performances with models trained from scratch on ALIGN+JFT, i.e., the blue
bars. Their similar accuracy suggest that the training losses – softmax cross-entropy or contrastive – have a
much smaller effect than the datasets. In other words, when given the same dataset, the image encoders in
BASIC models learn to become equally good, regardless of their loss functions.
To our surprise, training the text encoders for JFT-pretrained image encoders on ALIGN+JFT gains 1%
for BASIC-S and 1.8% for BASIC-L, compared to training these text encoders on ALIGN. We suspect that
these gains come from better representations for the textual prompts, since the models trained on ALIGN+JFT
also sees the textual prompts which consist of clean JFT class names. However, this speculation needs a more
thorough study to understand.
Finally, we find that if we take a converged model whose image encoder is pretrained on JFT and whose
text encoder is trained on ALIGN+JFT, then we continue to train both its image encoders and text encoders at
a small learning rate. This extra training phase gains us 1.4% ImageNet accuracy for BASIC-S, 0.6% for
BASIC-M, and 0.4% for BASIC-L (not shown in this section).
22
11. Limitations
Despite the strong results of our zero-shot transfer classifier, especially on natural image classification tasks,
they inevitably have their shortcomings. In this section, we discuss the problems that we find with our BASIC
models.
Zero-shot transfer models do not perform well on test sets that are underrepresented in the training
datasets. We emphasize the failures of BASIC on two test sets where BASIC models are much worse
than CLIP models: EuroSAT, MNIST, PatchCamelyon (PCam) (see Table 3 from Section 9.2). Here, we
summarize that BASIC models fail on MNIST and PCam because our training datasets ALIGN and JFT
have relatively few images of handwritten digits and of lymph nodes, which are the domain of these datasets.
Compared to MNIST and PCam, BASIC models do better on EuroSAT which consist of satellite land images,
but their accuracy is lower than that of CLIP models. This is because the class names for these satellite
images are not very descriptive to BASIC models. More analysis for these failures are in Appendix G.
Zero-shot transfer requires prompt engineering. In this paper, we use the prompts from CLIP (Radford
et al., 2021) to make our results comparable to previous works. In Appendix G, we present examples which
show that prompts that are badly chosen or adversarially chosen can hurt the accuracy of zero-shot transfer
models by flipping their predictions. These examples suggest that prompt engineering is an important research
topic to make zero-shot transfer models robust and reliable, but the topic is of out of the scope of this paper.
Combined scaling is expensive. As reported in Appendix E, the hardware and training time for our models
are not small. Despite the training cost, we can use the models in this paper without any finetuning, and hence
avoid the finetuning cost. We hope that future research can reduce our models’ training expense, e.g., larger
accelerator memory can save the extra re-materialization steps.
12. Conclusion
Zero-shot transfer learning represents a new paradigm where pretrained models can be used directly for
downstream applications without collecting any application-specific data. However, in order to become
practical for real-world applications, zero-shot transfer models need to bridge the accuracy gap to supervised
and semi-supervised models.
In this paper, we presented combined scaling techniques that significantly boost the performance of
zero-shot transfer models. We show that scaling in the data size, the model size, and the batch size all
improves the final model’s accuracy and robustness. To overcome the memory limit arising from combined
scaling, we devise a simple gradient accumulation method based on re-materialization.
References
Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay
Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: A system for large-scale machine learning. In OSDI,
2016.
Zeynep Akata, Florent Perronnin, Zaid Harchaoui, and Cordelia Schmid. Label-embedding for image classification.
IEEE TPAMI, 2015a.
Zeynep Akata, Scott Reed, Daniel Walter, Honglak Lee, and Bernt Schiele. Evaluation of output embeddings for
fine-grained image classification. In CVPR, 2015b.
Andrei Barbu, David Mayo, Julian Alverio, William Luo, Christopher Wang, Dan Gutfreund, Josh Tenenbaum, and
Boris Katz. Objectnet: A large-scale bias-controlled dataset for pushing the limits of object recognition models. In
NeurIPS, 2019.
23
Peter L Bartlett and Shahar Mendelson. Rademacher and gaussian complexities: Risk bounds and structural results.
Journal of Machine Learning Research, 3(Nov):463–482, 2002.
Irwan Bello, William Fedus, Xianzhi Du, Ekin D Cubuk, Aravind Srinivas, Tsung-Yi Lin, Jonathon Shlens, and Barret
Zoph. Revisiting resnets: Improved training and scaling strategies. In NeurIPS, 2021.
Thomas Berg, Jiongxin Liu, Seung Woo Lee, Michelle L Alexander, David W Jacobs, and Peter N Belhumeur. Birdsnap:
Large-scale fine-grained visual categorization of birds. CVPR, 2014.
Lukas Bossard, Matthieu Guillaumin, and Luc Van Gool. Food-101–mining discriminative components with random
forests. ECCV, 2014.
Andy Brock, Soham De, Samuel L Smith, and Karen Simonyan. High-performance large-scale image recognition
without normalization. In International Conference on Machine Learning, pages 1059–1071. PMLR, 2021.
Jiacheng Chen, Hexiang Hu, Hao Wu, Yuning Jiang, and Changhu Wang. Learning the best pooling strategy for visual
semantic embedding. In CVPR, 2021.
Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. Arxiv
1604.06174, 2016.
Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton. A simple framework for contrastive
learning of visual representations. In ICML, 2020a.
Ting Chen, Simon Kornblith, Kevin Swersky, Mohammad Norouzi, and Geoffrey Hinton. Big self-supervised models
are strong semi-supervised learners. In NIPS, 2020b.
Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He. Improved baselines with momentum contrastive learning.
ArXiv 2003.04297, 2020c.
Yen-Chun Chen, Linjie Li, Licheng Yu, Ahmed El Kholy, Faisal Ahmed, Zhe Gan, Yu Cheng, and Jingjing Liu. Uniter:
Universal image-text representation learning. In ECCV, 2020d.
Gong Cheng, Junwei Han, and Xiaoqiang Lu. Remote sensing image scene classification: Benchmark and state of the
art. Proceedings of the IEEE, 2017.
Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, and Andrea Vedald. Describing textures in the
wild. In CVPR, 2014.
Adam Coates, Andrew Ng, and Honglak Lee. An Analysis of Single Layer Networks in Unsupervised Feature Learning.
In AISTATS, 2011.
Zihang Dai, Hanxiao Liu, Quoc V. Le, and Mingxing Tan. Coatnet: Marrying convolution and attention for all data
sizes. In NeurIPS, 2021.
Karan Desai and Justin Johnson. Virtex: Learning visual representations from textual annotations. In CVPR, 2021.
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional
transformers for language understanding. In NAACL, 2018.
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner,
Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image
is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
M. Everingham, L. Van Gool, C. K. I. Williams, J. Winn, and A. Zisserman. The PASCAL Visual Object Classes
Challenge 2007 (VOC2007) Results.
Fartash Faghri, David J Fleet, Jamie Ryan Kiros, and Sanja Fidler. Vse++: Improving visual-semantic embeddings with
hard negatives. BMVC, 2017.
Li Fei-Fei, Fergus Rob, and Pietro Perona. Learning generative visual models from few training examples: An
incremental bayesian approach tested on 101 object categories. In CVPR, 2004.
Andrea Frome, Greg Corrado, Jonathon Shlens, Samy Bengio, Jeffrey Dean, Marc’Aurelio Ranzato, and Tomas
Mikolov. Devise: A deep visual-semantic embedding model. In Advances in Neural Information Processing Systems,
2013.
Noah Golowich, Alexander Rakhlin, and Ohad Shamir. Size-independent sample complexity of neural networks. In
Conference On Learning Theory, pages 297–299. PMLR, 2018.
Priya Goyal, Piotr Dollár, Ross B. Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch,
Yangqing Jia, and Kaiming He. Accurate, large minibatch SGD: training imagenet in 1 hour. Arxiv 1706.02677,
2017.
24
Andreas Griewank and Andrea Walther. Algorithm 799: revolve: an implementation of checkpointing for the reverse or
adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26(1):19–45,
2000.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR,
2015.
Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross B. Girshick. Momentum contrast for unsupervised visual
representation learning. In CVPR, 2020.
Patrick Helber, Benjamin Bischke, Andreas Dengel, and Damian Borth. Introducing eurosat: A novel dataset and
deep learning benchmark for land use and land cover classification. In IEEE International Geoscience and Remote
Sensing Symposium, 2018.
Dan Hendrycks, Steven Basart, Norman Mu, Saurav Kadavath, Frank Wang, Evan Dorundo, Rahul Desai, Tyler Zhu,
Samyak Parajuli, Mike Guo, Dawn Song, Jacob Steinhardt, and Justin Gilmer. The many faces of robustness: A
critical analysis of out-of-distribution generalization. ICCV, 2021a.
Dan Hendrycks, Kevin Zhao, Steven Basart, Jacob Steinhardt, and Dawn Song. Natural adversarial examples. CVPR,
2021b.
Yasuhide Mori Hironobu, Hironobu Takahashi, and Ryuichi Oka. Image-to-word transformation based on dividing and
vector quantizing images with words. In Citeseer, 1999.
Micah Hodosh, Peter Young, and Julia Hockenmaier. Framing image description as a ranking task: Data, models and
evaluation metrics. Journal of Artificial Intelligence Research, 2013.
Jie Hu, Li Shen, and Gang Sun. Squeeze-and-excitation networks. In Proceedings of the IEEE conference on computer
vision and pattern recognition, pages 7132–7141, 2018.
Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra, and Kilian Weinberger. Deep networks with stochastic depth. In
BMVC, 2017.
Yanping Huang, Youlong Cheng, Ankur Bapna, Orhan Firat, Dehao Chen, Mia Chen, HyoukJoong Lee, Jiquan Ngiam,
Quoc V Le, Yonghui Wu, et al. Gpipe: Efficient training of giant neural networks using pipeline parallelism.
Advances in neural information processing systems, 32:103–112, 2019.
Zhicheng Huang, Zhaoyang Zeng, Bei Liu, Dongmei Fu, and Jianlong Fu. Pixel-bert: Aligning image pixels with text
by deep multi-modal transformers. Arxiv 2004.00849, 2020.
Paras Jain, Ajay Jain, Aniruddha Nrusimha, Amir Gholami, Pieter Abbeel, Joseph Gonzalez, Kurt Keutzer, and Ion
Stoica. Checkmate: Breaking the memory wall with optimal tensor rematerialization. Proceedings of Machine
Learning and Systems, 2:497–511, 2020.
Chao Jia, Yinfei Yang, Ye Xia, Yi-Ting Chen, Zarana Parekh, Hieu Pham, Quoc V. Le, Yunhsuan Sung, Zhen Li, and
Tom Duerig. Scaling up visual and vision-language representation learning with noisy text supervision. In ICML,
2021.
Armand Joulin, Laurens Van Der Maaten, Allan Jabri, and Nicolas Vasilache. Learning visual features from large
weakly supervised data. In ECCV, 2016.
Norman P. Jouppi, Cliff Young, Nishant Patil, David A. Patterson, Gaurav Agrawal, Raminder Bajwa, Sarah Bates,
Suresh Bhatia, Nan Boden, Al Borchers, Rick Boyle, Pierre-luc Cantin, Clifford Chao, Chris Clark, Jeremy Coriell,
Mike Daley, Matt Dau, Jeffrey Dean, Ben Gelb, Tara Vazir Ghaemmaghami, Rajendra Gottipati, William Gulland,
Robert Hagmann, Richard C. Ho, Doug Hogberg, John Hu, Robert Hundt, Dan Hurt, Julian Ibarz, Aaron Jaffey,
Alek Jaworski, Alexander Kaplan, Harshit Khaitan, Andy Koch, Naveen Kumar, Steve Lacy, James Laudon, James
Law, Diemthu Le, Chris Leary, Zhuyuan Liu, Kyle Lucke, Alan Lundin, Gordon MacKean, Adriana Maggiore,
Maire Mahony, Kieran Miller, Rahul Nagarajan, Ravi Narayanaswami, Ray Ni, Kathy Nix, Thomas Norrie, Mark
Omernick, Narayana Penukonda, Andy Phelps, Jonathan Ross, Amir Salek, Emad Samadiani, Chris Severn, Gregory
Sizikov, Matthew Snelham, Jed Souter, Dan Steinberg, Andy Swing, Mercedes Tan, Gregory Thorson, Bo Tian,
Horia Toma, Erick Tuttle, Vijay . In-datacenter performance analysis of a tensor processing unit. Arxiv 1704.04760,
2017.
Andrej Karpathy and Li Fei-Fei. Deep visual-semantic alignments for generating image descriptions. In CVPR, 2015.
Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.
Ryan Kiros, Ruslan Salakhutdinov, and Richard S Zemel. Unifying visual-semantic embeddings with multimodal
neural language models. Arxiv 1411.2539, 2014.
25
Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, and Neil Houlsby.
Big transfer (bit): General visual representation learning. In ECCV, 2020.
Alex Krizhevsky. Learning multiple layers of features from tiny images. Technical report, 2009.
Taku Kudo and John Richardson. Sentencepiece: A simple and language independent subword tokenizer and detokenizer
for neural text processing. In EMNLP, 2018.
Ravi Kumar, Manish Purohit, Zoya Svitkina, Erik Vee, and Joshua Wang. Efficient rematerialization for deep networks.
Advances in Neural Information Processing Systems, 32, 2019.
Hugo Larochelle, Dumitru Erhan, and Yoshua Bengio. Zero-data learning of new tasks. In AAAI, 2008.
Yann LeCun, Corinna Cortes, and CJ Burges. Mnist handwritten digit database. ATT Labs [Online]. Available:
http://yann.lecun.com/exdb/mnist, 2, 2010.
Dmitry Lepikhin, HyoukJoong Lee, Yuanzhong Xu, Dehao Chen, Orhan Firat, Yanping Huang, Maxim Krikun, Noam
Shazeer, and Zhifeng Chen. Gshard: Scaling giant models with conditional computation and automatic sharding.
Arxiv 2006.16668, 2020.
Brian Lester, Rami Al-Rfou, and Noah Constant. The power of scale for parameter-efficient prompt tuning. In EMNLP,
2021.
Ang Li, Allan Jabri, Armand Joulin, and Laurens van der Maaten. Learning visual n-grams from web data. In ICCV,
2017.
Kunpeng Li, Yulun Zhang, Kai Li, Yuanyuan Li, and Yun Fu. Visual semantic reasoning for image-text matching. In
ICCV, 2019.
Fenglin Liu, Yuanxin Liu, Xuancheng Ren, Xiaodong He, and Xu Sun. Aligning visual regions and textual concepts for
semantic-grounded image representations. Arxiv 1905.06139, 2019.
Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In ICLR, 2019.
Jiasen Lu, Dhruv Batra, Devi Parikh, and Stefan Lee. Vilbert: Pretraining task-agnostic visiolinguistic representations
for vision-and-language tasks. NeurIPS, 2019.
Dhruv Mahajan, Ross B. Girshick, Vignesh Ramanathan, Kaiming He, Manohar Paluri, Yixuan Li, Ashwin Bharambe,
and Laurens van der Maaten. Exploring the limits of weakly supervised pretraining. In ECCV, 2018.
Andreas Maurer. A vector-contraction inequality for rademacher complexities. In International Conference on
Algorithmic Learning Theory, pages 3–17. Springer, 2016.
Nicola Messina, Giuseppe Amato, Andrea Esuli, Fabrizio Falchi, Claudio Gennaro, and Stéphane Marchand-Maillet.
Fine-grained visual textual alignment for cross-modal retrieval using transformer encoders. ACM Transactions on
Multimedia Computing, Communications, and Applications, 2020.
Mehryar Mohri, Afshin Rostamizadeh, and Ameet Talwalkar. Foundations of machine learning. MIT press, 2012.
Hyeonseob Nam, Jung-Woo Ha, and Jeonghee Kim. Dual attention networks for multimodal reasoning and matching.
In CVPR, 2017.
Yurii E. Nesterov. A method for solving the convex programming problem with convergence rate o(1/k 2 ). Soviet
Mathematics Doklady, 1983.
Maria-Elena Nilsback and Andrew Zisserman. Automated flower classification over a large number of classes. In
Indian Conference on Computer Vision, Graphics and Image Processing, 2008.
Mohammad Norouzi, Tomas Mikolov, Samy Bengio, Yoram Singer, Jonathon Shlens, Andrea Frome, Greg S Corrado,
and Jeffrey Dean. Zero-shot learning by convex combination of semantic embeddings. Arxiv 1312.5650, 2013.
Myle Ott, Sergey Edunov, David Grangier, and Michael Auli. Scaling neural machine translation. In Workshop in
Neural Machine Translation, 2018.
Omkar M Parkhi, Andrea Vedaldi, Andrew Zisserman, and CV Jawahar. Cats and dogs. CVPR, 2012.
Hieu Pham, Zihang Dai, Qizhe Xie, Minh-Thang Luong, and Quoc V. Le. Meta pseudo labels. In CVPR, 2021.
Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda
Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, and Ilya Sutskever. Learning transferable visual models
from natural language supervision. In ICML, 2021.
Benjamin Recht, Rebecca Roelofs, Ludwig Schmidt, and Vaishaal Shankar. Do imagenet classifiers generalize to
imagenet? In ICML, 2019.
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. ImageNet Large Scale Visual Recognition
26
Challenge. IJCV, 2009.
Mert Bulent Sariyildiz, Julien Perez, and Diane Larlus. Learning visual representations with caption annotations. In
ECCV, 2020.
Edgar Schönfeld, Sayna Ebrahimi, Samarth Sinha, Trevor Darrell, and Zeynep Akata. Generalized zero- and few-shot
learning via aligned variational autoencoders. In CVPR, 2019.
Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge
university press, 2014.
Noam Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with sublinear memory cost. Arxiv 1804.04235,
2018.
Noam Shazeer, Youlong Cheng, Niki Parmar, Dustin Tran, Ashish Vaswani, Penporn Koanantakool, Peter Hawkins,
HyoukJoong Lee, Mingsheng Hong, Cliff Young, et al. Mesh-tensorflow: Deep learning for supercomputers. Arxiv
1811.02084, 2018.
Richard Socher and Li Fei-Fei. Connecting modalities: Semi-supervised segmentation and annotation of images using
unaligned text corpora. In CVPR, 2010.
Richard Socher, Milind Ganjoo, Hamsa Sridhar, Osbert Bastani, Christopher D Manning, and Andrew Y Ng. Zero-shot
learning through cross-modal transfer. In Advances in Neural Information Processing Systems, 2013.
Richard Socher, Andrej Karpathy, Quoc V. Le, Christopher D. Manning, and Andrew Y. Ng. Grounded compositional
semantics for finding and describing images with sentences. TACL, 2014.
Khurram Soomro, Amir Roshan Zamir, and Mubarak Shah. UCF101: A dataset of 101 human actions classes from
videos in the wild. Arxiv 1212.0402, 2012.
Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: A simple
way to prevent neural networks from overfitting. In JMLR, 2014.
Chen Sun, Abhinav Shrivastava, Saurabh Singh, and Abhinav Gupta. Revisiting unreasonable effectiveness of data in
deep learning era. In ICCV, 2017.
Christian Szegedy, Wojciech Zaremba, Ilya Sutskever, Joan Bruna, Dumitru Erhan, Ian Goodfellow, and Rob Fergus.
Intriguing properties of neural networks. Arxiv 1312.6199, 2013.
Mingxing Tan and Quoc V. Le. Efficientnet: Rethinking model scaling for convolutional neural networks. In ICML,
2019.
Mingxing Tan and Quoc V. Le. Efficientnetv2: Smaller models and faster training. In ICML, 2021.
Mingxing Tan, Ruoming Pang, and Quoc V. Le. Efficientdet: Scalable and efficient object detection. In CVPR, 2020.
Rohan Taori, Achal Dave, Vaishaal Shankar, Nicholas Carlini, Benjamin Recht, and Ludwig Schmidt. Measuring
robustness to natural distribution shifts in image classification. In NeurIPS, 2020.
Yonglong Tian, Chen Sun, Ben Poole, Dilip Krishnan, Cordelia Schmid, and Phillip Isola. What makes for good views
for contrastive learning? In NeurIPS, 2020.
Tijmen Tieleman and Geoffrey Hinton. RmsProp: Divide the gradient by a running average of its recent magnitude.
COURSERA: Neural Networks for Machine Learning, 2012.
Hugo Touvron, Andrea Vedaldi, Matthijs Douze, and Hervé Jégou. Fixing the train-test resolution discrepancy. In
NeurIPS, 2019.
Aäron van den Oord, Yazhe Li, and Oriol Vinyals. Representation learning with contrastive predictive coding. Arxiv
1807.03748, 2018.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia
Polosukhin. Attention is all you need. In Advances in neural information processing systems, 2017.
Bastiaan S Veeling, Jasper Linmans, Jim Winkens, Taco Cohen, and Max Welling. Rotation equivariant CNNs for
digital pathology. Medical Image Computing and Computer Assisted Intervention, 2018.
Oriol Vinyals, Alexander Toshev, Samy Bengio, and Dumitru Erhan. Show and tell: A neural image caption generator.
In CVPR, 2015.
Haohan Wang, Songwei Ge, Zachary Lipton, and Eric P Xing. Learning robust global representations by penalizing
local predictive power. In NeurIPS, 2019.
Zhou Wang, Alan C Bovik, Hamid R Sheikh, and Eero P Simoncelli. Image quality assessment: from error visibility to
structural similarity. IEEE transactions on image processing, 13(4):600–612, 2004.
27
Jason Weston, Samy Bengio, and Nicolas Usunier. Large scale image annotation: learning to rank with joint word-image
embeddings. Machine Learning, 2010.
Yongqin Xian, Zeynep Akata, Gaurav Sharma, Quynh N. Nguyen, Matthias Hein, and Bernt Schiele. Latent embeddings
for zero-shot classification. In CVPR, 2016.
Yongqin Xian, Bernt Schiele, and Zeynep Akata. Zero-shot learning - the good, the bad and the ugly. In CVPR, 2017.
Jianxiong Xiao, James Hays, Krista A. Ehinger, Aude Oliva, and Antonio Torralba. Sun database: Large-scale scene
recognition from abbey to zoo. In CVPR, 2010.
Qizhe Xie, Minh-Thang Luong, Eduard Hovy, and Quoc V Le. Self-training with noisy student improves imagenet
classification. In CVPR, 2020.
Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhudinov, Rich Zemel, and Yoshua
Bengio. Show, attend and tell: Neural image caption generation with visual attention. In ICML, 2015.
Yuanzhong Xu, HyoukJoong Lee, Dehao Chen, Blake Hechtman, Yanping Huang, Rahul Joshi, Maxim Krikun, Dmitry
Lepikhin, Andy Ly, Marcello Maggioni, et al. Gspmd: General and scalable parallelization for ml computation
graphs. arXiv preprint arXiv:2105.04663, 2021.
Li Yuan, Qibin Hou, Zihang Jiang, Jiashi Feng, and Shuicheng Yan. Volo: Vision outlooker for visual recognition.
ArXiv 2106.13112, 2021.
Xiaohua Zhai, Alexander Kolesnikov, Neil Houlsby, and Lucas Beyer. Scaling vision transformers. In Arxiv 2106.04560,
2021.
Xiaohua Zhai, Xiao Wang, Basil Mustafa, Andreas Steiner, Daniel Keysers, Alexander Kolesnikov, and Lucas Beyer.
Lit: Zero-shot transfer with locked-image text tuning. In Proceedings of the IEEE/CVF Conference on Computer
Vision and Pattern Recognition, pages 18123–18133, 2022.
Li Zhang, Tao Xiang, and Shaogang Gong. Learning a deep embedding model for zero-shot learning. In CVPR, 2017.
Yuhao Zhang, Hang Jiang, Yasuhide Miura, Christopher D Manning, and Curtis P Langlotz. Contrastive learning of
medical visual representations from paired images and text. Arxiv 2010.00747, 2020.
28
A. Model sizes
In our preliminary experiments, we experimented with different model sizes. Table 5 presents the final, most
compute-to-performance efficient model sizes, which we use throughout the paper.
Table 5: Model sizes. For the image models, all specifications can be found from the model names in Dai et al. (2021).
No regularization. Other than the decoupled weight decay in AdaFactorW, we do not use any other regular-
ization technique. In fact, we find that with BASIC-S and BASIC-M, if we add other forms of regularization
such as stochastic depth (Huang et al., 2017) or dropout (Srivastava et al., 2014), our ImageNet top-1 accuracy
drops substantially. This suggests that our datasets are very large and perhaps in such situation, regularization
techniques do more harm than good by causing optimization difficulty to our models.
Another important effect of not using regularization in our training framework is to make the re-
materialization steps in Section 4.2 consistent. If we apply random perturbations to our forward passes,
e.g. by skipping layers like in stochastic depth or by setting random values to zeros, then two forward
passes for re-materialization (see Lines 2-5 and 11-14 in Algorithm 1) will compute two different passes.
While we could treat such difference as a form of regularization noise, our early experiment show that with
dropout-like regularizations, our training loss stays relatively large throughout the course of training. This
observation suggests that the noise causes some optimization difficulty to our models, so we opt not to use
any dropout-like regularization.
BASIC-S BASIC-{M,L}
Pretraining Contrastive Pretraining Contrastive
Optimizer AdaFactorW AdaFactorW AdaFactorW AdaFactorW
Batch size 16384 65536 16384 65536
Training steps 500K 500K 1.2M 500K
Warm-up steps 25K 25K 25K 25K
Max learning rate 1e-3 1e-3 4e-4 2.5e-4
Min learning rate 1e-5 1e-5 2e-5 1e-5
Learning decay schedule Cosine Cosine Linear Cosine
Weight decay 0.005 0.0025 0.01 0.0025
29
C. Evaluation Datasets Details
Here, we present the details of the datasets which we use to evaluate our BASIC models in Section 9.2.
It is worth noting that not all these datasets use the accuracy as the performance metric. This is because
these datasets have a certain level of imbalance between their classes, as well as other properties that make
them accuracy not the best suitable metric for them. For instance, the dataset Caltech-101 has a class called
“Background” which refers to any image that does not belong to its predefined 101 classes. One certainly
cannot come up with a textual description that describes this “class”. As such, Caltech-101 is evaluated using
mean per-class recall. Details about other datasets are in Table 7.
Table 7: Details of the datasets used in this paper to evaluate BASIC models. The evaluation results are presented in
Table 1 and Table 3.
ImageNet-V2 (Recht et al., 2019). This dataset is collected in a process that closely follows the process to
collect and annotate the images in the standard ILSVRC-2012 validation set, which is typically referred to
as “ImageNet” in the literature (and our paper as well). As such, gains observed on ImageNet often transfer
to ImageNet-V2. Recent works such as EfficientNets (Tan and Le, 2019, 2021) or ViT (Dosovitskiy et al.,
2021) also demonstrate the similar trend. For our experiment in Section 9.3, BASIC-M’s robustness accuracy
improves along with its ImageNet accuracy, following this trend. However, BASIC-L’s robustness does not.
30
We suspect this trend is because BASIC-L’s learning capacity is larger than that of BASIC-M, so BASIC-L
picks up more “spurious” patterns from ImageNet, making it less robust than BASIC-M.
ImageNet-R (Wang et al., 2019). ImageNet-R is a special robustness dataset in our study. Not only of our
BASIC models but also other zero-shot models – CLIP and ALIGN – are more accurate on ImageNet-R than
they are on ImageNet (see Table 1). These data points alone would suggest that ImageNet-R is somewhat
easier than ImageNet, until we look at the significant accuracy drops for other methods on ImageNet-R. For
instance, Noisy Student (Xie et al., 2020) and Meta Pseudo Labels (Pham et al., 2021) respectively achieve
only 74.9% and 72.7% accuracy on ImageNet-R, despite their accuracy of 88.4% and 90.2% on ImageNet
ILSVRC-2012. The real reason for such discrepancy in ImageNet-R is that ImageNet-R is collected by
selecting the ImageNet classes from visual art pieces, such as paintings, cartoons, graffiti, origami, and
sculptures. These art pieces are often displayed in a clean environment, free of noises such as multiple classes
per image, making the images easier to recognize. As such, BASIC, CLIP, and ALIGN, all perform better on
ImageNet-R. However, ImageNet-R images have a drastically different distribution compared to ImageNet
labeled training images, as they are respectively art images and natural images. This is why ImageNet-trained
models display a much lower accuracy on ImageNet, compared to zero-shot models.
The case of ObjectNet (Barbu et al., 2019). From Table 1, it can be seen that BASIC model’s improvement
over ALIGN and CLIP on Object is significantly lower than others on other benchmarks, i.e., 6.6% compared
to more than 8% (except for ImageNet-R, for which the accuracy of all models are saturated at over 90%).
We find out the reason is that, even though ObjectNet has images from the same classes with ImageNet, these
objects turn out to have their own more descriptive names, e.g. the class name “chairs” in ImageNet could be
“chairs by [viewpoint]” or “chairs with [background]”. As we later show in Section G, using different class
names and prompts can affect our results. This effect has also been observed in CLIP (Radford et al., 2021).
Here, we take the same class names and prompts for ImageNet and use them for ObjectNet. We suspect that
using ObjectNet-specific class names and prompts can improve our result.
E. Computational Cost
All of our models are implemented in TensorFlow (Abadi et al., 2016) and trained on Tensor Processing Units
(TPUs (Jouppi et al., 2017)). Our BASIC-S and BASIC-M models are all trained on TPUv3 chips, while
our BASIC-L models are trained on TPUv4 chips. These TPUv4 chips in their MegaCore mode can offer
32GB of memory, out of which our BASIC-L models use 30.1GB, which means that our model essentially
saturates the TPU’s memory. We note that oftentimes, a small portion of TPU memory needs to be reserved
for their low-level infra systems. Therefore, our BASIC-L models essentially saturate the accelerators with
the largest memory currently available. Given this memory usage, we use Algorithm 1 with the microbatch
size M = 8192 and the batch size N = 65536 to train this model. Table 8 summarizes the training cost for
each phase of our models BASIC-{S,M,L} as in Section 9.2.
31
Pretraining Text Encoder Text & Image Encoders
Model
Type Cores×Days Type Cores×Days Type Cores×Days
BASIC-S TPUv3 0.4K TPUv3 0.9K TPUv3 0.3K
BASIC-M TPUv3 1.7K TPUv3 3.9K TPUv3 1.2K
BASIC-L TPUv4 6.9K TPUv4 1.0K TPUv4 0.8K
Table 8: Computational usages to train our models. Core×Days is the product of the number of training days and the
number of cores used to train the models. For instance, using 2048 TPUs in 1 day equals to 2.048 Cores×Days. We use
this metric because sometimes, our jobs are run on different numbers of TPUs due to limited availability.
Eggs with mixed expressions. (3) 0.944 One plus one equals three. (3) 0.468
Emojis with mixed expressions. 0.032 One plus one equals one. 0.264
Happy eggs. 0.022 One plus one equals two. 0.240
Sad eggs. 0.002 One minus one equals three. 0.014
Happy emojis. <1e-3 One minus one equals two. <1e-3
Sad emojis. <1e-3 One minus one equals one. <1e-3
Chemistry equation on a white- 0.958
Cosplayed pikachu. (3) 0.764
board. (3)
Cosplayed charmander. 0.219
Math equation on a whiteboard. 0.024
Real pikachu. 0.012
Physics equation on a whiteboard. 0.012
Cosplayed eevee. 0.006
Chemistry equation on a paper. 0.005
Real charmander. <1e-3
Math equation on a paper. <1e-3
Real eevee. <1e-3
Physics equation on a paper. <1e-3
An alarm clock that reads 7:00 pm. 0.288
(3)
A shirari dog in cold weather. (3) 0.961
An alarm clock that reads 10:00 0.183
A shirari dog in warm weather. 0.038
am.
A corgi dog in cold weather. <1e-3
An alarm clock that reads 12:00 0.167
A shiba inu dog in cold weather. <1e-3
pm.
A corgi dog in warm weather. <1e-3
An alarm clock that reads 4:00 pm. 0.162
A shiba inu dog in warm weather. <1e-3
An alarm clock that reads 7:00 am. 0.133
An alarm clock that reads 2:00 pm. 0.067
G. Failure Analysis
Most machine learning models fail in certain tests. It is important to identify such failure cases, to understand
the failing causes, and if possible, to come up with fixes. Here, we first look at the test benchmarks in Table 3
from Section 9.2 where BASIC models perform worse than CLIP models. We identify the cause of failures
for BASIC models and recommend certain fixes that can improve their performance. Then, in Section G.2,
we present some erroneous behaviors of BASIC models via selected examples. These examples reveal some
weaknesses of zero-shot transfer models, and invite future research to improve them.
32
Patch Camelyon (PCam). PCam is perhaps the most sensitive dataset among the three benchmarks where
BASIC-L performs poorly. This dataset consists of images extracted from histopathologic scans of lymph
node sections, and models are asked to make the binary prediction – whether an input image has a cancerous
lymph node or note. For such an important task, the top-1 accuracy of both BASIC-L (59.6%) and CLIP
(63.0%) are far below the bars for practical deployments. We remark that PCam is a binary classification task,
so the accuracy of BASIC-L and CLIP are just slightly above random guessing. Their poor performance,
however, are quite understandable: classifying lymph nodes requires much more specific training, compared
to classifying common natural images. As our training data are weakly crawled and automatically curated
from the internet, without any emphasis on medical images, our BASIC-L model cannot learn enough to
perform well on PCam. We suspect the same speculation also holds for CLIP, as their data collection and
curation process is comparable to ours. Finally, the low accuracy of CLIP and BASIC models on PCam is an
assertion that despite the benefits of zero-shot transfer models, they are not ready to be deployed to tasks that
require in-domain expertise, e.g. medical knowledge.
EuroSAT. This dataset consists of satellite images taken for certain types of lands. Models are asked to
classify input images into one out of 10 given types of lands. The land types can be seen in Figure 8. The
failure of BASIC-L on EuroSAT is an example for the importance of prompt engineering in zero-shot transfer
learning for image-text models. In Figure 8, we show that by changing the dataset’s class names and the
model’s set of prompts, into words and phrases that essentially have the same meaning to humans, we can
improve the accuracy of BASIC-L from 51.0% to 55.7%. We do not further explore the changes in class
names and prompts to improve BASIC-L’s performance on EuroSAT, as they belong to a different topic from
the focus of this paper – combined scaling. However, our findings on this EuroSAT dataset suggests that
contrastive image-text models do not really “understand” texts. This is perhaps because of the low quality of
the texts in our training data, unlike the millions of words from books and articles like the training data of
NLP models such as BERT (Devlin et al., 2018).
MNIST. MNIST is a classical dataset in computer vision for handwritten digit classification. Simple
models can achieve more than 99.5% accuracy, and yet BASIC-L achieves the humble 40.3% accuracy.
Unlike the case of PCam, i.e. there is not enough training data in our training dataset, for MNIST, we find
that the ALIGN dataset has a fair amount of images that contain digits, either handwritten or printed. This
means that the image encoder of BASIC-L has seen digit figures, and suggests that the failures might be more
attributable to the text encoder, similar to the case of EuroSAT. In Figure 9, we show the confusion matrices
of BASIC-L models with three sets of class names: using the digits such as {‘0’, ‘1’, ...}, using the the texts
such as {‘one’, ‘two‘’, ...}, and using both such as {‘0 or zero’, ‘1 or one’, ...}. Unfortunately, we cannot
improve BASIC-L’s accuracy on MNIST, like we did for EuroSAT: BASIC-L’s accuracy is low in all three
cases, but the confusion matrices are visibly different: BASIC-L models ‘thinks’ that many digits look like ‘3’
for the digit-only class names, but many digits look like ‘1 or one’ in the digit-and-text class names. Again,
humans who understand languages will not make these mistakes. We think these mistakes constitute a new
type or robustness failures, which we hope will invite further research.
33
CLIP’s prompts & names
(top-1 acc=51.0%)
annual crop land
forest
brushland or shrubland
highway or road 2218
industrial buildings or commercial buildings
pasture land
permanent crop land
residential buildings or homes or apartments
1663
river
lake or sea
Figure 8: Confusion matrices of BASIC-L on the EuroSAT classification dataset (Helber et al., 2018). Shown are the
confusion matrices obtained from zero-shot transfer learning from BASIC-L, using prompts and class names and CLIP,
compared to the same model using prompts and class names that we tuned. The zero-shot top-1 accuracy with our
prompts and class names are 4.7% higher, and the confusion matrix illustrates this by showing more concentration on
the diagonal.
34
Digit only Text only Digit and Text
0 0 0 1000
1 1 1
2 2 2 800
3 3 3
4 4 4 600
5 5 5
6 6 6 400
7 7 7
8 8 8 200
9 9 9
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
Figure 9: Confusion matrices of BASIC-L’s predictions on MNIST. Digit only: we use the class names {“0”, “1”, ...,
“9”}; Text only: {“one”, “two”, ..., “nine”}; Digit and Text: {“0 or zero”, “1 or one”, ..., “9 or nine”}. The model has
vastly different confusion matrices for different class name, suggesting that it does not understand the meaning of these
strings, but instead, simply learns to match their embeddings.
35
More than 6 kittens in total. (7) 0.472 No strawberries found in the photo. (7) 0.393
More than 4 kittens in total. 0.342 No blueberries found in the photo. 0.304
More than 2 kittens in total. 0.186 No bananas found in the photo. 0.297
More than 6 puppies in total. <1e-3 No coconuts found in the photo. 0.003
More than 2 puppies in total. <1e-3 No pineapples found in the photo. 0.002
More than 4 puppies in total. <1e-3 No oranges found in the photo. 0.001
Closed road. (7) 0.716 Traffic sign indicating intersection. (3) 0.927
Slippery road. 0.170 Traffic sign indicating closed road. 0.027
Intersection. 0.076 Traffic sign indicating sharp left. 0.027
Stop. 0.034 Traffic sign indicating sharp right. 0.016
Sharp left. 0.003 Traffic sign indicating slippery road. 0.003
Sharp right. 0.002 Traffic sign indicating stop. <1e-3
red light saber. (3) 0.992 blue light saber. (3) 0.990
blue light saber. 0.005 red light saber. 0.004
red led light. 0.002 blue led light. 0.003
red neon light. <1e-3 blue neon light. 0.003
blue led light. <1e-3 red led light. <1e-3
blue neon light. <1e-3 red neon light. <1e-3
a blue light saber on the left and a red light saber on the right. (7) 0.297
a blue light saber on the right and a red light saber on the left. 0.235
a red light saber on the left and a blue light saber on the right. 0.205
a red light saber on the right and a blue light saber on the left. 0.173
a red light saber to the right of a blue light saber. 0.054
a red light saber to the left of a blue light saber. 0.0342
Figure 10: Selected failure cases for BASIC-L over unseen images. (1) The first block indicates that the model is not
precise in object counting and does not well handle negation in the prompts, possibly due the nature of our training data.
(2) The middle block shows two examples to indicate that prompt engineering can play a critical role in providing the
model with sufficient context to produce the desired output. (3) The last block shows that the model does not have the
sense of left and right, which is a relic of random left-right flips of images which we apply during training.
H. Proofs
In this appendix, we complete the proof of Theorem 1 and Theorem 2 by gradually analyzing the gap from
the general case to the special case.
Lemma 3 Let F be a set of maps x 7→ F (x) and G be a set of maps y 7→ G(y). Then, for any δ > 0, with
probability at least 1 − δ, the following holds for all F ∈ F and G ∈ G:
36
where HFx > ˆ
,G,e = {y 7→ exp(F (x) G(y)) : F ∈ F, G ∈ G}, HF ,G,`ˆB = {(x, y) 7→ `B (x, y) : F ∈ F, G ∈
G}, and
exp(F (x)> G(y))
A(x, y) = P .
1 B > G(ŷ )) E [exp(F (x)> G(ȳ))]
B k=1 exp(F (x) k ȳ
1 Pm
Here, Rm (H) := ES,σ [suph∈H m i=1 σi h(xi , yi )] where σ1 , . . . , σm are independent uniform random
variables taking values in {−1, 1}.
Ex,y [`¯M (x, y)] − ÊS [`ˆB (x, y)] = Ex,y [`¯M (x, y)] − Ex,y [`ˆB (x, y)] + Ex,y [`ˆB (x, y)] − ÊS [`ˆB (x, y)] (6)
= Ex,y [`¯M (x, y) − `ˆB (x, y)] + (Ex,y [`ˆB (x, y)] − ÊS [`ˆB (x, y)]). (7)
For the inside of the expectation in the first term, we can write it as
iid
Using Lemma 4 (see below) with the assumption that ŷ1 , ŷ2 , . . . , ŷB ∼ py and exp(F (x)> G(y)) ≤ c1 with
probability one, we have that for any δ > 0 and x ∈ X , with probability at least 1 − δ, the following holds
for all F ∈ F and G ∈ G:
B
r
> 1 X ln(1/δ)
Eȳ [exp(F (x) G(y))] − exp(F (x)> G(ŷk )) ≤ 2 sup RB (HF
x
,G,e ) + c1 , (12)
B x∈X 2B
k=1
where
B
" #
x 1 X
sup RB (HF ,G,e ) = sup Ey,σ sup σi hx (yi ) ,
x∈X x∈X x
hx ∈HF B
,G,e i=1
and
x >
HF ,G,e = {y 7→ exp(F (x) G(y)) : F ∈ F, G ∈ G}.
37
Let us choose the metric space (M, d) to be the Euclidian space Rκ with the Euclidian metric. That is, we
have the -covering of X with the Euclidean balls of radius r, with the r-converging number of
√
NC (r, X ) ≤ (2c8 κ/r)κ .
2c8
Thus, by setting r = √
B
,
√
NC (r, X ) ≤ ( κB)κ .
Using these,
sup γ(x) = inf sup γ(x) − γ(c) + γ(c) ≤ inf sup |γ(x) − γ(c)| + sup γ(c)
x∈X c∈C x∈X c∈C x∈X c∈C
≤ rc9 + sup γ(c)
c∈C
2c8 c9
= √ + sup γ(c).
B c∈C
Here, using equation 12 with union bounds, we have that for any δ > 0, with probability at least 1 − δ, the
following holds for all F ∈ F and G ∈ G:
s √ s √
κ
ln( κB) /δ) κ ln( κB/δ)
x x
sup γ(c) ≤ 2 sup RB (HF ,G,e ) + c1 ≤ 2 sup RB (HF ,G,e ) + c1 .
c∈C x∈X 2B x∈X 2B
Therefore, for any δ > 0, with probability at least 1 − δ, the following holds for all F ∈ F and G ∈ G:
√
s
x 2c8 κ ln( κB/δ)
sup γ(x) ≤ 2 sup RB (HF + √ + c1
,G,e ) (13)
x∈X x∈X B 2B
√ √
1
q
x
= 2 sup RB (HF ,G,e ) + √ 2 2c8 c9 + c1 κ ln( κB/δ) .
x∈X 2B
Combining equations equation 11 and equation 13 with union bound, we have that for any δ > 0, with
probability at least 1 − δ,
√ √
q
>
exp(F (x) G(y)) √1 x
2 2c8 + c1 κ ln(c9 κB/δ) + 2 supx∈X RB (HF ,G,e )
2B
≤ P . (16)
1 B > G(ŷ )) E [exp(F (x)> G(ȳ))]
B k=1 exp(F (x) k ȳ
By defining
exp(F (x)> G(y))
A(x, y) = PB ,
1 > >
B k=1 exp(F (x) G(ŷk )) Eȳ [exp(F (x) G(ȳ))]
38
we have that for any δ > 0, with probability at least 1 − δ,
For the second term, using Lemma 4 with the assumption that `ˆB (x, y) ≤ c2 for (x, y) ∼ p(x,y) , we have
that for any δ > 0, with probability at least 1 − δ, the following holds for all F ∈ F and G ∈ G:
where HF ,G,`ˆB = {(x, y) 7→ `ˆB (x, y) : F ∈ F, G ∈ G}. Combining equations equation 7, equation 17, and
equation 18 with union bound, we have that for any δ > 0, with probability at least 1 − δ, the following holds
all F ∈ F and G ∈ G:
The proof of Lemma 3 partially builds up on Lemma 4 below. Lemma 4 is a direct application of previous
results (Bartlett and Mendelson, 2002; Mohri et al., 2012; Shalev-Shwartz and Ben-David, 2014) to our
problem. We provide a proof of Lemma 4 by slightly modifying the proof of a previous work (Mohri et al.,
2012, Theorem 3.1) for the completeness (the proof utilizes the nonnegativity of h to have a slightly tighter
bound than Theorem 26.5 of Shalev-Shwartz and Ben-David, 2014):
Lemma 4 Let H be a set of maps z 7→ h(z) such that h(z) ∈ [0, λ] for all z in its domain. Then, for any
δ > 0, with probability at least 1 − δ over an i.i.d. draw of m i.i.d. samples (zi )m
i=1 , the following holds for
all maps h ∈ H:
m
r
1 X ln(1/δ)
Ez [h(z)] ≤ h(zi ) + 2Rm (H) + λ , (20)
m 2m
i=1
1 Pm
where Rm (H) := E(z1 ,...,zm ),σ [suph∈H m i=1 σi h(zi )] where σ1 , . . . , σm are independent uniform ran-
dom variables taking values in {−1, 1}.
39
To apply McDiarmid’s inequality to ϕ(S), we compute an upper bound on |ϕ(S) − ϕ(S 0 )| where S and S 0
be two test datasets differing by exactly one point of an arbitrary index i0 ; i.e., Si = Si0 for all i 6= i0 and
Si0 6= Si00 . Then,
h(zi0 ) − h(zi00 ) λ
ϕ(S 0 ) − ϕ(S) ≤ sup ≤ . (22)
h∈H m m
ES [ϕ(S)] (24)
m m
" " # #
1 X 1 X
= ES sup ES 0 h(zi0 ) − h(zi ) (25)
h∈H m m
i=1 i=1
m
" #
1 X 0
≤ ES,S 0 sup (h(zi ) − h(zi ) (26)
h∈H m i=1
m
" #
1 X 0
≤ Eξ,S,S 0 sup ξi (h(zi ) − h(zi )) (27)
h∈H m i=1
m
" #
1 X
≤ 2Eξ,S sup ξi h(zi )) = 2Rm (H) (28)
h∈H m i=1
where the fist line follows the definitions of each term, the second line uses the Jensen’s inequality and the
convexity of the supremum, and the third line follows that for each ξi ∈ {−1, +1}, the distribution of each
term ξi (h(zi0 ) − h(zi )) is the distribution of (h(zi0 ) − h(zi )) since S and S 0 are drawn iid with the same
distribution. The forth line uses the subadditivity of supremum.
x
H.1.1 A NALYZING supx∈X RB (HF ,G,e ) AND Rm (HF ,G,`ˆB )
Lemma 5 Let F be a set of maps x 7→ F (x) and G be a set of maps y 7→ G(y). Then,
B
" #
x 1 X
sup RB (HF ,G,e ) ≤ c1 c3 Ey,σ sup σi G(yi ) .
x∈X G∈G B i=1 2
Proof Since the derivative of exponential function exp(q) is exp(q) and we assume exp(F (x)> G(y)) ≤ c1 ,
the exponential function in the bounded domain of exp(F (x)> G(y)) ≤ c1 has Lipschitz constant of c1 .
40
Therefore,
B B
" # " #
1 X 1 X
Ey,σ sup σi hx (yi ) = Ey,σ sup σi exp(F (x)> G(yi ))
x
hx ∈HF B F ∈F ,G∈G B
,G,e i=1 i=1
B
" #
1 X
≤ c1 Ey,σ sup σi F (x)> G(yi )
F ∈F ,G∈G B i=1
B
" #
c1 >
X
= Ey,σ sup F (x) σi G(yi ) .
B F ∈F ,G∈G i=1
B
" #
c1 X
≤ Ey,σ sup kF (x)k2 σi G(yi )
B F ∈F ,G∈G i=1 2
B
" #
c1 X
≤ sup kF (x)k2 Ey,σ sup σi G(yi )
B F ∈F G∈G i=1
2
B
" #
1 X
= c1 sup kF (x)k2 Ey,σ sup σi G(yi ) .
F ∈F G∈G B i=1 2
Therefore,
B
" #
x 1 X
sup RB (HF ,G,e ) = sup Ey,σ sup σi hx (yi )
x∈X x∈X hx ∈HFx B
,G,e i=1
B
! " #
1 X
≤ c1 sup kF (x)k2 Ey,σ sup σi G(yi ) .
x∈X ,F ∈F G∈G B i=1 2
Lemma 6 Let F be a set of maps x 7→ F (x) and G be a set of maps y 7→ G(y). Then,
D
√ q X
2 2
Rm (HF ,G,`ˆB ) ≤ 2c4 c5 + c6 (Rm (Fk ) + Rm (Gk )) .
k=1
41
Using the definitions,
m
1 X
Rm (HF ,G,`ˆB ) = E(x,y),σ sup σi h(xi , yi )
h∈HF ,G,`ˆ m i=1
B
m
" #
1 X ˆ
= E(x,y),σ sup σi `B (xi , yi )
F ∈F ,G∈G m i=1
m
" #
1 X B exp(F (xi )> G(yi ))
= E(x,y),σ sup σi PB .
m F ∈F ,G∈G >
k=1 exp(F (xi ) G(ŷk ))
i=1
Define
B exp(p> q)
h(p, q) = PB .
> G(ŷ
k=1 exp(p k ))
Then,
m
" #
1 X
Rm (HF ,G,`ˆB ) = E(x,y),σ sup σi h(F (xi ), G(yi )) .
m F ∈F ,G∈G i=1
Moreover,
B
!
∂h(p, q) B exp(p> q) B exp(p> q) X
= PB q > − P 2
>
exp(p G(ŷk ))G(ŷk ) >
∂p >
k=1 exp(p G(ŷk ))
B >
k=1 exp(p G(ŷk )) k=1
∂h(p, q) B exp(p> q)
= PB p> .
∂q >
k=1 exp(p G(ŷk ))
Therefore,
!2 D PB !2 D
> G(ŷ ))G(ŷ )
exp(p> q) X exp(p k k i
X
k∇h(p, q)k22 = 1 PB qi − k=1
PB + p2i
exp(p> G(ŷ )) exp(p > G(ŷ ))
B k=1 k i=1 k=1 k i=1
Thus,
q
k∇h(p, q)k2 ≤ c4 c25 + c26 .
42
Using a vector-contraction inequality, i.e., Corollary 4 of (Maurer, 2016) with the additional expectation of
both sides of the inequality, we have that
Rm (HF ,G,`ˆB )
m
" #
1 X
= E(x,y),σ sup σi h(F (xi ), G(yi ))
m F ∈F ,G∈G i=1
√ p 2
2 m X D m X D
2c4 c5 + c6 X X
≤ E(x,y),σ sup σik F (xi )k + σij G(yi )j
m F ∈F ,G∈G i=1 k=1 i=1 j=1
√ p 2
m X D m X D
2c4 c5 + c26 X X
≤ E(x,y),σ sup σik F (xi )k + sup σij G(yi )j
m F ∈F i=1 k=1 G∈G i=1 j=1
√ p 2
m X D m X D
" #
2c4 c5 + c26 X X
= E(x,y),σ sup σik F (xi )k + E(x,y),σ sup σij G(yi )j
m F ∈F i=1 k=1 G∈G i=1 j=1
√ p 2 D
" m
# D
" m
#!
2c4 c5 + c26 X X X X
≤ Ex,σ sup σi f (xi ) + Ey,σ sup σi g(yi )
m f ∈F k g∈G k
k=1 i=1 k=1 i=1
D D
!
√ q X X
= 2c4 c25 + c26 Rm (Fk ) + Rm (Gk )
k=1 k=1
Proof [Proof of Theorem 2] From Lemma 3, we have that for any for any δ > 0, with probability at least
1 − δ, the following holds for all F ∈ F and G ∈ G:
43
Then by using Lemma 5 and 6,
h PB i
H.2 Bounding Ey,σ supG∈G B1 and D
P
σ
i=1 i G(y i ) k=1 (Rm (Fk ) + Rm (Gk )) for the special
2
case with deep neural networks
We now want to bound Ey,σ supG∈G k m
P PD
i=1 σi G(yi )k2 and k=1 (Rm (Fk ) + Rm (Gk )) in the case where
F and G represent deep neural networks. We consider standard deep neural networks, of the form
Proof Since
B
X B
X
σi G(yi ) = σi WL (σL−1 ◦ ωL−1 ◦ σL−2 · · · σ1 ◦ ω1 )(y)
i=1 2 i=1 2
B
X
≤ kWL kF σi (σL−1 ◦ ωL−1 ◦ σL−2 · · · σ1 ◦ ω1 )(y) ,
i=1 2
h PB i
the proof steps of Theorem 1 of (Golowich et al., 2018) work to bound Eσ supG∈G i=1 σi G(yi ) .
2
Therefore, using the proof of Theorem 1 of (Golowich et al., 2018),
B
" #
c7 ( 2 log(2)L + 1)( L
p Q
1 X l=1 Ml )
Eσ sup σi G(yi ) ≤ √
G∈G B i=1
B
2
44
Lemma 8 Suppose that the function σl0 is 1-Lipschitz and positive homogeneous for all l ∈ [L − 1] and
kxk ≤ c8 for all x ∈ X . Let F = {x 7→ F (x) : (∀l ∈ [L0 −1])[kWl0 kF ≤ Ml0 ]∧k(Wl0 )k kF ≤ ML0 0 ,k } where
(Wl0 )k is the k-th row of Wl0 . Suppose that the function σl is 1-Lipschitz and positive homogeneous for all
l ∈ [L − 1] and kyk ≤ c7 for all y ∈ Y. Let G = {y 7→ G(y) : (∀l ∈ [L − 1])[kWl kF ≤ Ml ] ∧ k(WL )k kF ≤
ML,k } where (Wl )k is the k-th row of Wl . Then,
D
X
(Rm (Fk ) + Rm (Gk ))
k=1
Q 0 −1 0 PD
c8 ( 2 log(2)L0 + 1)( L 0
p
2 log(2)L + 1)( L−1
PD
l=1 Ml )
p
k=1 ML0 ,k
Q
c7 ( l=1 Ml ) k=1 ML,k
≤ √ + √
m m
and qP
2 log(2)L + 1)( L−1 D
p
2
" B
# Q
1 X c7 ( l=1 Ml ) k=1 ML,k
Ey,σ sup σi G(yi ) ≤ √ .
G∈G B i=1
B
2
Proof From Theorem 1 of (Golowich et al., 2018), we have that
Q 0 −1 0
c8 ( 2 log(2)L0 + 1)( L 0
p
l=1 Ml )ML0 ,k
Rm (Fk ) ≤ √
m
and
c7 ( 2 log(2)L + 1)( L−1
p Q
l=1 Ml )ML,k
Rm (Gk ) ≤ √ .
m
Thus,
D
X
(Rm (Fk ) + Rm (Gk ))
k=1
Q 0 −1 0
D
c8 ( 2 log(2)L0 + 1)( L 0
p !
2 log(2)L + 1)( L−1 l=1 Ml )ML0 ,k
p Q
X c7 ( l=1 Ml )ML,k
≤ √ + √ .
m m
k=1
This proves the first statement. For the second statement, since k(Wl )k kF ≤ ML,k , we have that
D
X D
X
kWL k2F = k(WL )k k2F ≤ 2
ML,k .
k=1 k=1
qP
D 2
This implies that kWl kF ≤ k=1 ML,k . Thus, using Lemma 7,
qP
c7 ( 2 log(2)L + 1)( L−1 D
p
2
" B
# Q
1 X l=1 M l ) k=1 ML,k
Ey,σ sup σi G(yi ) ≤ √ .
G∈G B i=1
B
2
45
H.3 Combining all together for the special case with deep neural networks
We now combine the above lemmas to complete the proof of Theorem 1 for the special case with deep neural
networks:
Proof [Proof of Theorem 1] From Lemma 3, for any F and G, and for any δ > 0, with probability at least
1 − δ, the following holds for all F ∈ F and G ∈ G:
B
" #
x 1 X
sup RB (HF ,G,e ) ≤ c1 c3 Ey,σ sup σi G(yi ) ,
x∈X G∈G B i=1 2
and
D
√ q X
2 2
Rm (HF ,G,`ˆB ) ≤ 2c4 c5 + c6 (Rm (Fk ) + Rm (Gk )) .
k=1
Finally, using Lemma 8 for the particular F and G with deep neural networks, we have that
D
X
(Rm (Fk ) + Rm (Gk ))
k=1
p Q 0
0 + 1)( L −1 M 0 )
PD 0
2 log(2)L + 1)( L−1
PD
( 2 log(2)L
p
c k=1 ML0 ,k
Q
c7 ( l=1 M l ) k=1 M L,k 8 l=1 l
≤ √ + √
m m
and
qP
2 log(2)L + 1)( L−1 D
p
2
" B
# Q
1 X c7 ( l=1 M l ) k=1 ML,k
Ey,σ sup σi G(yi ) ≤ √ .
G∈G B i=1
B
2
46
Combining those, we have that for any δ > 0, with probability at least 1 − δ, the following holds for all
F ∈ F and G ∈ G:
47