Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LBANN Implementation of DPMNN on CSD-10K Data #2361

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions applications/FLASK/MPNN/MPN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import lbann
from lbann.modules import Module, ChannelwiseFullyConnectedModule


class MPNEncoder(Module):
""" """

global_count = 0

def __init__(
self,
atom_fdim,
bond_fdim,
hidden_size,
activation_func,
max_atoms,
bias=False,
depth=3,
name=None,
):
MPNEncoder.global_count += 1
# For debugging
self.name = name if name else "MPNEncoder_{}".format(MPNEncoder.global_count)

self.atom_fdim = atom_fdim
self.bond_fdim = bond_fdim
self.max_atoms = max_atoms
self.hidden_size = hidden_size
self.bias = bias
self.depth = depth
self.activation_func = activation_func

# Channelwise fully connected layer: (*, *, bond_fdim) -> (*, *, hidden_size)
self.W_i = ChannelwiseFullyConnectedModule(
self.hidden_size,
bias=self.bias,
activation=self.activation_func,
name=self.name + "W_i",
)

# Channelwise fully connected layer (*, *, hidden_size) -> (*, *, hidden_size))
self.W_h = ChannelwiseFullyConnectedModule(
self.hidden_size,
bias=self.bias,
activation=self.activation_func,
name=self.name + "W_h",
)
# Channelwise fully connected layer (*, *, atom_fdim + hidden_size) -> (*, *, hidden_size))
self.W_o = ChannelwiseFullyConnectedModule(
self.hidden_size,
bias=True,
activation=self.activation_func,
name=self.name + "W_o",
)

def message(
self,
bond_features,
bond2atom_mapping,
atom2bond_sources_mapping,
atom2bond_target_mapping,
bond2revbond_mapping,
):
""" """
messages = self.W_i(bond_features)
for depth in range(self.depth - 1):
nei_message = lbann.Gather(messages, atom2bond_target_mapping, axis=0)

a_message = lbann.Scatter(
nei_message,
atom2bond_sources_mapping,
dims=[self.max_atoms, self.hidden_size],
axis=0,
)

bond_message = lbann.Gather(
a_message,
bond2atom_mapping,
axis=0,
name=self.name + f"_bond_messages_{depth}",
)
rev_message = lbann.Gather(
messages,
bond2revbond_mapping,
axis=0,
name=self.name + f"_rev_bond_messages_{depth}",
)

messages = lbann.Subtract(bond_message, rev_message)
messages = self.W_h(messages)

return messages

def aggregate(self, atom_messages, bond_messages, bond2atom_mapping):
""" """
a_messages = lbann.Scatter(
bond_messages,
bond2atom_mapping,
axis=0,
dims=[self.max_atoms, self.hidden_size],
)

atoms_hidden = lbann.Concatenation(
[atom_messages, a_messages], axis=1, name=self.name + "atom_hidden_concat"
)
return self.W_o(atoms_hidden)

def readout(self, atom_encoded_features, graph_mask, num_atoms):
""" """
mol_encoding = lbann.Scatter(
atom_encoded_features,
graph_mask,
name=self.name + "graph_scatter",
axis=0,
dims=[1, self.hidden_size],
)
num_atoms = lbann.Reshape(num_atoms, dims=[1, 1])

mol_encoding = lbann.Divide(
mol_encoding,
lbann.Tessellate(
num_atoms,
dims=[1, self.hidden_size],
name=self.name + "expand_num_nodes",
),
name=self.name + "_reduce",
)
return mol_encoding

def forward(
self,
atom_input_features,
bond_input_features,
atom2bond_sources_mapping,
atom2bond_target_mapping,
bond2atom_mapping,
bond2revbond_mapping,
graph_mask,
num_atoms,
):
""" """
bond_messages = self.message(
bond_input_features,
bond2atom_mapping,
atom2bond_sources_mapping,
atom2bond_target_mapping,
bond2revbond_mapping,
)

atom_encoded_features = self.aggregate(
atom_input_features, bond_messages, bond2atom_mapping
)

readout = self.readout(atom_encoded_features, graph_mask, num_atoms)
return readout
79 changes: 79 additions & 0 deletions applications/FLASK/MPNN/PrepareDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from config import DATASET_CONFIG
from tqdm import tqdm
import numpy as np
from chemprop.args import TrainArgs
from chemprop.features import reset_featurization_parameters
from chemprop.data import MoleculeDataLoader, utils
import os.path as osp
import pickle


def retrieve_dual_mapping(atom2bond, ascope):
atom_bond_mapping = []
for a_start, a_size in enumerate:
_a2b = atom2bond.narrow(0, a_start, a_size)
for row, possible_bonds in enumerate(_a2b):
for bond in possible_bonds:
ind = bond.item() - 1 # Shift by 1 to account for null nodes
if ind >= 0:
atom_bond_mapping.append([row, ind])
return np.array(atom_bond_mapping)


def PrepareDataset(save_file_name, target_file):
data_file = osp.join(DATASET_CONFIG["DATA_DIR"], target_file)

arguments = [
"--data_path",
data_file,
"--dataset_type",
"regression",
"--save_dir",
"./data/10k_dft_density",
]
args = TrainArgs().parse_args(arguments)
reset_featurization_parameters()
data = utils.get_data(data_file, args=args)
# Need to use the data loader as the featurization happens in the dataloader
# Only use 1 mol as in LBANN we do not do coalesced batching (yet)
dataloader = MoleculeDataLoader(data, batch_size=1)
lbann_data = []
for mol in tqdm(dataloader):
mol_data = {}

mol_data["target"] = mol.targets()[0][0]
mol_data["num_atoms"] = mol.number_of_atoms[0][0]
# Multiply by 2 for directional bonds
mol_data["num_bonds"] = mol.number_of_bonds[0][0] * 2

mol_batch = mol.batch_graph()[0]
f_atoms, f_bonds, a2b, b2a, b2revb, ascope, bscope = mol_batch.get_components(
False
)

# shift by 1 as we don't use null nodes as in the ChemProp implementation
mol_data["atom_features"] = f_atoms[1:].numpy()
mol_data["bond_features"] = f_bonds[1:].numpy()
dual_graph_mapping = retrieve_dual_mapping(a2b, ascope)

mol_data['dual_graph_atom2bond_source'] = dual_graph_mapping[:, 0]
mol_data['dual_graph_atom2bond_target'] = dual_graph_mapping[:, 1]

# subtract 1 to shift the indices
mol_data['bond_graph_source'] = b2a[1:].numpy() - 1
mol_data['bond_graph_target'] = b2revb[1:].numpy() - 1

lbann_data.append(mol_data)

save_file = osp.join(DATASET_CONFIG["DATA_DIR"], save_file_name)
with open(save_file, 'wb') as f:
pickle.dump(lbann_data, f)


def main():
PrepareDataset("10k_density_lbann.bin", "10k_dft_density_data.csv")
PrepareDataset("10k_hof_lbann.bin", "10k_dft_hof_data.csv")


if __name__ == "__main__":
main()
38 changes: 38 additions & 0 deletions applications/FLASK/MPNN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# ChemProp on LBANN

## Prepere Dataset (optional)

If not on lbann system or required to regenerate the data file so it is ingestible on LBANN.

### Requirements

```
chemprop
numpy
torch
```

### Generate Data

The data generator is set to read from and write data to the `DATA_DIR` directory in `config.py`. Update that line to read and store
from a custom directory.


Generate the data by calling:


`python PrepareDataset.py
`

## Run the Trainer

### Hyperparameters

The hyperparameters for the model and training algorihms can be set in `config.py`.


### Run the trainer


### Results

21 changes: 21 additions & 0 deletions applications/FLASK/MPNN/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Dataset feature defeaults
# In general, don't change these unless using cusom data - S.Z.

DATASET_CONFIG: dict = {
"MAX_ATOMS": 100, # The number of maximum atoms in CSD dataset
"MAX_BONDS": 224, # The number of maximum bonds in CSD dataset
"ATOM_FEATURES": 133,
"BOND_FEATURES": 147,
"DATA_DIR": "/p/vast1/lbann/datasets/FLASK/CSD10K",
"TARGET_FILE": "10k_dft_density_data.csv" # Change to 10k_dft_hof_data.csv for heat of formation
}

# Hyperamaters used to set up trainer and MPN
# These can be changed freely
HYPERPARAMETERS_CONFIG: dict = {
"HIDDEN_SIZE": 300,
"LR": 0.001,
"BATCH_SIZE": 64,
"EPOCH": 100,
"MPN_DEPTH": 3
}
Loading