0% found this document useful (0 votes)
24 views15 pages

Masked Attention Is All You Need For Graphs: Duvenaud Et Al. 2015 Kearnes Et Al. 2016 Gilmer Et Al. 2017

Uploaded by

Harold Selvaggi
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
24 views15 pages

Masked Attention Is All You Need For Graphs: Duvenaud Et Al. 2015 Kearnes Et Al. 2016 Gilmer Et Al. 2017

Uploaded by

Harold Selvaggi
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 15

Masked Attention is All You Need for Graphs

David Buterez 1 Jon Paul Janet 2 Dino Oglic 3 Pietro Lio 1

Abstract learning for molecules has the potential to accelerate and


even revolutionise fields such as drug discovery and ma-
Graph neural networks (GNNs) and variations terials science, being one of the main factors behind the
of the message passing algorithm are the pre-
arXiv:2402.10793v1 [cs.LG] 16 Feb 2024

accelerated development of GDL, as well as one of the ear-


dominant means for learning on graphs, largely liest (Duvenaud et al., 2015; Kearnes et al., 2016; Gilmer
due to their flexibility, speed, and satisfactory et al., 2017). General purpose learning on graphs is typically
performance. The design of powerful and gen- specified as an instance of message passing, an iterative al-
eral purpose GNNs, however, requires significant gorithm where one must define a message function which
research efforts and often relies on handcrafted, aggregates information from a given node’s neighbourhood,
carefully-chosen message passing operators. Mo- as well as a node (possibly also edge) update function to
tivated by this, we propose a remarkably simple incorporate the encoded messages.
alternative for learning on graphs that relies ex-
clusively on attention. Graphs are represented Despite the overall success and wide adoption of GNNs,
as node or edge sets and their connectivity is en- several fundamental problems have been highlighted over
forced by masking the attention weight matrix, time. Firstly, although the message passing framework is
effectively creating custom attention patterns for highly customisable through user-defined, learnable mes-
each graph. Despite its simplicity, masked atten- sage and node update functions, the design of novel layers
tion for graphs (MAG) has state-of-the-art per- is a difficult research problem, where improvements take
formance on long-range tasks and outperforms years to achieve and often rely on hand-crafted operators.
strong message passing baselines and much more This is particularly the case for general purpose GNNs that
involved attention-based methods on over 55 node do not exploit additional input modalities such as atomic
and graph-level tasks. We also show significantly coordinates. For example, principal neighbourhood aggre-
better transfer learning capabilities compared to gation (PNA) (Corso et al., 2020) is regarded as one of
GNNs and comparable or better time and memory the most powerful message passing layers, but it is built
scaling. MAG has sub-linear memory scaling in using a collection of hand-picked neighbourhood aggrega-
the number of nodes or edges, enabling learning tion functions, it requires a dataset degree histogram which
on dense graphs and future-proofing the approach. must be pre-computed prior to learning, and further uses
hand-picked degree scalers. Another example is given by
Graph Attention Networks (GAT) (Veličković et al., 2018),
one of the most popular graph layers and one of the earliest
1. Introduction
efforts to combine attention with GNNs. It has been shown,
The field of geometric deep learning (GDL) seeks to de- afterwards, that the original formulation of GAT has lim-
scribe, understand, and even “unify” deep learning strategies ited expressive power, and that a simple reordering of the
for data structures such as sets, grids, and graphs by leverag- operations can improve performance (Brody et al., 2022).
ing the fundamental concepts of symmetry and invariance.
Secondly, the nature of message passing imposes certain
A remarkably successful application of GDL is learning on
limitations which have shaped most of the GNN literature.
graphs, abstractions that represent relationships between
One of the most prominent examples is the readout function
items of a set, and which naturally describe real-world phe-
used to combine node-level representations into a single
nomena such as social, biological, or transportation net-
graph-level representation, and which is required to be per-
works, as well as objects like molecules. In particular, deep
mutation invariant with respect to the node order. Thus,
1
Department of Computer Science and Technology, University the default choice for most GNNs remains a simple, non-
of Cambridge, Cambridge, UK 2 Molecular AI, BioPharmaceu- learnable function such as sum, mean, or max, despite the
ticals R&D, AstraZeneca, Gothenburg, Sweden 3 Centre for AI, potential limited expressivity. Recently, it has been shown
BioPharmaceuticals R&D, AstraZeneca, Cambridge, UK. that breaking this constraint can lead to improved perfor-
mance, and might be acceptable in certain scenarios such as

1
Masked Attention is All You Need for Graphs

learning on molecules, where the inputs can be presented or similar encodings, it does not encode graph structures as
in a canonical order (Buterez et al., 2022). Furthermore, tokens or other language (sequence) specific concepts, and
expressive readout functions that are based on Set Trans- it does not require any pre-computations.
formers (Lee et al., 2019) have been shown to consistently
Despite its simplicity, MAG generally outperforms strong
improve performance of GNNs regardless of the underlying
message passing baselines and much more involved
message passing algorithm (Buterez et al., 2022; 2023b).
attention-based algorithms. Our empirical evaluation covers
Thirdly, the majority of GNN architectures are plagued by long-range molecular benchmarks and over 55 tasks from
the well-known oversmoothing and oversquashing prob- different domains such as quantum mechanics, molecular
lems that limit the depth of GNNs, as well as difficulties docking, physical chemistry, biophysics, bioinformatics,
in modelling long-range relationships, all of which are con- computer vision, social networks, functional call graphs,
sequences of aggregating information from exponentially and synthetic graphs. This emphasises the fact that the
larger neighbourhoods (Alon & Yahav, 2021). The proposed carefully-selected and hand-crafted nature of most message
solutions typically take the form of message regularisation passing algorithms can be easily superseded by attention
schemes (Godwin et al., 2022; Zhao & Akoglu, 2020; Cai itself, without the need to explicitly define any operator.
et al., 2021). However, there is generally not a consensus on Beyond benchmarking, we also explore a recent research
the right architectural choices for deep GNNs. Separately, direction that showed the transformative benefits of transfer
standard GNNs have also shown limitations in terms of learning in drug discovery and quantum mechanics (Buterez
transfer learning and strategies such as pre-training and fine- et al., 2024). Here, we leverage a newly-published, refined
tuning, as opposed to other families of neural networks such version of the QM 9 dataset at a higher level of theory (Fediai
as language models (Hu et al., 2020). For certain types of et al., 2023), and show that MAG is a viable and well per-
data and tasks, non-standard GNNs that leverage attention- forming transfer learning strategy, while GNNs are limited.
based readouts are currently the only way to effectively
MAG is arguably one of the most straightforward ways to ap-
perform transfer learning (Buterez et al., 2024; 2023a).
ply attention on graphs. Thanks to modern implementations
Perhaps due to the progress that message passing neural of exact attention, MAG scales sub-linearly in terms of mem-
networks have enabled on a wide range of tasks, alternative ory with the number of nodes (MAGN) or edges (MAGE),
paradigms for learning on graphs are relatively underdevel- as it largely relies on (masked) self-attention. Although
oped. The attention mechanism (Vaswani et al., 2017) is current libraries are not optimised for masked attention and
one of the main sources of innovation within graph learning, many optimisations are possible (see Appendix A), both the
either by directly incorporating attention within message training time and the memory consumption are competitive
passing (Veličković et al., 2018; Brody et al., 2022), by for- or even better than plain message passing.
mulating graph learning as a language processing task (Ying
et al., 2021; Kim et al., 2022), or by combining standard 2. Related Work
GNN layers with an attention mechanism (Rampasek et al.,
2022; Shirzad et al., 2023; Buterez et al., 2022; 2023b). Graph neural networks with adaptive readouts – A re-
cent research trend consists of following standard GNN
Here, we propose a novel graph learning framework charac-
layers with an attention-based pooling (readout) function,
terised by its simplicity and lack of message passing layers,
for example using the Set Transformer (Buterez et al., 2022;
instead being based on a classical attention mechanism with
2024; 2023b), or standard Transformers (Jain et al., 2021).
masking to create custom attention patterns. Compared
The change to a more expressive readout function has pro-
to some of the previous works that have formulated graph
vided consistent uplifts in most supervised learning tasks,
learning as a language modelling task, we take a different
and has enabled easy and effective transfer learning for
approach and consider graph learning as a learning task
molecules through a pre-training and fine-tuning workflow.
on sets, where the graph connectivity (i.e., adjacency ma-
trix) is enforced by masking the pairwise attention weight Transformers for graphs – One of the most popular ap-
matrix and allowing only values that correspond to graph proaches is the direct application of standard Transformers
connections. We term this architecture Masked Attention for designed for language modelling to graphs. Graphormer
Graphs (MAG) and demonstrate that it can be customised (Ying et al., 2021) achieves this through an involved and
to propagate information across nodes (MAGN) or edges computation-heavy suite of pre-processing steps, involving
(MAGE). It is general purpose, in the sense that it only a centrality, spatial, and edge encoding. Graphormer models
relies on the graph structure and possibly node and edge the mean readout function from GNNs through a special
features, and it is not restricted to a particular domain such “virtual” node that is added to the graph and connected to
as chemistry. The simplicity of the architecture is further every other node. Another approach within this paradigm is
demonstrated by the fact that MAG does not use positional the Tokenized Graph Transformer (TokenGT) (Kim et al.,

2
Masked Attention is All You Need for Graphs

Input set
Node masking Node adjacency matrix
n2 n1 n1 n2 n3 n4 n5 n6 n7 n8 n9 (n 1
n2 ... n )
N
or (e1
e2 ... e )
M

n5 n1 Node features Edge features


n2
n3 n3
n4
n4
n6 n5
n8 Self Attention Block
n6
n7
n7

Encoder
n8
n9 Masked Self Attention Block
n9

or or Self Attention Block


e1 e1 e2 e3 e4 e5 e6 e7 e8
e2 e1 Masked Self Attention Block
e4 e2
e3
e3
e5
e6 e4
e7 e5

Decoder
e6 Pooling by Multihead Attention
e8 e7
e8
Self Attention Block
Edge masking Edge adjacency matrix

Figure 1. Overview of the Masked Attention for Graphs (MAG) architecture. An important aspect is the choice of processing the node
feature set (MAGN) or the edge feature set (MAGE), and the appropriate masking algorithm. Conventional self attention blocks (S) can be
alternated with masked variants (M) as desired (here, a choice of SMSM is depicted). The decoder is only required for graph-level tasks.

2022), which treats all the nodes and edges as independent 3. Preliminaries
tokens. To adapt sequence learning to the graph domain, To-
kenGT encodes the graph information using node identifiers Graphs – A graph is a tuple G = (V, E) where V repre-
derived from orthogonal random features or Laplacian eigen- sents the set of nodes, E ⊆ V × V is the set of edges, and
vectors, and learnable type identifiers for nodes and edges. Nn = |V|, Ne = |E|. Nodes are associated with feature
Such models rely on embedding layers and are thus limited vectors xu of dimension dn for all nodes u ∈ V, and de -
to integer node or edge features. Moreover, the high number dimensional edge features euv for all edges e ∈ E. The node
of sequence-based architectures (e.g. Big Bird (Zaheer et al., features are collected as rows in a matrix X ∈ RNn ×dn , and
2020), Performer (Choromanski et al., 2021), etc.) repre- similarly for edge features into E ∈ RNe ×de . The graph
sents a large unexplored territory for graphs. Spectral Atten- connectivity information can be represented as an adjacency
tion Networks (Kreuzer et al., 2021) use a computationally- matrix A, where Auv = 1 if (u, v) ∈ E and Auv = 0 other-
expensive positional encoding based on the Laplacian and wise, although an edge list representation is often more prac-
two complementary attention mechanisms over nodes. A tical. GNNs also define the neighbourhood of a node u ∈ V
different philosophy is that of a message passing and Trans- as Nu = {v | (u, v) ∈ E ∨ (v, u) ∈ E}. Many GNNs can
former hybrid. Such an approach is taken by the GraphGPS be described as a form of P message passing, which takes the
framework (Rampášek et al., 2022), which alternates mes- general form of mt+1u = v∈Nu message
t t
t (xu , xv , euv )
t+1 t t+1

sage passing layers with Transformer layers. Like previous and xu = updatet xu , mu (Gilmer et al., 2017),
works, GraphGPS puts a large emphasis on different types of where message and update are the message and node
encodings, proposing and analysing positional and structural update functions, respectively, and t is the current time step.
encodings, further divided into local, global, and relative As described above, message passing is highly customis-
encodings. Exphormer (Shirzad et al., 2023) is an evolution able, for example by novel definitions of the message and
of GraphGPS which adds virtual global nodes and sparse update functions, their inputs, or even adding new update
attention based on expander graphs. While effective, such steps such as the graph-level state embedding and update
frameworks do still rely on message passing and are thus proposed in Chen et al. (2021) for multi-fidelity applica-
not purely attention-based solutions. Limitations include de- tions. ForLgraph-level prediction tasks, a readout or pooling
pendence on approximations (Performer for both, expander function must be used,
L aggregating all the learnt node
graphs for Exphormer), decreased performance when encod- representations: xG = u∈V (xu ).
ings (GraphGPS) or special nodes (Exphormer) are removed, Set Transformer – The Set Transformer is an encoder-
and scalability (GraphGPS). Notably, Exphormer is the first decoder attention-based architecture for learning on sets. It
work to consider custom attention patterns for graphs in the leverages the scaled dot product and multihead attention
form of node-level neighbourhoods.

3
Masked Attention is All You Need for Graphs

Algorithm 1: Node masking algorithm, where the matrix according to the node or edge adjacency information
inputs follow PyTorch Geometric conventions. in order to incorporate the graph structure. We call this
1 from torch geometric import unbatch edge index
approach Masked Attention for Graphs (MAG).
2 function node mask(b ei, b map, B, M) Inputs – MAG supports two main ways of information
3 # b ei is the batched edge index
4 # b map maps nodes to graphs
propagation: (1) on nodes (MAGN), more specifically on
5 # B, M batch, respectively mask size the node feature matrix X, or (2) on edges (MAGE), i.e. on
6 mask ← torch.full(size=(B, M, M), fill=False) the edge feature matrix E.
7 graph idx ← b map.index select(0, b ei[0, :])
8 eis ← unbatch edge index(b ei, b map) Masking – We extend the Set Transformer with masked
9 ei ← torch.cat(eis, dim=1) equivalents of the MAB and SAB (MSAB). Practically, this
10 mask[graph idx, ei[0, :], ei[1, :]] ← True means supplying the new blocks with a mask tensor of shape
11 return ∼mask B×Nd ×Nd for MAGN and B×Ne ×Ne for MAGE, where
B is the batch size. In a naive implementation, masking sim-
mechanisms proposed in Vaswani et al. (2017) to define ply means replacing the targeted values of the scaled QK T
multihead attention blocks (MABs), self attention blocks product with negative infinity (or a very large negative value
(SABs), and pooling by multihead attention (PMA) blocks. for stability) before softmax (Appendix F for the equation).
We do not recapitulate the definition of multihead attention The correct mask for each batch is different and must thus
and assume instead that a function Multihead (Q, K, V) be computed on the fly. For MAGN, the mask allows only
is available (for query, key, and value matrices, respec- adjacent nodes. In other words, the Nd × Nd portion of
tively). The original Set Transformer is given by ST(X) = the mask corresponds directly to the node adjacency matrix
Decoder(Encoder(X)) with: (Algorithm 1). For MAGE, the mask operates on the set
of edges and must allow only edges that share a common
MAB(X, Y) = H + Linearϕ (H) (1) node. While this computation is not as trivial as the MAGN
and H = X + Multihead(X, Y, Y), (2) mask, both kinds of masks can be efficiently computed using
SAB(X) = MAB (X, X) , (3) exclusively tensor operations (Algorithm 2).
n
Encoder(X) = SAB (X) , (4) Architecture – At a high level, a complete MAG model
PMAk (Z) = MAB (Sk , Linearϕ (Z)) , (5) takes the form of an encoder where MSAB and SAB blocks
are alternated as desired, with a PMA-based decoder (Fig-
Decoder(Z) = Linearϕ (SABn (PMAk (Z))) . (6)
ure 1). Compared to the Set Transformer, we have also
and where Linearϕ is a linear layer followed by an activation adapted MAG to use a pre-LN architecture with layer or
function ϕ, SABn (·) represents n subsequent applications batch normalisation and optionally include multi-layer per-
of a SAB, and Sk is a tensor of k learnable seed vectors that ceptrons (MLPs) after multihead attention. For graph-level
are randomly initialised (PMAk outputs k vectors) (Buterez tasks, the PMA module of the Set Transformer acts as an
et al., 2023b). Notably, the original Set Transformer is equivalent to the readout function in GNNs, but fully based
designed exclusively for learning on sets, uses layer normal- on attention, while for node-level tasks PMA is not required.
isation (LN) in a post-LN fashion (Xiong et al., 2020a), and
does not use any form of positional or structural encodings. 5. Experiments
Efficient attention – Recently, Flash attention has enabled Our evaluation encompasses: (1) an extensive suite of over
exact attention with linear memory scaling and faster train- 55 benchmarks from various domains, where the goal is
ing and inference (up to x10) thanks to more efficiently util- to compare MAG with representative message passing net-
ising the architecture of modern graphics processing units works (GCN, GAT, GATv2, GIN, and PNA) and Trans-
(GPUs) (Dao et al., 2022; Dao, 2023). For self-attention, an formers for graphs (Graphormer, TokenGT). Due to space
exact attention implementation has also been developed that limitations, we cannot present the results for all methods in
scales with the square root of the sequence length in terms the main part of the paper, so we select the most represen-
of memory, and has comparable run time with standard tative ones based on the task, with the rest presented in the
implementations (Rabe & Staats, 2021). Appendix. Graphormer and TokenGT are applicable only
to a subset of tasks due to limitations discussed in Related
4. Masked attention for graphs Work. They also fail for many datasets due to very high
memory requirements (CPU and/or GPU). We cover long-
We formulate graph learning as a learning problem on sets. range tasks (Section 5.1), node-level tasks (Section 5.2) and
The main learning mechanism consists of applying attention graph-level tasks (Section 5.3). (2) A transfer learning per-
directly to the node or edge feature matrix by means of formance evaluation (Section 5.4) based on a new, refined
SABs. We propose masking the pairwise attention weight

4
Masked Attention is All You Need for Graphs

Algorithm 2: Edge masking algorithm, where the inputs follow PyTorch Geometric conventions. T is the transpose.
Helper functions are explained in Appendix B.
1 from mag import consecutive, first unique index
2 function edge adjacency(b ei) 17 function edge mask(b ei, b map, B, M)
3 E ← b ei.size(1) 18 mask ← torch.full(size=(B, M, M), fill=False)
4 source nodes ← b ei[0] 19 edge to graph ← b map.index select(0, b ei[0, :])
5 target nodes ← b ei[1] 20
6 21 edge adj ← edge adjacency(b ei)
7 # unsqueeze and expand 22 ei to original ← consecutive(
8 exp src ← source nodes.unsq(1).exp((-1, E)) 23 first unique index(edge to graph), b ei.size(1))
9 exp trg ← target nodes.unsq(1).exp((-1, E)) 24
10 25 edges ← edge adj.nonzero()
11 src adj ← exp src == T(exp src) 26 graph idx ← edge to graph.idx select(0, edges[:, 0])
12 trg adj ← exp trg == T(exp trg) 27 coord 1 ← ei to original.idx select(0, edges[:, 0])
13 cross ← (exp src == T(exp trg)) logical or 28 coord 2 ← ei to original.idx select(0, edges[:, 1])
14 (exp trg == T(exp src)) 29
15 30 mask[graph idx, coord 1, coord 2] ← True
16 return (src adj logical or trg adj logical or cross) 31 return ∼mask

Table 1. Test set mean absolute error (MAE) or average precision (AP) for two long-range molecular benchmarks. All the results except
for MAGE are extracted from (Tönshoff et al., 2023). The number of layers for PEPT- STRUCT, respectively PEPT- FUNC is given as (·/·).
Dataset GCN (6/6) GIN (10/8) GraphGPS (8/6) Exphormer (4/8) MAGE (3/4)
PEPT- STRUCT (MAE ↓) 0.2460 ± 0.0007 0.2473 ± 0.0017 0.2509 ± 0.0014 0.2481 ± 0.0007 0.2453 ± 0.0003
PEPT- FUNC (AP ↑) 0.6860 ± 0.0050 0.6621 ± 0.0067 0.6534 ± 0.0091 0.6527 ± 0.0043 0.6863 ± 0.0044

Table 2. Test set Matthews correlation coefficient (MCC) for 3 node-level classification tasks, presented as mean ± standard deviation
from 5 different runs. The highest mean values are highlighted in bold.
Dataset GCN GAT GATv2 GIN PNA MAGN
PPI ↑ 0.47 ± 0.02 0.82 ± 0.03 0.72 ± 0.39 0.35 ± 0.02 0.83 ± 0.01 0.99 ± 0.00
C ITE S EER ↑ 0.30 ± 0.04 0.19 ± 0.03 0.21 ± 0.05 0.23 ± 0.03 0.01 ± 0.02 0.54 ± 0.02
C ORA ↑ 0.48 ± 0.06 0.37 ± 0.01 0.43 ± 0.04 0.39 ± 0.02 0.04 ± 0.05 0.70 ± 0.02

variant of QM 9 and following the recent framework pro- 5.2. Node-level tasks
posed by Buterez et al. (2024). (3) The time and memory
Generally, node-level tasks take the form of citation net-
characteristics for all discussed methods (Section 5.5).
works of different sizes, and are not as varied as graph-level
problems. Nonetheless, they represent an interesting case
5.1. Long-range tasks
for MAG as PMA is not needed. Here, we selected 3 repre-
Graph learning with Transformers has traditionally been sentative datasets: PPI, C ITE S EER, and C ORA. In particular,
evaluated on long-range graph benchmarks (LRGB) MAGN is the most natural choice as it works over node
(Dwivedi et al., 2022). However, it was recently shown representations. Our results indicate that MAGN is the best
that simple GNNs outperform most attention-based meth- performing method by a large margin (Table 2). Graphormer
ods (Tönshoff et al., 2023). Nonetheless, we evaluated and TokenGT are not available for node-level classification,
MAGE on two long-range molecular tasks (Table 1) and and they would not work due to the large graph sizes.
conclude that it outperforms GraphGPS, Exphormer, GCN,
and GIN. Despite using half the number of layers as other 5.3. Graph-level tasks
methods or less, MAGE matches the 2nd model on the
Graph-level tasks are generally more varied, as they often
PEPT- STRUCT leaderboard, and is within the top 5 for PEPT-
originate from different domains and require a readout func-
FUNC (as of January 2024). Remarkably, MAG is the only
tion for GNNs and a PMA module for MAG. Here, we
top method that is exclusively based on attention (i.e. no
do not focus on the differences between readouts and in-
message passing or hybrid), does not use any positional or
stead choose a reasonably strong baseline (4-layer GNNs
similar encoding, and is general purpose (i.e. not specifically
and mean readout; Appendix C for the experimental setup).
built for molecules). We also did not use hyperparameter
A more granular evaluation of readouts has been done by
optimisation or sophisticated schedulers (Appendix C).
(Buterez et al., 2022), concluding that the difference be-
tween readouts is small for most tasks.

5
Masked Attention is All You Need for Graphs

QM 9 is a quantum mechanics dataset consisting of 133,885 HOMO and LUMO energies (Fediai et al., 2023), which pro-
small organic molecules and 19 regression targets given by vides alternative DFT calculations based on the correlation-
quantum properties (Ramakrishnan et al., 2014). We report consistent basis set aug-cc-DZVP and the PBE functional,
results for all 19 targets in Table 3, with GCN and TokenGT as well as calculations at the (more accurate) eigen-value-
separately in Appendix D due to limited space. We observe self-consistent GW level of theory. The transfer learning
that for 15 out of 19 properties, MAGE is the best perform- setup consists of randomly selected training, validation, and
ing method. For the dipole moment (µ), PNA is better by a test sets of 25K, 5K, and respectively 10K molecules with
slight margin, while for HOMO, LUMO, and the HOMO- GW calculations (from the total 133,885). As outlined by
LUMO gap (∆ϵ), Graphormer is stronger. We expect that Buterez et al. (2024), transfer learning can be performed in
attention-based methods would be the best suited for inten- transductive or inductive setups. In the transductive case,
sive and localised properties like the HOMO and LUMO, test set DFT-level measurements are used for pre-training,
which agrees with recent literature on this topic (Buterez while in the inductive setting they are not. Here, we perform
et al., 2023b). The fact that MAGE is not as competitive transfer learning by pre-training a model on the DFT target
on these tasks could be explained by MAGE converging for a fixed number of epochs and then fine-tuning it on the
quicker than other methods and getting stuck in local min- subset of 25K GW calculations. In the transductive case,
ima, especially for HOMO/LUMO which take a long time to pre-training occurs on the full set of 133K DFT calculations,
converge for most methods. We report the results of altering while in the inductive case the DFT test set values are re-
the number and order of attention blocks in Appendix I. moved (note that the evaluation is done on the test set GW
values). The results (Table 6) indicate that MAGE improves
DOCKSTRING (Garcı́a-Ortegón et al., 2022) is a recent drug
thanks to transfer learning from DFT by 45% (HOMO) and
discovery data collection consisting of molecular docking
53% (LUMO) in the challenging inductive case, with 10
scores for 260,155 small molecules and 5 high-quality tar-
to 20-fold improvements for the transductive case, while
gets from different protein families that were selected as
GNNs improve only by a modest amount.
a regression benchmark, with different levels of difficulty:
PARP 1 (enzyme, easy), F 2 (protease, easy to medium), KIT
(kinase, medium), ESR 2 (nuclear receptor, hard), and PGR 5.5. Time and memory utilisation
(nuclear receptor, hard). The tasks are expected to be chal- In MAG, the most computation-intensive component is the
lenging, as the docking score depends on the 3D structure encoder. All encoder blocks perform (masked)√self attention,
of the ligand–target complex. Furthermore, a cluster split which can be efficiently implemented with O( N ) memory
into train (221,274 molecules) and test (38,881 molecules) complexity (Nn , Ne for MAGN, respectively MAGE). The
sets is provided, ensuring a more difficult and meaningful time complexity is as for standard attention. The decoder
benchmark. We have selected a random subset of 19,993 PMA encodes cross-attention between the full set outputted
molecules from the original train set as a validation set (the by the encoder and a set of learnable k vectors, benefitting
test set is not modified). We report results for the 5 targets in from Flash attention (linear memory and time scaling). A
Table 3 and observe that MAGE is the strongest method for decoder with a single PMA block and k = 1, as used here,
4 of the tasks. Remarkably, MAGE matches or outperforms is even more efficient. The mask tensor requires B × N 2
the strongest methods in the original manuscript (Attentive memory (Appendix A for a solution). However, it does not
FP, a GNN based on attention, Xiong et al., 2020b) despite require gradients and MAGE runs with up to Ne ≈ 30, 000
using 20,000 less training molecules. Moreover, while the edges on a consumer GPU with 24GB. Indeed, we report
other methods are competitive for the 4 easy or medium competitive time and memory utilisation (Table 7).
difficulty tasks, MAGE outperforms the others by a large
margin for the most difficult target (PGR).
6. Discussion
We further extend our evaluation with a collection of
datasets that covers multiple domains such as bioinformat- We presented an end-to-end approach that leverages atten-
ics, physical chemistry, computer vision, social networks, tion in a novel way for learning on graphs and demonstrated
functional call graphs, and synthetic graphs (Tables 3 to 5). its effectiveness relative to message passing and more in-
We observe that MAGE and MAGN are competitive with volved attention-based methods. Our approach is end-to-
the other methods and in most cases better. end in the sense that it replaces both message passing and
readout/pooling functions with attention mechanisms. The
former is facilitated by modularly stacking self and masked
5.4. Transfer learning
attention blocks (node or edge-based masking). In the fu-
We follow the recipe recently outlined for drug discovery ture, node- and edge-masked blocks might be interspersed.
and quantum mechanics by Buterez et al. (2024). Here, we Both node and edge masked attention mechanisms can be
leverage a recently-published refined version of the QM 9 implemented using a few lines of code (Algorithms 1 and 2).

6
Masked Attention is All You Need for Graphs

Table 3. Test set root mean squared error (RMSE, standard for quantum mechanics) QM 9, and R2 for the rest of tasks, presented as
mean ± standard deviation from 5 different runs. The lowest (QM 9) and highest (rest) mean values are highlighted in bold. Only one of
Graphormer/TokenGT was chosen for spacing reasons, based on competitiveness and lack of out-of-memory errors (OOM). Any results
not displayed here (e.g. GCN, TokenGT) are available in Appendix D.
Property GAT GATv2 GIN PNA Graphormer MAGE
µ ↓ 0.61 ± 0.00 0.61 ± 0.01 0.61 ± 0.00 0.57 ± 0.00 0.63 ± 0.00 0.61 ± 0.02
α ↓ 2.66 ± 0.25 1.86 ± 0.28 1.18 ± 0.10 1.00 ± 0.02 0.57 ± 0.16 0.48 ± 0.01
ϵHOMO ↓ 0.12 ± 0.00 0.12 ± 0.00 0.12 ± 0.00 0.11 ± 0.00 0.10 ± 0.00 0.11 ± 0.00
ϵLUMO ↓ 0.14 ± 0.00 0.13 ± 0.00 0.13 ± 0.00 0.11 ± 0.00 0.11 ± 0.00 0.13 ± 0.00
∆ϵ ↓ 0.18 ± 0.00 0.17 ± 0.01 0.19 ± 0.00 0.16 ± 0.00 0.16 ± 0.01 0.19 ± 0.00
⟨R2 ⟩ ↓ 53.03 ± 4.06 47.81 ± 3.50 36.14 ± 0.18 35.17 ± 0.35 31.65 ± 1.79 28.57 ± 0.26
ZPVE ↓ 0.24 ± 0.01 0.14 ± 0.02 0.08 ± 0.01 0.06 ± 0.00 0.06 ± 0.02 0.03 ± 0.00
U0 ↓ 609.25 ± 68.91 329.56 ± 30.08 143.71 ± 15.27 100.49 ± 5.43 31.06 ± 11.45 9.93 ± 2.90
QM 9

U ↓ 579.75 ± 78.19 319.37 ± 11.63 160.75 ± 46.21 100.41 ± 4.27 28.30 ± 18.81 10.05 ± 3.18
H ↓ 558.89 ± 71.13 310.95 ± 20.79 166.10 ± 41.69 102.85 ± 4.13 32.18 ± 11.43 11.05 ± 6.37
G ↓ 580.03 ± 68.41 320.53 ± 36.14 166.66 ± 54.19 102.43 ± 3.36 19.90 ± 11.13 10.69 ± 6.34
cV ↓ 0.90 ± 0.16 0.64 ± 0.02 0.53 ± 0.04 0.39 ± 0.01 0.47 ± 0.08 0.17 ± 0.00
U0ATOM ↓ 3.96 ± 0.61 1.77 ± 0.22 1.19 ± 0.06 0.96 ± 0.03 0.45 ± 0.05 0.24 ± 0.01
U ATOM ↓ 4.53 ± 0.33 1.92 ± 0.25 1.19 ± 0.05 0.97 ± 0.02 0.47 ± 0.09 0.24 ± 0.01
H ATOM ↓ 3.78 ± 0.70 1.81 ± 0.24 1.15 ± 0.03 0.96 ± 0.03 0.42 ± 0.03 0.25 ± 0.00
GATOM ↓ 3.50 ± 0.67 1.73 ± 0.11 1.07 ± 0.04 0.85 ± 0.02 0.38 ± 0.03 0.22 ± 0.02
A ↓ 13.45 ± 4.20 5.70 ± 5.79 3.83 ± 4.08 27.99 ± 38.02 1.61 ± 0.19 0.75 ± 0.11
B ↓ 0.23 ± 0.01 0.25 ± 0.03 0.23 ± 0.02 0.26 ± 0.02 0.12 ± 0.03 0.08 ± 0.01
C ↓ 0.18 ± 0.03 0.22 ± 0.03 0.22 ± 0.07 0.21 ± 0.02 0.12 ± 0.02 0.05 ± 0.01
M OL N ET

F REE S OLV ↑ 0.95 ± 0.01 0.94 ± 0.03 0.72 ± 0.43 0.39 ± 0.51 0.92 ± 0.01 0.96 ± 0.00
LIPO ↑ 0.78 ± 0.01 0.78 ± 0.01 0.78 ± 0.01 0.80 ± 0.01 OOM 0.71 ± 0.01
ESOL ↑ 0.86 ± 0.01 0.85 ± 0.01 0.89 ± 0.01 0.88 ± 0.01 0.91 ± 0.01 0.93 ± 0.01
GAT GATv2 GIN PNA TokenGT MAGE
ESR 2 ↑ 0.57 ± 0.01 0.58 ± 0.01 0.59 ± 0.01 0.61 ± 0.00 0.48 ± 0.01 0.63 ± 0.00
DOCKSTRING

F2 ↑ 0.79 ± 0.02 0.83 ± 0.01 0.85 ± 0.00 0.85 ± 0.00 0.77 ± 0.00 0.88 ± 0.00
KIT ↑ 0.80 ± 0.00 0.80 ± 0.00 0.80 ± 0.01 0.82 ± 0.00 0.66 ± 0.02 0.80 ± 0.00
PARP 1 ↑ 0.79 ± 0.04 0.86 ± 0.01 0.88 ± 0.00 0.88 ± 0.00 0.79 ± 0.01 0.91 ± 0.00
PGR ↑ 0.50 ± 0.02 0.50 ± 0.03 0.56 ± 0.02 0.50 ± 0.06 0.50 ± 0.00 0.68 ± 0.00

Table 4. Test set Matthews correlation coefficient (MCC) for 3 graph-level classification tasks from MoleculeNet, presented as mean ±
standard deviation from 5 different runs. The highest mean values are highlighted in bold. GCN and TokenGT are available in Appendix D.
Dataset GAT GATv2 GIN PNA Graphormer MAGE
BBBP ↑ 0.71 ± 0.04 0.72 ± 0.04 0.71 ± 0.02 0.72 ± 0.03 OOM 0.76 ± 0.05
BACE ↑ 0.59 ± 0.02 0.60 ± 0.01 0.56 ± 0.05 0.45 ± 0.25 0.09 ± 0.07 0.65 ± 0.02
HIV ↑ 0.38 ± 0.02 0.38 ± 0.07 0.35 ± 0.05 0.38 ± 0.05 OOM 0.43 ± 0.02

Despite its simplicity, MAG consistently outperforms other datasets MAG performed better with batch normalisation
message passing baselines and more involved Transformer- instead of layer normalisation (more in Appendix G). Also,
based methods while supporting transfer learning through MAG does not use any form of positional encoding, which
pre-training and fine-tuning, which does not work well with is against current trends for GNNs and Transformers. Future
classical GNNs, and scales favourably in terms of time and research might investigate ways to combine node and edge
memory. Still, several software limitations impede a more processing, devise efficient attention implementations for
efficient MAG implementation (see Appendix A). Com- graphs, study the importance of positional encodings, inte-
pared to Transformers, we remarked several distinguishing grate successful concepts from other domains such as sparse
features. It performs as well as presented here without any expert models (Fedus et al., 2022), or study multi-modality.
sophisticated learning rate scheduler (warm-up, cosine an-
nealing, etc.) or optimiser. Interestingly, for about half the

7
Masked Attention is All You Need for Graphs

Table 5. Test set Matthews correlation coefficient (MCC) for graph-level classification tasks from various domains, presented as mean ±
standard deviation from 5 runs. All models use MAGE, except the last 6 (MAGN). The highest mean values are highlighted in bold.
Dataset GCN GAT GATv2 GIN PNA MAG
M AL N ET T INY ↑ 0.85 ± 0.01 0.87 ± 0.01 0.88 ± 0.01 0.90 ± 0.01 0.89 ± 0.01 0.91 ± 0.01
MNIST ↑ 0.85 ± 0.00 0.97 ± 0.00 0.97 ± 0.00 0.88 ± 0.01 0.98 ± 0.00 0.97 ± 0.00
CV

CIFAR 10 ↑ 0.46 ± 0.00 0.63 ± 0.01 0.63 ± 0.01 0.44 ± 0.02 0.65 ± 0.02 0.63 ± 0.01
↑ 0.61 ± 0.03 0.70 ± 0.04 0.72 ± 0.03 0.67 ± 0.04 0.73 ± 0.02 0.71 ± 0.03
B IO I NF.

ENZYMES
PROTEINS ↑ 0.21 ± 0.10 0.41 ± 0.04 0.47 ± 0.05 0.48 ± 0.08 0.44 ± 0.10 0.55 ± 0.06
DD ↑ 0.49 ± 0.02 0.46 ± 0.05 0.43 ± 0.04 0.39 ± 0.07 0.40 ± 0.10 0.57 ± 0.04
SYNTHETIC ↑ 1.00 ± 0.00 1.00 ± 0.00 1.00 ± 0.00 0.99 ± 0.03 1.00 ± 0.00 1.00 ± 0.00
SYNTHETIC

SYNTHETIC N . ↑ −0.01 ± 0.05 0.56 ± 0.12 0.79 ± 0.11 0.47 ± 0.14 0.91 ± 0.03 0.95 ± 0.03
SYNTHIE ↑ 0.87 ± 0.04 0.26 ± 0.07 0.30 ± 0.07 0.60 ± 0.07 0.84 ± 0.08 0.92 ± 0.05
TRIANGLES ↑ 0.17 ± 0.01 0.15 ± 0.02 0.14 ± 0.02 0.15 ± 0.01 0.07 ± 0.02 0.22 ± 0.05
COLORS -3 ↑ 0.30 ± 0.01 0.22 ± 0.02 0.25 ± 0.01 0.29 ± 0.02 0.38 ± 0.01 0.75 ± 0.01
IMDB - BINARY ↑ 0.59 ± 0.04 0.52 ± 0.05 0.51 ± 0.05 0.54 ± 0.05 0.50 ± 0.05 0.62 ± 0.06
IMDB - MULTI ↑ 0.18 ± 0.03 0.20 ± 0.02 0.17 ± 0.01 0.19 ± 0.02 0.20 ± 0.02 0.21 ± 0.03
REDDIT- BINARY ↑ 0.54 ± 0.02 0.53 ± 0.05 0.35 ± 0.20 0.49 ± 0.04 0.53 ± 0.07 0.61 ± 0.04
SOCIAL

REDDIT- M -5 K ↑ 0.34 ± 0.01 0.31 ± 0.01 0.29 ± 0.02 0.31 ± 0.01 0.34 ± 0.02 0.35 ± 0.01
REDDIT- M -12 K ↑ 0.34 ± 0.01 0.31 ± 0.01 0.26 ± 0.02 0.33 ± 0.00 0.35 ± 0.01 0.37 ± 0.03
TWITCH EGOS ↑ 0.37 ± 0.00 0.38 ± 0.00 0.38 ± 0.00 0.38 ± 0.00 0.39 ± 0.00 0.39 ± 0.00
REDDIT THR . ↑ 0.56 ± 0.00 0.57 ± 0.00 0.57 ± 0.00 0.57 ± 0.00 0.57 ± 0.00 0.57 ± 0.00
GITHUB STAR . ↑ 0.27 ± 0.01 0.30 ± 0.01 0.21 ± 0.11 0.28 ± 0.03 0.36 ± 0.02 0.31 ± 0.02

Table 6. Transfer learning performance (RMSE) on QM 9 for HOMO and LUMO, presented as mean ± standard deviation from 5 different
runs on test sets. The lowest mean values are highlighted in bold.

Task Strategy GCN GAT GATv2 GIN PNA MAGE


GW 0.23 ± 0.004 0.21 ± 0.002 0.21 ± 0.002 0.21 ± 0.001 0.19 ± 0.001 0.14 ± 0.001
HOMO ↓ Trans. 0.21 ± 0.001 0.16 ± 0.000 0.15 ± 0.000 0.15 ± 0.000 0.14 ± 0.000 0.01 ± 0.000
Ind. 0.21 ± 0.001 0.18 ± 0.000 0.18 ± 0.000 0.19 ± 0.000 0.17 ± 0.000 0.09 ± 0.000
GW 0.20 ± 0.001 0.19 ± 0.002 0.19 ± 0.002 0.19 ± 0.001 0.18 ± 0.003 0.12 ± 0.001
LUMO ↓ Trans. 0.21 ± 0.001 0.16 ± 0.000 0.17 ± 0.000 0.17 ± 0.000 0.16 ± 0.000 0.01 ± 0.000
Ind. 0.20 ± 0.000 0.17 ± 0.000 0.17 ± 0.000 0.17 ± 0.001 0.17 ± 0.001 0.08 ± 0.000

Table 7. Average training time per epoch (s) and used memory (GB) for all methods, presented as mean ± std. from 5 epochs and with the
number of parameters (#). For QM 9 (103,542 train items), the maximum number of nodes/edges per graph is 29 (MAGN), respectively 56
(MAGE). DD (942 train items) has 5,748 maximum nodes and 28,534 maximum edges. All algorithms use bfloat16 mixed training. (*)
Parameters for masked blocks are counted as for normal self-attention, despite a very large number of connections being dropped.
Method # QM 9 DD
Time (s) Memory (GB) Time (s) Memory (GB)
GCN 158K 13.74 ± 0.47 0.10 ± 0.00 2.20 ± 0.11 0.32 ± 0.02
GAT 10.1M 25.60 ± 0.35 1.14 ± 0.07 4.05 ± 0.08 3.61 ± 0.55
GATv2 20.1M 28.82 ± 0.65 1.36 ± 0.09 5.31 ± 0.04 5.51 ± 0.67
GIN 433K 14.95 ± 1.24 0.12 ± 0.00 2.13 ± 0.03 0.28 ± 0.06
PNA 6.9M 66.62 ± 1.33 2.53 ± 0.20 12.05 ± 0.05 6.77 ± 0.97
Graphormer 22.3M 186.60 ± 4.50 1.95 ± 0.00 OOM
TokenGT 8.1M 30.76 ± 0.26 1.21 ± 0.00 OOM
MAGN (naive) 16.46 ± 0.71 0.31 ± 0.00 OOM
8.3M*
MAGN (mem-efficient) 15.38 ± 0.70 0.27 ± 0.00 13.53 ± 0.03 1.19 ± 0.00
MAGE (naive) 26.11 ± 0.31 0.61 ± 0.00 OOM
8.6M*
MAGE (mem-efficient) 22.58 ± 0.49 0.44 ± 0.00 348.22 ± 1.32 21.90 ± 1.19

8
Masked Attention is All You Need for Graphs

References Dao, T., Fu, D. Y., Ermon, S., Rudra, A., and Ré, C. FlashAt-
tention: Fast and memory-efficient exact attention with
Alon, U. and Yahav, E. On the bottleneck of graph neural
IO-awareness. In Advances in Neural Information Pro-
networks and its practical implications. In International
cessing Systems, 2022.
Conference on Learning Representations, 2021.
Brody, S., Alon, U., and Yahav, E. How attentive are graph Duvenaud, D., Maclaurin, D., Aguilera-Iparraguirre, J.,
attention networks? In International Conference on Gómez-Bombarelli, R., Hirzel, T., Aspuru-Guzik, A.,
Learning Representations, 2022. and Adams, R. P. Convolutional networks on graphs
for learning molecular fingerprints. In Proceedings of
Buterez, D., Janet, J. P., Kiddle, S. J., Oglic, D., and Liò, the 28th International Conference on Neural Information
P. Graph neural networks with adaptive readouts. In Oh, Processing Systems - Volume 2, NIPS’15, pp. 2224–2232,
A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Ad- Cambridge, MA, USA, 2015. MIT Press.
vances in Neural Information Processing Systems, 2022.
Dwivedi, V. P., Rampášek, L., Galkin, M., Parviz, A., Wolf,
Buterez, D., Janet, J. P., Kiddle, S. J., and Liò, P. Mf-pcba: G., Luu, A. T., and Beaini, D. Long range graph bench-
Multifidelity high-throughput screening benchmarks for mark. In Thirty-sixth Conference on Neural Informa-
drug discovery and machine learning. Journal of Chemi- tion Processing Systems Datasets and Benchmarks Track,
cal Information and Modeling, 63(9):2667–2678, 2023a. 2022.
doi: 10.1021/acs.jcim.2c01569. PMID: 37058588.
Fediai, A., Reiser, P., Peña, J. E. O., Friederich, P., and
Buterez, D., Janet, J. P., Kiddle, S. J., Oglic, D., and Liò, P. Wenzel, W. Accurate gw frontier orbital energies of 134
Modelling local and general quantum mechanical prop- kilo molecules. Scientific Data, 10(1):581, Sep 2023.
erties with attention-based pooling. Communications ISSN 2052-4463. doi: 10.1038/s41597-023-02486-4.
Chemistry, 6(1):262, Nov 2023b. ISSN 2399-3669. doi:
10.1038/s42004-023-01045-7. Fedus, W., Dean, J., and Zoph, B. A review of sparse expert
models in deep learning, 2022.
Buterez, D., Janet, J. P., Kiddle, S., Oglic, D., and Liò,
P. Transfer learning with graph neural networks for Garcı́a-Ortegón, M., Simm, G. N. C., Tripp, A. J.,
improved molecular property prediction in the multi- Hernández-Lobato, J. M., Bender, A., and Bacallado,
fidelity setting. ChemRxiv, 2024. doi: 10.26434/ S. Dockstring: Easy molecular docking yields better
chemrxiv-2022-dsbm5-v3. benchmarks for ligand design. Journal of Chemical In-
formation and Modeling, 62(15):3486–3502, 2022. doi:
Cai, T., Luo, S., Xu, K., He, D., Liu, T.-Y., and Wang,
10.1021/acs.jcim.1c01334. PMID: 35849793.
L. Graphnorm: A principled approach to accelerating
graph neural network training. In 2021 International Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., and
Conference on Machine Learning, May 2021. Dahl, G. E. Neural message passing for quantum chem-
Chen, C., Zuo, Y., Ye, W., Li, X., and Ong, S. P. Learning istry. In Proceedings of the 34th International Confer-
properties of ordered and disordered materials from multi- ence on Machine Learning - Volume 70, ICML’17, pp.
fidelity data. Nature Computational Science, 1(1):46–53, 1263–1272. JMLR.org, 2017.
01 2021. ISSN 2662-8457. Godwin, J., Schaarschmidt, M., Gaunt, A. L., Sanchez-
Choromanski, K. M., Likhosherstov, V., Dohan, D., Song, Gonzalez, A., Rubanova, Y., Veličković, P., Kirkpatrick,
X., Gane, A., Sarlos, T., Hawkins, P., Davis, J. Q., Mo- J., and Battaglia, P. Simple GNN regularisation for 3d
hiuddin, A., Kaiser, L., Belanger, D. B., Colwell, L. J., molecular property prediction and beyond. In Interna-
and Weller, A. Rethinking attention with performers. In tional Conference on Learning Representations, 2022.
International Conference on Learning Representations,
Hu, W., Liu, B., Gomes, J., Zitnik, M., Liang, P., Pande, V.,
2021.
and Leskovec, J. Strategies for pre-training graph neu-
Corso, G., Cavalleri, L., Beaini, D., Liò, P., and Velickovic, ral networks. In International Conference on Learning
P. Principal neighbourhood aggregation for graph nets. Representations, 2020.
In Proceedings of the 34th International Conference on
Neural Information Processing Systems, NIPS’20, Red Jain, P., Wu, Z., Wright, M. A., Mirhoseini, A., Gonza-
Hook, NY, USA, 2020. Curran Associates Inc. ISBN lez, J. E., and Stoica, I. Representing long-range con-
9781713829546. text for graph neural networks with global attention. In
Beygelzimer, A., Dauphin, Y., Liang, P., and Vaughan,
Dao, T. Flashattention-2: Faster attention with better paral- J. W. (eds.), Advances in Neural Information Processing
lelism and work partitioning, 2023. Systems, 2021.

9
Masked Attention is All You Need for Graphs

Kearnes, S., McCloskey, K., Berndl, M., Pande, V., and Rampášek, L., Galkin, M., Dwivedi, V. P., Luu, A. T., Wolf,
Riley, P. Molecular graph convolutions: moving beyond G., and Beaini, D. Recipe for a General, Powerful, Scal-
fingerprints. Journal of Computer-Aided Molecular De- able Graph Transformer. Advances in Neural Information
sign, 30(8):595–608, Aug 2016. ISSN 1573-4951. doi: Processing Systems, 35, 2022.
10.1007/s10822-016-9938-8.
Shazeer, N. GLU variants improve transformer. CoRR,
Kim, J., Nguyen, D. T., Min, S., Cho, S., Lee, M., Lee, H., abs/2002.05202, 2020.
and Hong, S. Pure transformers are powerful graph learn-
ers. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, Shirzad, H., Velingker, A., Venkatachalam, B., Sutherland,
K. (eds.), Advances in Neural Information Processing D. J., and Sinop, A. K. Exphormer: Scaling graph trans-
Systems, 2022. formers with expander graphs, 2023.

Kreuzer, D., Beaini, D., Hamilton, W. L., Létourneau, V., Tönshoff, J., Ritzert, M., Rosenbluth, E., and Grohe, M.
and Tossou, P. Rethinking graph transformers with spec- Where did the gap go? reassessing the long-range graph
tral attention. In Beygelzimer, A., Dauphin, Y., Liang, P., benchmark. In The Second Learning on Graphs Confer-
and Vaughan, J. W. (eds.), Advances in Neural Informa- ence, 2023.
tion Processing Systems, 2021. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones,
Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., and Teh, L., Gomez, A. N., Kaiser, L. u., and Polosukhin, I. At-
Y. W. Set transformer: A framework for attention-based tention is all you need. In Guyon, I., Luxburg, U. V.,
permutation-invariant neural networks. In Proceedings of Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S.,
the 36th International Conference on Machine Learning, and Garnett, R. (eds.), Advances in Neural Information
pp. 3744–3753, 2019. Processing Systems, volume 30. Curran Associates, Inc.,
2017.
Lefaudeux, B., Massa, F., Liskovich, D., Xiong, W.,
Caggiano, V., Naren, S., Xu, M., Hu, J., Tin- Veličković, P., Cucurull, G., Casanova, A., Romero, A.,
tore, M., Zhang, S., Labatut, P., and Haziza, Liò, P., and Bengio, Y. Graph attention networks. In
D. xformers: A modular and hackable trans- International Conference on Learning Representations,
former modelling library. https://github.com/ 2018.
facebookresearch/xformers, 2022. Wolf, T., Debut, L., Sanh, V., Chaumond, J., Delangue,
Loshchilov, I. and Hutter, F. Decoupled weight decay reg- C., Moi, A., Cistac, P., Rault, T., Louf, R., Funtow-
ularization. In International Conference on Learning icz, M., Davison, J., Shleifer, S., von Platen, P., Ma,
Representations, 2019. C., Jernite, Y., Plu, J., Xu, C., Le Scao, T., Gugger,
S., Drame, M., Lhoest, Q., and Rush, A. Transform-
Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., ers: State-of-the-art natural language processing. In Liu,
Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, Q. and Schlangen, D. (eds.), Proceedings of the 2020
L., Desmaison, A., Kopf, A., Yang, E., DeVito, Z., Raison, Conference on Empirical Methods in Natural Language
M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., Processing: System Demonstrations, pp. 38–45, Online,
Bai, J., and Chintala, S. Pytorch: An imperative style, October 2020. Association for Computational Linguistics.
high-performance deep learning library. In Advances doi: 10.18653/v1/2020.emnlp-demos.6.
in Neural Information Processing Systems 32, pp. 8024–
8035. Curran Associates, Inc., 2019. Xiong, R., Yang, Y., He, D., Zheng, K., Zheng, S., Xing, C.,
Zhang, H., Lan, Y., Wang, L., and Liu, T.-Y. On layer
Rabe, M. N. and Staats, C. Self-attention does not need normalization in the transformer architecture. In Proceed-
o(n2 ) memory. CoRR, abs/2112.05682, 2021. ings of the 37th International Conference on Machine
Learning, ICML’20. JMLR.org, 2020a.
Ramakrishnan, R., Dral, P. O., Rupp, M., and von Lilienfeld,
O. A. Quantum chemistry structures and properties of Xiong, Z., Wang, D., Liu, X., Zhong, F., Wan, X., Li, X.,
134 kilo molecules. Scientific Data, 1(1):140022, Aug Li, Z., Luo, X., Chen, K., Jiang, H., and Zheng, M. Push-
2014. ISSN 2052-4463. doi: 10.1038/sdata.2014.22. ing the boundaries of molecular representation for drug
discovery with the graph attention mechanism. Journal
Rampasek, L., Galkin, M., Dwivedi, V. P., Luu, A. T., Wolf, of Medicinal Chemistry, 63(16):8749–8760, 2020b. doi:
G., and Beaini, D. Recipe for a general, powerful, scal- 10.1021/acs.jmedchem.9b00959. PMID: 31408336.
able graph transformer. In Oh, A. H., Agarwal, A., Bel-
grave, D., and Cho, K. (eds.), Advances in Neural Infor- Ying, C., Cai, T., Luo, S., Zheng, S., Ke, G., He, D., Shen,
mation Processing Systems, 2022. Y., and Liu, T.-Y. Do transformers really perform badly

10
Masked Attention is All You Need for Graphs

for graph representation? In Thirty-Fifth Conference on


Neural Information Processing Systems, 2021.
Zaheer, M., Guruganesh, G., Dubey, K. A., Ainslie, J., Al-
berti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q.,
Yang, L., and Ahmed, A. Big bird: Transformers for
longer sequences. In Larochelle, H., Ranzato, M., Had-
sell, R., Balcan, M., and Lin, H. (eds.), Advances in
Neural Information Processing Systems, volume 33, pp.
17283–17297. Curran Associates, Inc., 2020.
Zhao, L. and Akoglu, L. Pairnorm: Tackling oversmooth-
ing in gnns. In International Conference on Learning
Representations, 2020.

11
Masked Attention is All You Need for Graphs

A. Limitations B. Helper functions


In terms of limitations, we highlight that the available li- consecutive is a helper function that generates consecu-
braries are not optimised for masking or custom attention tive numbers starting from 0, with a length specified in
patterns. This is most evident for very dense graphs (tens of its tensor argument as the difference between adjacent
thousands of edges or more). Memory efficient and Flash at- elements, and a second integer argument used for the
tention are available natively in PyTorch (Paszke et al., 2019) last length computation, e.g. consecutive([1, 4, 6], 10)
starting from version 2, as well as in the xFormers library = [0, 1, 2, 0, 1, 0, 1, 2, 3], and first unique index finds the
(Lefaudeux et al., 2022). More specifically, we have tested at first occurrence of each unique element in the tensor (sorted),
least 5 different implementations of MAG: (1) leveraging the e.g. first unique index([3, 2, 3, 4, 2]) = [1, 0, 3].
MultiheadAttention module from PyTorch, (2) lever-
aging the MultiHeadDispatch module from xFormers, C. Experimental setup
(3) a manual implementation of multihead attention, relying
on PyTorch’s scaled dot product attention func- We follow a simple and universal experimental protocol to
tion, (4) a manual implementation of multihead attention, re- ensure that it is possible to compare the results of different
lying on xFormers’ memory efficient attention, methods and to evaluate a large number of datasets with high
and (5) a naive implementation. Options (1) - (4) can all throughput. We chose a number of reasonable hyperparam-
make use of efficient and fast implementations. However, eters and settings for all methods, regardless of their nature
we have observed performance differences between the 4 (GNN or attention-based). This includes the AdamW opti-
implementations, as well as compared to a naive implemen- miser (Loshchilov & Hutter, 2019), learning rate (0.0001),
tation. This behaviour is likely due to the different low-level batch size (128), 32-bit training (without mixed precision),
kernel implementations. Moreover, Flash attention does not early stopping with a patience of 30 epochs (100 for the very
support custom attention masks as there is little interest for small datasets such as FREESOLV), and gradient clipping
such functionality from a language modelling perspective. (set to the default value of 0.5). Furthermore, we used a
simple learning rate scheduler that halved the learning rate
Although the masks can be computed efficiently during
if no improvement was encountered for 15 epochs (half the
training, all frameworks require the last two dimensions of
early stopping patience).
the input mask tensor to be of shape (Nd , Nd ) for nodes
or (Ne , Ne ) for edges, effectively squaring the number of For GNNs, we used 4 graph layers for all algorithms (GCN,
nodes or edges. However, the mask tensors are very sparse GIN, GAT, GATv2, PNA) and the mean readout function, a
and a sparse tensor alternative could greatly reduce the node dimension of 64, and a hidden dimension for the graph
memory consumption for large and dense graphs. Such layers of 256. All layers use batch normalisation. Settings
an option exists for the PyTorch native attention, but it is specific to some algorithms were also given by reasonable
currently broken. defaults, such as 8 attention heads for GAT(v2), and 5 tow-
ers for PNA. In certain instances such as small datasets with
Another possible optimisation would be to use nested
dense graphs, the defaults selected above can lead to out-
(ragged) tensors to represent graphs, since padding is cur-
of-memory errors, even for GNNs, in particular the more
rently necessary to ensure identical dimensions for attention.
computationally-intensive algorithms such as GAT(v2) or
A prototype nested tensor attention is available in PyTorch;
PNA. In such cases, we lower the settings that were the most
however, not all the required operations are supported and
likely cause, such as the hidden dimension or batch size, to
converting between normal and nested tensors is slow.
the next (lower) power of two. In some cases, mixed preci-
For all implementations, it is required that the mask ten- sion training (using the bfloat16 type) was performed as
sor is repeated by the number of attention heads (e.g. 8 or an alternative if Ampere-class GPUs were available.
16). However, a notable bottleneck is encountered for the
For Graphormer and TokenGT, we leverage the
MultiheadAttention and MultiHeadDispatch
huggingface (Wolf et al., 2020) implementation.
variants described above, which require that the repeats hap-
Our selection of reasonable defaults include 3 layers, a
pen in the batch dimension, i.e. requiring 3D mask tensors
hidden and embedding size of 512, and 8 attention heads, in
of shape (B × H, N, N ), where H is the number of heads.
addition to the defaults mentioned earlier such as learning
The other two efficient implementations require a 4D mask
rate and batch size. Graphormer in particular can be difficult
instead, i.e. (B, H, N, N ), where one can use PyTorch’s
to train, and in such cases we reduce the complexity of the
expand function instead of repeat. The expand alter-
model using the same strategies as above. However, note
native does not use any additional memory, while repeat
that many out-of-memory errors for this family of models
requires ×H memory. Note that it is not possible to re-
are not due to GPU memory, but RAM; we attempted to use
shape the 4D tensor created using expand without using
up to 256GB, but conceded if it did not work.
additional memory.

12
Masked Attention is All You Need for Graphs

For MAG, the same suite of “general” defaults such as results from 5 different runs (seeds).
the learning rate and batch size apply. We also generally
follow the same configuration for all datasets. However, D. Additional results
since MAG is a new architecture with many unknowns, we
generally evaluate a small number of variations for each The missing GCN, TokenGT, and Graphormer results for
dataset and select the best one according to the validation some of the datasets presented in the main text are presented
metrics. The variations typically involve choosing MAGN below in Tables 8 and 9. These complete the results from
or MAGE, batch or layer normalisation, the number (3 or Tables 3 and 4.
4) and order of self-attention and masked self attention
blocks (e.g. SMM, MSMM, etc.), the hidden size (256 or 512)
Table 8. Test set root mean squared error (RMSE) for QM 9 and
and the number of attention heads (8 or 16). We generally
R2 for the others, for GCN and TokenGT, presented as mean ±
prefer MAGE in all situations and only consider MAGN for standard deviation from 5 different runs.
datasets where MAGE would take a very long time to run or
Property GCN TokenGT
requires heavy modifications to the default parameters. This
is because MAGE can naturally incorporate edge features µ 0.67 ± 0.01 1.00 ± 0.00
and almost always performs better. We have also found α 3.51 ± 0.18 2.14 ± 0.07
SwiGLU to often be better than plain MLPs (Shazeer, 2020). ϵHOMO 0.14 ± 0.00 0.26 ± 0.01
ϵLUMO 0.15 ± 0.00 0.42 ± 0.01
Graphormer (and TokenGT to a lower extent) have very ∆ϵ 0.21 ± 0.00 0.56 ± 0.01
expensive pre-processing steps which require up to hun- ⟨R2 ⟩ 67.74 ± 5.72 177.70 ± 6.39
dreds of GBs of storage space to cache intermediary results. ZPVE 0.23 ± 0.00 0.14 ± 0.00
The alternative would be to not use caching; however, this QM 9 U0 573.32 ± 47.41 228.49 ± 198.10
means that everything must be stored in memory, resulting U 594.55 ± 57.81 228.49 ± 198.10
in almost immediate crashing. As a further complication, H 575.10 ± 48.66 228.49 ± 198.10
for large datasets suck as DOCKSTRING, Graphormer would G 594.07 ± 57.00 228.49 ± 198.10
run for a few epochs but spontaneously crash, most likely cV 1.34 ± 0.26 1.05 ± 0.03
due to high memory utilisation during training. Combined U0ATOM 3.99 ± 0.04 2.02 ± 0.30
with the fact that one epoch took several hours, we have U ATOM 4.19 ± 0.33 2.22 ± 0.19
included Graphormer results only for a minority of datasets. H ATOM 4.05 ± 0.02 2.16 ± 0.10
GATOM 3.70 ± 0.02 1.91 ± 0.14
Hyperparameter optimisation A 1.17 ± 0.11 0.01 ± 0.01
Other than the basic filtering described above for MAG, we B 0.28 ± 0.02 0.40 ± 5.80
did not use any techniques for tuning. In particular, we C 0.24 ± 0.02 0.35 ± 2.36
did not perform hyperparameter optimisation and have not
M OL N ET

F REE S OLV 0.34 ± 0.51 0.86 ± 0.02


tuned aspects of the networks such as the optimiser, learn- LIPO 0.71 ± 0.01 OOM
ing rate, batch size, dropout, hidden dimensions, etc. We ESOL 0.86 ± 0.01 0.78 ± 0.01
acknowledge that we are not using optimal parameters for
GCN Graphormer
the majority of models. However, this is also true for MAG,
DOCKSTRING

and tuning every model presented here would dramatically ESR 2 0.53 ± 0.01 OOM
increase the time and resource utilisation (as well as the F2 0.78 ± 0.00 OOM
financial costs associated with it), defeating the purpose of KIT 0.76 ± 0.00 OOM
presenting a simple yet effective alternative to GNNs. PARP 1 0.81 ± 0.00 OOM
PGR 0.36 ± 0.01 OOM
Evaluation
Generally, for a self-contained evaluation we split all
datasets using a random 80%, 10%, 10% split for train, Table 9. Test set Matthews correlation coefficient (MCC) for GCN
validation, and test. The same data splits are used for and TokenGT for 3 graph-level classification tasks from Molecu-
the different evaluated algorithms. Some datasets, such leNet, presented as mean ± standard deviation, over 5 runs.
as MNIST or DOCKSTRING are provided with existing train, Dataset GCN TokenGT
test, and optionally validation splits. If such splits are avail- BACE 0.35 ± 0.31 0.19 ± 0.09
able through PyTorch Geometric or from the authors (such BBBP 0.68 ± 0.02 0.31 ± 0.09
as DOCKSTRING), we use them and we do not perform our HIV 0.39 ± 0.03 OOM
own random splits. For all datasets and models, we provide

13
Masked Attention is All You Need for Graphs

E. Dataset statistics F. Masked attention equation


We present a summary of all the used datasets, along with Masked attention can be thought of as a custom attention pat-
their size and the maximum number of nodes and edges tern, which for graphs was described succinctly by (Shirzad
encountered in a graph in the dataset (Table 10). The last two et al., 2023). An adaptation for masking would be:
are important as they determine the shape of the mask and of
the inputs for the attention blocks. Technically, we require MaskedSelfAttention(X, M) = (7)
that the maximum number of nodes/edges is determined h
X j  j    T 
j j
per batch and the tensors to be padded accordingly. This WO WV XM σ WK XM (WQ XM )
per-batch maximum is lower than the dataset maximum for j=1
most batches. However, certain operations such as layer
norm., if performed over the last two dimensions, require a where h is the number of attention heads, X is the input fea-
constant value. To enable this, we use the dataset maximum. tures matrix, M is the custom attention pattern/mask, which
restricts the attention to a subset of elements of X, denoted
j j
by XM , WQ , WK , WVj , WO j
are weight matrices corre-
Table 10. Summary of used datasets, their size, and the max. num- sponding to queries, keys, values, and outputs, respectively,
ber of nodes (N) and edges (E) seen in a graph in the dataset. and σ is the softmax function.
Dataset Size N E
LRGB

PEPT- STRUCT 15 535 444 928 G. Layer vs batch normalisation


PEPT- FUNC 15 535 444 928
Contrary to standard Transformers and current trends, we
PPI 24 3 480 106 754 have found that simply replacing the layer normalisation
NODE

CORA 1 2 708 10 556 (LN) operation within MAG with batch normalisation (BN)
C ITE S EER 1 3 327 9 104 can dramatically improve performance on a number of
QM 9 133 885 29 56 datasets. LN was our default initial choice and we have
found that it works better for the standard QM 9 dataset and
F REE S OLV 642 44 92
M OLECULE N ET

its properties (Table 3). For the other datasets and tasks
LIPO 4 200 216 438
(excluding the 19 QM 9 properties but including the QM 9
ESOL 1 128 119 252
GW tasks from Table 6 and the PEPT- STRUCT and PEPT- FN
BBBP 2 039 269 562
datasets), although we have not exhaustively tested LN vs
BACE 1 513 184 376
BN models, we have generally observed that LN is prefer-
HIV 41 127 438 882
able for 20 tasks and BN for 17. More specifically, BN was
DOCKSTRING 260 060 164 342 preferable for the DOCKSTRING properties PARP 1, ESR 2,
and PGR, as well as for the datasets: CIFAR 10, COLORS -
M AL N ET T INY 5 000 4 994 20 096
3, DD, ESOL, LIPO, PEPT- STRUCT, PEPT- FN, PROTEINS,
MNIST 70 000 75 600 REDDIT THREADS, REDDIT- MULTI -12 K, SYNTHIE, TRIAN -
CV

CIFAR 10 60 000 150 1 200 GLES , ENZYMES , and TWITCH EGOS . Apart from the fact
that most molecular datasets tend to do better with LN, there
B IO I NFO

ENZYMES 600 126 298


PROTEINS 1 113 620 2 098 is no obvious indication of which normalisation technique
DD 1 178 5 748 28 534 might be preferable for certain dataset types.

SYNTHETIC 300 100 392


SYNTHETIC

SYNTHETIC NEW 300 100 396 H. Experimental platform


SYNTHIE 400 100 424 Representative versions of the software used as part of
TRIANGLES 45 000 100 396 this paper include Python 3.11.6, PyTorch version 2.1.1
COLORS -3 10 500 200 794 with CUDA 11.8, PyTorch Geometric 2.4.0, PyTorch Light-
IMDB - BINARY 1 000 136 2 498 ning 1.9.5, huggingface transformers version 4.35.2, and
IMDB - MULTI 1 500 89 2 934 xFormers version 0.0.23. We have also tested our code with
REDDIT- BINARY 2 000 3 782 8 142 CUDA ≥ 12.0. It is worth noting that attention masking
SOCIAL

REDDIT- M -5 K 4 999 3 648 9 566 and efficient implementations of attention are early fea-
REDDIT- M -12 K 11 929 3 782 10 342 tures that are advancing quickly. This means that their
TWITCH EGOS 127 094 52 1 572 behaviour might change unexpectedly and there might be
REDDIT THR . 203 088 97 370 bugs. For example, PyTorch 2.1.1 recently fixed a bug that
GITHUB STAR . 12 725 957 9 336 concerned non-contiguous custom attention masks in the
scaled dot product attention function.

14
Masked Attention is All You Need for Graphs

In terms of hardware, the GPUs used include an NVIDIA The configurations are given as a sequence of masked atten-
RTX 3090 with 24GB VRAM, NVIDIA V100 with 16GB or tion (M) or self-attention (S) blocks, followed by a pooling
32GB of VRAM, and NVIDIA A100 with 40GB of VRAM. by multihead attention block (P). While performance is
Recent, efficient implementations of attention are optimised slightly lower than Table 3 due to the training changes, we
for the newest GPU architectures, generally starting from notice that models that use exclusively self-attention blocks
Ampere (RTX 3090 and A100). However, while slower, it is and no masked blocks are worse than models that predom-
possible to run memory efficient attention on V100 GPUs. inantly rely on masked attention, as expected. It is also
possible to use self-attention blocks after pooling by mul-
I. Multiple configurations of MAGE tihead attention. However, that would greatly increase the
number of possible configurations and we do not evaluate
We summarise the results of running MAGE on the QM 9 this option here.
property α (alpha) in Table 11. Compared to the main results
of Table 3, we used an early stopping patience of 10 epochs
and trained for a maximum of 150 epochs.

Table 11. Test set RMSE for MAGE configurations for α (QM 9),
presented as mean ± standard deviation from 5 different runs.
Configuration MAGE
MMMSP 0.499 ± 0.012
MMSSP 0.501 ± 0.007
MMSP 0.503 ± 0.010
MMSMP 0.506 ± 0.003
MSMMP 0.508 ± 0.010
MMMMP 0.509 ± 0.007
MSSMP 0.509 ± 0.004
SMSSP 0.513 ± 0.012
SMMSP 0.515 ± 0.008
MSSSP 0.517 ± 0.014
MSMP 0.518 ± 0.008
MSMSP 0.519 ± 0.010
SMSMP 0.520 ± 0.010
SMMMP 0.522 ± 0.010
MMMP 0.522 ± 0.003
SSMMP 0.525 ± 0.012
MSP 0.525 ± 0.008
MSSP 0.527 ± 0.015
SMMP 0.527 ± 0.017
SSMSP 0.528 ± 0.009
SMSP 0.530 ± 0.005
SSSMP 0.539 ± 0.014
MMP 0.545 ± 0.017
SSMP 0.546 ± 0.010
MP 0.561 ± 0.007
SMP 0.567 ± 0.011
SSSSP 0.590 ± 0.014
SSSP 0.604 ± 0.013
SSP 0.642 ± 0.013
SP 0.665 ± 0.026

15

You might also like

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy