From f2b7ccb8a09a57931d90331db40bd0df96749e27 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 23 Oct 2023 09:48:49 -0700 Subject: [PATCH 1/8] Adding lbann impl files for d-MPNN model on 10K-CSD data --- applications/FLASK/MPNN/MPN.py | 102 +++++++++++++++++++++++++++++ applications/FLASK/MPNN/README.md | 0 applications/FLASK/MPNN/config.py | 19 ++++++ applications/FLASK/MPNN/dataset.py | 0 applications/FLASK/MPNN/model.py | 29 ++++++++ applications/FLASK/MPNN/train.py | 0 6 files changed, 150 insertions(+) create mode 100644 applications/FLASK/MPNN/MPN.py create mode 100644 applications/FLASK/MPNN/README.md create mode 100644 applications/FLASK/MPNN/config.py create mode 100644 applications/FLASK/MPNN/dataset.py create mode 100644 applications/FLASK/MPNN/model.py create mode 100644 applications/FLASK/MPNN/train.py diff --git a/applications/FLASK/MPNN/MPN.py b/applications/FLASK/MPNN/MPN.py new file mode 100644 index 00000000000..61409ad884e --- /dev/null +++ b/applications/FLASK/MPNN/MPN.py @@ -0,0 +1,102 @@ +import lbann +from lbann.modules import Module, ChannelwiseFullyConnectedModule + + +class MPNEncoder(Module): + """ + """ + global_count = 0 + + def __init__(self, atom_fdim, bond_fdim, hidden_size, activation_func, bias=False, depth=3, name=None): + + MPNEncoder.global_count += 1 + # For debugging + self.name = (name + if name + else 'MPNEncoder_{}'.format(MPNEncoder.global_count)) + + self.atom_fdim = atom_fdim + self.bond_fdim = bond_fdim + self.hidden_size = hidden_size + self.bias = bias + self.depth = depth + self.activation_func = activation_func + + # Channelwise fully connected layer: (*, *, bond_fdim) -> (*, *, hidden_size) + self.W_i = \ + ChannelwiseFullyConnectedModule(self.hidden_size, + bias=self.bias, + activation=self.activation_func, + name=self.name + "W_i") + + # Channelwise fully connected layer (*, *, hidden_size) -> (*, *, hidden_size)) + self.W_h = \ + ChannelwiseFullyConnectedModule(self.hidden_size, + bias=self.bias, + activation=self.activation_func, + name=self.name + "W_h") + # Channelwise fully connected layer (*, *, atom_fdim + hidden_size) -> (*, *, hidden_size)) + self.W_o = \ + ChannelwiseFullyConnectedModule(self.hidden_size, + bias=True, + activation=self.activation_func, + name=self.name + "W_o") + + + def message(self, bond_features, + bond2atom_mapping, + atom2bond_mapping, + bond2revbond_mapping, + MAX_ATOMS): + """ + """ + messages = self.W_i(bond_features) + for depth in range(self.depth - 1): + a_message = lbann.Scatter(messages, bond2atom_mapping) + bond_message = lbann.Gather(a_message, atom2bond_mapping) + rev_message = lbann.Gather(messages, bond2revbond_mapping) + messages = lbann.SubtractOperator(bond_message, rev_message) + + messages = self.W_h(messages) + return messages + + + def aggregate(self, atom_messages, bond_messages, bond2atom_mapping): + """ + """ + a_messages = lbann.Scatter(bond_messages, bond2atom_mapping) + atoms_hidden = lbann.Concatentate([atom_messages, atom_messages], + dim=0) + return self.W_o(atoms_hidden) + + + def readout(self, atom_encoded_features, graph_mask, num_atoms, max_atoms): + """ + """ + mol_encoding = lbann.Scatter(atom_encoded_features, + graph_mask, + name=self.name + "graph_scatter") + mol_encoding = lbann.DivideOperator(mol_encoding, lbann.Tessallate(num_atoms, + dims=[max_atoms, 1])) + return mol_encoding + + def forward(self, + atom_input_features, + bond_input_features, + atom2bond_mapping, + bond2atom_mapping, + bond2revbond_mapping, + graph_mask, num_atoms, max_atoms): + """ + """ + bond_messages = self.message(bond_input_features, + bond2atom_mapping, + atom2bond_mapping, + bond2revbond_mapping, + max_atoms) + + atom_encoded_features = self.aggregate(atom_input_features, bond_messages, + bond2atom_mapping) + + readout = self.readout(atom_encoded_features, graph_mask, num_atoms, max_atoms) + return readout diff --git a/applications/FLASK/MPNN/README.md b/applications/FLASK/MPNN/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/applications/FLASK/MPNN/config.py b/applications/FLASK/MPNN/config.py new file mode 100644 index 00000000000..803fef6d7fa --- /dev/null +++ b/applications/FLASK/MPNN/config.py @@ -0,0 +1,19 @@ +# Dataset feature defeaults +# In general, don't change these unless using cusom data - S.Z. + +DATASET_CONFIG = { + "MAX_ATOMS": 100, # The number of maximum atoms in CSD dataset + "MAX_BONDS": 224, # The number of maximum bonds in CSD dataset + "ATOM_FEATURES": 133, + "BOND_FEATURES" : 147 +} + +# Hyperamaters used to set up trainer and MPN +# These can be changed freely +HYPERPARAMETERS_CONFIG = { + "HIDDEN_SIZE":300, + "LR": 0.001, + "BATCH_SIZE" : 128, + "EPOCH" : 50, + "MPN_DEPTH": 3 +} diff --git a/applications/FLASK/MPNN/dataset.py b/applications/FLASK/MPNN/dataset.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py new file mode 100644 index 00000000000..468736a0a9f --- /dev/null +++ b/applications/FLASK/MPNN/model.py @@ -0,0 +1,29 @@ +import lbann +from config import DATASET_CONFIG +from MPN import MPNEncoder + + +def graph_splitter(_input): + """ + """ + split_indices = [] + start_index = 0 + + max_atoms = DATASET_CONFIG['MAX_ATOMS'], + max_bonds = DATASET_CONFIG['MAX_BONDS'], + atom_features = DATASET_CONFIG['ATOM_FEATURES'] + bond_features = DATASET_CONFIG['BOND_FEATURES'] + + f_atom_size = max_atoms * atom_features + f_bond_size = max_bonds * bond_features + + + + + return f_atoms, f_bonds, atom2bond_mapping, bond2atom_mapping,\ + bond2bond_mapping, graph_mask, num_atoms + + +def make_model(): + + diff --git a/applications/FLASK/MPNN/train.py b/applications/FLASK/MPNN/train.py new file mode 100644 index 00000000000..e69de29bb2d From 6a9eb867276383788a5b571ab8d7d4afdc3dfc92 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 23 Oct 2023 10:42:20 -0700 Subject: [PATCH 2/8] - Fix MPN model scatter gathers dimensions - Add graph data splitting --- applications/FLASK/MPNN/MPN.py | 60 ++++++++++++++++++++++++-------- applications/FLASK/MPNN/model.py | 43 +++++++++++++++++++---- 2 files changed, 81 insertions(+), 22 deletions(-) diff --git a/applications/FLASK/MPNN/MPN.py b/applications/FLASK/MPNN/MPN.py index 61409ad884e..fc2c59dc15e 100644 --- a/applications/FLASK/MPNN/MPN.py +++ b/applications/FLASK/MPNN/MPN.py @@ -45,45 +45,72 @@ def __init__(self, atom_fdim, bond_fdim, hidden_size, activation_func, bias=Fals def message(self, bond_features, bond2atom_mapping, - atom2bond_mapping, + atom2bond_sources_mapping, + atom2bond_target_mapping, bond2revbond_mapping, MAX_ATOMS): """ """ messages = self.W_i(bond_features) for depth in range(self.depth - 1): - a_message = lbann.Scatter(messages, bond2atom_mapping) - bond_message = lbann.Gather(a_message, atom2bond_mapping) - rev_message = lbann.Gather(messages, bond2revbond_mapping) - messages = lbann.SubtractOperator(bond_message, rev_message) + nei_message = lbann.Gather(messages, + atom2bond_sources_mapping, + axis=0) + + a_message = lbann.Scatter(nei_message, + atom2bond_target_mapping, + dims=[MAX_ATOMS, self.hidden_size], + axis=0) + bond_message = lbann.Gather(a_message, + bond2atom_mapping) + rev_message = lbann.Gather(messages, + bond2revbond_mapping) + + messages = lbann.SubtractOperator(bond_message, rev_message) messages = self.W_h(messages) + return messages - def aggregate(self, atom_messages, bond_messages, bond2atom_mapping): + def aggregate(self, + atom_messages, + bond_messages, + bond2atom_mapping, + NUM_ATOMS): """ """ - a_messages = lbann.Scatter(bond_messages, bond2atom_mapping) - atoms_hidden = lbann.Concatentate([atom_messages, atom_messages], + a_messages = lbann.Scatter(bond_messages, + bond2atom_mapping, + axis=0, + dims=[NUM_ATOMS, self.hidden_size]) + + atoms_hidden = lbann.Concatentate([atom_messages, a_messages], dim=0) return self.W_o(atoms_hidden) - def readout(self, atom_encoded_features, graph_mask, num_atoms, max_atoms): + def readout(self, + atom_encoded_features, + graph_mask, + num_atoms, + max_atoms): """ """ mol_encoding = lbann.Scatter(atom_encoded_features, graph_mask, name=self.name + "graph_scatter") - mol_encoding = lbann.DivideOperator(mol_encoding, lbann.Tessallate(num_atoms, - dims=[max_atoms, 1])) + mol_encoding = lbann.DivideOperator(mol_encoding, + lbann.Tessallate(num_atoms, + dims=[max_atoms, 1])) return mol_encoding + def forward(self, atom_input_features, bond_input_features, - atom2bond_mapping, + atom2bond_sources_mapping, + atom2bond_target_mapping, bond2atom_mapping, bond2revbond_mapping, graph_mask, num_atoms, max_atoms): @@ -91,12 +118,15 @@ def forward(self, """ bond_messages = self.message(bond_input_features, bond2atom_mapping, - atom2bond_mapping, + atom2bond_sources_mapping, + atom2bond_target_mapping, bond2revbond_mapping, max_atoms) - atom_encoded_features = self.aggregate(atom_input_features, bond_messages, - bond2atom_mapping) + atom_encoded_features = self.aggregate(atom_input_features, + bond_messages, + bond2atom_mapping, + num_atoms) readout = self.readout(atom_encoded_features, graph_mask, num_atoms, max_atoms) return readout diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py index 468736a0a9f..8a1fab0b7d7 100644 --- a/applications/FLASK/MPNN/model.py +++ b/applications/FLASK/MPNN/model.py @@ -6,22 +6,51 @@ def graph_splitter(_input): """ """ - split_indices = [] - start_index = 0 + split_indices = [0] + max_atoms = DATASET_CONFIG['MAX_ATOMS'], max_bonds = DATASET_CONFIG['MAX_BONDS'], atom_features = DATASET_CONFIG['ATOM_FEATURES'] bond_features = DATASET_CONFIG['BOND_FEATURES'] + indices_length = max_bonds + f_atom_size = max_atoms * atom_features + split_indices.append(f_atom_size) + f_bond_size = max_bonds * bond_features - + split_indices.append(f_bond_size) + + split_indices.append(max_bonds) + split_indices.append(max_bonds) + split_indices.append(max_bonds) + split_indices.append(max_bonds) + + split_indices.append(max_atoms) + split_indices.append(1) + + for i in range(1, len(split_indices)): + split_indices[i] = split_indices[i] + split_indices[i - 1] - - - return f_atoms, f_bonds, atom2bond_mapping, bond2atom_mapping,\ - bond2bond_mapping, graph_mask, num_atoms + graph_input = lbann.Slice(_input, axis=0, slice_points=split_indices) + f_atoms = lbann.Reshape(lbann.Identity(graph_input), dims=[max_atoms, atom_features]) + f_bonds = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds, bond_features]) + atom2bond_source_mapping = lbann.Reshape(lbann.Identity(graph_input), + dims=[max_bonds]) + atom2bond_target_mapping = lbann.Reshape(lbann.Identity(graph_input), + dims=[max_bonds]) + bond2atom_mapping = lbann.Reshape(lbann.Identity(graph_input), + dims=[max_bonds]) + bond2bond_mapping = lbann.Reshape(lbann.Identity(graph_input), + dims=[max_bonds]) + graph_mask = lbann.Reshape(lbann.Identity(graph_input), + dims=[max_atoms]) + num_atoms = lbann.Reshape(lbann.Identity(graph_input), + dims=[1]) + + return f_atoms, f_bonds, atom2bond_source_mapping, atom2bond_target_mapping, \ + bond2atom_mapping, bond2bond_mapping, graph_mask, num_atoms def make_model(): From 3262bb802edcea487da891662fd6dc0f2eecc9c9 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 23 Oct 2023 11:03:51 -0700 Subject: [PATCH 3/8] Add readout layers --- applications/FLASK/MPNN/MPN.py | 6 +++- applications/FLASK/MPNN/model.py | 58 ++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/applications/FLASK/MPNN/MPN.py b/applications/FLASK/MPNN/MPN.py index fc2c59dc15e..00e7d1c85f3 100644 --- a/applications/FLASK/MPNN/MPN.py +++ b/applications/FLASK/MPNN/MPN.py @@ -7,7 +7,11 @@ class MPNEncoder(Module): """ global_count = 0 - def __init__(self, atom_fdim, bond_fdim, hidden_size, activation_func, bias=False, depth=3, name=None): + def __init__(self, + atom_fdim, + bond_fdim, + hidden_size, + activation_func, bias=False, depth=3, name=None): MPNEncoder.global_count += 1 # For debugging diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py index 8a1fab0b7d7..14e31705bc7 100644 --- a/applications/FLASK/MPNN/model.py +++ b/applications/FLASK/MPNN/model.py @@ -1,5 +1,5 @@ import lbann -from config import DATASET_CONFIG +from config import DATASET_CONFIG, HYPERPARAMETERS_CONFIG from MPN import MPNEncoder @@ -29,6 +29,8 @@ def graph_splitter(_input): split_indices.append(max_atoms) split_indices.append(1) + split_indices.append(1) + for i in range(1, len(split_indices)): split_indices[i] = split_indices[i] + split_indices[i - 1] @@ -40,19 +42,55 @@ def graph_splitter(_input): dims=[max_bonds]) atom2bond_target_mapping = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds]) - bond2atom_mapping = lbann.Reshape(lbann.Identity(graph_input), - dims=[max_bonds]) - bond2bond_mapping = lbann.Reshape(lbann.Identity(graph_input), - dims=[max_bonds]) - graph_mask = lbann.Reshape(lbann.Identity(graph_input), - dims=[max_atoms]) - num_atoms = lbann.Reshape(lbann.Identity(graph_input), - dims=[1]) + bond2atom_mapping = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds]) + bond2bond_mapping = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds]) + graph_mask = lbann.Reshape(lbann.Identity(graph_input), dims=[max_atoms]) + num_atoms = lbann.Reshape(lbann.Identity(graph_input), dims=[1]) + target = lbann.Reshape(lbann.Identity(graph_input), dims=[1]) return f_atoms, f_bonds, atom2bond_source_mapping, atom2bond_target_mapping, \ - bond2atom_mapping, bond2bond_mapping, graph_mask, num_atoms + bond2atom_mapping, bond2bond_mapping, graph_mask, num_atoms, target def make_model(): + _input = lbann.Input(data_field='samples') + + f_atoms, f_bonds, atom2bond_source_mapping, atom2bond_target_mapping, \ + bond2atom_mapping, bond2bond_mapping, graph_mask, num_atoms, target = graph_splitter(_input) + + encoder = MPNEncoder(atom_fdim=DATASET_CONFIG['ATOM_FEATURES'], + bond_fdim=DATASET_CONFIG['BOND_FEATURES'], + hidden_size=HYPERPARAMETERS_CONFIG['HIDDEN_SIZE'], + activation_func=lbann.Relu) + + encoded_vec = encoder(f_atoms, + f_bonds, + atom2bond_source_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2bond_mapping, + graph_mask, + num_atoms) + + # Readout layers + x = lbann.FullyConnected(encoded_vec, num_neurons=HYPERPARAMETERS_CONFIG['HIDDEN_SIZE'], + name="READOUT_Linear_1") + x = lbann.Relu(x, name="READOUT_Activation_1") + + x = lbann.FullyConnected(x, num_neurons=1, + name="READOUT_output") + + loss = lbann.MeanSquaredError(x, target) + + layers = lbann.traverse_layer_graph(_input) + training_output = lbann.CallbackPrint(interval=1, print_global_stat_only=False) + gpu_usage = lbann.CallbackGPUMemoryUsage() + timer = lbann.CallbackTimer() + callbacks = [training_output, gpu_usage, timer] + model = lbann.Model(HYPERPARAMETERS_CONFIG['EPOCH'], + layers=layers, + objective_function=loss, + callbacks=callbacks) + return model From ccf12ccb2c42a1e22afc458bfda3ef44d58bb2ae Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 23 Oct 2023 16:41:34 -0700 Subject: [PATCH 4/8] Added training and datareader funcs --- applications/FLASK/MPNN/config.py | 2 +- applications/FLASK/MPNN/model.py | 21 +++++++++++++++++++++ applications/FLASK/MPNN/train.py | 23 +++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/applications/FLASK/MPNN/config.py b/applications/FLASK/MPNN/config.py index 803fef6d7fa..45ff4db4c30 100644 --- a/applications/FLASK/MPNN/config.py +++ b/applications/FLASK/MPNN/config.py @@ -10,7 +10,7 @@ # Hyperamaters used to set up trainer and MPN # These can be changed freely -HYPERPARAMETERS_CONFIG = { +HYPERPARAMETERS_CONFIG: dict = { "HIDDEN_SIZE":300, "LR": 0.001, "BATCH_SIZE" : 128, diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py index 14e31705bc7..d3467137a27 100644 --- a/applications/FLASK/MPNN/model.py +++ b/applications/FLASK/MPNN/model.py @@ -1,6 +1,7 @@ import lbann from config import DATASET_CONFIG, HYPERPARAMETERS_CONFIG from MPN import MPNEncoder +import os.path as osp def graph_splitter(_input): @@ -94,3 +95,23 @@ def make_model(): callbacks=callbacks) return model +def make_data_reader(classname='dataset', + sample='get_sample_func', + num_samples='num_samples_func', + sample_dims='sample_dims_func'): + data_dir = osp.dirname(osp.realpath(__file__)) + reader = lbann.reader_pb2.DataReader() + + for role in ['train', 'validation', 'test']: + _reader = reader.reader.add() + _reader.name = 'python' + _reader.role = role + _reader.shuffle = True + _reader.fraction_of_data_to_use = 1.0 + _reader.python.module = classname + _reader.python.module_dir = data_dir + _reader.python.sample_function = f"{role}_{sample}" + _reader.python.num_samples_function = f"{role}_{num_samples}" + _reader.python.sample_dims_function = f"{role}_{sample_dims}" + + return reader \ No newline at end of file diff --git a/applications/FLASK/MPNN/train.py b/applications/FLASK/MPNN/train.py index e69de29bb2d..db9fd5a472e 100644 --- a/applications/FLASK/MPNN/train.py +++ b/applications/FLASK/MPNN/train.py @@ -0,0 +1,23 @@ +import lbann +import lbann.contrib.launcher +import lbann.contrib.args +from config import HYPERPARAMETERS_CONFIG +from model import make_model, make_data_reader +import argparse + + +desc = " Training a MPNN Model using LBANN" +parser = argparse.ArgumentParser(description=desc) + +args = parser.parse_args() +kwargs = lbann.contrib.args.get_scheduler_kwargs(args) +job_name = args.job_name + +model = make_model() +data_reader = make_data_reader() +optimizer = lbann.SGD(learn_rate=HYPERPARAMETERS_CONFIG["LR"]) +trainer = lbann.Trainer(mini_batch_size=HYPERPARAMETERS_CONFIG["BATCH_SIZE"]) + +lbann.contrib.launcher.run( + trainer, model, data_reader, optimizer, job_name=job_name, **kwargs +) From 5daf734c1d0e2b595fb9fdabd84560372b71ead6 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 23 Oct 2023 16:54:38 -0700 Subject: [PATCH 5/8] Add test time output dump callback --- applications/FLASK/MPNN/MPN.py | 258 ++++++++++++++--------------- applications/FLASK/MPNN/config.py | 2 +- applications/FLASK/MPNN/dataset.py | 2 + applications/FLASK/MPNN/model.py | 12 +- 4 files changed, 139 insertions(+), 135 deletions(-) diff --git a/applications/FLASK/MPNN/MPN.py b/applications/FLASK/MPNN/MPN.py index 00e7d1c85f3..9a52f0dc33d 100644 --- a/applications/FLASK/MPNN/MPN.py +++ b/applications/FLASK/MPNN/MPN.py @@ -3,134 +3,130 @@ class MPNEncoder(Module): - """ - """ - global_count = 0 - - def __init__(self, - atom_fdim, - bond_fdim, - hidden_size, - activation_func, bias=False, depth=3, name=None): - - MPNEncoder.global_count += 1 - # For debugging - self.name = (name - if name - else 'MPNEncoder_{}'.format(MPNEncoder.global_count)) - - self.atom_fdim = atom_fdim - self.bond_fdim = bond_fdim - self.hidden_size = hidden_size - self.bias = bias - self.depth = depth - self.activation_func = activation_func - - # Channelwise fully connected layer: (*, *, bond_fdim) -> (*, *, hidden_size) - self.W_i = \ - ChannelwiseFullyConnectedModule(self.hidden_size, - bias=self.bias, - activation=self.activation_func, - name=self.name + "W_i") - - # Channelwise fully connected layer (*, *, hidden_size) -> (*, *, hidden_size)) - self.W_h = \ - ChannelwiseFullyConnectedModule(self.hidden_size, - bias=self.bias, - activation=self.activation_func, - name=self.name + "W_h") - # Channelwise fully connected layer (*, *, atom_fdim + hidden_size) -> (*, *, hidden_size)) - self.W_o = \ - ChannelwiseFullyConnectedModule(self.hidden_size, - bias=True, - activation=self.activation_func, - name=self.name + "W_o") - - - def message(self, bond_features, - bond2atom_mapping, - atom2bond_sources_mapping, - atom2bond_target_mapping, - bond2revbond_mapping, - MAX_ATOMS): - """ - """ - messages = self.W_i(bond_features) - for depth in range(self.depth - 1): - nei_message = lbann.Gather(messages, - atom2bond_sources_mapping, - axis=0) - - a_message = lbann.Scatter(nei_message, - atom2bond_target_mapping, - dims=[MAX_ATOMS, self.hidden_size], - axis=0) - - bond_message = lbann.Gather(a_message, - bond2atom_mapping) - rev_message = lbann.Gather(messages, - bond2revbond_mapping) - - messages = lbann.SubtractOperator(bond_message, rev_message) - messages = self.W_h(messages) - - return messages - - - def aggregate(self, - atom_messages, - bond_messages, - bond2atom_mapping, - NUM_ATOMS): - """ - """ - a_messages = lbann.Scatter(bond_messages, - bond2atom_mapping, - axis=0, - dims=[NUM_ATOMS, self.hidden_size]) - - atoms_hidden = lbann.Concatentate([atom_messages, a_messages], - dim=0) - return self.W_o(atoms_hidden) - - - def readout(self, - atom_encoded_features, - graph_mask, - num_atoms, - max_atoms): - """ - """ - mol_encoding = lbann.Scatter(atom_encoded_features, - graph_mask, - name=self.name + "graph_scatter") - mol_encoding = lbann.DivideOperator(mol_encoding, - lbann.Tessallate(num_atoms, - dims=[max_atoms, 1])) - return mol_encoding - - - def forward(self, - atom_input_features, - bond_input_features, - atom2bond_sources_mapping, - atom2bond_target_mapping, - bond2atom_mapping, - bond2revbond_mapping, - graph_mask, num_atoms, max_atoms): - """ - """ - bond_messages = self.message(bond_input_features, - bond2atom_mapping, - atom2bond_sources_mapping, - atom2bond_target_mapping, - bond2revbond_mapping, - max_atoms) - - atom_encoded_features = self.aggregate(atom_input_features, - bond_messages, - bond2atom_mapping, - num_atoms) - - readout = self.readout(atom_encoded_features, graph_mask, num_atoms, max_atoms) - return readout + """ """ + + global_count = 0 + + def __init__( + self, + atom_fdim, + bond_fdim, + hidden_size, + activation_func, + max_atoms, + bias=False, + depth=3, + name=None, + ): + MPNEncoder.global_count += 1 + # For debugging + self.name = name if name else "MPNEncoder_{}".format(MPNEncoder.global_count) + + self.atom_fdim = atom_fdim + self.bond_fdim = bond_fdim + self.max_atoms = max_atoms + self.hidden_size = hidden_size + self.bias = bias + self.depth = depth + self.activation_func = activation_func + + # Channelwise fully connected layer: (*, *, bond_fdim) -> (*, *, hidden_size) + self.W_i = ChannelwiseFullyConnectedModule( + self.hidden_size, + bias=self.bias, + activation=self.activation_func, + name=self.name + "W_i", + ) + + # Channelwise fully connected layer (*, *, hidden_size) -> (*, *, hidden_size)) + self.W_h = ChannelwiseFullyConnectedModule( + self.hidden_size, + bias=self.bias, + activation=self.activation_func, + name=self.name + "W_h", + ) + # Channelwise fully connected layer (*, *, atom_fdim + hidden_size) -> (*, *, hidden_size)) + self.W_o = ChannelwiseFullyConnectedModule( + self.hidden_size, + bias=True, + activation=self.activation_func, + name=self.name + "W_o", + ) + + def message( + self, + bond_features, + bond2atom_mapping, + atom2bond_sources_mapping, + atom2bond_target_mapping, + bond2revbond_mapping, + ): + """ """ + messages = self.W_i(bond_features) + for depth in range(self.depth - 1): + nei_message = lbann.Gather(messages, atom2bond_sources_mapping, axis=0) + + a_message = lbann.Scatter( + nei_message, + atom2bond_target_mapping, + dims=[self.max_atoms, self.hidden_size], + axis=0, + ) + + bond_message = lbann.Gather(a_message, bond2atom_mapping) + rev_message = lbann.Gather(messages, bond2revbond_mapping) + + messages = lbann.SubtractOperator(bond_message, rev_message) + messages = self.W_h(messages) + + return messages + + def aggregate(self, atom_messages, bond_messages, bond2atom_mapping): + """ """ + a_messages = lbann.Scatter( + bond_messages, + bond2atom_mapping, + axis=0, + dims=[self.max_atoms, self.hidden_size], + ) + + atoms_hidden = lbann.Concatentate([atom_messages, a_messages], dim=0) + return self.W_o(atoms_hidden) + + def readout(self, atom_encoded_features, graph_mask, num_atoms): + """ """ + mol_encoding = lbann.Scatter( + atom_encoded_features, graph_mask, name=self.name + "graph_scatter" + ) + mol_encoding = lbann.DivideOperator( + mol_encoding, lbann.Tessallate(num_atoms, dims=[self.max_atoms, 1]) + ) + return mol_encoding + + def forward( + self, + atom_input_features, + bond_input_features, + atom2bond_sources_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2revbond_mapping, + graph_mask, + num_atoms, + max_atoms, + ): + """ """ + bond_messages = self.message( + bond_input_features, + bond2atom_mapping, + atom2bond_sources_mapping, + atom2bond_target_mapping, + bond2revbond_mapping, + ) + + atom_encoded_features = self.aggregate( + atom_input_features, bond_messages, bond2atom_mapping + ) + + readout = self.readout(atom_encoded_features, graph_mask, num_atoms) + return readout diff --git a/applications/FLASK/MPNN/config.py b/applications/FLASK/MPNN/config.py index 45ff4db4c30..301290d28fa 100644 --- a/applications/FLASK/MPNN/config.py +++ b/applications/FLASK/MPNN/config.py @@ -1,7 +1,7 @@ # Dataset feature defeaults # In general, don't change these unless using cusom data - S.Z. -DATASET_CONFIG = { +DATASET_CONFIG: dict = { "MAX_ATOMS": 100, # The number of maximum atoms in CSD dataset "MAX_BONDS": 224, # The number of maximum bonds in CSD dataset "ATOM_FEATURES": 133, diff --git a/applications/FLASK/MPNN/dataset.py b/applications/FLASK/MPNN/dataset.py index e69de29bb2d..e7d00d1b046 100644 --- a/applications/FLASK/MPNN/dataset.py +++ b/applications/FLASK/MPNN/dataset.py @@ -0,0 +1,2 @@ +import pickle + diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py index d3467137a27..a7325cfc533 100644 --- a/applications/FLASK/MPNN/model.py +++ b/applications/FLASK/MPNN/model.py @@ -47,7 +47,7 @@ def graph_splitter(_input): bond2bond_mapping = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds]) graph_mask = lbann.Reshape(lbann.Identity(graph_input), dims=[max_atoms]) num_atoms = lbann.Reshape(lbann.Identity(graph_input), dims=[1]) - target = lbann.Reshape(lbann.Identity(graph_input), dims=[1]) + target = lbann.Reshape(lbann.Identity(graph_input), dims=[1], name='TARGET') return f_atoms, f_bonds, atom2bond_source_mapping, atom2bond_target_mapping, \ bond2atom_mapping, bond2bond_mapping, graph_mask, num_atoms, target @@ -61,6 +61,7 @@ def make_model(): encoder = MPNEncoder(atom_fdim=DATASET_CONFIG['ATOM_FEATURES'], bond_fdim=DATASET_CONFIG['BOND_FEATURES'], + max_atoms=DATASET_CONFIG['MAX_ATOMS'], hidden_size=HYPERPARAMETERS_CONFIG['HIDDEN_SIZE'], activation_func=lbann.Relu) @@ -79,22 +80,27 @@ def make_model(): x = lbann.Relu(x, name="READOUT_Activation_1") x = lbann.FullyConnected(x, num_neurons=1, - name="READOUT_output") + name="PREDICTION") loss = lbann.MeanSquaredError(x, target) layers = lbann.traverse_layer_graph(_input) + + # Callbacks training_output = lbann.CallbackPrint(interval=1, print_global_stat_only=False) gpu_usage = lbann.CallbackGPUMemoryUsage() timer = lbann.CallbackTimer() + predictions = lbann.CallbackDumpOutputs(['TARGET', 'PREDICTION'], + role='test') - callbacks = [training_output, gpu_usage, timer] + callbacks = [training_output, gpu_usage, timer, predictions] model = lbann.Model(HYPERPARAMETERS_CONFIG['EPOCH'], layers=layers, objective_function=loss, callbacks=callbacks) return model + def make_data_reader(classname='dataset', sample='get_sample_func', num_samples='num_samples_func', From c72e5db958cd0cc9639ecca588f802293517ae3a Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 30 Oct 2023 01:02:59 -0700 Subject: [PATCH 6/8] Running MPN on Lassen --- applications/FLASK/MPNN/MPN.py | 38 +++++-- applications/FLASK/MPNN/config.py | 10 +- applications/FLASK/MPNN/dataset.py | 103 +++++++++++++++++++ applications/FLASK/MPNN/model.py | 160 ++++++++++++++++++----------- applications/FLASK/MPNN/train.py | 2 + 5 files changed, 237 insertions(+), 76 deletions(-) diff --git a/applications/FLASK/MPNN/MPN.py b/applications/FLASK/MPNN/MPN.py index 9a52f0dc33d..7f1bad9c7ea 100644 --- a/applications/FLASK/MPNN/MPN.py +++ b/applications/FLASK/MPNN/MPN.py @@ -64,19 +64,29 @@ def message( """ """ messages = self.W_i(bond_features) for depth in range(self.depth - 1): - nei_message = lbann.Gather(messages, atom2bond_sources_mapping, axis=0) + nei_message = lbann.Gather(messages, atom2bond_target_mapping, axis=0) a_message = lbann.Scatter( nei_message, - atom2bond_target_mapping, + atom2bond_sources_mapping, dims=[self.max_atoms, self.hidden_size], axis=0, ) - bond_message = lbann.Gather(a_message, bond2atom_mapping) - rev_message = lbann.Gather(messages, bond2revbond_mapping) + bond_message = lbann.Gather( + a_message, + bond2atom_mapping, + axis=0, + name=self.name + f"_bond_messages_{depth}", + ) + rev_message = lbann.Gather( + messages, + bond2revbond_mapping, + axis=0, + name=self.name + f"_rev_bond_messages_{depth}", + ) - messages = lbann.SubtractOperator(bond_message, rev_message) + messages = lbann.Subtract(bond_message, rev_message) messages = self.W_h(messages) return messages @@ -90,16 +100,25 @@ def aggregate(self, atom_messages, bond_messages, bond2atom_mapping): dims=[self.max_atoms, self.hidden_size], ) - atoms_hidden = lbann.Concatentate([atom_messages, a_messages], dim=0) + atoms_hidden = lbann.Concatenation( + [atom_messages, a_messages], axis=1, name=self.name + "atom_hidden_concat" + ) return self.W_o(atoms_hidden) def readout(self, atom_encoded_features, graph_mask, num_atoms): """ """ mol_encoding = lbann.Scatter( - atom_encoded_features, graph_mask, name=self.name + "graph_scatter" + atom_encoded_features, graph_mask, name=self.name + "graph_scatter", axis=0, + dims=[1, self.hidden_size] ) - mol_encoding = lbann.DivideOperator( - mol_encoding, lbann.Tessallate(num_atoms, dims=[self.max_atoms, 1]) + num_atoms = lbann.Reshape(num_atoms, dims=[1, 1]) + + mol_encoding = lbann.Divide( + mol_encoding, + lbann.Tessellate( + num_atoms, dims=[1, self.hidden_size], name=self.name + "expand_num_nodes" + ), + name=self.name + "_reduce", ) return mol_encoding @@ -113,7 +132,6 @@ def forward( bond2revbond_mapping, graph_mask, num_atoms, - max_atoms, ): """ """ bond_messages = self.message( diff --git a/applications/FLASK/MPNN/config.py b/applications/FLASK/MPNN/config.py index 301290d28fa..96d62a1376e 100644 --- a/applications/FLASK/MPNN/config.py +++ b/applications/FLASK/MPNN/config.py @@ -5,15 +5,17 @@ "MAX_ATOMS": 100, # The number of maximum atoms in CSD dataset "MAX_BONDS": 224, # The number of maximum bonds in CSD dataset "ATOM_FEATURES": 133, - "BOND_FEATURES" : 147 + "BOND_FEATURES": 147, + "DATA_DIR": "/p/vast1/lbann/datasets/FLASK/CSD10K", + "TARGET_FILE": "10k_dft_density_data.csv" # Change to 10k_dft_hof_data.csv for heat of formation } # Hyperamaters used to set up trainer and MPN # These can be changed freely HYPERPARAMETERS_CONFIG: dict = { - "HIDDEN_SIZE":300, + "HIDDEN_SIZE": 300, "LR": 0.001, - "BATCH_SIZE" : 128, - "EPOCH" : 50, + "BATCH_SIZE": 128, + "EPOCH": 50, "MPN_DEPTH": 3 } diff --git a/applications/FLASK/MPNN/dataset.py b/applications/FLASK/MPNN/dataset.py index e7d00d1b046..f56bded0e63 100644 --- a/applications/FLASK/MPNN/dataset.py +++ b/applications/FLASK/MPNN/dataset.py @@ -1,2 +1,105 @@ import pickle +import numpy as np + +MAX_ATOMS = 100 # The number of maximum atoms in CSD dataset +MAX_BONDS = 224 # The number of maximum bonds in CSD dataset +ATOM_FEATURES = 133 +BOND_FEATURES = 147 + +SAMPLE_SIZE = ( + (MAX_ATOMS * ATOM_FEATURES) + + (MAX_BONDS * BOND_FEATURES) + + 4 * MAX_BONDS + + MAX_ATOMS + + 2 +) + +DATA_DIR = "/p/vast1/lbann/datasets/FLASK/CSD10K/" + +with open(DATA_DIR + "10k_density_lbann.bin", "rb") as f: + data = pickle.load(f) + +train_index = np.load(DATA_DIR + "train_sample_indices.npy") +valid_index = np.load(DATA_DIR + "valid_sample_indices.npy") +test_index = np.load(DATA_DIR + "test_sample_indices.npy") + + +def padded_index_array(size, special_ignore_index=-1): + padded_indices = np.zeros(size, dtype=np.float32) + special_ignore_index + return padded_indices + + +def pad_data_sample(data): + num_atoms = data["num_atoms"] + num_bonds = data["num_bonds"] + f_atoms = np.zeros((MAX_ATOMS, ATOM_FEATURES), dtype=np.float32) + f_atoms[:num_atoms, :] = data["atom_features"] + + f_bonds = np.zeros((MAX_BONDS, BOND_FEATURES), dtype=np.float32) + + f_bonds[:num_bonds, :] = data["bond_features"] + + atom2bond_source = padded_index_array(MAX_BONDS) + atom2bond_source[:num_bonds] = data["dual_graph_atom2bond_source"] + + atom2bond_target = padded_index_array(MAX_BONDS) + atom2bond_target[:num_bonds] = data["dual_graph_atom2bond_target"] + + bond2atom_source = padded_index_array(MAX_BONDS) + bond2atom_source[:num_bonds] = data["bond_graph_source"] + bond2bond_target = padded_index_array(MAX_BONDS) + bond2bond_target[:num_bonds] = data["bond_graph_target"] + + atom_mask = padded_index_array(MAX_ATOMS) + atom_mask[:num_atoms] = np.zeros(num_atoms) + + num_atoms = np.array([num_atoms]).astype(np.float32) + target = np.array([data["target"]]).astype(np.float32) + + _data_array = [ + f_atoms.flatten(), + f_bonds.flatten(), + atom2bond_source.flatten(), + atom2bond_target.flatten(), + bond2atom_source.flatten(), + bond2bond_target.flatten(), + atom_mask.flatten(), + num_atoms.flatten(), + target.flatten(), + ] + + flattened_data_array = np.concatenate(_data_array, axis=None) + return flattened_data_array + + +def train_sample(index): + return pad_data_sample(data[train_index[index]]) + + +def validation_sample(index): + return pad_data_sample(data[valid_index[index]]) + + +def test_sample(index): + return pad_data_sample(data[test_index[index]]) + + +def train_num_samples(): + return 8164 + + +def validation_num_samples(): + return 1020 + + +def test_num_samples(): + return 1022 + + +def sample_dims(): + return (SAMPLE_SIZE,) + + +if __name__ == "__main__": + print(train_sample(2).shape, sample_dims()) diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py index a7325cfc533..a783be14d3f 100644 --- a/applications/FLASK/MPNN/model.py +++ b/applications/FLASK/MPNN/model.py @@ -5,24 +5,22 @@ def graph_splitter(_input): - """ - """ + """ """ split_indices = [0] - - max_atoms = DATASET_CONFIG['MAX_ATOMS'], - max_bonds = DATASET_CONFIG['MAX_BONDS'], - atom_features = DATASET_CONFIG['ATOM_FEATURES'] - bond_features = DATASET_CONFIG['BOND_FEATURES'] + max_atoms = DATASET_CONFIG["MAX_ATOMS"] + max_bonds = DATASET_CONFIG["MAX_BONDS"] + atom_features = DATASET_CONFIG["ATOM_FEATURES"] + bond_features = DATASET_CONFIG["BOND_FEATURES"] indices_length = max_bonds f_atom_size = max_atoms * atom_features split_indices.append(f_atom_size) - + f_bond_size = max_bonds * bond_features split_indices.append(f_bond_size) - + split_indices.append(max_bonds) split_indices.append(max_bonds) split_indices.append(max_bonds) @@ -32,55 +30,88 @@ def graph_splitter(_input): split_indices.append(1) split_indices.append(1) - for i in range(1, len(split_indices)): - split_indices[i] = split_indices[i] + split_indices[i - 1] - + split_indices[i] = split_indices[i] + split_indices[i - 1] + graph_input = lbann.Slice(_input, axis=0, slice_points=split_indices) - f_atoms = lbann.Reshape(lbann.Identity(graph_input), dims=[max_atoms, atom_features]) - f_bonds = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds, bond_features]) - atom2bond_source_mapping = lbann.Reshape(lbann.Identity(graph_input), - dims=[max_bonds]) - atom2bond_target_mapping = lbann.Reshape(lbann.Identity(graph_input), - dims=[max_bonds]) - bond2atom_mapping = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds]) - bond2bond_mapping = lbann.Reshape(lbann.Identity(graph_input), dims=[max_bonds]) + f_atoms = lbann.Reshape( + lbann.Identity(graph_input), dims=[max_atoms, atom_features] + ) + f_bonds = lbann.Reshape( + lbann.Identity(graph_input), dims=[max_bonds, bond_features] + ) + atom2bond_source_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) + atom2bond_target_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) + bond2atom_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) + bond2bond_mapping = lbann.Reshape( + lbann.Identity(graph_input), dims=[indices_length] + ) graph_mask = lbann.Reshape(lbann.Identity(graph_input), dims=[max_atoms]) num_atoms = lbann.Reshape(lbann.Identity(graph_input), dims=[1]) - target = lbann.Reshape(lbann.Identity(graph_input), dims=[1], name='TARGET') - - return f_atoms, f_bonds, atom2bond_source_mapping, atom2bond_target_mapping, \ - bond2atom_mapping, bond2bond_mapping, graph_mask, num_atoms, target + target = lbann.Reshape(lbann.Identity(graph_input), dims=[1], name="TARGET") + + return ( + f_atoms, + f_bonds, + atom2bond_source_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2bond_mapping, + graph_mask, + num_atoms, + target, + ) def make_model(): - _input = lbann.Input(data_field='samples') - - f_atoms, f_bonds, atom2bond_source_mapping, atom2bond_target_mapping, \ - bond2atom_mapping, bond2bond_mapping, graph_mask, num_atoms, target = graph_splitter(_input) - - encoder = MPNEncoder(atom_fdim=DATASET_CONFIG['ATOM_FEATURES'], - bond_fdim=DATASET_CONFIG['BOND_FEATURES'], - max_atoms=DATASET_CONFIG['MAX_ATOMS'], - hidden_size=HYPERPARAMETERS_CONFIG['HIDDEN_SIZE'], - activation_func=lbann.Relu) - - encoded_vec = encoder(f_atoms, - f_bonds, - atom2bond_source_mapping, - atom2bond_target_mapping, - bond2atom_mapping, - bond2bond_mapping, - graph_mask, - num_atoms) + _input = lbann.Input(data_field="samples") + + ( + f_atoms, + f_bonds, + atom2bond_source_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2bond_mapping, + graph_mask, + num_atoms, + target, + ) = graph_splitter(_input) + + encoder = MPNEncoder( + atom_fdim=DATASET_CONFIG["ATOM_FEATURES"], + bond_fdim=DATASET_CONFIG["BOND_FEATURES"], + max_atoms=DATASET_CONFIG["MAX_ATOMS"], + hidden_size=HYPERPARAMETERS_CONFIG["HIDDEN_SIZE"], + activation_func=lbann.Relu, + ) + + encoded_vec = encoder( + f_atoms, + f_bonds, + atom2bond_source_mapping, + atom2bond_target_mapping, + bond2atom_mapping, + bond2bond_mapping, + graph_mask, + num_atoms, + ) # Readout layers - x = lbann.FullyConnected(encoded_vec, num_neurons=HYPERPARAMETERS_CONFIG['HIDDEN_SIZE'], - name="READOUT_Linear_1") + x = lbann.FullyConnected( + encoded_vec, + num_neurons=HYPERPARAMETERS_CONFIG["HIDDEN_SIZE"], + name="READOUT_Linear_1", + ) x = lbann.Relu(x, name="READOUT_Activation_1") - x = lbann.FullyConnected(x, num_neurons=1, - name="PREDICTION") + x = lbann.FullyConnected(x, num_neurons=1, name="PREDICTION") loss = lbann.MeanSquaredError(x, target) @@ -90,27 +121,32 @@ def make_model(): training_output = lbann.CallbackPrint(interval=1, print_global_stat_only=False) gpu_usage = lbann.CallbackGPUMemoryUsage() timer = lbann.CallbackTimer() - predictions = lbann.CallbackDumpOutputs(['TARGET', 'PREDICTION'], - role='test') + predictions = lbann.CallbackDumpOutputs( + layers="PREDICTION", execution_modes="test" + ) callbacks = [training_output, gpu_usage, timer, predictions] - model = lbann.Model(HYPERPARAMETERS_CONFIG['EPOCH'], - layers=layers, - objective_function=loss, - callbacks=callbacks) + model = lbann.Model( + HYPERPARAMETERS_CONFIG["EPOCH"], + layers=layers, + objective_function=loss, + callbacks=callbacks, + ) return model -def make_data_reader(classname='dataset', - sample='get_sample_func', - num_samples='num_samples_func', - sample_dims='sample_dims_func'): +def make_data_reader( + classname="dataset", + sample="sample", + num_samples="num_samples", + sample_dims="sample_dims", +): data_dir = osp.dirname(osp.realpath(__file__)) reader = lbann.reader_pb2.DataReader() - for role in ['train', 'validation', 'test']: + for role in ["train", "validation", "test"]: _reader = reader.reader.add() - _reader.name = 'python' + _reader.name = "python" _reader.role = role _reader.shuffle = True _reader.fraction_of_data_to_use = 1.0 @@ -118,6 +154,6 @@ def make_data_reader(classname='dataset', _reader.python.module_dir = data_dir _reader.python.sample_function = f"{role}_{sample}" _reader.python.num_samples_function = f"{role}_{num_samples}" - _reader.python.sample_dims_function = f"{role}_{sample_dims}" - - return reader \ No newline at end of file + _reader.python.sample_dims_function = "sample_dims" + + return reader diff --git a/applications/FLASK/MPNN/train.py b/applications/FLASK/MPNN/train.py index db9fd5a472e..be3ef9fc255 100644 --- a/applications/FLASK/MPNN/train.py +++ b/applications/FLASK/MPNN/train.py @@ -8,6 +8,8 @@ desc = " Training a MPNN Model using LBANN" parser = argparse.ArgumentParser(description=desc) +lbann.contrib.args.add_scheduler_arguments(parser, 'ChemProp') +lbann.contrib.args.add_optimizer_arguments(parser) args = parser.parse_args() kwargs = lbann.contrib.args.get_scheduler_kwargs(args) From 64cc7c1c104206743d7c3f3beb8d4c15f972b716 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 30 Oct 2023 01:03:30 -0700 Subject: [PATCH 7/8] Add data helper file that converts chemprop data to lbann data --- applications/FLASK/MPNN/PrepareDataset.py | 79 +++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 applications/FLASK/MPNN/PrepareDataset.py diff --git a/applications/FLASK/MPNN/PrepareDataset.py b/applications/FLASK/MPNN/PrepareDataset.py new file mode 100644 index 00000000000..538af6cff87 --- /dev/null +++ b/applications/FLASK/MPNN/PrepareDataset.py @@ -0,0 +1,79 @@ +from config import DATASET_CONFIG +from tqdm import tqdm +import numpy as np +from chemprop.args import TrainArgs +from chemprop.features import reset_featurization_parameters +from chemprop.data import MoleculeDataLoader, utils +import os.path as osp +import pickle + + +def retrieve_dual_mapping(atom2bond, ascope): + atom_bond_mapping = [] + for a_start, a_size in enumerate: + _a2b = atom2bond.narrow(0, a_start, a_size) + for row, possible_bonds in enumerate(_a2b): + for bond in possible_bonds: + ind = bond.item() - 1 # Shift by 1 to account for null nodes + if ind >= 0: + atom_bond_mapping.append([row, ind]) + return np.array(atom_bond_mapping) + + +def PrepareDataset(save_file_name, target_file): + data_file = osp.join(DATASET_CONFIG["DATA_DIR"], target_file) + + arguments = [ + "--data_path", + data_file, + "--dataset_type", + "regression", + "--save_dir", + "./data/10k_dft_density", + ] + args = TrainArgs().parse_args(arguments) + reset_featurization_parameters() + data = utils.get_data(data_file, args=args) + # Need to use the data loader as the featurization happens in the dataloader + # Only use 1 mol as in LBANN we do not do coalesced batching (yet) + dataloader = MoleculeDataLoader(data, batch_size=1) + lbann_data = [] + for mol in tqdm(dataloader): + mol_data = {} + + mol_data["target"] = mol.targets()[0][0] + mol_data["num_atoms"] = mol.number_of_atoms[0][0] + # Multiply by 2 for directional bonds + mol_data["num_bonds"] = mol.number_of_bonds[0][0] * 2 + + mol_batch = mol.batch_graph()[0] + f_atoms, f_bonds, a2b, b2a, b2revb, ascope, bscope = mol_batch.get_components( + False + ) + + # shift by 1 as we don't use null nodes as in the ChemProp implementation + mol_data["atom_features"] = f_atoms[1:].numpy() + mol_data["bond_features"] = f_bonds[1:].numpy() + dual_graph_mapping = retrieve_dual_mapping(a2b, ascope) + + mol_data['dual_graph_atom2bond_source'] = dual_graph_mapping[:, 0] + mol_data['dual_graph_atom2bond_target'] = dual_graph_mapping[:, 1] + + # subtract 1 to shift the indices + mol_data['bond_graph_source'] = b2a[1:].numpy() - 1 + mol_data['bond_graph_target'] = b2revb[1:].numpy() - 1 + + lbann_data.append(mol_data) + + save_file = osp.join(DATASET_CONFIG["DATA_DIR"], save_file_name) + with open(save_file, 'wb') as f: + pickle.dump(lbann_data, f) + + +def main(): + PrepareDataset("10k_density_lbann.bin", "10k_dft_density_data.csv") + PrepareDataset("10k_hof_lbann.bin", "10k_dft_hof_data.csv") + + +if __name__ == "__main__": + main() From e5080722af6538478f870568fef720ce361c7633 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Sun, 3 Dec 2023 22:09:49 -0800 Subject: [PATCH 8/8] Added some documentation for model and training files --- applications/FLASK/MPNN/MPN.py | 11 +++++--- applications/FLASK/MPNN/README.md | 38 +++++++++++++++++++++++++++ applications/FLASK/MPNN/config.py | 4 +-- applications/FLASK/MPNN/dataset.py | 13 +++++++++- applications/FLASK/MPNN/model.py | 41 +++++++++++++++++++++--------- applications/FLASK/MPNN/train.py | 2 +- 6 files changed, 90 insertions(+), 19 deletions(-) diff --git a/applications/FLASK/MPNN/MPN.py b/applications/FLASK/MPNN/MPN.py index 7f1bad9c7ea..a71e39a1aff 100644 --- a/applications/FLASK/MPNN/MPN.py +++ b/applications/FLASK/MPNN/MPN.py @@ -108,15 +108,20 @@ def aggregate(self, atom_messages, bond_messages, bond2atom_mapping): def readout(self, atom_encoded_features, graph_mask, num_atoms): """ """ mol_encoding = lbann.Scatter( - atom_encoded_features, graph_mask, name=self.name + "graph_scatter", axis=0, - dims=[1, self.hidden_size] + atom_encoded_features, + graph_mask, + name=self.name + "graph_scatter", + axis=0, + dims=[1, self.hidden_size], ) num_atoms = lbann.Reshape(num_atoms, dims=[1, 1]) mol_encoding = lbann.Divide( mol_encoding, lbann.Tessellate( - num_atoms, dims=[1, self.hidden_size], name=self.name + "expand_num_nodes" + num_atoms, + dims=[1, self.hidden_size], + name=self.name + "expand_num_nodes", ), name=self.name + "_reduce", ) diff --git a/applications/FLASK/MPNN/README.md b/applications/FLASK/MPNN/README.md index e69de29bb2d..0804ef4bddd 100644 --- a/applications/FLASK/MPNN/README.md +++ b/applications/FLASK/MPNN/README.md @@ -0,0 +1,38 @@ +# ChemProp on LBANN + +## Prepere Dataset (optional) + +If not on lbann system or required to regenerate the data file so it is ingestible on LBANN. + +### Requirements + +``` +chemprop +numpy +torch +``` + +### Generate Data + +The data generator is set to read from and write data to the `DATA_DIR` directory in `config.py`. Update that line to read and store +from a custom directory. + + +Generate the data by calling: + + +`python PrepareDataset.py +` + +## Run the Trainer + +### Hyperparameters + +The hyperparameters for the model and training algorihms can be set in `config.py`. + + +### Run the trainer + + +### Results + diff --git a/applications/FLASK/MPNN/config.py b/applications/FLASK/MPNN/config.py index 96d62a1376e..e6b5cd18fef 100644 --- a/applications/FLASK/MPNN/config.py +++ b/applications/FLASK/MPNN/config.py @@ -15,7 +15,7 @@ HYPERPARAMETERS_CONFIG: dict = { "HIDDEN_SIZE": 300, "LR": 0.001, - "BATCH_SIZE": 128, - "EPOCH": 50, + "BATCH_SIZE": 64, + "EPOCH": 100, "MPN_DEPTH": 3 } diff --git a/applications/FLASK/MPNN/dataset.py b/applications/FLASK/MPNN/dataset.py index f56bded0e63..fd55e050195 100644 --- a/applications/FLASK/MPNN/dataset.py +++ b/applications/FLASK/MPNN/dataset.py @@ -31,6 +31,15 @@ def padded_index_array(size, special_ignore_index=-1): def pad_data_sample(data): + """ + Args: + data(dict): Dictionary of data samples with fields 'num_atoms', 'num_bonds', + 'dual_graph_atom2bond_source', 'dual_graph_atom2bond_target', + 'bond_graph_source', 'bond_grap_target', and 'target' + + Returns: + (np.array) + """ num_atoms = data["num_atoms"] num_bonds = data["num_bonds"] f_atoms = np.zeros((MAX_ATOMS, ATOM_FEATURES), dtype=np.float32) @@ -55,7 +64,9 @@ def pad_data_sample(data): atom_mask[:num_atoms] = np.zeros(num_atoms) num_atoms = np.array([num_atoms]).astype(np.float32) - target = np.array([data["target"]]).astype(np.float32) + target = (np.array([data["target"]]).astype(np.float32) + 67.14776709141553) / ( + 108.13423283538837 + ) _data_array = [ f_atoms.flatten(), diff --git a/applications/FLASK/MPNN/model.py b/applications/FLASK/MPNN/model.py index a783be14d3f..5931497e37f 100644 --- a/applications/FLASK/MPNN/model.py +++ b/applications/FLASK/MPNN/model.py @@ -5,7 +5,18 @@ def graph_splitter(_input): - """ """ + """ + Args: + _input: (lbann.InputLayer) The padded, flattened graph data + return: + (lbann.Layer, lbann.Layer, lbann.Layer, lbann.Layer, lbann.Layer, + lbann.Layer, lbann.Layer, lbann.Layer, lbann.Layer) + + A 9-tuple with the input features, bond features, source atom to bond + graph mapping, target atom to bond graph mapping, bong to atom mapping, + bond to bond mapping, graph mask, number of atoms in the molecule, and + the target + """ split_indices = [0] max_atoms = DATASET_CONFIG["MAX_ATOMS"] @@ -70,6 +81,10 @@ def graph_splitter(_input): def make_model(): + """ + Returns: + (lbann.Model) LBANN model for a regression target on the CSD10K dataset + """ _input = lbann.Input(data_field="samples") ( @@ -121,11 +136,18 @@ def make_model(): training_output = lbann.CallbackPrint(interval=1, print_global_stat_only=False) gpu_usage = lbann.CallbackGPUMemoryUsage() timer = lbann.CallbackTimer() - predictions = lbann.CallbackDumpOutputs( - layers="PREDICTION", execution_modes="test" - ) - - callbacks = [training_output, gpu_usage, timer, predictions] + predictions = lbann.CallbackDumpOutputs(layers="PREDICTION", execution_modes="test") + + targets = lbann.CallbackDumpOutputs(layers="TARGET", execution_modes="test") + step_learning_rate = lbann.CallbackStepLearningRate(step=10, amt=0.9) + callbacks = [ + training_output, + gpu_usage, + timer, + predictions, + targets, + step_learning_rate, + ] model = lbann.Model( HYPERPARAMETERS_CONFIG["EPOCH"], layers=layers, @@ -135,12 +157,7 @@ def make_model(): return model -def make_data_reader( - classname="dataset", - sample="sample", - num_samples="num_samples", - sample_dims="sample_dims", -): +def make_data_reader(classname="dataset", sample="sample", num_samples="num_samples"): data_dir = osp.dirname(osp.realpath(__file__)) reader = lbann.reader_pb2.DataReader() diff --git a/applications/FLASK/MPNN/train.py b/applications/FLASK/MPNN/train.py index be3ef9fc255..21b5ab7b708 100644 --- a/applications/FLASK/MPNN/train.py +++ b/applications/FLASK/MPNN/train.py @@ -17,7 +17,7 @@ model = make_model() data_reader = make_data_reader() -optimizer = lbann.SGD(learn_rate=HYPERPARAMETERS_CONFIG["LR"]) +optimizer = lbann.Adam(learn_rate=HYPERPARAMETERS_CONFIG["LR"], beta1=0.9, beta2=0.99, eps=1e-8, adamw_weight_decay=0) trainer = lbann.Trainer(mini_batch_size=HYPERPARAMETERS_CONFIG["BATCH_SIZE"]) lbann.contrib.launcher.run(