Molecules as graphs
In this work, we use the molecular graph representation defined as follows. A molecule is represented by an undirected graph \({\mathcal {G}}=({\mathcal {V}},{\mathcal {E}})\) with up to m nodes, where \({\mathcal {V}}\) and \({\mathcal {E}}\) represent the set of nodes and the set of edges, respectively. The node vectors \({\mathbf {v}}^i \in {\mathcal {V}}\) and edge vectors \({\mathbf {e}}^{i,j} \in {\mathcal {E}}\) are associated with heavy atoms and their bonds, respectively, in the molecule. It should be noted that \({\mathbf {e}}^{i,j}={\mathbf {e}}^{j,i}\) because we use an undirected graph. For the i-th atom, \({\mathbf {v}}^i=(v^{i,1},\ldots ,v^{i,p})\) is a p-dimensional vector formed by concatenating three one-hot vectors indicating the atom type, formal charge, and number of explicit hydrogens. The dimension p depends on the dataset used. For the bond between the i-th and j-th atoms, \({\mathbf {e}}^{i,j}=(e^{i,j,1},\ldots ,e^{i,j,q})\) is a q-dimensional one-hot vector associated with the bond type. We kekulize the molecule for simplicity so that the only bond types to consider are single, double, triple, and none, hence \(q=4\). Additionally, the properties of the molecule are represented as a property vector \({\mathbf {y}}=(y^1,\ldots ,y^l)\).
Graph variational autoencoder
We construct a conditional version of the graph VAE [34] in a non-autoregressive manner. It seeks to find the generative distribution of \({\mathcal {G}}\) conditioned on a latent vector \({\mathbf {z}}\) and a property vector \({\mathbf {y}}\) and parameterized by \(\theta \), denoted as \(p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})\). The prior distributions of \({\mathbf {z}}\) and \({\mathbf {y}}\) are assumed to be \(p({\mathbf {z}})={\mathcal {N}}({\mathbf {z}}|{\mathbf {0}},{\mathbf {I}})\) and \(p({\mathbf {y}})={\mathcal {N}}({\mathbf {y}}|\varvec{\mu }_{\mathbf {y}},\varvec{\Sigma }_{\mathbf {y}})\), respectively. To address the intractability of the posterior distribution \(p_\theta ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})\), we introduce an approximate posterior distribution \(q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})={\mathcal {N}}({\mathbf {z}}|{\varvec{\mu }_{\mathbf {z}}({\mathcal {G}},{\mathbf {y}}), \text {diag}(\varvec{\sigma }^2_{\mathbf {z}}({\mathcal {G}},{\mathbf {y}})}))\), which has a normal distribution and is parameterized by \(\phi \).
The distributions \(q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})\) and \(p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})\) are called the encoder and decoder of the VAE, respectively. For the encoder, we use a message passing neural network (MPNN) [27], which is a variant of a graph neural network that operates directly on graphs of different sizes and is invariant to graph isomorphism. The encoder takes \({\mathcal {G}}\) and \({\mathbf {y}}\) and outputs the mean vector \(\varvec{\mu }_{\mathbf {z}}({\mathcal {G}},{\mathbf {y}})\) and variance vector \(\varvec{\sigma }^2_{\mathbf {z}}({\mathcal {G}},{\mathbf {y}})\), from which \({\mathbf {z}}\) is sampled based on reparametrization as \(\varvec{\mu }_{\mathbf {z}}({\mathcal {G}},{\mathbf {y}})+\varvec{\epsilon } \odot \varvec{\sigma }^2_{\mathbf {z}}({\mathcal {G}},{\mathbf {y}})\) with \(\varvec{\epsilon } \sim {\mathcal {N}}({\mathbf {0}},{\mathbf {I}})\). The decoder is modeled as a fully-connected neural network that outputs \(mp+m(m-1)q/2\) values at once from \({\mathbf {z}}\) and \({\mathbf {y}}\) with node-wise and edge-wise softmax activation. The output values form a probabilistic graph \(g({\mathbf {z}},{\mathbf {y}})\) composed of m nodes and \(m(m-1)/2\) edges.
The original learning objective of the VAE is given with respect the parameters \(\phi \) and \(\theta \) as:
$$\begin{aligned} {\mathcal {L}}_\text {VAE} (\phi ,\theta ) = {\mathbb {E}}_{{\mathbf {z}} \sim q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})} \left[ -\log p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}}) \right] + {\mathcal {D}}_\text {KL}(q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})|| p({\mathbf {z}}) ), \end{aligned}$$
(1)
where the first and second terms on the right-hand side are regarded as the reconstruction loss and regularization loss, respectively. Owing to graph isomorphism, the calculation of the reconstruction loss necessitates a graph matching procedure that involves comparing an input graph and its probabilistic reconstruction which is computationally expensive. For example, the max-pooling matching algorithm has computational complexity of \({\mathcal {O}}(m^4)\) [28].
To make the learning more efficient, we introduce an approximate graph matching procedure, which aims to alleviate the computational burden for the reconstruction loss. Additionally, we incorporate reinforcement learning and auxiliary property prediction into the training of the VAE to further improve the generation performance. Details regarding the learning objectives utilized in this work are presented in the following subsection.
Approximate graph matching
The reconstruction loss \({\mathbb {E}}_{{\mathbf {z}} \sim q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})} \left[ -\log p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}}) \right] \) involves comparing an original input graph \({\mathcal {G}}=({\mathcal {V}},{\mathcal {E}})\) and its reconstruction by the VAE. In this subsection, we denote the reconstruction of \({\mathcal {G}}\) as a probabilistic graph \(\widetilde{{\mathcal {G}}}=(\widetilde{{\mathcal {V}}},\widetilde{{\mathcal {E}}})\), where \(\widetilde{{\mathbf {v}}}^i \in \widetilde{{\mathcal {V}}}\) and \(\widetilde{{\mathbf {e}}}^{i,j} \in \widetilde{{\mathcal {E}}}\). Because the reconstruction loss must be invariant to graph isomorphism, a graph matching procedure that seeks the best possible matching between the two graphs is needed. To avoid expensive computation, we propose to use approximate graph matching. The main idea is to approximate the distance between \({\mathcal {G}}\) and \(\widetilde{{\mathcal {G}}}\) by comparing the numbers of atom types, bond types, atom-bond pair types, and atom-bond-atom pair types.
Assuming that each edge vector is represented as a four-dimensional vector as \({\mathbf {e}}^{i,j}=(e^{i,j(\texttt {single})},e^{i,j(\texttt {double})},e^{i,j(\texttt {triple})},e^{i,j(\texttt {none})})\), the reconstruction loss is approximated as follows:
$$\begin{aligned}&{\mathbb {E}}_{{\mathbf {z}} \sim q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})} \left[ -\log p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}}) \right] \simeq \left\| \left( \sum _i{{\mathbf {v}}^i}\right) - \left( \sum _i{\widetilde{{\mathbf {v}}}^i}\right) \right\| ^2 \nonumber \\&\quad + \left\| \left( \sum _{i,j}{{\mathbf {e}}^{i,j}}\right) - \left( \sum _{i,j}{\tilde{{\mathbf {e}}}^{i,j}}\right) \right\| ^2\nonumber \\&\quad + \sum _{\texttt {b}\in {\{\texttt {single},\texttt {double},\texttt {triple}\}}} \left\| \left( \sum _{i,j}{e^{i,j (\texttt {b})}{\mathbf {v}}^i }\right) - \left( \sum _{i,j}{{\tilde{e}}^{i,j (\texttt {b})}\mathbf {{\widetilde{v}}}^i }\right) \right\| ^2\nonumber \\&\quad + \sum _{\texttt {b}\in {\{\texttt {single},\texttt {double},\texttt {triple}\}}} \left\| \left( \sum _{i,j}{e^{i,j (\texttt {b})}{\mathbf {v}}^i {{\mathbf {v}}^{jT}}}\right) - \left( \sum _{i,j}{ {\tilde{e}}^{i,j (\texttt {b})}\mathbf {{\widetilde{v}}}^i {\mathbf {{\widetilde{v}}}^{jT}}}\right) \right\| ^2, \end{aligned}$$
(2)
where \({{\mathbf {v}}}^i \in {{\mathcal {V}}}\), \({{\mathbf {e}}}^{i,j} \in {{\mathcal {E}}}\), \(\widetilde{{\mathbf {v}}}^i \in \widetilde{{\mathcal {V}}}\), and \(\widetilde{{\mathbf {e}}}^{i,j} \in \widetilde{{\mathcal {E}}}\). When calculating the approximated reconstruction loss, we discard the non-atom and non-bond types from the vectors. The first, second, third, and fourth terms on the right-hand side correspond to the comparison of the numbers of atom types, bond types, atom-bond pair types, and atom-bond-atom pair types, respectively. They are independent of node ordering because they summate over the nodes in a graph, and are thus invariant to graph isomorphism. All the operations in the above equation are differentiable.
Reinforcement learning
We further improve the VAE via reinforcement learning with the aim of generating chemically valid molecules. We adopt a deterministic policy gradient framework, in which the decoder of the VAE is regarded as a policy network that takes the two vectors \({\mathbf {z}}\) and \({\mathbf {y}}\) as state inputs. It outputs a probabilistic graph as an action from the state. The reward for the action is the chemical validity of the probabilistic graph, which is evaluated using an external reward function R. The reward function returns 1 if the probabilistic graph can be decoded into a chemically valid molecular graph and 0 otherwise. In this work, the chemical validity of a molecular graph \({\mathcal {G}}\) is evaluated via a sanitization check. With the reward function, the policy network learns how to generate a probabilistic graph that fulfills the expected reward of 1.
We wish to optimize the molecular graph generation of the VAE for maximizing the external reward function R. Because the reward function is non-differentiable, it cannot be incorporated directly into the learning procedure. We build a reward network r that approximates the reward function R. The reward network is modeled as an MPNN with a sigmoid output. It takes a probabilistic graph as input and predicts its actual reward value. The reward network can backpropagate the VAE. We train the VAE to generate a probabilistic graph towards maximizing the output of the reward network.
For the learning objective, we derive two additional losses \({\mathcal {L}}_\text {RL} (\phi ,\theta )\) and \({{\mathcal {L}}_\text {RL} (r)}\) as:
$$\begin{aligned} {{\mathcal {L}}_\text {RL} (\phi ,\theta )} & = {} {\mathbb {E}}_{{\mathbf {z}} \sim q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})}{\mathbb {E}}_{{\mathcal {G}} \sim p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})}{[-\log r({\mathcal {G}}) ]} \nonumber \\ & \quad + {\mathbb {E}}_{{\mathbf {z}} \sim p({\mathbf {z}}), {\mathbf {y}} \sim p({\mathbf {y}})}{\mathbb {E}}_{{\mathcal {G}} \sim p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})}{[-\log r({\mathcal {G}})]}; \end{aligned}$$
(3)
$$\begin{aligned} {{\mathcal {L}}_\text {RL} (r)} &= {} {-R({\mathcal {G}})\log r({\mathcal {G}})-(1-R({\mathcal {G}}))\log (1-r({\mathcal {G}}))}\nonumber \\ & \quad +{\mathbb {E}}_{{\mathbf {z}} \sim q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})}{\mathbb {E}}_{{\mathcal {G}} \sim p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})}{[- R({\mathcal {G}}) \log r({\mathcal {G}}) - (1-R({\mathcal {G}})) \log (1-r({\mathcal {G}})) ]}\nonumber \\ & \quad + {\mathbb {E}}_{{\mathbf {z}} \sim p({\mathbf {z}}), {\mathbf {y}} \sim p({\mathbf {y}})}{\mathbb {E}}_{{\mathcal {G}} \sim p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})}{[- R({\mathcal {G}}) \log r({\mathcal {G}}) - (1-R({\mathcal {G}})) \log (1-r({\mathcal {G}})) ]}. \end{aligned}$$
(4)
The VAE is trained for minimizing \({\mathcal {L}}_\text {RL} (\phi ,\theta )\), while the reward network r is trained to minimize \({{\mathcal {L}}_\text {RL} (r)}\).
It should be noted that we can impose extra domain-specific constraints regarding structures or properties on the external reward function R for constrained generation. For example, we can define a blacklist of undesired substructures and make the value of the reward function 0 when its input contains any substructure in the blacklist. This prevents generated molecules from having undesired substructures.
Auxiliary property prediction
Augmenting a generative model with side information has known to improve the quality of generated samples as well as the stability of model training [19, 35]. We incorporate auxiliary property prediction into the VAE learning to enable generating probabilistic graphs that correspond to desired properties as well as to diversify the generated outcomes. We build a predictor network as an MPNN with l linear outputs. It learns from the training dataset to predict the property vector \({\mathbf {y}}\) of a given graph. Because the predictor network can backpropagate the VAE, we train the VAE to generate a probabilistic graph whose corresponding \({\mathbf {y}}\) is to be reconstructed by the predictor network.
For learning with auxiliary property prediction, we derive two additional losses \({\mathcal {L}}_\text {Y} (\phi ,\theta )\) and \({{\mathcal {L}}_\text {Y} (f)}\) as:
$$\begin{aligned} {\mathcal {L}}_\text {Y} (\phi ,\theta ) & = {} {\mathbb {E}}_{{\mathbf {z}} \sim q_\phi ({\mathbf {z}}|{\mathcal {G}},{\mathbf {y}})}{\mathbb {E}}_{{\mathcal {G}} \sim p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})}{ \left[ R({\mathcal {G}}) \cdot ||{\mathbf {y}}-f({\mathcal {G}})||^2 \right] }\nonumber \\ & \quad + {\mathbb {E}}_{{\mathbf {z}} \sim p({\mathbf {z}}), {\mathbf {y}} \sim p({\mathbf {y}})}{\mathbb {E}}_{{\mathcal {G}} \sim p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})}{\left[ R({\mathcal {G}}) \cdot ||{\mathbf {y}}-f({\mathcal {G}})||^2 \right] }; \end{aligned}$$
(5)
$$\begin{aligned} {\mathcal {L}}_\text {Y} (f)= & {} ||{\mathbf {y}}-f({\mathcal {G}})||^2. \end{aligned}$$
(6)
Only the probabilistic graphs that are deemed valid by the external reward function R are incorporated into the first loss \({\mathcal {L}}_\text {Y} (\phi ,\theta )\). The VAE is trained to minimize \({\mathcal {L}}_\text {Y} (\phi ,\theta )\). Simultaneously, the predictor network f is trained to minimize \({{\mathcal {L}}_\text {Y} (f)}\).
Learning from data
The proposed model is composed of four main components: the encoder network \(q_\phi \), decoder network \(p_\theta \), reward network r, and predictor network f. The full learning objective combines the vanilla objective of VAE (1) along with objectives for approximate graph matching (2), reinforcement learning (3–4), and auxiliary property prediction (5–6). The objective functions for the VAE part and the other part are \({\mathcal {J}}_1\) and \({\mathcal {J}}_2\), respectively, given as:
$$\begin{aligned} {\mathcal {J}}_1 (\phi ,\theta )= & {} \sum _{({\mathcal {G}}_t,{\mathbf {y}}_t)\sim {p}_\text {data}}\left[ {\mathcal {L}}_\text {VAE} (\phi ,\theta )+ \beta _1 \cdot {\mathcal {L}}_\text {RL} (\phi ,\theta )+ \beta _2 \cdot {\mathcal {L}}_\text {Y} (\phi ,\theta )\right] ; \end{aligned}$$
(7)
$$\begin{aligned} {\mathcal {J}}_2 (r,f)= & {} \sum _{({\mathcal {G}}_t,{\mathbf {y}}_t)\sim {p}_\text {data}}\left[ \beta _1 \cdot {\mathcal {L}}_\text {RL} (r)+ \beta _2 \cdot {\mathcal {L}}_\text {Y} (f)\right] , \end{aligned}$$
(8)
where \(\beta _1\) and \(\beta _2\) are hyperparameters that control the trade-off between different learning objectives.
Given an empirical data distribution \(p_\text {data}({\mathcal {G}},{\mathbf {y}})\), we train the entire model for minimizing the two objective functions \({\mathcal {J}}_1(\phi , \theta )\) and \({\mathcal {J}}_2(r,f)\) simultaneously. For each iteration, a training batch X is sampled from the data distribution. The VAE parameters \(\phi \) and \(\theta \) are updated via gradient descent of \({\mathcal {J}}_1(\phi , \theta )\) on X, and the reward network r and predictor network f are updated via gradient descent of \({\mathcal {J}}_2(r,f)\). Algorithm 1 presents the pseudocode of the learning procedure.
Molecular graph generation
Once the model is trained, we use the decoder \(p_\theta ({\mathcal {G}}|{\mathbf {z}},{\mathbf {y}})\) to generate molecular graphs. For unconditional generation, \({\mathbf {y}}_*\) is sampled from its prior distribution \(p({\mathbf {y}})\). To conditionally generate molecular graphs, \({\mathbf {y}}_*\) is sampled from a conditional distribution. For example, if the target condition is given as \(y^k=\tau \), then \({\mathbf {y}}_* \sim p({\mathbf {y}}|y^k=\tau )\). We sample \({\mathbf {z}}_*\) from \(p({\mathbf {z}})\). Given \({\mathbf {y}}_*\) and \({\mathbf {z}}_*\), the decoder returns a probabilistic output, which is decoded based on node-wise and edge-wise argmax to obtain a molecular graph \({\mathcal {G}}_*\) as
$$\begin{aligned} {\mathcal {G}}_* =\underset{{\mathcal {G}}}{\text {argmax }}p_\theta ({\mathcal {G}}|{\mathbf {z}}={\mathbf {z}}_*,{\mathbf {y}}={\mathbf {y}}_*). \end{aligned}$$
(9)
We use a simple decoding method to discretize probabilistic outputs for deriving molecular graphs. Some studies reported that post-processing of probabilistic outputs based on such methods as maximum spanning tree [28] and beam search [36] can improve the validity of the generated molecular graphs.