diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..53d06d3 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,36 @@ +name: Tests + +on: + push: + branches: + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.8 + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + + - name: Run Ruff + run: ruff check --output-format=github . + + # - name: Install package + # run: pip install . + + # - name: Test with pytest + # run: | + # pytest diff --git a/.gitignore b/.gitignore index a61da04..eb1ab5d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ notebooks/_*.ipynb +# vscode +.vscode + # jupyter MANIFEST build diff --git a/molexpress/__init__.py b/molexpress/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/molexpress/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/molexpress/_version.py b/molexpress/_version.py deleted file mode 100644 index 6853c36..0000000 --- a/molexpress/_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '0.0.0' \ No newline at end of file diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index cb49071..7743c52 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -1,78 +1,65 @@ +from __future__ import annotations + import numpy as np -from rdkit import Chem +from molexpress import types from molexpress.datasets import featurizers from molexpress.ops import chem_ops -from molexpress import types - -class MolecularGraphEncoder: +class MolecularGraphEncoder: def __init__( self, atom_featurizers: list[featurizers.Featurizer], bond_featurizers: list[featurizers.Featurizer] = None, - self_loops: bool = False, + self_loops: bool = False, ) -> None: self.node_encoder = MolecularNodeEncoder(atom_featurizers) - self.edge_encoder = MolecularEdgeEncoder( - bond_featurizers, self_loops=self_loops - ) - - def __call__( - self, - molecule: types.Molecule | types.SMILES | types.InChI - ) -> np.ndarray: + self.edge_encoder = MolecularEdgeEncoder(bond_featurizers, self_loops=self_loops) + + def __call__(self, molecule: types.Molecule | types.SMILES | types.InChI) -> np.ndarray: molecule = chem_ops.get_molecule(molecule) return {**self.node_encoder(molecule), **self.edge_encoder(molecule)} @staticmethod def _collate_fn( - data: list[tuple[types.MolecularGraph, np.ndarray]] + data: list[tuple[types.MolecularGraph, np.ndarray]], ) -> tuple[types.MolecularGraph, np.ndarray]: - - """TODO: Not sure where to implement this collate function. + """TODO: Not sure where to implement this collate function. Temporarily putting it here. Procedure: Merges list of graphs into a single disjoint graph. """ - x, y = list(zip(*data)) - - num_nodes = np.array([ - graph['node_state'].shape[0] for graph in x - ]) - + x, y = list(zip(*data)) + + num_nodes = np.array([graph["node_state"].shape[0] for graph in x]) + disjoint_graph = {} - disjoint_graph['node_state'] = np.concatenate([ - graph['node_state'] for graph in x - ]) + disjoint_graph["node_state"] = np.concatenate([graph["node_state"] for graph in x]) - if 'edge_state' in x[0]: - disjoint_graph['edge_state'] = np.concatenate([ - graph['edge_state'] for graph in x - ]) + if "edge_state" in x[0]: + disjoint_graph["edge_state"] = np.concatenate([graph["edge_state"] for graph in x]) - edge_src = np.concatenate([graph['edge_src'] for graph in x]) - edge_dst = np.concatenate([graph['edge_dst'] for graph in x]) - num_edges = np.array([graph['edge_src'].shape[0] for graph in x]) - indices = np.repeat(range(len(x)), num_edges) + edge_src = np.concatenate([graph["edge_src"] for graph in x]) + edge_dst = np.concatenate([graph["edge_dst"] for graph in x]) + num_edges = np.array([graph["edge_src"].shape[0] for graph in x]) + indices = np.repeat(range(len(x)), num_edges) edge_incr = np.concatenate([[0], num_nodes[:-1]]) edge_incr = np.take_along_axis(edge_incr, indices, axis=0) - disjoint_graph['edge_src'] = edge_src + edge_incr - disjoint_graph['edge_dst'] = edge_dst + edge_incr - disjoint_graph['graph_indicator'] = np.repeat(range(len(x)), num_nodes) + disjoint_graph["edge_src"] = edge_src + edge_incr + disjoint_graph["edge_dst"] = edge_dst + edge_incr + disjoint_graph["graph_indicator"] = np.repeat(range(len(x)), num_nodes) return disjoint_graph, np.stack(y) class Composer: - """Wraps a list of featurizers. - + While a Featurizer encodes an atom or bond based on a single property, the Composer encodes an atom or bond based on multiple properties. @@ -84,13 +71,12 @@ class Composer: def __init__(self, featurizers: list[featurizers.Featurizer]) -> None: self.featurizers = featurizers assert all( - self.featurizers[0].output_dtype == f.output_dtype - for f in self.featurizers + self.featurizers[0].output_dtype == f.output_dtype for f in self.featurizers ), "'dtype' of features need to be consistent." def __call__(self, inputs: types.Atom | types.Bond) -> np.ndarray: return np.concatenate([f(inputs) for f in self.featurizers]) - + @property def output_dim(self): return sum(f.output_dim for f in self.featurizers) @@ -98,50 +84,42 @@ def output_dim(self): @property def output_dtype(self): return self.featurizers[0].output_dtype - -class MolecularEdgeEncoder: +class MolecularEdgeEncoder: def __init__( - self, - featurizers: list[featurizers.Featurizer], - self_loops: bool = False + self, featurizers: list[featurizers.Featurizer], self_loops: bool = False ) -> None: - self.featurizer = Composer(featurizers) + self.featurizer = Composer(featurizers) self.self_loops = self_loops self.output_dim = self.featurizer.output_dim self.output_dtype = self.featurizer.output_dtype def __call__(self, molecule: types.Molecule) -> np.ndarray: - - edge_src, edge_dst = chem_ops.get_adjacency( - molecule, self_loops=self.self_loops) + edge_src, edge_dst = chem_ops.get_adjacency(molecule, self_loops=self.self_loops) if self.featurizer is None: - return {'edge_src': edge_src, 'edge_dst': edge_dst} + return {"edge_src": edge_src, "edge_dst": edge_dst} if molecule.GetNumBonds() == 0: edge_state = np.zeros( - shape=(0, self.output_dim + int(self.self_loops)), + shape=(0, self.output_dim + int(self.self_loops)), dtype=self.output_dtype ) return { - 'edge_src': edge_src, - 'edge_dst': edge_dst, - 'edge_state': edge_state + "edge_src": edge_src, + "edge_dst": edge_dst, + "edge_state": edge_state, } - + bond_encodings = [] for i, j in zip(edge_src, edge_dst): - bond = molecule.GetBondBetweenAtoms(int(i), int(j)) if bond is None: assert self.self_loops, "Found a bond to be None." - bond_encoding = np.zeros( - self.output_dim + 1, dtype=self.output_dtype - ) + bond_encoding = np.zeros(self.output_dim + 1, dtype=self.output_dtype) bond_encoding[-1] = 1 else: bond_encoding = self.featurizer(bond) @@ -151,23 +129,19 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray: bond_encodings.append(bond_encoding) return { - 'edge_src': edge_src, - 'edge_dst': edge_dst, - 'edge_state': np.stack(bond_encodings) + "edge_src": edge_src, + "edge_dst": edge_dst, + "edge_state": np.stack(bond_encodings), } - -class MolecularNodeEncoder: +class MolecularNodeEncoder: def __init__( - self, - featurizers: list[featurizers.Featurizer], + self, + featurizers: list[featurizers.Featurizer], ) -> None: - self.featurizer = Composer(featurizers) + self.featurizer = Composer(featurizers) def __call__(self, molecule: types.Molecule) -> np.ndarray: - node_encodings = np.stack([ - self.featurizer(atom) for atom in molecule.GetAtoms() - ], axis=0) - return {'node_state': np.stack(node_encodings)} - \ No newline at end of file + node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0) + return {"node_state": np.stack(node_encodings)} diff --git a/molexpress/datasets/featurizers.py b/molexpress/datasets/featurizers.py index e9e22bb..17a5a17 100644 --- a/molexpress/datasets/featurizers.py +++ b/molexpress/datasets/featurizers.py @@ -1,80 +1,151 @@ -from abc import ABC -from abc import abstractmethod +from __future__ import annotations -from rdkit.Chem import Lipinski -from rdkit.Chem import Crippen -from rdkit.Chem import rdMolDescriptors -from rdkit.Chem import rdPartialCharges +import math +from abc import ABC, abstractmethod import numpy as np -import math +from rdkit.Chem import Crippen, Lipinski, rdMolDescriptors, rdPartialCharges from molexpress import types - DEFAULT_VOCABULARY = { - 'AtomType': { - 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', - 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', - 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', - 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', - 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', - 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', - 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', - 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', - 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', - 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', - 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', - 'Rg', 'Cn' - }, - 'Hybridization': { - 's', 'sp', 'sp2', 'sp3', 'sp3d', 'sp3d2', 'unspecified' - }, - 'CIPCode': { - 'R', 'S', 'None' - }, - 'FormalCharge': { - -3, -2, -1, 0, 1, 2, 3, 4 - }, - 'TotalNumHs': { - 0, 1, 2, 3, 4 - }, - 'TotalValence': { - 0, 1, 2, 3, 4, 5, 6, 7, 8 - }, - 'NumRadicalElectrons': { - 0, 1, 2, 3 - }, - 'Degree': { - 0, 1, 2, 3, 4, 5, 6, 7, 8 - }, - 'RingSize': { - 0, 3, 4, 5, 6, 7, 8 - }, - 'BondType': { - 'single', 'double', 'triple', 'aromatic' - }, - 'Stereo': { - 'stereoe', 'stereoz', 'stereoany', 'stereonone' + "AtomType": { + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", }, + "Hybridization": {"s", "sp", "sp2", "sp3", "sp3d", "sp3d2", "unspecified"}, + "CIPCode": {"R", "S", "None"}, + "FormalCharge": {-3, -2, -1, 0, 1, 2, 3, 4}, + "TotalNumHs": {0, 1, 2, 3, 4}, + "TotalValence": {0, 1, 2, 3, 4, 5, 6, 7, 8}, + "NumRadicalElectrons": {0, 1, 2, 3}, + "Degree": {0, 1, 2, 3, 4, 5, 6, 7, 8}, + "RingSize": {0, 3, 4, 5, 6, 7, 8}, + "BondType": {"single", "double", "triple", "aromatic"}, + "Stereo": {"stereoe", "stereoz", "stereoany", "stereonone"}, } class Featurizer(ABC): - """Abstract featurizer. - + Featurizes a single atom or bond based on a single property. """ - def __init__( - self, - output_dim: int = None, - output_dtype: str = 'float32' - ) -> None: + def __init__(self, output_dim: int = None, output_dtype: str = "float32") -> None: self._output_dim = int(output_dim) if output_dim is not None else 1 self._output_dtype = output_dtype - + @abstractmethod def call(self, x: types.Atom | types.Bond) -> types.Scalar: pass @@ -89,47 +160,40 @@ def output_dtype(self) -> str: class OneHotFeaturizer(Featurizer): - """Abstract one-hot featurizer.""" def __init__( self, - vocab: list[str] | list[int] = None, + vocab: list[str] | list[int] = None, oov: bool = False, - output_dtype: str = 'float32', + output_dtype: str = "float32", ): if not vocab: vocab = DEFAULT_VOCABULARY.get(self.__class__.__name__) if vocab is None: raise ValueError("Need to supply a 'vocab'.") - - self.vocab = list(vocab) + + self.vocab = list(vocab) self.vocab.sort(key=lambda x: x if x is not None else "") self.oov = oov - super().__init__( - output_dim=len(self.vocab) + int(self.oov), - output_dtype=output_dtype - ) + super().__init__(output_dim=len(self.vocab) + int(self.oov), output_dtype=output_dtype) if self.oov: - self.vocab += [''] - + self.vocab += [""] + encodings = np.eye(self.output_dim, dtype=self.output_dtype) self.mapping = dict(zip(self.vocab, encodings)) - + def __call__(self, x: types.Atom | types.Bond) -> np.ndarray: feature = self.call(x) - encoding = self.mapping.get( - feature, None if not self.oov else self.mapping[''] - ) + encoding = self.mapping.get(feature, None if not self.oov else self.mapping[""]) if encoding is not None: return encoding return np.zeros([self.output_dim], dtype=self.output_dtype) - -class FloatFeaturizer(Featurizer): +class FloatFeaturizer(Featurizer): """Abstract scalar floating point featurizer.""" def __call__(self, x: types.Atom | types.Bond) -> np.ndarray: @@ -138,19 +202,19 @@ def __call__(self, x: types.Atom | types.Bond) -> np.ndarray: class AtomType(OneHotFeaturizer): def call(self, inputs: types.Atom) -> str: - return inputs.GetSymbol() + return inputs.GetSymbol() class Hybridization(OneHotFeaturizer): def call(self, inputs: types.Atom) -> str: return inputs.GetHybridization().name.lower() - + class CIPCode(OneHotFeaturizer): def call(self, atom: types.Atom) -> str | None: if atom.HasProp("_CIPCode"): return atom.GetProp("_CIPCode") - return 'None' + return "None" class ChiralCenter(FloatFeaturizer): @@ -260,7 +324,7 @@ class GasteigerCharge(FloatFeaturizer): def call(self, atom: types.Atom) -> float: mol = atom.GetOwningMol() rdPartialCharges.ComputeGasteigerCharges(mol) - val = atom.GetDoubleProp('_GasteigerCharge') + val = atom.GetDoubleProp("_GasteigerCharge") if val is not None and math.isfinite(val): return val return 0.0 @@ -274,7 +338,7 @@ def call(self, bond: types.Bond) -> str: class Stereo(OneHotFeaturizer): def call(self, bond: types.Bond) -> str: return bond.GetStereo().name.lower() - + class Conjugated(FloatFeaturizer): def call(self, bond: types.Bond) -> bool: @@ -284,6 +348,5 @@ def call(self, bond: types.Bond) -> bool: class Rotatable(FloatFeaturizer): def call(self, bond: types.Bond) -> bool: mol = bond.GetOwningMol() - atom_indices = tuple( - sorted([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])) + atom_indices = tuple(sorted([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])) return atom_indices in Lipinski._RotatableBonds(mol) diff --git a/molexpress/layers/__init__.py b/molexpress/layers/__init__.py index 0d79c61..d48b332 100644 --- a/molexpress/layers/__init__.py +++ b/molexpress/layers/__init__.py @@ -1,4 +1,4 @@ -from molexpress.layers.base_layer import BaseLayer -from molexpress.layers.gcn_conv import GCNConv -from molexpress.layers.gin_conv import GINConv -from molexpress.layers.readout import Readout \ No newline at end of file +from molexpress.layers.base_layer import BaseLayer as BaseLayer +from molexpress.layers.gcn_conv import GCNConv as GCNConv +from molexpress.layers.gin_conv import GINConv as GINConv +from molexpress.layers.readout import Readout as Readout diff --git a/molexpress/layers/base_layer.py b/molexpress/layers/base_layer.py index 87c7715..711cfa5 100644 --- a/molexpress/layers/base_layer.py +++ b/molexpress/layers/base_layer.py @@ -1,29 +1,28 @@ +from __future__ import annotations + import keras -from molexpress import types +from molexpress import types class BaseLayer(keras.layers.Layer): - """Base layer.""" def __init__( - self, + self, units: int, activation: keras.layers.Activation = None, use_bias: bool = True, - kernel_initializer: keras.initializers.Initializer = 'glorot_uniform', - bias_initializer: keras.initializers.Initializer = 'zeros', + kernel_initializer: keras.initializers.Initializer = "glorot_uniform", + bias_initializer: keras.initializers.Initializer = "zeros", kernel_regularizer: keras.regularizers.Regularizer = None, bias_regularizer: keras.regularizers.Regularizer = None, activity_regularizer: keras.regularizers.Regularizer = None, kernel_constraint: keras.constraints.Constraint = None, bias_constraint: keras.constraints.Constraint = None, - **kwargs + **kwargs, ) -> None: - super().__init__( - activity_regularizer=activity_regularizer, **kwargs - ) + super().__init__(activity_regularizer=activity_regularizer, **kwargs) self.units = units self.use_bias = use_bias self.activation = keras.activations.get(activation) @@ -36,87 +35,65 @@ def __init__( def get_config(self) -> dict[str, types.Any]: config = super().get_config() - config.update({ - 'units': self.units, - 'activation': keras.activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'kernel_initializer': keras.initializers.serialize( - self.kernel_initializer), - 'bias_initializer': keras.initializers.serialize( - self.bias_initializer), - 'kernel_regularizer': keras.regularizers.serialize( - self.kernel_regularizer), - 'bias_regularizer': keras.regularizers.serialize( - self.bias_regularizer), - 'activity_regularizer': keras.regularizers.serialize( - self.activity_regularizer), - 'kernel_constraint': keras.constraints.serialize( - self.kernel_constraint), - 'bias_constraint': keras.constraints.serialize( - self.bias_constraint), - }) + config.update( + { + "units": self.units, + "activation": keras.activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), + "bias_initializer": keras.initializers.serialize(self.bias_initializer), + "kernel_regularizer": keras.regularizers.serialize(self.kernel_regularizer), + "bias_regularizer": keras.regularizers.serialize(self.bias_regularizer), + "activity_regularizer": keras.regularizers.serialize(self.activity_regularizer), + "kernel_constraint": keras.constraints.serialize(self.kernel_constraint), + "bias_constraint": keras.constraints.serialize(self.bias_constraint), + } + ) return config def compute_output_shape( - self, - input_shape: dict[str, tuple[int, ...]] + self, input_shape: dict[str, tuple[int, ...]] ) -> dict[str, tuple[int, ...]]: output_shape = input_shape - output_shape['node_state'] = ( - *input_shape['node_state'][:-1], self.units - ) - if input_shape['edge_state'] is not None: - output_shape['edge_state'] = ( - *input_shape['edge_state'][:-1], self.units - ) + output_shape["node_state"] = (*input_shape["node_state"][:-1], self.units) + if input_shape["edge_state"] is not None: + output_shape["edge_state"] = (*input_shape["edge_state"][:-1], self.units) return output_shape def add_kernel( - self, - name: str, - shape: tuple[int, ...], - dtype: str = 'float32', - **kwargs + self, name: str, shape: tuple[int, ...], dtype: str = "float32", **kwargs ) -> types.Variable: return self.add_weight( name=name, shape=shape, dtype=dtype, - **self._common_weight_kwargs('kernel'), + **self._common_weight_kwargs("kernel"), **kwargs, ) - + def add_bias( - self, - name: str, - shape: tuple[int, ...] = None, - dtype: str = 'float32', - **kwargs + self, name: str, shape: tuple[int, ...] = None, dtype: str = "float32", **kwargs ) -> types.Variable: return self.add_weight( name=name, shape=shape if shape is not None else (self.units,), dtype=dtype, - **self._common_weight_kwargs('bias'), + **self._common_weight_kwargs("bias"), **kwargs, ) - def _common_weight_kwargs( - self, - weight_type: str - ) -> dict[str, types.Any]: + def _common_weight_kwargs(self, weight_type: str) -> dict[str, types.Any]: initializer = getattr(self, f"{weight_type}_initializer", None) regularizer = getattr(self, f"{weight_type}_regularizer", None) - regularizer = None if regularizer is None else regularizer.from_config( - regularizer.get_config() + regularizer = ( + None if regularizer is None else regularizer.from_config(regularizer.get_config()) ) constraint = getattr(self, f"{weight_type}_constraint", None) - constraint = None if constraint is None else constraint.from_config( - constraint.get_config() + constraint = ( + None if constraint is None else constraint.from_config(constraint.get_config()) ) return { - 'initializer': initializer, - 'regularizer': regularizer, - 'constraint': constraint, + "initializer": initializer, + "regularizer": regularizer, + "constraint": constraint, } - \ No newline at end of file diff --git a/molexpress/layers/gcn_conv.py b/molexpress/layers/gcn_conv.py index ccd6fb6..c7be2ed 100644 --- a/molexpress/layers/gcn_conv.py +++ b/molexpress/layers/gcn_conv.py @@ -1,12 +1,11 @@ -import keras +import keras from molexpress import types -from molexpress.ops import gnn_ops from molexpress.layers.base_layer import BaseLayer +from molexpress.ops import gnn_ops class GCNConv(BaseLayer): - def __init__( self, units: int, @@ -15,8 +14,8 @@ def __init__( normalization: bool = True, skip_connection: bool = True, dropout_rate: float = 0, - kernel_initializer: keras.initializers.Initializer = 'glorot_uniform', - bias_initializer: keras.initializers.Initializer = 'zeros', + kernel_initializer: keras.initializers.Initializer = "glorot_uniform", + bias_initializer: keras.initializers.Initializer = "zeros", kernel_regularizer: keras.regularizers.Regularizer = None, bias_regularizer: keras.regularizers.Regularizer = None, activity_regularizer: keras.regularizers.Regularizer = None, @@ -36,17 +35,14 @@ def __init__( kernel_constraint=kernel_constraint, bias_constraint=bias_constraint, **kwargs, - ) + ) self.dropout_rate = dropout_rate self.skip_connection = skip_connection self.normalization = normalization - def build( - self, - input_shape: dict[str, tuple[int, ...]] - ) -> None: - node_state_shape = input_shape['node_state'] - edge_state_shape = input_shape.get('edge_state') + def build(self, input_shape: dict[str, tuple[int, ...]]) -> None: + node_state_shape = input_shape["node_state"] + edge_state_shape = input_shape.get("edge_state") node_dim = node_state_shape[-1] @@ -54,18 +50,16 @@ def build( if self._transform_skip_connection: self.skip_connect_kernel = self.add_kernel( - name='skip_connect_kernel', shape=(node_dim, self.units) + name="skip_connect_kernel", shape=(node_dim, self.units) ) - self.node_kernel = self.add_kernel( - name='node_kernel', shape=(node_dim, self.units) - ) - - self.bias = self.add_bias(name='bias') + self.node_kernel = self.add_kernel(name="node_kernel", shape=(node_dim, self.units)) + + self.bias = self.add_bias(name="bias") if edge_state_shape is not None: self.edge_kernel = self.add_kernel( - name='edge_kernel', shape=(edge_state_shape[-1], self.units) + name="edge_kernel", shape=(edge_state_shape[-1], self.units) ) if self.normalization: @@ -75,34 +69,27 @@ def build( self.dropout = keras.layers.Dropout(self.dropout_rate) def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph: - x = inputs.copy() - node_state = x.pop('node_state') - edge_src = x['edge_src'] - edge_dst = x['edge_dst'] - edge_state = x.get('edge_state') - edge_weight = x.get('edge_weight') + node_state = x.pop("node_state") + edge_src = x["edge_src"] + edge_dst = x["edge_dst"] + edge_state = x.get("edge_state") + edge_weight = x.get("edge_weight") node_state_updated = gnn_ops.transform( - state=node_state, - kernel=self.node_kernel, - bias=self.bias + state=node_state, kernel=self.node_kernel, bias=self.bias ) if edge_state is not None: - edge_state = gnn_ops.transform( - state=edge_state, - kernel=self.edge_kernel, - bias=None - ) + edge_state = gnn_ops.transform(state=edge_state, kernel=self.edge_kernel, bias=None) node_state_updated = gnn_ops.aggregate( - node_state=node_state_updated, - edge_src=edge_src, - edge_dst=edge_dst, - edge_state=edge_state, - edge_weight=edge_weight + node_state=node_state_updated, + edge_src=edge_src, + edge_dst=edge_dst, + edge_state=edge_state, + edge_weight=edge_weight, ) if self.normalization: @@ -113,21 +100,21 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph: if self.skip_connection: if self._transform_skip_connection: - node_state = gnn_ops.transform( - state=node_state, kernel=self.skip_connect_kernel - ) + node_state = gnn_ops.transform(state=node_state, kernel=self.skip_connect_kernel) node_state_updated += node_state if self.dropout_rate: node_state_updated = self.dropout(node_state_updated) return dict(node_state=node_state_updated, **x) - + def get_config(self) -> dict[str, types.Any]: config = super().get_config() - config.update({ - 'normalization': self.normalization, - 'skip_connection': self.skip_connection, - 'dropout_rate': self.dropout_rate - }) - return config \ No newline at end of file + config.update( + { + "normalization": self.normalization, + "skip_connection": self.skip_connection, + "dropout_rate": self.dropout_rate, + } + ) + return config diff --git a/molexpress/layers/gin_conv.py b/molexpress/layers/gin_conv.py index ab73901..599e7b2 100644 --- a/molexpress/layers/gin_conv.py +++ b/molexpress/layers/gin_conv.py @@ -1,12 +1,13 @@ -import keras +from __future__ import annotations + +import keras from molexpress import types -from molexpress.ops import gnn_ops from molexpress.layers.base_layer import BaseLayer +from molexpress.ops import gnn_ops class GINConv(BaseLayer): - def __init__( self, units: int, @@ -15,8 +16,8 @@ def __init__( normalization: bool = True, skip_connection: bool = True, dropout_rate: float = 0, - kernel_initializer: keras.initializers.Initializer = 'glorot_uniform', - bias_initializer: keras.initializers.Initializer = 'zeros', + kernel_initializer: keras.initializers.Initializer = "glorot_uniform", + bias_initializer: keras.initializers.Initializer = "zeros", kernel_regularizer: keras.regularizers.Regularizer = None, bias_regularizer: keras.regularizers.Regularizer = None, activity_regularizer: keras.regularizers.Regularizer = None, @@ -36,52 +37,43 @@ def __init__( kernel_constraint=kernel_constraint, bias_constraint=bias_constraint, **kwargs, - ) + ) self.dropout_rate = dropout_rate self.skip_connection = skip_connection self.normalization = normalization - def build( - self, - input_shape: dict[str, tuple[int, ...]] - ) -> None: - node_state_shape = input_shape['node_state'] - edge_state_shape = input_shape.get('edge_state') + def build(self, input_shape: dict[str, tuple[int, ...]]) -> None: + node_state_shape = input_shape["node_state"] + edge_state_shape = input_shape.get("edge_state") node_dim = node_state_shape[-1] if edge_state_shape is not None: edge_dim = edge_state_shape[-1] - + self._transform_node_state = node_dim != self.units if self._transform_node_state: self.special_node_kernel = self.add_kernel( - name='special_node_kernel', shape=(node_dim, self.units) + name="special_node_kernel", shape=(node_dim, self.units) ) node_dim = self.units - self.node_kernel_1 = self.add_kernel( - name='node_kernel_2', shape=(node_dim, self.units) - ) - self.node_kernel_2 = self.add_kernel( - name='node_kernel_2', shape=(node_dim, self.units) - ) + self.node_kernel_1 = self.add_kernel(name="node_kernel_2", shape=(node_dim, self.units)) + self.node_kernel_2 = self.add_kernel(name="node_kernel_2", shape=(node_dim, self.units)) if self.use_bias: - self.node_bias_1 = self.add_bias(name='node_bias_1') - self.node_bias_2 = self.add_bias(name='node_bias_2') + self.node_bias_1 = self.add_bias(name="node_bias_1") + self.node_bias_2 = self.add_bias(name="node_bias_2") self._transform_edge_state = edge_dim != node_dim if edge_state_shape is not None and self._transform_edge_state: self.special_edge_kernel = self.add_kernel( - name='special_edge_kernel', shape=(edge_dim, node_dim) + name="special_edge_kernel", shape=(edge_dim, node_dim) ) - self.epsilon = self.add_weight( - name='epsilon', shape=(), initializer='zeros' - ) - + self.epsilon = self.add_weight(name="epsilon", shape=(), initializer="zeros") + if self.normalization: self.normalize = keras.layers.BatchNormalization() @@ -89,44 +81,36 @@ def build( self.dropout = keras.layers.Dropout(self.dropout_rate) def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph: - x = inputs.copy() - node_state = x.pop('node_state') - edge_src = x['edge_src'] - edge_dst = x['edge_dst'] - edge_state = x.get('edge_state') - edge_weight = x.get('edge_weight') - + node_state = x.pop("node_state") + edge_src = x["edge_src"] + edge_dst = x["edge_dst"] + edge_state = x.get("edge_state") + edge_weight = x.get("edge_weight") if edge_state is not None and self._transform_edge_state: edge_state = gnn_ops.transform( - state=edge_state, - kernel=self.special_edge_kernel, - bias=None + state=edge_state, kernel=self.special_edge_kernel, bias=None ) if self._transform_node_state: node_state = gnn_ops.transform( - state=node_state, - kernel=self.special_node_kernel, - bias=None + state=node_state, kernel=self.special_node_kernel, bias=None ) node_state_updated = gnn_ops.aggregate( - node_state=node_state, - edge_src=edge_src, - edge_dst=edge_dst, - edge_state=edge_state, - edge_weight=edge_weight + node_state=node_state, + edge_src=edge_src, + edge_dst=edge_dst, + edge_state=edge_state, + edge_weight=edge_weight, ) - + node_state_updated += (1 + self.epsilon) * node_state node_state_updated = gnn_ops.transform( - state=node_state_updated, - kernel=self.node_kernel_1, - bias=self.node_bias_1 + state=node_state_updated, kernel=self.node_kernel_1, bias=self.node_bias_1 ) if self.normalization: @@ -135,9 +119,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph: node_state_updated = self.activation(node_state_updated) node_state_updated = gnn_ops.transform( - state=node_state_updated, - kernel=self.node_kernel_2, - bias=self.node_bias_2 + state=node_state_updated, kernel=self.node_kernel_2, bias=self.node_bias_2 ) if self.activation is not None: @@ -150,12 +132,14 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph: node_state_updated = self.dropout(node_state_updated) return dict(node_state=node_state_updated, **x) - + def get_config(self) -> dict[str, types.Any]: config = super().get_config() - config.update({ - 'normalization': self.normalization, - 'skip_connection': self.skip_connection, - 'dropout_rate': self.dropout_rate - }) - return config \ No newline at end of file + config.update( + { + "normalization": self.normalization, + "skip_connection": self.skip_connection, + "dropout_rate": self.dropout_rate, + } + ) + return config diff --git a/molexpress/layers/readout.py b/molexpress/layers/readout.py index 45b0a5e..59fa7d5 100644 --- a/molexpress/layers/readout.py +++ b/molexpress/layers/readout.py @@ -1,31 +1,31 @@ +from __future__ import annotations + import keras from molexpress import types -from molexpress.ops import gnn_ops +from molexpress.ops import gnn_ops class Readout(keras.layers.Layer): - - def __init__(self, mode: str = 'mean', **kwargs) -> None: + def __init__(self, mode: str = "mean", **kwargs) -> None: super().__init__(**kwargs) self.mode = mode - if self.mode == 'max': - self._readout_fn = keras.ops.segment_max - elif self.mode == 'sum': + if self.mode == "max": + self._readout_fn = keras.ops.segment_max + elif self.mode == "sum": self._readout_fn = keras.ops.segment_sum else: self._readout_fn = gnn_ops.segment_mean def build(self, input_shape: dict[str, tuple[int, ...]]) -> None: - if 'graph_indicator' not in input_shape: - raise ValueError( - "Cannot perform readout: 'graph_indicator' not found.") + if "graph_indicator" not in input_shape: + raise ValueError("Cannot perform readout: 'graph_indicator' not found.") def call(self, inputs: types.MolecularGraph) -> types.Array: - graph_indicator = keras.ops.cast(inputs['graph_indicator'], 'int32') + graph_indicator = keras.ops.cast(inputs["graph_indicator"], "int32") return self._readout_fn( - data=inputs['node_state'], + data=inputs["node_state"], segment_ids=graph_indicator, num_segments=None, - sorted=False, - ) \ No newline at end of file + sorted=False, + ) diff --git a/molexpress/ops/chem_ops.py b/molexpress/ops/chem_ops.py index 9a8ff0d..c71c810 100644 --- a/molexpress/ops/chem_ops.py +++ b/molexpress/ops/chem_ops.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import numpy as np -from rdkit import Chem +from rdkit import Chem from molexpress import types @@ -8,15 +10,14 @@ def get_molecule( molecule: types.Molecule | types.SMILES | types.InChI, catch_errors: bool = False, ) -> Chem.Mol | None: - """Generates an molecule object.""" if isinstance(molecule, Chem.Mol): return molecule - string = molecule + string = molecule - if string.startswith('InChI'): + if string.startswith("InChI"): molecule = Chem.MolFromInchi(string, sanitize=False) else: molecule = Chem.MolFromSmiles(string, sanitize=False) @@ -28,36 +29,30 @@ def get_molecule( if flag != Chem.SanitizeFlags.SANITIZE_NONE: if not catch_errors: return None - # Sanitize molecule again, without the sanitization step that caused + # Sanitize molecule again, without the sanitization step that caused # the error previously. Unrealistic molecules might pass without an error. - Chem.SanitizeMol( - molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag) + Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag) - Chem.AssignStereochemistry( - molecule, cleanIt=True, force=True, flagPossibleStereoCenters=True) + Chem.AssignStereochemistry(molecule, cleanIt=True, force=True, flagPossibleStereoCenters=True) return molecule + def get_adjacency( molecule: types.Molecule, self_loops: bool = False, sparse: bool = True, - dtype: str = 'int32', + dtype: str = "int32", ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: - """Computes the (sparse) adjacency matrix of the molecule""" adjacency_matrix: np.ndarray = Chem.GetAdjacencyMatrix(molecule) if self_loops: - adjacency_matrix += np.eye( - adjacency_matrix.shape[0], dtype=adjacency_matrix.dtype - ) + adjacency_matrix += np.eye(adjacency_matrix.shape[0], dtype=adjacency_matrix.dtype) if not sparse: return adjacency_matrix.astype(dtype) edge_src, edge_dst = np.where(adjacency_matrix) return edge_src.astype(dtype), edge_dst.astype(dtype) - - diff --git a/molexpress/ops/gnn_ops.py b/molexpress/ops/gnn_ops.py index 6770471..c26c654 100644 --- a/molexpress/ops/gnn_ops.py +++ b/molexpress/ops/gnn_ops.py @@ -1,15 +1,17 @@ -import keras +from __future__ import annotations -from molexpress import types +import keras + +from molexpress import types def transform( - state, + state, kernel: types.Variable, bias: types.Variable = None, -) -> types.Array: +) -> types.Array: """Transforms node or edge states via learnable weights. - + Args: state: The current node or edge states to be updated. @@ -17,16 +19,17 @@ def transform( The learnable kernel. bias: The learnable bias. - + Returns: A transformed node state. """ state_transformed = keras.ops.matmul(state, kernel) if bias is not None: - state_transformed += bias + state_transformed += bias return state_transformed - + + def aggregate( node_state: types.Array, edge_src: types.Array, @@ -35,16 +38,16 @@ def aggregate( edge_weight: types.Array = None, ) -> types.Array: """Aggregates node states based on edges. - + Given node A with edges AB and AC, the information (states) of nodes B and C will be passed to node A. Args: - node_state: + node_state: The current state of the nodes. edge_src: The indices of the source nodes. - edge_dst: + edge_dst: The indices of the destination nodes. edge_state: Optional edge states. @@ -56,39 +59,35 @@ def aggregate( """ num_nodes = keras.ops.shape(node_state)[0] - expected_rank = 2 + expected_rank = 2 current_rank = len(keras.ops.shape(edge_src)) for _ in range(expected_rank - current_rank): edge_src = keras.ops.expand_dims(edge_src, axis=-1) edge_dst = keras.ops.expand_dims(edge_dst, axis=-1) - - node_state_src = keras.ops.take_along_axis( - node_state, edge_src, axis=0 - ) + + node_state_src = keras.ops.take_along_axis(node_state, edge_src, axis=0) if edge_weight is not None: - node_state_src *= edge_weight + node_state_src *= edge_weight if edge_state is not None: - node_state_src += edge_state + node_state_src += edge_state edge_dst = keras.ops.squeeze(edge_dst, axis=-1) node_state_updated = keras.ops.segment_sum( - data=node_state_src, - segment_ids=edge_dst, - num_segments=num_nodes, - sorted=False + data=node_state_src, segment_ids=edge_dst, num_segments=num_nodes, sorted=False ) return node_state_updated + def segment_mean( data: types.Array, segment_ids: types.Array, num_segments: int = None, - sorted: bool = False + sorted: bool = False, ) -> types.Array: """Performs a mean of data based on segment indices. - + A permutation invariant reduction of the node states to obtain an encoding of the graph. @@ -106,9 +105,6 @@ def segment_mean( New data that has been reduced. """ x = keras.ops.segment_sum( - data=data, - segment_ids=segment_ids, - num_segments=num_segments, - sorted=sorted - ) + data=data, segment_ids=segment_ids, num_segments=num_segments, sorted=sorted + ) return x / keras.ops.cast(keras.ops.bincount(segment_ids), x.dtype)[:, None] diff --git a/molexpress/types.py b/molexpress/types.py index be12a48..3315e9b 100644 --- a/molexpress/types.py +++ b/molexpress/types.py @@ -1,12 +1,14 @@ -from typing import TypedDict -from typing import Protocol -from typing import TypeVar -from typing import Any +from __future__ import annotations -from rdkit import Chem +from typing import ( + Any, # noqa: F401 + Protocol, # noqa: F401 + TypedDict, + TypeVar, +) +from rdkit import Chem -Scalar = TypeVar("Scalar") Array = TypeVar("Array") Variable = TypeVar("Variable") @@ -14,17 +16,17 @@ DType = TypeVar("DType") Molecule = Chem.Mol -Atom = Chem.Atom +Atom = Chem.Atom Bond = Chem.Bond SMILES = TypeVar("SMILES", bound=str) -InChI = TypeVar("InChI", bound=str) +InChI = TypeVar("InChI", bound=str) class MolecularGraph(TypedDict): - node_state: Array + node_state: Array edge_src: Array - edge_dst: Array - edge_state: Array | None - edge_weight: Array | None + edge_dst: Array + edge_state: Array | None + edge_weight: Array | None graph_indicator: Array | None diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..be7ca30 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[project] +name = "molexpress" +description = "Graph Neural Networks for molecules with Keras 3." +readme = "README.md" +license = { file = "LICENSE" } +dynamic = ["version"] +authors = [ + { name = "Alexander Kensert", email = "alexander.kensert@gmail.com" }, +] +keywords = [ + "python", + "keras-3", + "machine-learning", + "deep-learning", + "graph-neural-networks", + "graph-convolutional-networks", + "graphs", + "molecules", + "chemistry", + "cheminformatics", + "bioinformatics", +] +classifiers = [ + "Programming Language :: Python :: 3", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: POSIX :: Linux", +] +requires-python = ">=3.9" +dependencies = ["rdkit>=2023.9.5", "keras>=3", "numpy"] + +[project.optional-dependencies] +dev = ["ruff", "isort"] + +[project.urls] +homepage = "https://github.com/compomics/molexpress" + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["molexpress"] + +[tool.isort] +profile = "black" + +[tool.ruff] +line-length = 99 +target-version = 'py39' + +[tool.ruff.format] +docstring-code-format = true diff --git a/setup.py b/setup.py deleted file mode 100644 index e34478d..0000000 --- a/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -import setuptools -import os -import sys - -def get_version(): - version_path = os.path.join(os.path.dirname(__file__), 'molexpress') - sys.path.insert(0, version_path) - from _version import __version__ as version - return version - -with open("README.md", "r") as fh: - long_description = fh.read() - -install_requires = [ - "tensorflow>=2.16.1", # Installs Keras 3 - "rdkit>=2023.9.5", - "jupyter", # Optional, but needed for the notebooks -] - -setuptools.setup( - name='molexpress', - version=get_version(), - author="Alexander Kensert", - author_email="alexander.kensert@gmail.com", - description="Graph Neural Networks with Keras 3.", - long_description=long_description, - long_description_content_type="text/markdown", - license="MIT", - url="https://github.com/compomics/molexpress", - packages=setuptools.find_packages(include=["molexpress*"]), - install_requires=install_requires, - classifiers=[ - "Programming Language :: Python :: 3", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Operating System :: POSIX :: Linux", - ], - python_requires=">=3.10.6", - keywords=[ - 'python', - 'keras-3', - 'machine-learning', - 'deep-learning', - 'graph-neural-networks', - 'graph-convolutional-networks', - 'graphs', - 'molecules', - 'chemistry', - 'cheminformatics', - 'bioinformatics', - ] -)