Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save atom hybridization #408

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/PULL_REQUEST_TEMPLATE/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ Checklist

## Developers certificate of origin
- [ ] I certify that this contribution is covered by the MIT License [here](https://github.com/OpenFreeEnergy/openfe/blob/main/LICENSE) and the **Developer Certificate of Origin** at <https://developercertificate.org/>.

12 changes: 2 additions & 10 deletions gufe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

from . import tokenization, visualization
from .chemicalsystem import ChemicalSystem
from .components import (
Component,
ProteinComponent,
SmallMoleculeComponent,
SolventComponent,
)
from .components import Component, ProteinComponent, SmallMoleculeComponent, SolventComponent
from .ligandnetwork import LigandNetwork
from .mapping import AtomMapper # more specific to atom based components
from .mapping import ComponentMapping # how individual Components relate
Expand All @@ -21,10 +16,7 @@
from .protocols import ProtocolDAGResult # the collected result of a DAG
from .protocols import ProtocolUnit # the individual step within a method
from .protocols import ProtocolUnitResult # the result of a single Unit
from .protocols import ( # potentially many DAGs together, giving an estimate
Context,
ProtocolResult,
)
from .protocols import Context, ProtocolResult # potentially many DAGs together, giving an estimate
from .settings import Settings
from .transformations import NonTransformation, Transformation

Expand Down
8 changes: 2 additions & 6 deletions gufe/chemicalsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,11 @@ def __init__(
self._name = name

def __repr__(self):
return (
f"{self.__class__.__name__}(name={self.name}, components={self.components})"
)
return f"{self.__class__.__name__}(name={self.name}, components={self.components})"

def _to_dict(self):
return {
"components": {
key: value for key, value in sorted(self.components.items())
},
"components": {key: value for key, value in sorted(self.components.items())},
"name": self.name,
}

Expand Down
21 changes: 5 additions & 16 deletions gufe/components/explicitmoleculecomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,12 @@ def _check_partial_charges(mol: RDKitMol) -> None:
p_chgs = np.array(mol.GetProp("atom.dprop.PartialCharge").split(), dtype=float)

if len(p_chgs) != mol.GetNumAtoms():
errmsg = (
f"Incorrect number of partial charges: {len(p_chgs)} "
f" were provided for {mol.GetNumAtoms()} atoms"
)
errmsg = f"Incorrect number of partial charges: {len(p_chgs)} " f" were provided for {mol.GetNumAtoms()} atoms"
raise ValueError(errmsg)

if (sum(p_chgs) - Chem.GetFormalCharge(mol)) > 0.01:
errmsg = (
f"Sum of partial charges {sum(p_chgs)} differs from "
f"RDKit formal charge {Chem.GetFormalCharge(mol)}"
f"Sum of partial charges {sum(p_chgs)} differs from " f"RDKit formal charge {Chem.GetFormalCharge(mol)}"
)
raise ValueError(errmsg)

Expand All @@ -81,16 +77,12 @@ def _check_partial_charges(mol: RDKitMol) -> None:
atom_charge = atom.GetDoubleProp("PartialCharge")
if not np.isclose(atom_charge, charge):
errmsg = (
f"non-equivalent partial charges between atom and "
f"molecule properties: {atom_charge} {charge}"
f"non-equivalent partial charges between atom and " f"molecule properties: {atom_charge} {charge}"
)
raise ValueError(errmsg)

if np.all(np.isclose(p_chgs, 0.0)):
wmsg = (
f"Partial charges provided all equal to "
"zero. These may be ignored by some Protocols."
)
wmsg = f"Partial charges provided all equal to " "zero. These may be ignored by some Protocols."
warnings.warn(wmsg)
else:
wmsg = (
Expand Down Expand Up @@ -121,10 +113,7 @@ def __init__(self, rdkit: RDKitMol, name: str = ""):

n_confs = len(conformers)
if n_confs > 1:
warnings.warn(
f"Molecule provided with {n_confs} conformers. "
f"Only the first will be used."
)
warnings.warn(f"Molecule provided with {n_confs} conformers. " f"Only the first will be used.")

if not any(atom.GetAtomicNum() == 1 for atom in rdkit.GetAtoms()):
warnings.warn(
Expand Down
31 changes: 8 additions & 23 deletions gufe/components/proteincomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def from_pdbx_file(cls, pdbx_file: str, name=""):
return cls._from_openmmPDBFile(openmm_PDBFile=openmm_PDBxFile, name=name)

@classmethod
def _from_openmmPDBFile(
cls, openmm_PDBFile: Union[PDBFile, PDBxFile], name: str = ""
):
def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], name: str = ""):
"""Converts to our internal representation (rdkit Mol)

Parameters
Expand Down Expand Up @@ -201,9 +199,7 @@ def _from_openmmPDBFile(

# Set Positions
rd_mol = editable_rdmol.GetMol()
positions = np.array(
openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3
)
positions = np.array(openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3)

for frame_id, frame in enumerate(positions):
conf = Conformer(frame_id)
Expand All @@ -218,9 +214,7 @@ def _from_openmmPDBFile(
atomic_num = a.GetAtomicNum()
atom_name = a.GetMonomerInfo().GetName()

connectivity = sum(
_BONDORDER_TO_ORDER[bond.GetBondType()] for bond in a.GetBonds()
)
connectivity = sum(_BONDORDER_TO_ORDER[bond.GetBondType()] for bond in a.GetBonds())
default_valence = periodicTable.GetDefaultValence(atomic_num)

if connectivity == 0: # ions:
Expand Down Expand Up @@ -364,9 +358,7 @@ def chainkey(m):

if (new_resid := reskey(mi)) != current_resid:
_, resname, resnum, icode = new_resid
r = top.addResidue(
name=resname, chain=c, id=str(resnum), insertionCode=icode
)
r = top.addResidue(name=resname, chain=c, id=str(resnum), insertionCode=icode)
current_resid = new_resid

a = top.addAtom(
Expand All @@ -381,9 +373,7 @@ def chainkey(m):
for bond in self._rdkit.GetBonds():
a1 = atom_lookup[bond.GetBeginAtomIdx()]
a2 = atom_lookup[bond.GetEndAtomIdx()]
top.addBond(
a1, a2, order=_BONDORDERS_RDKIT_TO_OPENMM.get(bond.GetBondType(), None)
)
top.addBond(a1, a2, order=_BONDORDERS_RDKIT_TO_OPENMM.get(bond.GetBondType(), None))

return top

Expand All @@ -405,9 +395,7 @@ def to_openmm_positions(self) -> omm_unit.Quantity:

return openmm_pos

def to_pdb_file(
self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]
) -> str:
def to_pdb_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]) -> str:
"""
serialize protein to pdb file.

Expand Down Expand Up @@ -449,9 +437,7 @@ def to_pdb_file(

return out_path

def to_pdbx_file(
self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]
) -> str:
def to_pdbx_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]) -> str:
"""
serialize protein to pdbx file.

Expand Down Expand Up @@ -529,8 +515,7 @@ def _to_dict(self) -> dict:
]

conformers = [
serialize_numpy(conf.GetPositions()) # .m_as(unit.angstrom)
for conf in self._rdkit.GetConformers()
serialize_numpy(conf.GetPositions()) for conf in self._rdkit.GetConformers() # .m_as(unit.angstrom)
]

# Result
Expand Down
20 changes: 20 additions & 0 deletions gufe/components/smallmoleculecomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@
}
_BONDSTEREO_TO_INT = {v: k for k, v in _INT_TO_BONDSTEREO.items()}

# following the numbering in rdkit
_INT_TO_HYBRIDIZATION = {
0: Chem.rdchem.HybridizationType.UNSPECIFIED,
1: Chem.rdchem.HybridizationType.S,
2: Chem.rdchem.HybridizationType.SP,
3: Chem.rdchem.HybridizationType.SP2,
4: Chem.rdchem.HybridizationType.SP3,
5: Chem.rdchem.HybridizationType.SP2D,
6: Chem.rdchem.HybridizationType.SP3D,
7: Chem.rdchem.HybridizationType.SP3D2,
8: Chem.rdchem.HybridizationType.OTHER,
}
_HYBRIDIZATION_TO_INT = {v: k for k, v in _INT_TO_HYBRIDIZATION.items()}


def _setprops(obj, d: dict) -> None:
# add props onto rdkit "obj" (atom/bond/mol/conformer)
Expand Down Expand Up @@ -223,6 +237,7 @@ def _to_dict(self) -> dict:
_ATOMCHIRAL_TO_INT[atom.GetChiralTag()],
atom.GetAtomMapNum(),
atom.GetPropsAsDict(includePrivate=False),
_HYBRIDIZATION_TO_INT[atom.GetHybridization()],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in this case if this isn't set?

I.e. if you read and old file (which now skips the hybridization reading), then write it out again - what does this pick up?

)
)
output["atoms"] = atoms
Expand Down Expand Up @@ -264,6 +279,11 @@ def _from_dict(cls, d: dict):
a.SetChiralTag(_INT_TO_ATOMCHIRAL[atom[4]])
a.SetAtomMapNum(atom[5])
_setprops(a, atom[6])
try:
a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]])
except IndexError:
pass

em.AddAtom(a)

for bond in d["bonds"]:
Expand Down
13 changes: 3 additions & 10 deletions gufe/components/solventcomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,10 @@ def __init__(

self._neutralize = neutralize

if not isinstance(
ion_concentration, unit.Quantity
) or not ion_concentration.is_compatible_with(unit.molar):
raise ValueError(
f"ion_concentration must be given in units of"
f" concentration, got: {ion_concentration}"
)
if not isinstance(ion_concentration, unit.Quantity) or not ion_concentration.is_compatible_with(unit.molar):
raise ValueError(f"ion_concentration must be given in units of" f" concentration, got: {ion_concentration}")
if ion_concentration.m < 0:
raise ValueError(
f"ion_concentration must be positive, " f"got: {ion_concentration}"
)
raise ValueError(f"ion_concentration must be positive, " f"got: {ion_concentration}")

self._ion_concentration = ion_concentration

Expand Down
7 changes: 2 additions & 5 deletions gufe/custom_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def is_openff_quantity_dict(dct):
"shape": list(obj.shape),
"bytes": obj.tobytes(),
},
from_dict=lambda dct: np.frombuffer(
dct["bytes"], dtype=np.dtype(dct["dtype"])
).reshape(dct["shape"]),
from_dict=lambda dct: np.frombuffer(dct["bytes"], dtype=np.dtype(dct["dtype"])).reshape(dct["shape"]),
)


Expand All @@ -118,8 +116,7 @@ def is_openff_quantity_dict(dct):
":is_custom:": True,
"pint_unit_registry": "openff_units",
},
from_dict=lambda dct: dct["magnitude"]
* DEFAULT_UNIT_REGISTRY.Quantity(dct["unit"]),
from_dict=lambda dct: dct["magnitude"] * DEFAULT_UNIT_REGISTRY.Quantity(dct["unit"]),
is_my_obj=lambda obj: isinstance(obj, DEFAULT_UNIT_REGISTRY.Quantity),
is_my_dict=is_openff_quantity_dict,
)
Expand Down
27 changes: 7 additions & 20 deletions gufe/ligandnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def __init__(
nodes = []

self._edges = frozenset(edges)
edge_nodes = set(
chain.from_iterable((e.componentA, e.componentB) for e in edges)
)
edge_nodes = set(chain.from_iterable((e.componentA, e.componentB) for e in edges))
self._nodes = frozenset(edge_nodes) | frozenset(nodes)
self._graph = None

Expand Down Expand Up @@ -70,9 +68,7 @@ def graph(self) -> nx.MultiDiGraph:
for node in sorted(self._nodes):
graph.add_node(node)
for edge in sorted(self._edges):
graph.add_edge(
edge.componentA, edge.componentB, object=edge, **edge.annotations
)
graph.add_edge(edge.componentA, edge.componentB, object=edge, **edge.annotations)

self._graph = nx.freeze(graph)

Expand Down Expand Up @@ -116,14 +112,10 @@ def _serializable_graph(self) -> nx.Graph:
# from here, we just build the graph
serializable_graph = nx.MultiDiGraph()
for mol, label in mol_to_label.items():
serializable_graph.add_node(
label, moldict=json.dumps(mol.to_dict(), sort_keys=True)
)
serializable_graph.add_node(label, moldict=json.dumps(mol.to_dict(), sort_keys=True))

for molA, molB, mapping, annotation in edge_data:
serializable_graph.add_edge(
molA, molB, mapping=mapping, annotations=annotation
)
serializable_graph.add_edge(molA, molB, mapping=mapping, annotations=annotation)

return serializable_graph

Expand All @@ -134,8 +126,7 @@ def _from_serializable_graph(cls, graph: nx.Graph):
This is the inverse of ``_serializable_graph``.
"""
label_to_mol = {
node: SmallMoleculeComponent.from_dict(json.loads(d))
for node, d in graph.nodes(data="moldict")
node: SmallMoleculeComponent.from_dict(json.loads(d)) for node, d in graph.nodes(data="moldict")
}

edges = [
Expand Down Expand Up @@ -242,9 +233,7 @@ def sys_from_dict(component):
"""
syscomps = {alchemical_label: component}
other_labels = set(labels) - {alchemical_label}
syscomps.update(
{label: components[label] for label in other_labels}
)
syscomps.update({label: components[label] for label in other_labels})

if autoname:
name = f"{component.name}_{leg_name}"
Expand All @@ -261,9 +250,7 @@ def sys_from_dict(component):
else:
name = ""

transformation = gufe.Transformation(
sysA, sysB, protocol, mapping=edge, name=name
)
transformation = gufe.Transformation(sysA, sysB, protocol, mapping=edge, name=name)

transformations.append(transformation)

Expand Down
4 changes: 1 addition & 3 deletions gufe/mapping/atom_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ class AtomMapper(GufeTokenizable):
"""

@abc.abstractmethod
def suggest_mappings(
self, A: gufe.Component, B: gufe.Component
) -> Iterator[AtomMapping]:
def suggest_mappings(self, A: gufe.Component, B: gufe.Component) -> Iterator[AtomMapping]:
"""Suggests possible mappings between two Components

Suggests zero or more :class:`.AtomMapping` objects, which are possible
Expand Down
24 changes: 5 additions & 19 deletions gufe/mapping/ligandatommapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@ def __init__(
nB = self.componentB.to_rdkit().GetNumAtoms()
for i, j in componentA_to_componentB.items():
if not (0 <= i < nA):
raise ValueError(
f"Got invalid index for ComponentA ({i}); " f"must be 0 <= n < {nA}"
)
raise ValueError(f"Got invalid index for ComponentA ({i}); " f"must be 0 <= n < {nA}")
if not (0 <= j < nB):
raise ValueError(
f"Got invalid index for ComponentB ({i}); " f"must be 0 <= n < {nB}"
)
raise ValueError(f"Got invalid index for ComponentB ({i}); " f"must be 0 <= n < {nB}")

self._compA_to_compB = componentA_to_componentB

Expand All @@ -89,19 +85,11 @@ def componentB_to_componentA(self) -> dict[int, int]:

@property
def componentA_unique(self):
return (
i
for i in range(self.componentA.to_rdkit().GetNumAtoms())
if i not in self._compA_to_compB
)
return (i for i in range(self.componentA.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB)

@property
def componentB_unique(self):
return (
i
for i in range(self.componentB.to_rdkit().GetNumAtoms())
if i not in self._compA_to_compB.values()
)
return (i for i in range(self.componentB.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB.values())

@property
def annotations(self):
Expand All @@ -118,9 +106,7 @@ def _to_dict(self):
"componentA": self.componentA,
"componentB": self.componentB,
"componentA_to_componentB": self._compA_to_compB,
"annotations": json.dumps(
self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder
),
"annotations": json.dumps(self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder),
}

@classmethod
Expand Down
Loading
Loading