Chap7 GNN (20240229) - DL4H Practioner Guide
Chap7 GNN (20240229) - DL4H Practioner Guide
i i
“output” — 2024/3/4 — 6:48 — page 132 — #135
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 133 — #136
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 134 — #137
indicates node connections: if node i and node j is connected in the graph, A[i, j] = 1,
otherwise 0. Additionally, we assume E as the edge set of the graph. For example,
(i, j) 2 E if A[i, j] = 1, which means the there is an edge between node i and j. Note
that GNN works mostly with undirected graphs, i.e., A[i, j] = 1 implies A[ j, i] = 1,
and thus A is symmetric.
Example. For certain prediction tasks, such as forecasting the COVID-19 cases
at every county in the US, traditional DNN models might treat each county as an
independent training sample and build a multi-layer model to predict the counts based
on county-level features. Differently, GCN models could leverage the spatio-temporal
connections between counties. It first connects nearby counties together to form an
adjacency graph A since they are geographically similar and thus share similar COVID
responses. Then, GCN models will leverage the graph structure A to aggregate the
features among nearby counties to forecast the target.
Formally, one graph convolution layer could be defined as,
1)
H(t) = ReLU(ÃH(t W(t) ), (7.1)
where the initial hidden embeddings, H(0) = X, are the county-level features. Ã repre-
sents the normalized adjacency matrix (we will discuss it below) and W(t) represents
the layer-wise parameter matrix (t is the layer index). In comparison, the layer-wise
propagation of a simple DNN can be represented as H(t) = ReLU(H(t 1) W(t) ), and
the difference is the multiplication of the adjacency matrix. Essentially, GNN adds
dependency between different counties when making the predictions.
Normalizing the graph adjacency matrix For numerical stability purposes, the graph
learning algorithm usually requires a normalized adjacency matrix. There are two
common ways of normalizing the adjacency matrix: random walk normalization and
symmetric normalization.
The matrix A is symmetric by definition. Researchers often add a self-loop to
connect a node with itself and improve the numeric stability of graph operations. This
is achieved by setting A[i, i] = 1, 8 i, which is equivalent to adding an identity matrix
to the adjacency.
A A + I. (7.2)
Random Walk Adjacency Normalization: The first type is named random walk
normalization, which calculates the degree of each node, resulting into a matrix D 2
N N ⇥N . D is a diagonal matrix and each element is the row sum of the self-looped
adjacency matrix.
’
N
dii = ai j . (7.3)
j=1
Then, the random walk normalization will normalize A over each row, making the
row sum become 1, which aligns with the concept of random walk transition matrix
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 135 — #138
in Markov process (will discuss that in Chapter 9). Thus, its name is random walk
normalized adjacency Ãrw .
Ãrw = D 1 A. (7.4)
Example: Consider a simple graph with the following adjacency matrix:
©0 1 1™
A = ≠≠1 0 0ÆÆ̈
´1 0 0
Adding the self-loop gives:
©1 1 1™
A = ≠≠1 1 0ÆÆ̈
´1 0 1
Calculating the degree matrix D, we have:
©3 0 0™
D = ≠≠0 2 0ÆÆ̈
´0 0 2
©3 0 0™
D = ≠≠0 2 0ÆÆ̈
´0 0 2
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 136 — #139
q q
1 1 1
©q3 6™
≠ Æ
6
= ≠≠ 16 1
0 ÆÆ
≠q 2
Æ̈
1 1
´ 6 0 2
In fact, the Equation (7.1) can also be re-written into node-level updating formula (with
self-loop and random walk normalization applied to the adjacency matrix),
1)
Z(t) = H(t W(t) , (7.6)
(t) 1
↵uv = , 8(u, v) 2 E, (7.7)
|N (u)| + 1
© ’ (t) (t) ™
h(t)
u = ReLU ≠ ↵uv zv Æ̈ . (7.8)
´ v 2N(u)
• Equation (7.8) takes the weighted sum of all the neighborhood set (including the
node u itself) and applies the ReLU function to add non-linearity.
The PyTorch implementation of GCN model could be found below. Readers might
refer to this repository 1 as well, where we provide a notebook on applying the GCN
model to the Zachary’s karate club dataset 2 .
1 class GraphConvolutionLayer(nn.Module):
2 def __init__(self, in_features, out_features):
3 super(GraphConvolutionLayer, self).__init__()
4 self.transform = nn.Linear(in_features, out_features)
5
16 class GCN(nn.Module):
17 def __init__(self, num_features, hidden_dim, num_classes):
18 super(GCN, self).__init__()
1 https://github.com/sunlabuiuc/pyhealth-book/tree/main/chap7-GNN/notebook
2 https://en.wikipedia.org/wiki/Zachary’s_karate_club
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 137 — #140
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 138 — #141
© ’ (t) (t) ™
u = ReLU ≠
h(t) ↵uv zv Æ̈ . (7.12)
´v 2N(u)
Here, both the layer-wise weight matrix W(t) and the attention weight ✓ (t) are parame-
ters.
• Equation (7.10) calculates the attention score using a non-linear neural network (the
third approaches in Section 5.2.3). Basically, we concatenate the embedding of
two connected nodes, u and v, and then apply a linear transformation with a
LeakyReLU activation.
• Equation (7.11) applies the Softmax activation on the attention score, ensuring that
the normalized attention scores sum up to 1 for the neighborhood of node u.
In GCN, these two steps are merged into one with equal weights in the same
neighborhood.
• Equation (7.12) similarly takes the weighted sum of all the neighborhood set (in-
cluding the node u itself) and applies the ReLU function to add non-linearity.
The PyTorch implementation of GAT is presented below, and we provide the note-
book showing its application on Karate dataset in this repository 3 .
1 class GraphAttentionLayer(nn.Module):
2 def __init__(self, in_features, out_features):
3 super(GraphAttentionLayer, self).__init__()
4 self.in_features = in_features
5 self.out_features = out_features
6
20 # step 2
3 https://github.com/sunlabuiuc/pyhealth-book/tree/main/chap7-GNN/notebook
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 139 — #142
32 # step 3
33 # use attention to reweight the nodes and sum them up
34 h = torch.matmul(attention, h)
35 return h
36
37
38 class GAT(nn.Module):
39 def __init__(self, num_features, hidden_dim, num_classes=4):
40 super(GAT, self).__init__()
41 self.layer1 = GraphAttentionLayer(num_features, hidden_dim)
42 self.layer2 = GraphAttentionLayer(hidden_dim, num_classes)
43
The main difference lies in the GraphAttentionLayer class. This class defines a
single graph attention layer. It takes in the number of input features and the number
of output features as parameters. Inside the constructor (__init__), it initializes
parameters for linear transformations (self.W and self.a) which will be applied
to the input features. Additionally, it initializes a leaky ReLU activation function
(self.leakyrelu). In the forward method, it takes an adjacency matrix (adj) and
input features (X) as input. It first applies a linear transformation to the input features
(h = torch.mm(X, self.W), step 1). Then, it calculates attention scores using
a learned attention mechanism (step 2). It concatenates embeddings of node pairs,
calculates attention scores, applies a mask to the attention matrix, applies a softmax
function to get attention weights. Finally, it reweights the nodes based on attention
weights (step 3).
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 140 — #143
• Equation (7.13) take the last-layer embedding from neighborhood set of node u and
aggregated them into a message m(t) u . For example, the aggregation function in
GCN is an equally weighted sum while the aggregation in GAT is an attention
sum.
• Then, Equation (7.14) take the message m(t) u as well as the previous embedding of
node u as two inputs and update the current embedding by an update function.
For example, the update function in both GCN and GAT is adding the previous
embedding to the message and then apply a non-linear transformation to be the
updated embedding h(t) u .
The pytorch implemention of various MPNN models could refer to this repository 4 .
Below, we show one type of MPNN implementation from a recent drug recomendation
paper [75] (refer to Equation 7 and Equation 8 of the paper), which uses MPNN for
drug molecule graph representation. More concrete applications on Karate datasets
could be found in this repository 5 .
1 class MessagePassingLayer(nn.Module):
2 def __init__(self, in_features, out_features):
3 super(MessagePassingLayer, self).__init__()
4 self.message_passing = nn.Linear(2 * in_features, out_features)
5 self.read_out = nn.Linear(out_features, out_features)
6
4 https://github.com/priba/nmp_qc
5 https://github.com/sunlabuiuc/pyhealth-book/tree/main/chap7-GNN/notebook
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 141 — #144
35 class MPNN(nn.Module):
36 def __init__(self, num_features, hidden_dim, num_classes=4):
37 super(MPNN, self).__init__()
38 self.layer1 = MessagePassingLayer(num_features, hidden_dim)
39 self.layer2 = MessagePassingLayer(hidden_dim, num_classes)
40
Essentially, GCN, GAT, MPNN, and their variants are fundamental tools for learning
graph node embeddings, and the embedding vectors could be used for various different
purposes, such as node classification, node value regression, edge classification, graph
classification, etc.
Deep learning on graphs has made much exciting progress in both practical deploy-
ments and various application domains. This section will cover some advancements in
graph neural network research.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 142 — #145
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 143 — #146
[103]), or hybrid graphs for the application of GNNs, and the nodes in the graphs
could be homogeneous [104] (such as molecular graphs with atom as nodes), hetero-
geneous [81] (such as disease-drug bipartite graphs) or multi-relational [105] (EHR
knowledge graphs contain patients, drugs, diagnoses codes, etc).
Second, with these diverse graphs, GNN models have benefited various applica-
tions in many different problem settings, such as graph-to-sequence [106], graph-to-
tree [107], graph-to-graph translations [108], and various NLP tasks, such as natural
language generation [103], questions answering [109], information extraction [110],
knowledge graph reasoning [111], etc.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 144 — #147
shown that 3CLpro is essential for the viral life cycle, making it a critical pathogenic
component of SARS-CoV. The AID1706 SARS Cov 3CL dataset 6 records tens of
thousands of molecules that could potentially be active to 3CLpro.
We load this dataset from DeepPurpose package 7 , which contains 26640 molecules
represented as SMILES strings and each is associated with an "active or not" binary
labels. For convenience, we only use the first 1000 SMILES to build the training and
test datasets.
1 from DeepPurpose.dataset import *
2
Let us dive deeper and show what the dataset looks like:
1 # We look at the first 10 SMILES strings
2 print (X_drugs[:10])
3 """
4 [’CC1=C(SC(=N1)NC(=O)COC2=CC=CC=C2OC)C’ ’CC1=CC=C(C=C1)C(=O)NCCCN2CCOCC2’
5 ’CSC1=CC=C(C=C1)C(=O)NC2CCSC3=CC=CC=C23’
6 ’CCOC(=O)N1CCC(CC1)N2CC34C=CC(O3)C(C4C2=O)C(=O)NC5=CC=C(C=C5)C’
7 ’CC1=CC(=NN1C(=O)C2=CC(=CC(=C2)[N+](=O)[O-])[N+](=O)[O-])C’
8 ’CC1=CC=C(C=C1)C(=O)CSC2=NN=C(N2CC3=CC=CO3)CNC4=C(C=C(C=C4)C)C’
9 ’CC(C1=CC(=C(C=C1)Cl)Cl)NC(=O)CCl’
10 ’CCOC(=O)CN1CC23C=CC(O2)C(C3C1=O)C(=O)NC4=CC5=C(C=C4)OCO5’
11 ’COC(=O)C1=CC=C(C=C1)COC(=O)C2=CC(=C(N=C2)Cl)Cl’
12 ’C1=CC=C2C(=C1)C=C(C(=O)O2)C3=C(C=C(C=C3)NC(=O)CC4=CC=C(C=C4)Cl)Cl]
13 """
14
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 145 — #148
First, we combine all SMILES and labels to be a pandas.DataFrame and split the
data into training and test portions by 80% and 20%.
1 # split the dataset into 80%: 20%
2
6 print (train)
7
8 """
9 SMILES Label
10 0 CC1=C(SC(=N1)NC(=O)COC2=CC=CC=C2OC)C 0
11 1 CC1=CC=C(C=C1)C(=O)NCCCN2CCOCC2 0
12 2 CSC1=CC=C(C=C1)C(=O)NC2CCSC3=CC=CC=C23 0
13 3 CCOC(=O)N1CCC(CC1)N2CC34C=CC(O3)C(C4C2=O)C(=O)... 0
14 4 CC1=CC(=NN1C(=O)C2=CC(=CC(=C2)[N+](=O)[O-])[N+... 0
15 .. ... ...
16 795 COC1=CC=CC=C1N(CC(=O)NC2=CC(=C(C=C2)Cl)C(=O)OC... 1
17 796 COC1=C(C=C(C=C1Cl)C(=O)NN)Cl 1
18 797 C1CC(C2=CC=CC=C2C1)NC(=O)CCC(=O)N3CCN(CC3)S(=O... 0
19 798 CC(=O)NC1=CC=C(C=C1)N(C(C2=CC=C(C=C2)OC)C(=O)N... 1
20 799 CC(C)N(CC1=CC=CC=C1)CC(COC2=CC=CC3=C2C(=CN3)CC... 0
21
To process the SMILES strings, we use the rdkit package 8 . The main function below
is the create_dataset function, which takes the training or test datasets in and loops
over the SMILES strings. Within the for-loop, smiles is the SMILES string of the
current molecule, and property is the label.
1 def create_dataset(data_in, radius=2):
2 dataset = []
3
4 for smiles, property in data_in.values:
5 try:
6 """Create each data with the above defined functions."""
7 mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
8 atoms = create_atoms(mol, atom_dict)
9 molecular_size = len(atoms)
10 i_jbond_dict = create_ijbonddict(mol, bond_dict)
11 fingerprints = extract_fingerprints(radius, atoms, i_jbond_dict,
12 fingerprint_dict, edge_dict)
13 adjacency = Chem.GetAdjacencyMatrix(mol)
14
8 https://www.rdkit.org/
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 146 — #149
23 except:
24 pass
25
26 return dataset
6
7 def create_atoms(mol, atom_dict):
8 """Transform the atom types in a molecule (e.g., H, C, and O)
9 into the indices (e.g., H=0, C=1, and O=2).
10 Note that each atom index considers the aromaticity.
11 """
12 atoms = [a.GetSymbol() for a in mol.GetAtoms()]
13 for a in mol.GetAromaticAtoms():
14 i = a.GetIdx()
15 atoms[i] = (atoms[i], ’aromatic’)
16 atoms = [atom_dict[a] for a in atoms]
17 return np.array(atoms)
18
19
33
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 147 — #150
39
43 else:
44 nodes = atoms
45 i_jedge_dict = i_jbond_dict
46
47 for _ in range(radius):
48
49 """Update each node ID considering its neighboring nodes and edges.
50 The updated node IDs are the fingerprint IDs.
51 """
52 nodes_ = []
53 for i, j_edge in i_jedge_dict.items():
54 neighbors = [(nodes[j], edge) for j, edge in j_edge]
55 fingerprint = (nodes[i], tuple(sorted(neighbors)))
56 nodes_.append(fingerprint_dict[fingerprint])
57
58 """Also update each edge ID considering
59 its two nodes on both sides.
60 """
61 i_jedge_dict_ = defaultdict(lambda: [])
62 for i, j_edge in i_jedge_dict.items():
63 for j, edge in j_edge:
64 both_side = tuple(sorted((nodes[i], nodes[j])))
65 edge = edge_dict[(both_side, edge)]
66 i_jedge_dict_[i].append((j, edge))
67
68 nodes = nodes_
69 i_jedge_dict = i_jedge_dict_
70
71 return np.array(nodes)
To this end, let us look at the summary of the first two molecules in the training set,
following the example below, which includes:
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 148 — #151
• The atom index of the molecule (in tensor form), which will be used to index the
learnable atom embeddings in the later molecule graph neural network class.
• Adjacency matrix of the molecular graph (in the tensor form).
• Molecule size (such as 36).
• Molecule property label (such as 0).
5 """
6 [(tensor([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 30, 29, 28,
7 31, 32, 33, 34, 34, 34, 35, 36, 36, 37, 37, 37, 37, 38, 38, 38, 34, 34, 34]),
8 tensor([[0., 1., 0., ..., 0., 0., 0.],
9 [1., 0., 1., ..., 0., 0., 0.],
10 [0., 1., 0., ..., 0., 0., 0.],
11 ...,
12 [0., 0., 0., ..., 0., 0., 0.],
13 [0., 0., 0., ..., 0., 0., 0.],
14 [0., 0., 0., ..., 0., 0., 0.]]), 36, 0),
15
16 (tensor([46, 47, 48, 48, 49, 48, 48, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
17 58, 57, 34, 34, 34, 37, 37, 37, 37, 60, 61, 61, 62, 62, 61, 61, 61, 61, 36,
18 36, 36, 36, 61, 61]),
19 tensor([[0., 1., 0., ..., 0., 0., 0.],
20 [1., 0., 1., ..., 0., 0., 0.],
21 [0., 1., 0., ..., 0., 0., 0.],
22 ...,
23 [0., 0., 0., ..., 0., 0., 0.],
24 [0., 0., 0., ..., 0., 0., 0.],
25 [0., 0., 0., ..., 0., 0., 0.]]), 41, 0)]
26 """
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 149 — #152
potential numerical issues while un-normalized adjacency works fine here) to aggregate
the neighborhood information to be the current hidden node embedding. The output
graph embedding is calculated by summing over all the node embeddings and then we
apply a final linear layer to obtain the predicted probability of molecule property.
In the implementation, the self.forward() function of the model takes two steps:
(i) use gnn with molecule adjacency matrix to get molecular graph embeddings; (ii)
apply a property prediction layer on top of the graph embedding to predict the binary
property label. The self.gnn() part is the core of the MoleculeGNN class, which
implements batch processing for molecule graphs with different sizes using padding.
1 class MolecularGNN(nn.Module):
2 """
3 based on https://github.com/masashitsubaki/molecularGNN_smiles
4 """
5 def __init__(self, N_fingerprints, dim, layer_gnn_hidden):
6 super(MolecularGNN, self).__init__()
7 # learnable atom initial features
8 self.embed_fingerprint = nn.Embedding(N_fingerprints, dim)
9 # gnn layers (will be used together with the adj in self.gnn)
10 self.W_fingerprint = nn.ModuleList([nn.Linear(dim, dim)
11 for _ in range(layer_gnn_hidden)])
12 # final prediction layers
13 self.W_property = nn.Linear(dim, 1)
14
15 def pad(self, matrices, pad_value):
16 """Pad the list of matrices
17 with a pad_value (e.g., 0) for batch processing.
18 For example, given a list of matrices [A, B, C],
19 we obtain a new matrix [A00, 0B0, 00C],
20 where 0 is the zero (i.e., pad value) matrix.
21 """
22 shapes = [m.shape for m in matrices]
23 M, N = sum([s[0] for s in shapes]), sum([s[1] for s in shapes])
24 zeros = torch.FloatTensor(np.zeros((M, N)))
25 pad_matrices = pad_value + zeros
26 i, j = 0, 0
27 for k, matrix in enumerate(matrices):
28 m, n = shapes[k]
29 pad_matrices[i:i+m, j:j+n] = matrix
30 i += m
31 j += n
32 return pad_matrices
33
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 150 — #153
45
12 model.train()
13 for i in range(0, len(dataset_train), batch_size):
9 This is because by combining the operations into one layer, we take advantage of the log-sum-exp trick
for numerical stability.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 151 — #154
23 optimizer.zero_grad()
24 loss.backward()
25 optimizer.step()
26 train_loss += loss.item()
27
39 predicted += pred
40 groundtruth += label
41
45 """
46 --- epoch: 0 ---, train loss: 39.3298280239, test AUROC: 0.5515151515151515
47 --- epoch: 1 ---, train loss: 24.1569204330, test AUROC: 0.6504513140027158
48 --- epoch: 2 ---, train loss: 18.3505403399, test AUROC: 0.7955907021327583
49 --- epoch: 3 ---, train loss: 15.5090635120, test AUROC: 0.8324640937174036
50 --- epoch: 4 ---, train loss: 12.8106079697, test AUROC: 0.8664265706282513
51 --- epoch: 5 ---, train loss: 11.1026673614, test AUROC: 0.9002656363197294
52 --- epoch: 6 ---, train loss: 9.69138415157, test AUROC: 0.9347860791826309
53 --- epoch: 7 ---, train loss: 8.68385870754, test AUROC: 0.9528820856254485
54 --- epoch: 8 ---, train loss: 7.63521204888, test AUROC: 0.9483418367346939
55 --- epoch: 9 ---, train loss: 6.85020373761, test AUROC: 0.969544766004943
56 --- epoch: 10 ---, train loss: 6.0492331385, test AUROC: 0.9728867623604466
57 --- epoch: 11 ---, train loss: 5.4066562131, test AUROC: 0.989516129032258
58 --- epoch: 12 ---, train loss: 4.9491942748, test AUROC: 0.9864766964501328
59 --- epoch: 13 ---, train loss: 4.4005933329, test AUROC: 0.9948178266762338
60 --- epoch: 14 ---, train loss: 3.9376147910, test AUROC: 0.9946519795657727
61 """
If readers want to practice by their own, a complete notebook can be found in this
public repository 10 .
10 https://github.com/sunlabuiuc/pyhealth-book/tree/main/chap7-GNN/notebook
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 152 — #155
In this section, we are going to revisit the ChestXray image classification challenge
using Graph Neural Networks (GNNs). In previous sections, Convolutional Neural
Network (CNN) models were employed for ChestXray image classification, treating
each patient’s image data as an independent sample. However, GNNs offer a novel
approach that capitalizes on leveraging demographic similarities to potentially enhance
predictions. This novel approach operates under the premise that patients with similar
demographic features may exhibit similar labels in their X-ray images.
The approach involves constructing a comprehensive graph where each node repre-
sents a single patient (we assume each patient only has a single x-ray image) Crucially,
the connection between nodes in this graph is established based on the similarity of
patient demographics. Consequently, during training, the embeddings of neighboring
nodes play a pivotal role in updating the embedding of the current node. To manage
instances where numerous images are connected to a single image, a neighborhood
sampling strategy, akin to the one proposed in GraphSAGE by Hamilton et al. [76], is
employed. This strategy aids in managing and updating node embeddings effectively
in a scalable way, optimizing the learning process within the GNN framework for
enhanced ChestXray image classification.
We will use pyhealth modules to implement the whole pipeline.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 153 — #156
path to the ChestXray as the argument, and everything else is handled by the API
already.
1 from pyhealth.datasets import COVID19CXRDataset
2
3 root = "/srv/local/data/COVID-19_Radiography_Dataset"
4 base_dataset = COVID19CXRDataset(root)
5 sample_dataset = base_dataset.set_task()
Similar to Section 4.3, we add one more sample transformation step to further clean up
the image data, which involves enforcing channel consistency, resizing and normalizing
images. The transformation is made possible with the featurizer design in the module.
We does not show the data transformation for other pipelines for simplicity.
1 from torchvision import transforms
2
3 transform = transforms.Compose([
4 transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)),
5 transforms.Resize((224, 224)),
6 transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304])
7 ])
8
9 def encode(sample):
10 sample["path"] = transform(sample["path"])
11 return sample
12
13 sample_dataset.set_transform(encode)
After data transformation, as usual, we split the data into training, validation, and
test by 70% : 10% : 20%. This time, we leverage another data splitter, which is split
by sample, meaning that we do not care whether the data from the same patients are in
the same datasets (either training, validation, or test).
1 from pyhealth.datasets import split_by_sample
2
3 # Get Index of train, valid, test set
4 train_index, val_index, test_index = split_by_sample(
5 dataset=sample_dataset,
6 ratios=[0.7, 0.1, 0.2],
7 get_index = True
8 )
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 154 — #157
3 model = Graph_TorchvisionModel(
4 dataset=sample_dataset,
5 feature_keys=["path"],
6 label_key="label",
7 mode="multiclass",
8 model_name="resnet18",
9 model_config={},
10 gnn_config={"input_dim": 256, "hidden_dim": 128},
11 )
A new procedure in this pipeline is that we need to build an image graph structure.
The demographic information of each patient is already stored in sample_dataset
and our initialized model will build a graph based on the information, so we input the
sample_dataset as an argument.
Next, we leverage the neighborhood sampler to simplify the graph and sparsify the
connections by the GraphSAGE model [76]. After doing so, the training, validation,
and test sets are prepared for training and evaluation.
1 from pyhealth.sampler import NeighborSampler
2
3 graph = model.build_graph(sample_dataset, random = True)
4
8 # We sample all edges connected to target node for validation and test (Sizes =
[-1, -1])
9 valid_dataloader = NeighborSampler(sample_dataset, graph["edge_index"],
node_idx=val_index, sizes=[-1, -1],
batch_size=64, shuffle=False,
num_workers=12)
10 test_dataloader = NeighborSampler(sample_dataset, graph["edge_index"], node_idx
=test_index, sizes=[-1, -1], batch_size
=64, shuffle=False, num_workers=12)
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 155 — #158
After training the model for 10 epochs, we could evaluate the model performance on
the test set.
1 print(resnet_trainer.evaluate(test_dataloader))
2 """
3 {’accuracy’: 0.4786590097780537, ’f1_macro’: 0.1618557783142647, ’f1_micro’: 0.
4786590097780537, ’loss’: 1.
256981566770753}
4 """
11 https://github.com/sunlabuiuc/pyhealth-
book/blob/main/chap6/notebook/graph_torchvision_model.ipynb
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 156 — #159
7.5 Takeaways
• Graph Data: Many real-world data are irregular and are represented as graphs
(compared to images and time-series which are regular grid-like data). Specifi-
cally, graph is a data structure that consists of a set of nodes and a set of edges
connecting these nodes.
• Common Graph Neural Networks (GNN): Graph convolutional networks (GCN),
graph attention network (GAT), message passing neural networks (MPNN) are
common GNN models. They differ in that GCN treats all edges equally, GAT
learns the edge weights by attention modules, and MPNN allows flexible aggre-
gation mechanism.
• Advancements in GNN Training: Some key components of modern GNN training
tricks are practically useful, such as neighborhood sampling, distributed graph
partition and training, heterogeneous graph modeling.
• GNN Applications in CV and NLP: In the domains of computer vision and natu-
ral language processing, researchers basically construct different types of graph
structures from data and then apply the GNN model to learn node or graph
embeddings.
• Molecule Property Prediction Example: GNN is powerful in learning molecule
graph structures by treating atom as nodes and bonds as edges for predicting the
molecular structure properties.
• PyHealth GNN code examples: GNN could be applied on similarity graphs con-
structed by patient demongraphics, which could further enhance the predictive
model, such as on ChestXray image classification.
Questions
• What are the key elements of a graph? What is the adjacency matrix of a graph?
What are the degree matrix of a graph?
• Comparing two different ways of normalizing the adjacency graph?
• In the section, we discuss the undirected graph. However, the concepts of adjacency
matrix, degree matrix, and normalized adjacency graph can be generalized to
directed graph as well. Could you specify these concepts for the graph shown in
Figure 7.2.
• What are graph neural networks? What are the key difference between GNNs and
DNNs? Could you explain the difference between GCN and multi-layer perceptron
(both use ReLU as activation)?
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 157 — #160
• Could you elaborate on the difference between GCN and GAT? Why MPNN is a
general form? Could you use MPNN equation (i.e., Message and Update) function
to formulate the GCN and GAT networks?
• Why GNN is suitable for molecule property prediction? Could you choose your own
favorite molecule graph and your favorite GNN models and explain how to use
this GNN model to model the molecule graph structure?
• In Section 7.4, we connect the graph by patient demographics features, could you
connect a new Xray image graph by image pixel-level similarity and re-implement
the whole pipeline again. Show the results and explain the why two different
graphs lead to different final prediction results.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 256 — #259
[1] C. C. Tappert, “Who is the father of deep learning?” in 2019 International Conference
on Computational Science and Computational Intelligence (CSCI), 2019, pp. 343–348.
[2] A. E. Johnson, L. Bulgarelli, L. Shen, A. Gayles, A. Shammout, S. Horng, T. J. Pollard,
S. Hao, B. Moody, B. Gow et al., “Mimic-iv, a freely accessible electronic health record
dataset,” Scientific data, vol. 10, no. 1, p. 1, 2023.
[3] C. Yang, Z. Wu, P. Jiang, Z. Lin, J. Gao, B. P. Danek, and J. Sun, “Pyhealth: A deep
learning toolkit for healthcare applications,” in Proceedings of the 29th ACM SIGKDD
Conference on Knowledge Discovery and Data Mining, 2023, pp. 5788–5789.
[4] C. Yang, M. B. Westover, and J. Sun, “ManyDG: Many-domain generalization
for healthcare applications,” in The Eleventh International Conference on Learning
Representations, 2023. [Online]. Available: https://openreview.net/forum?id=lcSfirnflpW
[5] E. Choi, Z. Xu, Y. Li, M. W. Dusenberry, G. Flores, Y. Xue, and A. M. Dai, “Learning the
graphical structure of electronic health records with graph convolutional transformer,”
2020.
[6] C. Xiao, T. Ma, A. B. Dieng, D. M. Blei, and F. Wang, “Readmission prediction via deep
contextual embedding of clinical concepts,” PloS one, vol. 13, no. 4, p. e0195024, 2018.
[7] E. Choi, M. T. Bahadori, J. Sun, J. Kulas, A. Schuetz, and W. Stewart, “Retain: An
interpretable predictive model for healthcare using reverse time attention mechanism,”
Advances in neural information processing systems, vol. 29, 2016.
[8] S. Hochreiter and J. Schmidhuber, “Long short-term memory,” Neural computation,
vol. 9, no. 8, pp. 1735–1780, 1997.
[9] J. Chung, C. Gulcehre, K. Cho, and Y. Bengio, “Empirical evaluation of gated recurrent
neural networks on sequence modeling,” arXiv preprint arXiv:1412.3555, 2014.
[10] E. Choi, A. Schuetz, W. F. Stewart, and J. Sun, “Medical concept representation learning
from electronic health records and its application on heart failure prediction,” arXiv
preprint arXiv:1602.03686, 2016.
[11] S. Mallya, M. Overhage, N. Srivastava, T. Arai, and C. Erdman, “Effectiveness of lstms
in predicting congestive heart failure onset,” arXiv preprint arXiv:1902.02443, 2019.
[12] G. Maragatham and S. Devi, “Lstm model for prediction of heart failure in big data,”
Journal of medical systems, vol. 43, pp. 1–13, 2019.
[13] K. Fukushima, “Neocognitron: A self-organizing neural network model for a mechanism
of pattern recognition unaffected by shift in position,” Biological Cybernetics, vol. 36,
no. 4, pp. 193–202, Apr 1980. [Online]. Available: https://doi.org/10.1007/BF00344251
[14] Y. LeCun, B. Boser, J. Denker, D. Henderson, R. Howard, W. Hubbard, and L. Jackel,
“Handwritten digit recognition with a back-propagation network,” Advances in neural
information processing systems, vol. 2, 1989.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 257 — #260
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 258 — #261
[29] Beam search strategies for neural machine translation. Association for Computational
Linguistics, 2017.
[30] D. Bahdanau, K. Cho, and Y. Bengio, “Neural machine translation by jointly learning to
align and translate,” arXiv preprint arXiv:1409.0473, 2014.
[31] M.-T. Luong, H. Pham, and C. D. Manning, “Effective approaches to attention-based
neural machine translation,” 2015.
[32] L. Cui, S. Biswal, L. M. Glass, G. Lever, J. Sun, and C. Xiao, “CONAN: Complementary
pattern augmentation for rare disease detection,” AAAI, vol. 34, no. 01, pp. 614–621, Apr.
2020.
[33] Z. Yang, A. Mitra, W. Liu, D. Berlowitz, and H. Yu, “TransformEHR: transformer-
based encoder-decoder generative model to enhance prediction of disease outcomes using
electronic health records,” Nat. Commun., vol. 14, no. 1, p. 7857, Nov. 2023.
[34] B. Theodorou, C. Xiao, and J. Sun, “Synthesize high-dimensional longitudinal electronic
health records via hierarchical autoregressive language model,” Nat. Commun., vol. 14,
no. 1, p. 5305, Aug. 2023.
[35] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. u.
Kaiser, and I. Polosukhin, “Attention is all you need,” in Advances in Neural
Information Processing Systems, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach,
R. Fergus, S. Vishwanathan, and R. Garnett, Eds., vol. 30. Curran Associates,
Inc., 2017. [Online]. Available: https://proceedings.neurips.cc/paper_files/paper/2017/
file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
[36] J. Devlin, M. Chang, K. Lee, and K. Toutanova, “BERT: pre-training of deep bidirectional
transformers for language understanding,” in Proceedings of the 2019 Conference of
the North American Chapter of the Association for Computational Linguistics: Human
Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019,
Volume 1 (Long and Short Papers), J. Burstein, C. Doran, and T. Solorio, Eds.
Association for Computational Linguistics, 2019, pp. 4171–4186. [Online]. Available:
https://doi.org/10.18653/v1/n19-1423
[37] E. Alsentzer, J. Murphy, W. Boag, W.-H. Weng, D. Jindi, T. Naumann, and
M. McDermott, “Publicly available clinical BERT embeddings,” in Proceedings of
the 2nd Clinical Natural Language Processing Workshop, A. Rumshisky, K. Roberts,
S. Bethard, and T. Naumann, Eds. Minneapolis, Minnesota, USA: Association
for Computational Linguistics, Jun. 2019, pp. 72–78. [Online]. Available:
https://aclanthology.org/W19-1909
[38] N. Kitaev, £. Kaiser, and A. Levskaya, “Reformer: The efficient transformer,” arXiv
preprint arXiv:2001.04451, 2020.
[39] I. Beltagy, M. E. Peters, and A. Cohan, “Longformer: The long-document transformer,”
arXiv preprint arXiv:2004.05150, 2020.
[40] Z. Lan, M. Chen, S. Goodman, K. Gimpel, P. Sharma, and R. Soricut, “Albert:
A lite bert for self-supervised learning of language representations,” arXiv preprint
arXiv:1909.11942, 2019.
[41] S. Wang, B. Z. Li, M. Khabsa, H. Fang, and H. Ma, “Linformer: Self-attention with linear
complexity,” arXiv preprint arXiv:2006.04768, 2020.
[42] K. Choromanski, V. Likhosherstov, D. Dohan, X. Song, A. Gane, T. Sarlos, P. Hawkins,
J. Davis, A. Mohiuddin, L. Kaiser et al., “Rethinking attention with performers,” arXiv
preprint arXiv:2009.14794, 2020.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 259 — #262
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 260 — #263
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 261 — #264
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 262 — #265
[91] H. Xu, C. Jiang, X. Liang, and Z. Li, “Spatial-aware graph relation network for large-scale
object detection,” in Proceedings of the IEEE/CVF Conference on Computer Vision and
Pattern Recognition, 2019, pp. 9298–9307.
[92] K. Han, Y. Wang, J. Guo, Y. Tang, and E. Wu, “Vision gnn: An image is worth graph
of nodes,” Advances in Neural Information Processing Systems, vol. 35, pp. 8291–8303,
2022.
[93] S. Yan, Y. Xiong, and D. Lin, “Spatial temporal graph convolutional networks for skeleton-
based action recognition,” in Proceedings of the AAAI conference on artificial intelligence,
vol. 32, no. 1, 2018.
[94] G. Brasó and L. Leal-Taixé, “Learning a neural solver for multiple object tracking,” in
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition,
2020, pp. 6247–6257.
[95] A. S. Heinsfeld, A. R. Franco, R. C. Craddock, A. Buchweitz, and F. Meneguzzi, “Iden-
tification of autism spectrum disorder using deep learning and the abide dataset,” Neu-
roImage: Clinical, vol. 17, pp. 16–23, 2018.
[96] B.-H. Kim, J. C. Ye, and J.-J. Kim, “Learning dynamic graph representation of brain
connectome with spatio-temporal attention,” Advances in Neural Information Processing
Systems, vol. 34, pp. 4314–4327, 2021.
[97] L. Wu, Y. Chen, K. Shen, X. Guo, H. Gao, S. Li, J. Pei, B. Long et al., “Graph neural net-
works for natural language processing: A survey,” Foundations and Trends® in Machine
Learning, vol. 16, no. 2, pp. 119–328, 2023.
[98] C. Zhang, Q. Li, and D. Song, “Aspect-based sentiment classification with aspect-specific
graph convolutional networks,” in Proceedings of the 2019 Conference on Empirical
Methods in Natural Language Processing and the 9th International Joint Conference on
Natural Language Processing (EMNLP-IJCNLP), K. Inui, J. Jiang, V. Ng, and X. Wan,
Eds. Hong Kong, China: Association for Computational Linguistics, Nov. 2019, pp.
4568–4578. [Online]. Available: https://aclanthology.org/D19-1464
[99] K. Xu, L. Wu, Z. Wang, Y. Feng, M. Witbrock, and V. Sheinin, “Graph2seq: Graph to se-
quence learning with attention-based neural networks,” arXiv preprint arXiv:1804.00823,
2018.
[100] T. Wang, X. Wan, and H. Jin, “AMR-to-text generation with graph transformer,” Trans-
actions of the Association for Computational Linguistics, vol. 8, 2020.
[101] M. Xu, L. Li, D. Wong, Q. Liu, L. S. Chao et al., “Document graph for neural machine
translation,” arXiv preprint arXiv:2012.03477, 2020.
[102] Y. Chen, L. Wu, and M. J. Zaki, “Graphflow: exploiting conversation flow with graph neu-
ral networks for conversational machine comprehension,” in Proceedings of the Twenty-
Ninth International Joint Conference on Artificial Intelligence, ser. IJCAI’20, 2021.
[103] D. Cai and W. Lam, “Graph transformer for graph-to-sequence learning,” in Proceedings
of the AAAI conference on artificial intelligence, vol. 34, no. 05, 2020, pp. 7464–7471.
[104] C. Zheng and P. Kordjamshidi, “SRLGRN: Semantic role labeling graph reasoning net-
work,” in Proceedings of the 2020 Conference on Empirical Methods in Natural Language
Processing (EMNLP), B. Webber, T. Cohn, Y. He, and Y. Liu, Eds. Online: Association
for Computational Linguistics, Nov. 2020.
[105] Z. Guo, Y. Zhang, Z. Teng, and W. Lu, “Densely connected graph convolutional networks
for graph-to-sequence learning,” Transactions of the Association for Computational Lin-
guistics, vol. 7, pp. 297–312, 2019.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 263 — #266
[106] L. Song, Y. Zhang, Z. Wang, and D. Gildea, “A graph-to-sequence model for AMR-
to-text generation,” in Proceedings of the 56th Annual Meeting of the Association for
Computational Linguistics (Volume 1: Long Papers), I. Gurevych and Y. Miyao, Eds.
Melbourne, Australia: Association for Computational Linguistics, Jul. 2018.
[107] S. Li, L. Wu, S. Feng, F. Xu, F. Xu, and S. Zhong, “Graph-to-tree neural networks
for learning structured input-output translation with applications to semantic parsing
and math word problem,” in Findings of the Association for Computational Linguistics:
EMNLP 2020, T. Cohn, Y. He, and Y. Liu, Eds. Online: Association for Computational
Linguistics, Nov. 2020.
[108] Q. Fu, L. Song, W. Du, and Y. Zhang, “End-to-end AMR coreference resolution,” in
Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics
and the 11th International Joint Conference on Natural Language Processing (Volume
1: Long Papers), C. Zong, F. Xia, W. Li, and R. Navigli, Eds. Online: Association for
Computational Linguistics, Aug. 2021.
[109] J. Han, B. Cheng, and X. Wang, “Open domain question answering based on text en-
hanced knowledge graph with hyperedge infusion,” in Findings of the Association for
Computational Linguistics: EMNLP 2020, T. Cohn, Y. He, and Y. Liu, Eds. Online:
Association for Computational Linguistics, Nov. 2020.
[110] Y. Luo and H. Zhao, “Bipartite flat-graph network for nested named entity recognition,”
in Proceedings of the 58th Annual Meeting of the Association for Computational Lin-
guistics, D. Jurafsky, J. Chai, N. Schluter, and J. Tetreault, Eds. Online: Association for
Computational Linguistics, Jul. 2020.
[111] P. Kapanipathi, V. Thost, S. S. Patel, S. Whitehead, I. Abdelaziz, A. Balakrishnan,
M. Chang, K. Fadnis, C. Gunasekara, B. Makni et al., “Infusing knowledge into the
textual entailment task using graph convolutional networks,” in Proceedings of the AAAI
Conference on Artificial Intelligence, vol. 34, no. 05, 2020, pp. 8074–8081.
[112] O. Wieder, S. Kohlbacher, M. Kuenemann, A. Garon, P. Ducrot, T. Seidel, and T. Langer,
“A compact review of molecular property prediction with graph neural networks,” Drug
Discovery Today: Technologies, vol. 37, pp. 1–12, 2020.
[113] D. Jiang, Z. Wu, C.-Y. Hsieh, G. Chen, B. Liao, Z. Wang, C. Shen, D. Cao, J. Wu, and
T. Hou, “Could graph neural networks learn better molecular representation for drug
discovery? a comparison study of descriptor-based and graph-based models,” Journal of
cheminformatics, vol. 13, no. 1, pp. 1–23, 2021.
[114] J. Lim, S. Ryu, K. Park, Y. J. Choe, J. Ham, and W. Y. Kim, “Predicting drug–target
interaction using a novel graph neural network with 3d structure-embedded graph repre-
sentation,” Journal of chemical information and modeling, vol. 59, no. 9, pp. 3981–3988,
2019.
[115] I. Ghebrehiwet, N. Zaki, R. Damseh, and M. S. Mohamad, “Revolutionizing personalized
medicine with generative ai: A systematic review,” 2024.
[116] W. H. Pinaya, M. S. Graham, E. Kerfoot, P.-D. Tudosiu, J. Dafflon, V. Fernandez,
P. Sanchez, J. Wolleb, P. F. da Costa, A. Patel et al., “Generative ai for medical imaging:
extending the monai framework,” arXiv preprint arXiv:2307.15208, 2023.
[117] X. Zeng, F. Wang, Y. Luo, S.-g. Kang, J. Tang, F. C. Lightstone, E. F. Fang, W. Cornell,
R. Nussinov, and F. Cheng, “Deep generative molecular design reshapes drug discovery,”
Cell Reports Medicine, 2022.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 264 — #267
[118] T. Das, Z. Wang, and J. Sun, “Twin: Personalized clinical trial digital twin generation,” in
Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data
Mining, 2023, pp. 402–413.
[119] Z. Wang, C. Gao, L. M. Glass, and J. Sun, “Artificial intelligence for in silico clinical
trials: A review,” arXiv preprint arXiv:2209.09023, 2022.
[120] K. Yu, Y. Wang, Y. Cai, C. Xiao, E. Zhao, L. Glass, and J. Sun, “Rare disease de-
tection by sequence modeling with generative adversarial networks,” arXiv preprint
arXiv:1907.01022, 2019.
[121] B. Yelmen, A. Decelle, L. Ongaro, D. Marnetto, C. Tallec, F. Montinaro, C. Furtlehner,
L. Pagani, and F. Jay, “Creating artificial human genomes using generative neural net-
works,” PLoS genetics, vol. 17, no. 2, p. e1009303, 2021.
[122] B. Theodorou, C. Xiao, and J. Sun, “Synthesize high-dimensional longitudinal electronic
health records via hierarchical autoregressive language model,” Nature communications,
vol. 14, no. 1, p. 5305, 2023.
[123] Z. Wang, Q. She, A. F. Smeaton, T. E. Ward, and G. Healy, “Synthetic-neuroscore: Using
a neuro-ai interface for evaluating generative adversarial networks,” Neurocomputing,
vol. 405, pp. 26–36, 2020.
[124] T. Golany, K. Radinsky, and D. Freedman, “Simgans: Simulator-based generative adver-
sarial networks for ecg synthesis to improve deep ecg classification,” in International
Conference on Machine Learning. PMLR, 2020, pp. 3597–3606.
[125] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville,
and Y. Bengio, “Generative adversarial nets,” Advances in neural information processing
systems, vol. 27, 2014.
[126] ——, “Generative adversarial networks,” Communications of the ACM, vol. 63, no. 11,
pp. 139–144, 2020.
[127] D. P. Kingma and M. Welling, “Auto-encoding variational bayes,” arXiv preprint
arXiv:1312.6114, 2013.
[128] G. E. Hinton and R. R. Salakhutdinov, “Reducing the dimensionality of data with neural
networks,” science, vol. 313, no. 5786, pp. 504–507, 2006.
[129] P. Vincent, H. Larochelle, Y. Bengio, and P.-A. Manzagol, “Extracting and composing
robust features with denoising autoencoders,” in Proceedings of the 25th international
conference on Machine learning, 2008, pp. 1096–1103.
[130] L. Weng, “From autoencoder to beta-vae,” lilianweng.github.io, 2018. [Online].
Available: https://lilianweng.github.io/posts/2018-08-12-vae/
[131] A. Makhzani and B. Frey, “K-sparse autoencoders,” arXiv preprint arXiv:1312.5663,
2013.
[132] M. Arjovsky, S. Chintala, and L. Bottou, “Wasserstein generative adversarial networks,”
in International conference on machine learning. PMLR, 2017, pp. 214–223.
[133] J.-Y. Zhu, T. Park, P. Isola, and A. A. Efros, “Unpaired image-to-image translation us-
ing cycle-consistent adversarial networks,” in Proceedings of the IEEE international
conference on computer vision, 2017, pp. 2223–2232.
[134] I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, S. Mohamed, and
A. Lerchner, “beta-vae: Learning basic visual concepts with a constrained variational
framework,” in International conference on learning representations, 2016.
[135] P. Dhariwal and A. Nichol, “Diffusion models beat gans on image synthesis,” Advances
in neural information processing systems, vol. 34, pp. 8780–8794, 2021.
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 265 — #268
[136] J. Ho, C. Saharia, W. Chan, D. J. Fleet, M. Norouzi, and T. Salimans, “Cascaded diffusion
models for high fidelity image generation,” The Journal of Machine Learning Research,
vol. 23, no. 1, pp. 2249–2281, 2022.
[137] L. Yang, Z. Zhang, Y. Song, S. Hong, R. Xu, Y. Zhao, W. Zhang, B. Cui, and M.-H.
Yang, “Diffusion models: A comprehensive survey of methods and applications,” ACM
Computing Surveys, vol. 56, no. 4, pp. 1–39, 2023.
[138] J. Sohl-Dickstein, E. A. Weiss, N. Maheswaranathan, and S. Ganguli, “Deep unsupervised
learning using nonequilibrium thermodynamics,” 2015.
[139] Y. Song and S. Ermon, “Generative modeling by estimating gradients of the data distri-
bution,” 2020.
[140] J. Ho, A. Jain, and P. Abbeel, “Denoising diffusion probabilistic models,” 2020.
[141] M. Welling and Y. W. Teh, “Bayesian learning via stochastic gradient langevin dynamics,”
in Proceedings of the 28th international conference on machine learning (ICML-11),
2011, pp. 681–688.
[142] L. Weng, “What are diffusion models?” lilianweng.github.io, Jul 2021. [Online].
Available: https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
[143] R. Bellman, “The theory of dynamic programming,” Bulletin of the American Mathemat-
ical Society, vol. 60, no. 6, pp. 503–515, 1954.
[144] ——, “Dynamic programming,” Science, vol. 153, no. 3731, pp. 34–37, 1966.
[145] V. Mnih, K. Kavukcuoglu, D. Silver, A. Graves, I. Antonoglou, D. Wierstra, and M. Ried-
miller, “Playing atari with deep reinforcement learning,” arXiv preprint arXiv:1312.5602,
2013.
[146] D. Silver, G. Lever, N. Heess, T. Degris, D. Wierstra, and M. Riedmiller, “Deterministic
policy gradient algorithms,” in International conference on machine learning. Pmlr,
2014, pp. 387–395.
[147] R. S. Sutton, D. McAllester, S. Singh, and Y. Mansour, “Policy gradient methods for
reinforcement learning with function approximation,” Advances in neural information
processing systems, vol. 12, 1999.
[148] V. Mnih, A. P. Badia, M. Mirza, A. Graves, T. Lillicrap, T. Harley, D. Silver, and
K. Kavukcuoglu, “Asynchronous methods for deep reinforcement learning,” in Inter-
national conference on machine learning. PMLR, 2016, pp. 1928–1937.
[149] J. Schulman, F. Wolski, P. Dhariwal, A. Radford, and O. Klimov, “Proximal policy
optimization algorithms,” arXiv preprint arXiv:1707.06347, 2017.
[150] T. P. Lillicrap, J. J. Hunt, A. Pritzel, N. Heess, T. Erez, Y. Tassa, D. Silver, and
D. Wierstra, “Continuous control with deep reinforcement learning,” arXiv preprint
arXiv:1509.02971, 2015.
[151] J. Schulman, S. Levine, P. Abbeel, M. Jordan, and P. Moritz, “Trust region policy opti-
mization,” in International conference on machine learning. PMLR, 2015, pp. 1889–
1897.
[152] S. Fujimoto, H. Hoof, and D. Meger, “Addressing function approximation error in actor-
critic methods,” in International conference on machine learning. PMLR, 2018, pp.
1587–1596.
[153] T. Haarnoja, A. Zhou, P. Abbeel, and S. Levine, “Soft actor-critic: Off-policy maximum
entropy deep reinforcement learning with a stochastic actor,” in International conference
on machine learning. PMLR, 2018, pp. 1861–1870.
[154] G.-Q. Zhang, L. Cui, R. Mueller, S. Tao, M. Kim, M. Rueschman, S. Mariani, D. Mobley,
and S. Redline, “The national sleep research resource: towards a sleep data commons,”
i i
i i
i i
i i
“output” — 2024/3/4 — 6:48 — page 50 — #269
Journal of the American Medical Informatics Association, vol. 25, no. 10, pp. 1351–1358,
2018.
[155] S. F. Quan, B. V. Howard, C. Iber, J. P. Kiley, F. J. Nieto, G. T. O’Connor, D. M. Rapoport,
S. Redline, J. Robbins, J. M. Samet et al., “The sleep heart health study: design, rationale,
and methods,” Sleep, vol. 20, no. 12, pp. 1077–1085, 1997.
i i
i i