Ensor Roduct Ttention Is All You Need
Ensor Roduct Ttention Is All You Need
Abstract
1 Introduction
Large language models (LLMs) have revolutionized natural language processing, demonstrating ex-
ceptional performance across tasks (Brown et al., 2020; Chowdhery et al., 2023; Touvron et al.,
2023; Bubeck et al., 2023). As these models evolve, their ability to process longer contexts be-
comes increasingly important for sophisticated applications such as document analysis, complex
reasoning, and code completions. However, managing longer sequences during inference poses sig-
nificant computational and memory challenges, particularly due to the storage of key-value (KV)
caches (Zhang et al., 2023c; Liu et al., 2024c). Because memory consumption grows linearly with
sequence length, the maximum context window is limited by practical hardware constraints.
A variety of solutions have been explored to address this memory bottleneck. Some approaches com-
press or selectively prune cached states through sparse attention patterns (Child et al., 2019) or token
eviction strategies (Zhang et al., 2023c; Xiao et al., 2024; Ribar et al., 2024), though such methods
risk discarding tokens that may later prove important. Other work proposes off-chip storage of key-
value states (He & Zhai, 2024), at the expense of increased I/O latency. Attention variants like
multi-query attention (MQA) (Shazeer, 2019) and grouped-query attention (GQA) (Ainslie et al.,
2023) reduce per-token cache requirements by sharing keys and values across heads, but often com-
promise flexibility or require significant architectural modifications. Meanwhile, low-rank weight
factorization methods such as LoRA (Hu et al., 2022) effectively reduce fine-tuning memory, yet
do not address the KV cache overhead that dominates runtime. The recently introduced Multi-head
Latent Attention (MLA) in Deepseek-V2 (Liu et al., 2024a) caches compressed key-value repre-
∗
Equal contribution; ⋄ Tech lead; † Corresponding author.
Linear
Concat
1 1 1
Scale R Scale R Scale R
Q K V
RoPE RoPE
AQ BQ AK BK AV BV
Linear Linear Linear Linear Linear Linear
Figure 1: Tensor Product Attention (TPA) in the Tensor ProducT ATTenTion Transformer (T6).
Different from multi-head attention, in each layer, firstly the hidden state goes through different
linear layers to get the latent factor matrices A’s and B’s for query, key, and value. We additionally
apply RoPE to BQ and BK for query and key. Then the multi-head query, key, and value vectors
are attained by the tensor product of A(·) and B(·) . Finally, the output of TPA is produced by scaled
dot-product attention followed by linear projection of concatenated results of multiple heads.
sentations but needs additional position-encoded parameters per head due to incompatibility with
Rotary Position Embedding (RoPE) efficiently (Su et al., 2024b).
In order to overcome the limitations of existing approaches, we introduce Tensor Product Atten-
tion (TPA), as illustrated in Figure 1, a novel architecture that uses higher-order tensors to factorize
queries (Q), keys (K), and values (V) during attention computation. By dynamically factorizing
activations rather than static weights (e.g., LoRA), TPA constructs low-rank, contextual represen-
tations that substantially reduce KV cache memory usage with improved representational capacity.
In practice, TPA can reduce the memory overhead by an order of magnitude compared to stan-
dard multi-head attention (MHA) with lower pretraining validation loss (perplexity) and improved
downstream performance.
A key advantage of TPA is its native compatibility with rotary positional embeddings (RoPE) (Su
et al., 2024b), enabling a straightforward drop-in replacement for multi-head attention (MHA) layers
in modern LLM architectures such as LLaMA (Touvron et al., 2023) and Gemma (Team et al., 2024).
Our primary contributions are summarized as follows:
• We propose Tensor Product Attention (TPA), A mechanism that factorizes Q, K, and V activa-
tions using contextual tensor-decompositions to achieve 10× or more reduction in inference-time
KV cache size relative to standard attention mechanism (Vaswani et al., 2017) with improved per-
formance compared to previous methods such as MHA, MQA, GQA, and MLA. In addition, we
unify existing attention mechanisms by revealing that MHA, MQA, and GQA all arise naturally
as non-contextual variants of TPA.
• We propose Tensor ProducT ATTenTion Transformer (T6), a new TPA-based model architecture
for sequence modeling. On language modeling experiments, T6 consistently improves validation
perplexity and downstream evaluation performance with reduced KV cache size.
• We show TPA integrates seamlessly with RoPE (Su et al., 2024b), facilitating easy adoption in
popular foundation model architectures such as LLaMA and Gemma.
2
3.1
Large Model, FineWeb-edu100B 3.1
Large Model, FineWeb-edu100B
MHA MHA
3.0 MQA 3.0 MQA
GQA GQA
2.9 MLA 2.9 MLA
TPA-KVonly TPA-KVonly
Validation Loss
TPA TPA
Training Loss
2.8 2.8
2.7 2.7
2.6 2.6
2.5 2.5
2.4 2.4
0 10 20 30 40 50 0 10 20 30 40 50
Training tokens (B) Training tokens (B)
(a) Training Loss (b) Validation Loss
Figure 2: Training loss and validation loss of pretraining large-size (773M) models with different
attention mechanisms on the FineWeb-Edu-100B dataset.
2 Background
In this section, we review several classical forms of attention: Scaled Dot-Product Attention, Multi-
Head Attention (MHA) (Vaswani et al., 2017), Multi-Query Attention (MQA) (Shazeer, 2019),
and Grouped Query Attention (GQA) (Ainslie et al., 2023), as well as Rotary Position Embedding
(RoPE, Su et al. (2024b)). We also introduce a recent method called Multi-head Latent Attention
(MLA) used in DeepSeek-V2 (Liu et al., 2024a) and DeepSeek-V3 (Liu et al., 2024b).
Notations. We use bold uppercase letters (e.g., X, Q) for matrices, bold lowercase (e.g., a, b) for
vectors, and italic uppercase (e.g., WiQ ) for learnable parameter matrices. We denote by [n] the set
{1, . . . , n} for some positive integer n. We use ⊤ to denote the transpose of a vector or a matrix. Let
dmodel be the embedding dimension, h the number of attention heads, dh the dimension per head,
xt ∈ Rd the input for the t-th token at a given attention layer, X ∈ RT ×dmodel denotes the input
embeddings for T tokens, and Q, K, V ∈ RT ×h×dh denote the queries, keys, and values of h heads
for T tokens. With a little abuse of notation, Qi , Ki , Vi ∈ RT ×dh denote the i-th head of queries,
keys, and values, and Qt , Kt , Vt ∈ Rh×dh denote the heads of the query, key, and value for t-th
token.
Throughout the paper, W Q , W K , W V denote projection matrices for queries, keys, and values, re-
spectively. In multi-head attention, each head is associated with its own set of WiQ , WiK , WiV , and
each has dimension WiQ , WiK , WiV ∈ R dmodel ×dk , where dk is typically set to dh , the dimension of
each head.5 Similarly, we have an output projection matrix W O ∈ R(h·dh )×dmodel . For methods like
MQA and GQA, some of these are shared or partially shared across heads, but their shapes remain
consistent.
We define the tensor product of two vectors as follows: for vectors a ∈ Rm , b ∈ Rn , the tensor
product of a and b is:
a ⊗ b = C ∈ Rm×n , with Cij = ai bj ,
where ai and bj are the i-th and j-th elements of a and b respectively, and Cij is the (i, j)-th entry
of C. We also define the vectorization of a matrix C ∈ Rm×n by:
vec(C) = d ∈ Rmn , with di·n+j = Cij ,
where di·n+j is the (i · n + j)-th element of d.
Scaled dot-product attention (Vaswani et al., 2017) determines how to focus on different parts of an
input sequence by comparing queries (Q) and keys (K). It produces a weighted combination of the
5
Often, one sets h × dh = dmodel , so each head has query/key/value dimension dh .
3
values (V). Formally, the attention output is:
⊤
QK
Attention(Q, K, V) = Softmax √
dk
V,
where each of Q, K, V is an (n × dk ) matrix for n tokens and key dimension dk . The division by
√
dk stabilizes training by controlling the scale of the inner products.
Multi-Head Attention (MHA) extends scaled dot-product attention by dividing the model’s internal
representation into several heads. Each head learns different projections for queries, keys, and
values, allowing the model to attend to different types of information. For each token embedding
xt ∈ Rdmodel , MHA computes each head i as follows:
Qt,i = (WiQ )⊤ xt ∈ Rdh , Kt,i = (WiK )⊤ xt ∈ Rdh , Vt,i = (WiV )⊤ xt ∈ Rdh ,
headi = Attention Qi , Ki , Vi ,
where WiQ , WiK , WiV ∈ Rdmodel ×dh are learnable projection matrices for the i-th head,
Qi , Ki , Vi ∈ RT ×dh . After computing each head’s attention, the outputs are concatenated and
mapped back to the original dimension via another matrix W O ∈ Rhdh ×dmodel :
MHA(Q, K, V) = Concat head1 , . . . , headh W O .
MHA can capture a rich set of dependencies while each head focuses on different subspaces.
Multi-Query Attention (MQA) (Shazeer, 2019) significantly reduces memory usage by sharing keys
and values across heads, while still preserving unique query projections. For a sequence of embed-
dings X ∈ RT ×dmodel ,
Qi = XWiQ , Kshared = XWshared K
, Vshared = XWshared V
.
T ×dk
Hence, each head i only has a distinct query Qi ∈ R , but shares the same key Kshared ∈ RT ×dk
and value Vshared ∈ RT ×dk . In practice, this means:
WiQ ∈ Rdmodel ×dk , K
Wshared V
, Wshared ∈ R dmodel ×dk .
The resulting MQA operation is:
MQA(X) = Concat head1 , . . . , headh W O ,
where
headi = Attention Qi , Kshared , Vshared .
By sharing these key and value projections, MQA cuts down on memory usage (especially for the
key-value cache in autoregressive inference) but loses some expressivity since all heads must rely
on the same key/value representations.
Grouped Query Attention (GQA) (Ainslie et al., 2023) generalizes MHA and MQA by grouping
heads. Specifically, we partition the h total heads into G groups. Each group has a single set of
keys and values, but each individual head within that group still retains its own query projection.
Formally, if g(i) maps a head i ∈ [h] to its group index g ∈ [G], then:
K
Kg(i) = X Wg(i) , V
Vg(i) = X Wg(i) , Qi = X WiQ ,
and
headi = Attention Qi , Kg(i) , Vg(i) .
Again, WgK , WgV ∈ Rdmodel ×dk for each group g, and WiQ ∈ R dmodel ×dk for each head i. The
complete output is again a concatenation of all heads:
GQA(X) = Concat head1 , . . . , headh W O .
By adjusting G between 1 and h, GQA can interpolate between sharing all key/value projections
across heads (i.e., MQA) and having one set of projections per head (i.e., MHA).
4
2.5 Rotary Position Embedding (RoPE)
Many recent LLMs use rotary position embedding (RoPE; Su et al., 2024b) to encode positional
information in the query/key vectors. Specifically, let RoPEt denote the rotation operator Tt ∈
Rdh ×dh corresponding
to the t-th position.
Tt is a block-diagonal matrix, which consists of block-
cos(tθj ) − sin(tθj )
diagonal matrix , j ∈ {1, · · · , dh /2}, where {θj } are pre-defined frequency
sin(tθj ) cos(tθj )
parameters, e.g., θj = 1/100002j/dh . Then we define
RoPE (Qt ) ≜ Qt Tt , where Qt ∈ Rh×dh .
A fundamental property is that
Tt T⊤
s = Tt−s , (2.1)
which ensures that relative positions (t − s) are preserved, thereby providing a form of translation
invariance in the rotary position embedding.
Below, we briefly outline the Multi-head Latent Attention (MLA) approach used by DeepSeek-
V2 (Liu et al., 2024a) and DeepSeek-V3 (Liu et al., 2024b). MLA introduces a low-rank compres-
sion of the keys and values to reduce the Key-Value (KV) caching cost at inference.
CKV = XW DKV , (W DKV ∈ R dmodel ×dc ),
Concat KC C C C KV
W U K , (W U K ∈ Rdc ×dh h ),
1 , K2 , . . . , Kh = K = C
R
KR = RoPE XW KR , (W KR ∈ Rdmodel ×dh ),
Ki = Concat KC R
i ,K ,
Concat V1 , V2 , . . . , Vh = V = C W , (W U V ∈ Rdc ×dh h ),
C C C C KV UV
where CKV ∈ RT ×dc is the compressed KV latent (with dc ≪ dh h), and RoPE(·) represents the
RoPE transform applied to the separate key embeddings KR of dimension dR
h . Thus, only C
KV
and
R
K need to be cached, reducing KV memory usage while largely preserving performance compared
to standard MHA (Vaswani et al., 2017).
MLA also compresses the queries, lowering their training-time memory footprint:
′
CQ = XW DQ , (W DQ ∈ R dmodel ×dc ),
′
Concat QC C C C Q UQ
, (W U Q ∈ R dc ×dh h ),
1 , Q2 , . . . , Qh = Q = C W
′ R
Concat QR R R R Q QR
, (W QR ∈ Rdc ×dh h ),
1 , Q2 , . . . , Qh = Q = RoPE C W
Q = Concat QC , QR .
′
Here, CQ ∈ RT ×dc (with d′c ≪ dh h) is the compressed query latent. As above, each W DQ , W U Q ,
and W QR connects these lower-dimensional query latents back to h heads of dimension dh + dRh.
Given compressed queries, keys, and values, the final attention output for the t-th token is:
Q K⊤
Oi = Softmax √ i i R ViC ,
dh +dh
U = Concat O1 , O2 , . . . , Oh W O ,
5
(·)
where Wi is the i-th head of the original weight, and [WiDQ WiU Q (WiU K )⊤ ] can be computed
previously for faster decoding. However, this process fails when RoPE is considered according to Su
(2024). Since RoPE can be considered as multiplication with a block-diagonal matrix Tt ∈ Rdh ×dh
(see Section 2.5), with the property (2.1) that Tt T⊤
s = Tt−s , then
UQ ⊤
q⊤ ⊤
t,i ks,i = [Tt (Wi ) (WiDQ )⊤ xt ]⊤ [Ts ⊤ (WiU K )⊤ cKV
s ]
DQ
(2.3)
= x⊤
t [Wi WiU Q Tt−s (WiU K )⊤ ]cKV
s .
Different from (2.2), acceleration by pre-computing [WiDQ WiU Q Tt−s (WiU K )⊤ ] fails since it
varies for different (t, s) position pairs. Therefore, MLA adds the additional kR
t part with a rel-
atively smaller size for RoPE compatibility. In Section 3.2, we will show that TPA addresses the
issue of RoPE-incompatibility by applying tensor product.
Let xt ∈ Rdmodel for t = 1, . . . , T be the hidden-state vector corresponding to the t-th token in a
sequence of length T . A typical multi-head attention block has h heads, each of dimension dh ,
satisfying dmodel = h × dh . Standard attention projects the entire sequence into three tensors,
Q, K, V ∈ RT ×h×dh , where Qt , Kt , Vt ∈ Rh×dh denote the slices for the t-th token.
Contextual Factorization (CF). Instead of forming each head’s query, key, or value via a single
linear map, TPA factorizes each Qt , Kt , Vt into a sum of (contextual) tensor products whose ranks
are Rq , Rk , and Rv , respectively and may differ. Specifically, for each token t, with a small abuse
of notation, we define:
RQ
1 X Q
Qt = a (xt ) ⊗ bQ r (xt ), aQ h Q
r (xt ) ∈ R , br (xt ) ∈ R ,
dh
(3.1)
RQ r=1 r
RK
1 X
Kt = aK (xt ) ⊗ bK
r (xt ), aK h K dh
r (xt ) ∈ R , br (xt ) ∈ R , (3.2)
RK r=1 r
RV
1 X
Vt = aV (xt ) ⊗ bVr (xt ), aVr (xt ) ∈ Rh , bVr (xt ) ∈ Rdh . (3.3)
RV r=1 r
Hence, for queries, each tensor product aQ Q h
r (xt ) ⊗ br (xt ) : R × R
dh
→ Rh×dh adds up to form the
h×dh
query slice Qt ∈ R . Similarly, analogous definitions apply to key slice Kt and value slice Vt .
Latent Factor Maps. Each factor in the tensor product depends on the token’s hidden state xt . For
example, for queries, we can write:
Q Q
aQ a h
r (xt ) = Wr xt ∈ R , bQ b dh
r (xt ) = Wr xt ∈ R ,
and similarly for keys and values.
One often merges the rank index into a single output dimension. For instance, for queries:
Q Q
aQ (xt ) = W a xt ∈ RRq ·h , bQ (xt ) = W b xt ∈ RRq ·dh ,
which are then reshaped into AQ (xt ) ∈ RRq ×h and BQ (xt ) ∈ RRq ×dh . Summing over Rq and
scaled by R1q yields
1
Qt = AQ (xt )⊤ BQ (xt ) ∈ Rh×dh .
RQ
6
Repeating for all tokens reconstitutes Q ∈ RT ×h×dh . Similar procedures can be applied to obtain
K and V with ranks Rk and Rv , respectively.
Scaled Dot-Product Attention. Once Q, K, V are factorized, multi-head attention proceeds as in
standard Transformers. For each head i ∈ {1, . . . , h}:
headi = Softmax √1d Qi (Ki )⊤ Vi , (3.4)
h
T ×dh
where Qi , Ki , Vi ∈ R are the slices along the head dimension. Concatenating these h heads
along the last dimension yields an RT ×(h·dh ) tensor, which is projected back to RT ×dmodel by an
output weight matrix W O ∈ R(h·dh )×dmodel :
TPA(Q, K, V) = Concat head1 , . . . , headh W O .
(3.5)
Q K V Q K
Parameter Initialization. We initialize the weight matrices Wra , Wra , Wra , Wrb , Wrb ,
V
Wrb using Xavier initialization (Glorot & Bengio, 2010).pSpecifically, each pentry of the weight ma-
trix is drawn from a uniform distribution with bounds [− 6/(nin + nout ), 6/(nin + nout )], where
nin and nout are the input and output dimensions of the respective weight matrices. This initialization
strategy helps maintain the variance of activations and gradients across the network.
In a typical workflow of adding RoPE to standard multi-head attention, one first computes Qt , Ks ∈
Rh×dh of the t-th token and s-th token and then applies:
Qt 7→ Q ft = RoPEt (Qt ), Ks 7→ Kfs = RoPEs (Ks ).
Direct Integration. A useful optimization is to integrate RoPE directly into the TPA factorization.
For example, one can pre-rotate the token-dimension factors:
Be K (xt ) ←− RoPEt BK (xt ) , (3.6)
yielding a pre-rotated key representation:
RK
et = 1 1
X
aK K
AK (xt )⊤ RoPEt BK (xt ) .
K (r) (xt ) ⊗ RoPEt b(s) (xt ) =
RK r=1 RK
Thus, each Kt is already rotated before caching, removing the need for explicit rotation at the
decoding time and accelerating autoregressive inference. Depending on hardware and performance
requirements, one can also adopt different RoPE integration approaches for training and inference.
Theorem 1 (RoPE’s Compatibility with TPA). Let Qt be factorized by TPA as
1
Qt = AQ (xt )⊤ BQ (xt ) ∈ Rh×dh ,
RQ
where AQ (xt ) ∈ RRQ ×h and BQ (xt ) ∈ RRQ ×dh . Then we have:
1
AQ (xt )⊤ B
RoPE(Qt ) = e Q (xt ), where B
e Q (xt ) = RoPEt BQ (xt ) . (3.7)
RQ
In addition, assume Qt and Ks are factorized by TPA and then rotated by RoPEt , RoPEs . Let
Q
e t = RoPEt (Qt ) and K
e s = RoPEs (Ks ). Then we have
RoPEt−s (Qt )K⊤ = Q e ⊤,
et K
s s
Focusing on individual heads i, the above matrix equality implies:
⊤
RoPEt−s qt,i ks,i = q e⊤ k
es,i .
t,i
dh
where qt,i ∈ R is the i-th query head of t-th token, and ks,i ∈ Rdh is the j-th key head of s-th
token, and
et,i = RoPE(qt,i ) = Tt qt,i ∈ Rdh ,
q es,i = RoPE(ks,i ) = Ts ks,i ∈ Rdh .
k
Theorem 1 indicates that TPA does not break RoPE’s relative translational property. We prove
Theorem 1 in Appendix A. In short, RoPEt acts as a block-diagonal orthogonal transform (i.e., a
matrix Tt ) on BQ (xt ). Consequently, AQ (xt ) remains unchanged, while each column of BQ (xt )
is rotated appropriately, preserving the TPA structure.
7
3.3 KV Caching and Memory Reduction
In autoregressive decoding, standard attention caches Kt , Vt ∈ Rh×dh for each past token t. This
accumulates to RT ×h×dh for keys and RT ×h×dh for values, i.e., 2 T h dh total.
TPA Factorized KV Caching. Instead of storing the full Kt and Vt , TPA stores only their factor-
ized ranks. Specifically, we keep
AK (xt ), B
e K (xt ) and AV (xt ), BV (xt ),
Table 1: Comparison of different attention mechanisms. Here, RQ , RK , and RV denote the ranks
for queries, keys, and values in TPA, respectively. Variants of TPA, such as TPA (KVonly), TPA
(Non-contextual A), and TPA (Non-contextual B), are detailed in Section 3.5. For MLA, dR h and
dh are the dimensions for RoPE and non-RoPE parts; d′c and dc are the dimensions of compressed
vectors for query and key-value, respectively.
M ETHOD KV C ACHE # PARAMETERS # Q UERY H EADS # KV H EADS
MHA 2hdh 4d2model h h
MQA 2dh (2 + 2/h)d2model h 1
GQA 2gdh (2 + 2g/h)d2model h g
d′c (dmodel + hdh + hdR h)
MLA dc + dR h h
h +dmodel dR
h + dc (dmodel + 2hdh )
aQ h
i = RQ ei ∈ R , (ei ∈ Rh is the i-th standard basis vector), (3.8)
so that ei ⊗ · corresponds to the i-th head of Qt .
(b) Contextual token factors. Define
bQ Q ⊤ dh
i (xt ) = (Wi ) xt ∈ R , (3.9)
where WiQ ∈ Rdmodel ×dh is the per-head query projection defined before, hence bQ
i (xt ) depen-
dent on xt .
8
Substituting (3.8)–(3.9) into (3.1) gives:
h h
X i
Qt = ei ⊗ (WiQ )⊤ xt ∈ Rh×dh . (3.10)
i=1
Each term ei ⊗ (WiQ )⊤ xt in (3.10) contributes only to the i-th row, reconstituting the usual
MHA form of Qt . Analogous constructions hold for Kt and Vt using WiK , WiV . Thus, MHA is a
non-contextual, full-rank variant of TPA.
TPA with Non-contextual A. More broadly, TPA can use non-contextual head-dimension factors
aQ K V h Q K V
r , ar , ar ∈ R (i.e., independent of xt ), while allowing br (xt ), br (xt ), br (xt ) to remain
context-dependent. Then, for keys:
RK
1 X
Kt = aK ⊗ bK
r (xt ),
RK r=1 r
and similarly for queries/values. This reduces per-token computations and can be effective when
head-dimension relationships are relatively stable across all tokens.
MQA and GQA as Non-Contextual TPA. Multi-Query Attention (MQA) (Shazeer, 2019) and
Grouped Query Attention (GQA) (Ainslie et al., 2023) also emerge naturally from TPA by restricting
the head-dimension factors to be non-contextual and low-rank:
• MQA as Rank-1 TPA. In MQA, all heads share a single set of keys/values, corresponding to
RK = RV = 1 along the head dimension. Concretely,
Kt = (1, . . . , 1)⊤ ⊗ bK (xt ), Vt = (1, . . . , 1)⊤ ⊗ bV (xt ),
forces every head to use the same Kt , Vt . Each head retains a distinct query projection, matching
the MQA design.
• GQA as Grouped Rank-1 TPA. GQA partitions h heads into G groups, each sharing keys/values
within that group. In TPA form, each group g has a dedicated non-contextual factor pair aK V
g , ag ∈
h
R , which acts as a “mask” for the heads in that group. Varying G from 1 to h interpolates from
MQA to standard MHA.
Hence, by constraining TPA’s head-dimension factors to be constant masks (one for MQA; multiple
for GQA), these popular variants are recovered as special cases.
TPA with Non-contextual B. Conversely, one may fix the token-dimension factors bQ K V
r , br , br ∈
dh Q K V
R as learned parameters, while allowing ar (xt ), ar (xt ), ar (xt ) to adapt to xt . For keys:
RK
1 X
Kt = aK (xt ) ⊗ bK
r ,
RK r=1 r
and similarly for keys/values. This arrangement is effective if the token-dimension structure remains
mostly uniform across the sequence, while the head-dimension factors capture context.
TPA KV Only. One can preserve a standard query mapping,
Qt = W Q xt ∈ Rh×dh ,
and factorize only the keys and values. This leaves the query projection as the original linear trans-
formation while reducing memory usage via factorized KV caching.
TPA KV with Shared B. Another variant is to share the token-dimension factors of keys and
values:
bK V
r (xt ) = br (xt ),
lowering parameter counts and the KV cache footprint. While it constrains K and V to be formed
from the same token basis, it can still perform well and provide additional memory savings.
9
Nonlinear Head Factors. Rather than applying purely linear mappings to the head-dimension
factors aQ K V
r , ar , ar , one may introduce element-wise nonlinearities such as σ(·) or softmax(·). This
effectively yields a Mixture of Heads Attention (MoH Attention), where each component becomes a
learned mixture weight modulated by the nonlinearity.
Discussion. These variants illustrate TPA’s versatility in balancing memory cost, computational
overhead, and representation power. By choosing which dimensions (heads or tokens) remain con-
textual and adjusting ranks (RQ , RK , RV ), TPA unifies multiple existing attention mechanisms—
such as MHA, MQA, and GQA—under one framework, while potentially reducing the KV cache
size by an order of magnitude during autoregressive inference.
We propose a new architecture called Tensor ProducT ATTenTion Transformer (T6), which uses
our Tensor Product Attention (TPA) in place of standard MHA (multi-head attention) or GQA
(grouped-query attention). Building upon the query, key, and value tensors Q, K, V ∈ RN ×h×dh
defined in Section 3.1, T6 utilize the overall architecture of LLaMA (Touvron et al., 2023) while
changing the self-attention block to our TPA-based version. The feed-forward network (FFN) adopts
a SwiGLU layer, as in (Shazeer, 2020; Touvron et al., 2023).
TPA QKV Factorization. Let each token’s hidden-state vector be xt ∈ Rdmodel , and we fol-
low Section 3.1 to project the entire sequence into three tensors Q, K, V ∈ RT ×h×dh ,
where Qt , Kt , Vt ∈ Rh×dh denote the slices for the t-th token. The factor components
aQ Q K K V V
r (xt ), br (xt ), ar (xt ), br (xt ), ar (xt ), br (xt ) are produced by linear transformations on xt .
Q Q
For instance, letting Wra ∈ Rh×dmodel and Wrb ∈ Rdh ×dmodel , we have:
Q Q
aQ a
r (xt ) = Wr xt , bQ b
r (xt ) = Wr xt .
In practice, we merge all ranks r into a single dimension of the output, reshape, and sum over rank
indices; see Section 3.1 for details. The factorization for K and V follows the same pattern.
Rotary Positional Embedding (RoPE). As discussed in Section 3.2, RoPE (Su et al., 2024b) is
applied to the Q and K. Within TPA, we pre-rotate the factor bQ K
t (xt ) and bs (xs ) directly, so that
each Ks is already rotated prior to caching, see (3.6) and Theorem 1.
Attention Step and Output Projection. Once we have Q, K, V factorized per token with RoPE
applied on Q and K, the attention step proceeds for each head i ∈ {1, . . . , h} using (3.4). Finally,
concatenating these h heads and then projecting them back using an output weight matrix gives the
final attention result, as shown in (3.5).
SwiGLU Feed-Forward Network. Following Shazeer (2020); Touvron et al. (2023), our T6 uses
a SwiGLU-based Feed-Forward Network (FFN):
FFN(x) = σ(x W1 ) ⊙ (x W2 ) W3 ,
where σ is the SiLU (a.k.a., swish) nonlinearity, ⊙ is element-wise product, and W1 , W2 , W3 are
learnable parameters. Note that other activation functions can also be used.
Overall T6 Block Structure. Putting everything together, one T6 block consists of:
x ← x + TPA RMSNorm(x) ,
x ← x + SwiGLU-FFN RMSNorm(x) .
We place norm layers (e.g., RMSNorm) before each sub-layer. Stacking L such blocks yields a T6
model architecture with L layers.
4 Experiments
4.1 Language Modeling Tasks
All experiments reported in this paper are implemented on the nanoGPT code base (Karpathy, 2022),
using the FineWeb-Edu 100B dataset (Lozhkov et al., 2024). The dataset contains 100 billion tokens
for training and 0.1 billion tokens for validation. We compare T6 against the baseline Llama archi-
tecture (Touvron et al., 2023) with SwiGLU activation (Shazeer, 2020) and RoPE embeddings (Su
10
et al., 2024a), as well as Llama variants that replace Multi-Head Attention (MHA; Vaswani et al.,
2017) with Multi-Query Attention (MQA; Shazeer, 2019), Grouped Query Attention (GQA; Ainslie
et al., 2023), or Multi-head Latent Attention (MLA; Liu et al., 2024a). In our experiments, the num-
ber of heads h is adjusted for each attention mechanism to ensure that all attention mechanisms have
the same number of parameters as the standard Multi-Head Attention (MHA), which has 4d2model
parameters per attention layer. We train models at four scales: small (124M parameters), medium
(353M), and large (773M). Details on architecture hyperparameters and training hardware appear in
Appendix B.1.
Training Setup. We follow the nanoGPT training configuration. In particular, we use the
AdamW (Loshchilov, 2017) optimizer with (β1 , β2 ) = (0.9, 0.95), a weight decay of 0.1, and
gradient clipping at 1.0. We follow the same setting as nanoGPT that the learning rate is managed
by a cosine annealing scheduler (Loshchilov & Hutter, 2016) with 2,000 warmup steps and a (total)
global batch size of 480. For the small, medium, and large models, we set maximum learning rates
of 6 × 10−4 , 3 × 10−4 , and 2 × 10−4 (respectively), and minimum learning rates of 3 × 10−5 ,
3 × 10−5 , and 1 × 10−5 (respectively).
Training & Validation Curves. Figures 2 and 3 compare training and validation loss curves for
the large (773M) and medium (353M) models on FineWeb-Edu-100B. Overall, TPA (red curves)
and its simpler variant TPA-KVonly (pink curves) converge as fast as or faster than the baselines
(MHA, MQA, GQA, MLA) while also achieving visibly lower final losses. For instance, in Fig-
ure 2(b), TPA and TPA-KVonly remain below the MHA baseline in terms of validation loss at
nearly all training stages. Meanwhile, Multi-Head Latent Attention (MLA) (Liu et al., 2024a) (blue
curves) generally trains more slowly and yields higher losses.
Validation Perplexity. Figure 4 shows the validation perplexities of the medium- and large-scale
models. Mirroring the loss curves, TPA and TPA-KVonly steadily outperform MHA, MQA, GQA,
and MLA over the course of training. By the end of pretraining (around 49B tokens), TPA-based
approaches achieve the lowest perplexities in most configurations.
Downstream Evaluation. We evaluate zero-shot and two-shot performance on standard bench-
marks, including ARC (Yadav et al., 2019), BoolQ (Clark et al., 2019), HellaSwag (Zellers et al.,
2019), OBQA (Mihaylov et al., 2018), PIQA (Bisk et al., 2020), WinoGrande (Sakaguchi et al.,
2020) and MMLU (Hendrycks et al., 2021), using the lm-evaluation-harness codebase (Gao
et al., 2024). For ARC-E, ARC-C, HellaSwag, OBQA, PIQA, and SciQ, we report accuracy norm;
for other tasks, we report standard accuracy. Tables 8–9 in the appendix present results for small
models; Tables 2–3 for medium models; Tables 4–5 for large models;
For the medium-size (353M) models (Tables 2–3), TPA generally ties or outperforms all competing
methods, achieving, for example, an average of 51.41% in zero-shot mode versus MHA’s 50.11%,
MQA’s 50.44%, and MLA’s 48.96%. When given two-shot prompts, TPA again leads with 53.12%
average accuracy. A similar trend appears for the large-size (773M) models (Tables 4–5), where
TPA-KVonly attains the highest average (53.52% zero-shot, 55.33% two-shot), closely followed by
full TPA.
Our experiments confirm that TPA consistently matches or exceeds the performance of established
attention mechanisms (MHA, MQA, GQA, MLA) across medium and large model scales. The
fully factorized TPA excels on mid-scale models, while TPA-KVonly can rival or surpass it at
larger scales. In both cases, factorizing the attention activations shrinks autoregressive KV cache
requirements by up to 5×–10×, thus enabling much longer context windows under fixed memory
budgets. In summary, tensor product attention provides a flexible, memory-efficient alternative to
standard multi-head attention, advancing the scalability of modern language models.
5 Related Work
11
3.1
Medium Model, FineWeb-edu100B 3.1
Medium Model, FineWeb-edu100B
MHA MHA
MQA MQA
3.0 GQA 3.0 GQA
MLA MLA
TPA-KVonly TPA-KVonly
Validation Loss
2.9 TPA 2.9 TPA
Training Loss
2.8 2.8
2.7 2.7
2.6 2.6
0 10 20 30 40 50 0 10 20 30 40 50
Training tokens (B) Training tokens (B)
(a) Training Loss (b) Validation Loss
Figure 3: The training loss and validation loss of medium-size (353M) models with different atten-
tion mechanisms on the FineWeb-Edu 100B dataset.
21
Medium Model, FineWeb-edu100B 20
Large Model, FineWeb-edu100B
MHA MHA
20 MQA 19 MQA
GQA GQA
19 MLA 18 MLA
Validation Perplexity
Validation Perplexity
TPA-KVonly TPA-KVonly
18 TPA 17 TPA
17 16
16 15
15 14
14 13
13 12
0 10 20 30 40 50 0 10 20 30 40 50
Training tokens (B) Training tokens (B)
(a) Validation Perplexity of Medium Models (b) Validation Perplexity of Large Models
Figure 4: The validation perplexity of medium-size (353M) models and large-size (773M) models
with different attention mechanisms on the FineWeb-Edu 100B dataset.
Table 2: The evaluation results of medium models with different attention mechanisms pretrained
using the FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The best scores in each
column are bolded. Abbreviations: HellaSw. = HellaSwag, W.G. = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSw. OBQA PIQA W.G. MMLU SciQ Avg.
MHA 56.52 29.27 58.84 44.06 35.00 68.44 51.07 25.35 76.40 49.44
MQA 55.68 28.24 60.86 44.17 35.20 68.66 52.72 25.14 72.90 49.29
GQA 54.88 29.61 56.36 43.77 35.20 68.82 52.57 25.41 74.80 49.05
MLA 55.30 29.27 58.96 41.92 35.40 67.25 51.78 25.20 75.60 48.96
TPA-KVonly 57.11 30.03 61.25 44.83 34.60 69.04 54.54 23.35 74.60 49.93
TPA 59.30 31.91 60.98 45.57 34.60 69.48 53.91 24.93 77.20 50.88
et al., 2021; Zhang et al., 2023b; Sun et al., 2023; Zhang et al., 2024). To decrease memory usage
and circumvent the limitation of memory bandwidth in training, Shazeer (2019) proposed Multi-
Query Attention (MQA) where multiple query heads share the same key head and value head. To
tackle with the issue of quality degradation and instability in training, Grouped-Query Attention
(GQA) (Ainslie et al., 2023) divides queries into several groups, and each group of queries shares a
single key head and value head. Recently, DeepSeek-V2 (Liu et al., 2024a) applied multihead latent
attention (MLA) to achieve better performance than MHA while reducing KV cache in inference
time by sharing the same low-rank representation of key and value. In comparison to the approaches
above, TPA applied a low-rank tensor product to compute the queries, keys, and values where the
12
Table 3: The evaluation results of medium models with different attention mechanisms pre-trained
using the FineWeb-Edu 100B dataset (2-shot with lm-evaluation-harness). The best scores in each
column are bolded. Abbreviations: HellaSw. = HellaSwag, W.G. = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSw. OBQA PIQA W.G. MMLU SciQ Avg.
MHA 64.44 32.85 59.05 44.18 33.20 68.72 50.12 26.01 87.40 49.44
MQA 64.27 32.94 57.71 44.36 31.80 68.01 51.70 25.99 86.00 49.29
GQA 61.70 32.17 52.81 43.99 33.80 68.50 53.35 24.44 86.40 50.80
MLA 62.75 30.80 59.17 42.02 34.80 67.08 52.41 26.11 84.80 51.10
TPA-KVonly 65.99 33.70 57.49 44.47 34.20 69.53 53.28 24.23 86.50 49.93
TPA 66.54 34.47 58.96 45.35 33.00 69.21 53.99 24.51 91.30 53.04
Table 4: The evaluation results of large models with different attention mechanisms pre-trained
using the FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The best scores in each
column are bolded. Abbreviations: HellaSw. = HellaSwag, W.G. = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSw. OBQA PIQA W.G. MMLU SciQ Avg.
MHA 59.93 33.62 61.93 50.63 36.00 71.06 55.41 22.87 81.20 52.52
MQA 60.73 33.62 57.34 50.09 37.00 69.97 55.49 25.30 79.60 52.13
GQA 61.66 34.30 58.72 49.85 38.40 71.16 53.75 25.23 77.60 52.30
MLA 60.73 31.57 61.74 48.96 35.40 69.59 55.09 26.39 76.70 51.80
TPA-KVonly 63.26 34.13 61.96 50.66 37.20 72.09 55.25 26.06 81.10 53.52
TPA 63.22 35.58 60.03 51.26 36.80 71.44 55.56 24.77 79.60 53.10
Table 5: The evaluation results of large models with different attention mechanisms pre-trained
using the FineWeb-Edu 100B dataset (2-shot with lm-evaluation-harness). The best scores in each
column are bolded. Abbreviations: HellaSwag = HellaSwag, WG = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSwag OBQA PIQA WG MMLU SciQ Avg.
MHA 67.85 36.35 59.82 50.22 35.00 70.67 53.35 23.92 91.10 54.25
MQA 68.86 36.09 53.79 50.50 37.00 70.89 54.70 25.01 88.00 53.87
GQA 69.15 36.09 58.84 50.29 36.20 70.73 54.22 26.08 90.00 54.62
MLA 68.56 35.41 60.12 49.18 38.00 69.21 55.25 25.29 88.20 54.36
TPA-KVonly 71.34 37.71 59.76 51.10 36.00 71.49 54.62 25.83 90.10 55.33
TPA 70.41 37.71 60.06 51.30 34.00 71.06 54.54 25.79 90.30 55.02
cached representations of keys and values are much smaller than those in MHA, achieving better
reduction on memory assumption of KV cache in inference time.
Low-Rank Factorizations. Low-rank approximations have been applied to compress model pa-
rameters and reduce complexity including LoRA (Hu et al., 2022), which factorizes weight up-
dates during fine-tuning, and its derivatives for other training scenarios such as efficient pretraining
(ReLoRA (Lialin et al., 2023), MoRA (Jiang et al., 2024)), long-context training (LongLoRA (Chen
et al., 2024), SinkLoRA (Zhang, 2024)), as well as continual training (InfLoRA (Liang & Li, 2024),
GS-LoRA (Zhao et al., 2024), I-LoRA (Ren et al., 2024)). These approaches typically produce static
low-rank expansions that do not explicitly depend on the input context. And Malladi et al. (2023);
Zeng & Lee (2024) provided theoretical proof of the expressiveness of low-rank approximation. For
the initialization of factorization matrices, OLoRA (Büyükakyüz, 2024) applied QR-decomposition
of pretrained weight to achieve better performance of language models while LoLDU (Shi et al.,
2024) used LDU-decomposition to accelerate training of LoRA. Moreover, AdaLoRA (Zhang et al.,
2023a) utilized Singular Value Decomposition (SVD) of the pretrained weight and introduced im-
portance score for each parameter as a measurement to achieve dynamic adjustment of rank. TPA,
by contrast, constructs Q, K, and V as contextually factorized tensors, enabling dynamic adaptation.
KV Cache Optimization. During the inference time of Transformers, key and value tensors of the
previous tokens are repeatedly computed due to their auto-regressive nature. To enhance efficiency,
firstly proposed by Ott et al. (2019), these tensors can be cached in memory for future decoding,
referred to as the KV cache. However, the KV cache requires additional memory usage and may
add to more latencies due to the bandwidth limitation (Adnan et al., 2024). Therefore, previous
studies have explored diverse approaches to mitigate these issues, including KV cache eviction to
13
discard less significant tokens (Zhang et al., 2023c; Xiao et al., 2024; Cai et al., 2024; Adnan et al.,
2024), dynamic sparse attention among selected keys and values (Ribar et al., 2024; Tang et al.,
2024; Singhania et al., 2024), KV cache offloading to CPU (He & Zhai, 2024; Lee et al., 2024; Sun
et al., 2024), as well as quantization of KV cache (Xiao et al., 2023; Liu et al., 2024c; Hooper et al.,
2024). Besides these methods, it is also effective to reduce the amount of KV cache for each token,
by approaches such as reducing the number of KV heads (Ren et al., 2024; Ainslie et al., 2023),
cross-layer KV re-usage (Xiao et al., 2019; Mu et al., 2024; Wu et al., 2024), and low-rank KV
representation (Saxena et al., 2024). Different from the methods above, TPA reduces the size of the
KV cache by using tensor-decomposed keys and values.
6 Conclusion
We introduced Tensor Product Attention (TPA), which factorizes query, key, and value matrices
into rank-R tensor products dependent on the token’s hidden state. Storing only the factorized
key/value components during autoregressive decoding substantially decreases the kv memory size
with improved performance compared with MHA, MQA, GQA, and MLA. The approach is fully
compatible with RoPE (and can store pre-rotated keys). Variants of TPA include factorizing only
the key/value or sharing basis vectors across tokens. Overall, TPA offers a powerful mechanism for
compressing KV storage while improving the model performance, thereby enabling longer sequence
contexts under constrained memory.
References
Muhammad Adnan, Akhil Arunkumar, Gaurav Jain, Prashant Nair, Ilya Soloveychik, and Pu-
rushotham Kamath. Keyformer: Kv cache reduction through key tokens selection for efficient
generative inference. Proceedings of Machine Learning and Systems, 6:114–127, 2024.
Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit
Sanghai. GQA: training generalized multi-query transformer models from multi-head check-
points. In Houda Bouamor, Juan Pino, and Kalika Bali (eds.), Proceedings of the 2023 Conference
on Empirical Methods in Natural Language Processing, EMNLP 2023, Singapore, December 6-
10, 2023, pp. 4895–4901. Association for Computational Linguistics, 2023. doi: 10.18653/V1/
2023.EMNLP-MAIN.298. URL https://doi.org/10.18653/v1/2023.emnlp-main.298.
Yonatan Bisk, Rowan Zellers, Ronan Le Bras, Jianfeng Gao, and Yejin Choi. PIQA: reasoning
about physical commonsense in natural language. In The Thirty-Fourth AAAI Conference on
Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelli-
gence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial
Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pp. 7432–7439. AAAI Press,
2020.
Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal,
Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are
few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Ka-
mar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, et al. Sparks of artificial general
intelligence: Early experiments with gpt-4. arXiv preprint arXiv:2303.12712, 2023.
Kerim Büyükakyüz. Olora: Orthonormal low-rank adaptation of large language models. arXiv
preprint arXiv:2406.01775, 2024.
Zefan Cai, Yichi Zhang, Bofei Gao, Yuliang Liu, Tianyu Liu, Keming Lu, Wayne Xiong, Yue Dong,
Baobao Chang, Junjie Hu, et al. Pyramidkv: Dynamic kv cache compression based on pyramidal
information funneling. arXiv preprint arXiv:2406.02069, 2024.
Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, and Jiaya Jia. Lon-
glora: Efficient fine-tuning of long-context large language models. In The Twelfth International
Conference on Learning Representations, ICLR 2024, Vienna, Austria, May 7-11, 2024, 2024.
14
Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse
transformers. arXiv preprint arXiv:1904.10509, 2019.
Krzysztof Marcin Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea
Gane, Tamás Sarlós, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser,
David Benjamin Belanger, Lucy J. Colwell, and Adrian Weller. Rethinking attention with per-
formers. In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event,
Austria, May 3-7, 2021, 2021.
Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam
Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, Parker Schuh,
Kensen Shi, Sasha Tsvyashchenko, Joshua Maynez, Abhishek Rao, Parker Barnes, Yi Tay, Noam
Shazeer, Vinodkumar Prabhakaran, Emily Reif, Nan Du, Ben Hutchinson, Reiner Pope, James
Bradbury, Jacob Austin, Michael Isard, Guy Gur-Ari, Pengcheng Yin, Toju Duke, Anselm Lev-
skaya, Sanjay Ghemawat, Sunipa Dev, Henryk Michalewski, Xavier Garcia, Vedant Misra, Kevin
Robinson, Liam Fedus, Denny Zhou, Daphne Ippolito, David Luan, Hyeontaek Lim, Barret
Zoph, Alexander Spiridonov, Ryan Sepassi, David Dohan, Shivani Agrawal, Mark Omernick,
Andrew M. Dai, Thanumalayan Sankaranarayana Pillai, Marie Pellat, Aitor Lewkowycz, Erica
Moreira, Rewon Child, Oleksandr Polozov, Katherine Lee, Zongwei Zhou, Xuezhi Wang, Bren-
nan Saeta, Mark Diaz, Orhan Firat, Michele Catasta, Jason Wei, Kathy Meier-Hellstern, Douglas
Eck, Jeff Dean, Slav Petrov, and Noah Fiedel. Palm: Scaling language modeling with pathways.
J. Mach. Learn. Res., 24:240:1–240:113, 2023.
Christopher Clark, Kenton Lee, Ming-Wei Chang, Tom Kwiatkowski, Michael Collins, and Kristina
Toutanova. Boolq: Exploring the surprising difficulty of natural yes/no questions. In Jill Burstein,
Christy Doran, and Thamar Solorio (eds.), Proceedings of the 2019 Conference of the North Amer-
ican Chapter of the Association for Computational Linguistics: Human Language Technologies,
NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Volume 1 (Long and Short Papers),
pp. 2924–2936. Association for Computational Linguistics, 2019.
Leo Gao, Jonathan Tow, Baber Abbasi, Stella Biderman, Sid Black, Anthony DiPofi, Charles Fos-
ter, Laurence Golding, Jeffrey Hsu, Alain Le Noac’h, Haonan Li, Kyle McDonell, Niklas Muen-
nighoff, Chris Ociepa, Jason Phang, Laria Reynolds, Hailey Schoelkopf, Aviya Skowron, Lintang
Sutawika, Eric Tang, Anish Thite, Ben Wang, Kevin Wang, and Andy Zou. A framework for few-
shot language model evaluation, 07 2024. URL https://zenodo.org/records/12608602.
Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural
networks. In Proceedings of the thirteenth international conference on artificial intelligence and
statistics, pp. 249–256. JMLR Workshop and Conference Proceedings, 2010.
Insu Han, Rajesh Jayaram, Amin Karbasi, Vahab Mirrokni, David P. Woodruff, and Amir Zandieh.
Hyperattention: Long-context attention in near-linear time. In The Twelfth International Con-
ference on Learning Representations, ICLR 2024, Vienna, Austria, May 7-11, 2024. OpenRe-
view.net, 2024. URL https://openreview.net/forum?id=Eh0Od2BJIM.
Jiaao He and Jidong Zhai. Fastdecode: High-throughput gpu-efficient llm serving using heteroge-
neous pipelines. arXiv preprint arXiv:2403.11421, 2024.
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob
Steinhardt. Measuring massive multitask language understanding. In 9th International Confer-
ence on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021, 2021.
Coleman Hooper, Sehoon Kim, Hiva Mohammadzadeh, Michael W Mahoney, Yakun Sophia Shao,
Kurt Keutzer, and Amir Gholami. Kvquant: Towards 10 million context length llm inference with
kv cache quantization. arXiv preprint arXiv:2401.18079, 2024.
Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang,
and Weizhu Chen. Lora: Low-rank adaptation of large language models. In The Tenth Inter-
national Conference on Learning Representations, ICLR 2022, Virtual Event, April 25-29, 2022,
2022.
15
Ting Jiang, Shaohan Huang, Shengyue Luo, Zihan Zhang, Haizhen Huang, Furu Wei, Weiwei Deng,
Feng Sun, Qi Zhang, Deqing Wang, et al. Mora: High-rank updating for parameter-efficient fine-
tuning. arXiv preprint arXiv:2405.12130, 2024.
Andrej Karpathy. NanoGPT. https://github.com/karpathy/nanoGPT, 2022.
Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are
rnns: Fast autoregressive transformers with linear attention. In International conference on ma-
chine learning, pp. 5156–5165. PMLR, 2020.
Wonbeom Lee, Jungi Lee, Junghwan Seo, and Jaewoong Sim. {InfiniGen}: Efficient generative
inference of large language models with dynamic {KV} cache management. In 18th USENIX
Symposium on Operating Systems Design and Implementation (OSDI 24), pp. 155–172, 2024.
Xiaoyu Li, Yingyu Liang, Zhenmei Shi, and Zhao Song. A tighter complexity analysis of sparsegpt.
arXiv preprint arXiv:2408.12151, 2024.
Vladislav Lialin, Sherin Muckatira, Namrata Shivagunde, and Anna Rumshisky. Relora: High-
rank training through low-rank updates. In The Twelfth International Conference on Learning
Representations, 2023.
Yan-Shuo Liang and Wu-Jun Li. Inflora: Interference-free low-rank adaptation for continual learn-
ing. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition,
pp. 23638–23647, 2024.
Yingyu Liang, Heshan Liu, Zhenmei Shi, Zhao Song, Zhuoyan Xu, and Junze Yin. Conv-basis: A
new paradigm for efficient attention inference and gradient computation in transformers. arXiv
preprint arXiv:2405.05219, 2024a.
Yingyu Liang, Jiangxuan Long, Zhenmei Shi, Zhao Song, and Yufa Zhou. Beyond linear approxi-
mations: A novel pruning approach for attention matrix. arXiv preprint arXiv:2410.11261, 2024b.
Aixin Liu, Bei Feng, Bin Wang, Bingxuan Wang, Bo Liu, Chenggang Zhao, Chengqi Dengr, Chong
Ruan, Damai Dai, Daya Guo, et al. Deepseek-v2: A strong, economical, and efficient mixture-
of-experts language model. arXiv preprint arXiv:2405.04434, 2024a.
Aixin Liu, Bei Feng, Bing Xue, Bingxuan Wang, Bochao Wu, Chengda Lu, Chenggang Zhao,
Chengqi Deng, Chenyu Zhang, Chong Ruan, et al. Deepseek-v3 technical report. arXiv preprint
arXiv:2412.19437, 2024b.
Zirui Liu, Jiayi Yuan, Hongye Jin, Shaochen Zhong, Zhaozhuo Xu, Vladimir Braverman, Beidi
Chen, and Xia Hu. KIVI: A tuning-free asymmetric 2bit quantization for KV cache. In Forty-first
International Conference on Machine Learning, ICML 2024, Vienna, Austria, July 21-27, 2024,
2024c.
I Loshchilov. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101, 2017.
Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient descent with warm restarts. arXiv
preprint arXiv:1608.03983, 2016.
Anton Lozhkov, Loubna Ben Allal, Leandro von Werra, and Thomas Wolf. Fineweb-edu: the
finest collection of educational content, 2024. URL https://huggingface.co/datasets/
HuggingFaceFW/fineweb-edu.
Sadhika Malladi, Alexander Wettig, Dingli Yu, Danqi Chen, and Sanjeev Arora. A kernel-based
view of language model fine-tuning. In International Conference on Machine Learning, pp.
23610–23641. PMLR, 2023.
Todor Mihaylov, Peter Clark, Tushar Khot, and Ashish Sabharwal. Can a suit of armor conduct
electricity? A new dataset for open book question answering. In Ellen Riloff, David Chiang,
Julia Hockenmaier, and Jun’ichi Tsujii (eds.), Proceedings of the 2018 Conference on Empirical
Methods in Natural Language Processing, Brussels, Belgium, October 31 - November 4, 2018,
pp. 2381–2391. Association for Computational Linguistics, 2018. doi: 10.18653/V1/D18-1260.
URL https://doi.org/10.18653/v1/d18-1260.
16
Yongyu Mu, Yuzhang Wu, Yuchun Fan, Chenglong Wang, Hengyu Li, Qiaozhi He, Murun Yang,
Tong Xiao, and Jingbo Zhu. Cross-layer attention sharing for large language models. arXiv
preprint arXiv:2408.01890, 2024.
Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier,
and Michael Auli. fairseq: A fast, extensible toolkit for sequence modeling. In Waleed Am-
mar, Annie Louis, and Nasrin Mostafazadeh (eds.), Proceedings of the 2019 Conference of the
North American Chapter of the Association for Computational Linguistics: Human Language
Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Demonstrations, pp.
48–53. Association for Computational Linguistics, 2019.
Weijieying Ren, Xinlong Li, Lei Wang, Tianxiang Zhao, and Wei Qin. Analyzing and reducing
catastrophic forgetting in parameter efficient tuning. arXiv preprint arXiv:2402.18865, 2024.
Luka Ribar, Ivan Chelombiev, Luke Hudlass-Galley, Charlie Blake, Carlo Luschi, and Douglas Orr.
Sparq attention: Bandwidth-efficient LLM inference. In Forty-first International Conference on
Machine Learning, ICML 2024, Vienna, Austria, July 21-27, 2024, 2024.
Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. Winogrande: An ad-
versarial winograd schema challenge at scale. In The Thirty-Fourth AAAI Conference on Artifi-
cial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence
Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial In-
telligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pp. 8732–8740. AAAI Press,
2020.
Utkarsh Saxena, Gobinda Saha, Sakshi Choudhary, and Kaushik Roy. Eigen attention: Attention
in low-rank space for KV cache compression. In Yaser Al-Onaizan, Mohit Bansal, and Yun-
Nung Chen (eds.), Findings of the Association for Computational Linguistics: EMNLP 2024,
Miami, Florida, USA, November 12-16, 2024, pp. 15332–15344. Association for Computational
Linguistics, 2024.
Imanol Schlag, Kazuki Irie, and Jürgen Schmidhuber. Linear transformers are secretly fast weight
programmers. In International Conference on Machine Learning, pp. 9355–9366. PMLR, 2021.
Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint
arXiv:1911.02150, 2019.
Noam Shazeer. Glu variants improve transformer. arXiv preprint arXiv:2002.05202, 2020.
Yiming Shi, Jiwei Wei, Yujia Wu, Ran Ran, Chengwei Sun, Shiyuan He, and Yang Yang. Loldu:
Low-rank adaptation via lower-diag-upper decomposition for parameter-efficient fine-tuning.
arXiv preprint arXiv:2410.13618, 2024.
Zhenmei Shi, Jiefeng Chen, Kunyang Li, Jayaram Raghuram, Xi Wu, Yingyu Liang, and Somesh
Jha. The trade-off between universality and label efficiency of representations from contrastive
learning. In The Eleventh International Conference on Learning Representations, ICLR 2023,
Kigali, Rwanda, May 1-5, 2023, 2023.
Prajwal Singhania, Siddharth Singh, Shwai He, Soheil Feizi, and Abhinav Bhatele. Loki: Low-rank
keys for efficient sparse attention. arXiv preprint arXiv:2406.02542, 2024.
Jianlin Su. The extreme pull between cache and effect: From MHA, MQA, GQA to MLA. https:
//spaces.ac.cn/archives/10091, May 2024.
Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer: En-
hanced transformer with rotary position embedding. Neurocomputing, 568:127063, 2024a.
Jianlin Su, Murtadha H. M. Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer:
Enhanced transformer with rotary position embedding. Neurocomputing, 568:127063, 2024b.
Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie
Chi, and Beidi Chen. Shadowkv: Kv cache in shadows for high-throughput long-context llm
inference. arXiv preprint arXiv:2410.21465, 2024.
17
Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and
Furu Wei. Retentive network: A successor to transformer for large language models. arXiv
preprint arXiv:2307.08621, 2023.
Jiaming Tang, Yilong Zhao, Kan Zhu, Guangxuan Xiao, Baris Kasikci, and Song Han. QUEST:
query-aware sparsity for efficient long-context LLM inference. In Forty-first International Con-
ference on Machine Learning, ICML 2024, Vienna, Austria, July 21-27, 2024, 2024.
Gemma Team, Thomas Mesnard, Cassidy Hardin, Robert Dadashi, Surya Bhupatiraju, Shreya
Pathak, Laurent Sifre, Morgane Rivière, Mihir Sanjay Kale, Juliette Love, et al. Gemma: Open
models based on gemini research and technology. arXiv preprint arXiv:2403.08295, 2024.
Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée
Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, et al. Llama: Open and
efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
Yao-Hung Hubert Tsai, Shaojie Bai, Makoto Yamada, Louis-Philippe Morency, and Ruslan
Salakhutdinov. Transformer dissection: An unified understanding for transformer’s attention via
the lens of kernel. In Proceedings of the 2019 Conference on Empirical Methods in Natural Lan-
guage Processing and the 9th International Joint Conference on Natural Language Processing
(EMNLP-IJCNLP), pp. 4344–4353, 2019.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural informa-
tion processing systems, 30, 2017.
You Wu, Haoyi Wu, and Kewei Tu. A systematic study of cross-layer kv sharing for efficient llm
inference. arXiv preprint arXiv:2410.14442, 2024.
Guangxuan Xiao, Ji Lin, Mickael Seznec, Hao Wu, Julien Demouth, and Song Han. Smoothquant:
Accurate and efficient post-training quantization for large language models. In International
Conference on Machine Learning, pp. 38087–38099. PMLR, 2023.
Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. Efficient streaming
language models with attention sinks. In The Twelfth International Conference on Learning Rep-
resentations, ICLR 2024, Vienna, Austria, May 7-11, 2024, 2024.
Tong Xiao, Yinqiao Li, Jingbo Zhu, Zhengtao Yu, and Tongran Liu. Sharing attention weights
for fast transformer. In Sarit Kraus (ed.), Proceedings of the Twenty-Eighth International Joint
Conference on Artificial Intelligence, IJCAI 2019, Macao, China, August 10-16, 2019, pp. 5292–
5298. ijcai.org, 2019.
Vikas Yadav, Steven Bethard, and Mihai Surdeanu. Quick and (not so) dirty: Unsupervised se-
lection of justification sentences for multi-hop question answering. In Kentaro Inui, Jing Jiang,
Vincent Ng, and Xiaojun Wan (eds.), Proceedings of the 2019 Conference on Empirical Methods
in Natural Language Processing and the 9th International Joint Conference on Natural Lan-
guage Processing, EMNLP-IJCNLP 2019, Hong Kong, China, November 3-7, 2019, pp. 2578–
2589. Association for Computational Linguistics, 2019. doi: 10.18653/V1/D19-1260. URL
https://doi.org/10.18653/v1/D19-1260.
Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. Hellaswag: Can a ma-
chine really finish your sentence? In Anna Korhonen, David R. Traum, and Lluı́s Màrquez
(eds.), Proceedings of the 57th Conference of the Association for Computational Linguistics,
ACL 2019, Florence, Italy, July 28- August 2, 2019, Volume 1: Long Papers, pp. 4791–
4800. Association for Computational Linguistics, 2019. doi: 10.18653/V1/P19-1472. URL
https://doi.org/10.18653/v1/p19-1472.
Yuchen Zeng and Kangwook Lee. The expressive power of low-rank adaptation. In The Twelfth
International Conference on Learning Representations, ICLR 2024, Vienna, Austria, May 7-11,
2024, 2024.
Hengyu Zhang. Sinklora: Enhanced efficiency and chat capabilities for long-context large language
models. arXiv preprint arXiv:2406.05678, 2024.
18
Michael Zhang, Kush Bhatia, Hermann Kumbong, and Christopher Ré. The hedgehog & the porcu-
pine: Expressive linear attentions with softmax mimicry. In The Twelfth International Conference
on Learning Representations, ICLR 2024, Vienna, Austria, May 7-11, 2024, 2024.
Qingru Zhang, Minshuo Chen, Alexander Bukharin, Pengcheng He, Yu Cheng, Weizhu Chen, and
Tuo Zhao. Adaptive budget allocation for parameter-efficient fine-tuning. In The Eleventh Inter-
national Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023.
OpenReview.net, 2023a.
Ruiqi Zhang, Spencer Frei, and Peter L Bartlett. Trained transformers learn linear models in-context.
arXiv preprint arXiv:2306.09927, 2023b.
Zhenyu Zhang, Ying Sheng, Tianyi Zhou, Tianlong Chen, Lianmin Zheng, Ruisi Cai, Zhao Song,
Yuandong Tian, Christopher Ré, Clark Barrett, et al. H2o: Heavy-hitter oracle for efficient gen-
erative inference of large language models. Advances in Neural Information Processing Systems,
36:34661–34710, 2023c.
Hongbo Zhao, Bolin Ni, Junsong Fan, Yuxi Wang, Yuntao Chen, Gaofeng Meng, and Zhaoxiang
Zhang. Continual forgetting for pre-trained vision models. In Proceedings of the IEEE/CVF
Conference on Computer Vision and Pattern Recognition, pp. 28631–28642, 2024.
19
Appendix
A Proofs of Theorems 21
B More on Experiments 22
B.1 Experimental Settings . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 22
B.2 Additional Experimental Results . . . . . . . . . . . . . . . . . . . . . . . . . . . 22
B.3 Ablation Studies on Learning Rates . . . . . . . . . . . . . . . . . . . . . . . . . 22
20
A Proofs of Theorems
Proof of Theorem 1.
21
B More on Experiments
B.1 Experimental Settings
We list the main architecture hyper-parameters and training devices in Table 6. We fix dh = 64
for all the models. Moreover, we fix the number of KV heads with 2 for GQA models; dR h = 32
for MLA models; and Rk = Rv = 2, Rq = 6 for TPA and TPA-KV only models. Other hyper-
parameters are listed in Table 7.
Table 6: The architecture hyper-parameters and training devices of models. Abbreviations: BS. =
Batch Size, GAS. = Gradient Accumulation Steps.
We display the evaluation results for small-size (124M) models in Tables 8-9.
Table 8: The evaluation results of small models with different attention mechanisms pre-trained us-
ing FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The best scores in each column
are bolded. Abbreviations: HellaSw. = HellaSwag, W.G. = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSw. OBQA PIQA W.G. MMLU SciQ Avg.
MHA 50.63 26.96 59.39 36.18 32.00 64.96 51.85 23.40 70.30 46.19
MQA 49.62 25.34 55.72 35.94 31.40 64.85 51.30 23.37 68.70 45.14
GQA 48.70 25.68 56.15 35.58 31.40 64.91 51.62 23.12 68.20 45.04
MLA 49.66 26.45 61.22 33.94 32.40 62.73 50.43 23.29 71.50 45.74
TPA-KVonly 51.05 26.54 57.25 36.77 32.60 65.02 50.91 23.64 69.70 45.94
TPA 51.26 27.39 57.00 36.68 32.80 64.47 49.72 24.61 72.00 46.21
We implement a set of parallel experiments for medium models with learning rate 6 × 10−4 , and the
curves for training loss, validation loss and validation perplexity are displayed in Figure 5. We also
show the performance of these models on the benchmarks described in Section 4 in Tables 10-11.
The results show that TPA and TPA-KVonly models can also outperform other types of attention
with different learning rates.
22
Table 9: The evaluation results of small models with different attention mechanisms on FineWeb-
Edu 100B dataset (2-shot with lm-evaluation-harness). The best scores in each column are bolded.
Abbreviations: HellaSw. = HellaSwag, W.G. = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSw. OBQA PIQA W.G. MMLU SciQ Avg.
MHA 57.66 28.24 57.28 36.43 29.60 64.09 51.14 26.57 82.00 48.11
MQA 53.79 26.35 44.95 34.18 28.80 62.79 52.01 25.91 78.10 45.21
GQA 55.01 25.94 55.72 35.68 31.80 65.29 51.93 25.27 77.80 47.16
MLA 52.78 26.19 57.25 33.19 29.60 63.98 50.43 24.90 76.00 46.04
TPA-KVonly 54.25 27.90 57.06 36.36 31.80 64.31 53.59 26.18 79.20 47.85
TPA 57.53 28.07 56.33 36.49 31.80 64.36 51.14 25.92 79.70 47.93
Validation Perplexity
TPA-KVonly TPA-KVonly TPA-KVonly
18 TPA
Validation Loss
17
2.8 2.8 16
15
2.7 2.7
14
2.6 2.6 13
0 10 20 30 40 50 0 10 20 30 40 50 0 10 20 30 40 50
Training tokens (B) Training tokens (B) Training tokens (B)
(a) Training Loss (b) Validation Loss (c) Validation Perplexity
Figure 5: The training loss, validation loss and validation perplexity of medium-size (353M) models
with learning rate 6 × 10−4 and different attention mechanisms on the FineWeb-Edu 100B dataset.
Table 10: The evaluation results of medium models (learning rate=6 × 10−4 ) with different attention
mechanisms pre-trained using FineWeb-Edu 100B dataset (0-shot with lm-evaluation-harness). The
best scores in each column are bolded. Abbreviations: HellaSw. = HellaSwag, W.G. = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSw. OBQA PIQA W.G. MMLU SciQ Avg.
MHA 59.51 29.52 59.60 45.68 34.20 68.82 53.43 23.33 76.90 50.11
MQA 57.62 31.91 59.45 45.69 35.40 69.31 53.51 26.47 74.60 50.44
GQA 28.67 31.48 58.29 45.45 35.20 68.50 54.46 24.58 76.50 47.01
MLA 57.49 29.44 59.97 44.09 25.77 68.66 53.04 25.77 76.40 48.96
TPA-KVonly 58.01 30.12 58.01 45.95 35.60 69.10 53.12 25.39 75.10 50.04
TPA 58.38 31.57 59.39 46.83 37.00 70.02 54.06 25.52 79.90 51.41
Table 11: The evaluation results of medium models (learning rate 6 × 10−4 ) with different attention
mechanisms pre-trained using FineWeb-Edu 100B dataset (2-shot with lm-evaluation-harness). The
best scores in each column are bolded. Abbreviations: HellaSw. = HellaSwag, W.G. = WinoGrande.
Method ARC-E ARC-C BoolQ HellaSw. OBQA PIQA W.G. MMLU SciQ Avg.
MHA 64.73 32.42 58.29 45.89 34.20 68.50 53.20 25.86 88.00 52.34
MQA 64.98 33.62 55.02 45.81 34.00 69.59 53.43 24.30 85.20 51.77
GQA 65.24 33.19 56.54 45.41 34.80 69.04 55.72 24.73 87.90 52.51
MLA 63.80 31.06 58.50 44.19 35.40 68.44 51.62 25.22 88.50 51.86
TPA-KVonly 64.69 32.34 59.48 46.23 35.40 70.08 54.06 25.64 86.30 52.69
TPA 67.97 34.56 57.22 46.87 34.60 69.91 52.01 25.07 89.90 53.12
23