Skip to content

Commit

Permalink
Merge pull request #76 from marjanAlbouye/master
Browse files Browse the repository at this point in the history
Add function to identify connections in a snapshot
  • Loading branch information
marjanalbooyeh authored Jan 22, 2024
2 parents 13c6fdd + 1929fd5 commit 3953734
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 0 deletions.
248 changes: 248 additions & 0 deletions cmeutils/gsd_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings
from tempfile import NamedTemporaryFile

import freud
import gsd.hoomd
import hoomd
import networkx as nx
import numpy as np
from boltons.setutils import IndexedSet

from cmeutils.geometry import moit

Expand Down Expand Up @@ -374,3 +377,248 @@ def xml_to_gsd(xmlfile, gsdfile):
snap.bonds.group = bonds
newt.append(snap)
print(f"XML data written to {gsdfile}")


def identify_snapshot_connections(snapshot):
"""Identify angle and dihedral connections in a snapshot from bonds.
Parameters
----------
snapshot : gsd.hoomd.Frame
The snapshot to read in.
Returns
-------
gsd.hoomd.Frame
The snapshot with angle and dihedral information added.
"""
if snapshot.bonds.N == 0:
warnings.warn(
"No bonds found in snapshot, hence, no angles or "
"dihedrals will be identified."
)
return snapshot
bond_groups = snapshot.bonds.group
connection_matches = _find_connections(bond_groups)

if connection_matches["angles"]:
_fill_connection_info(
snapshot=snapshot,
connections=connection_matches["angles"],
type_="angles",
)
if connection_matches["dihedrals"]:
_fill_connection_info(
snapshot=snapshot,
connections=connection_matches["dihedrals"],
type_="dihedrals",
)
return snapshot


def _fill_connection_info(snapshot, connections, type_):
p_types = snapshot.particles.types
p_typeid = snapshot.particles.typeid
_connection_types = []
_connection_typeid = []
for conn in connections:
conn_sites = [p_types[p_typeid[i]] for i in conn]
sorted_conn_sites = _sort_connection_by_name(conn_sites, type_)
type = "-".join(sorted_conn_sites)
# check if type not in angle_types and types_inv not in angle_types:
if type not in _connection_types:
_connection_types.append(type)
_connection_typeid.append(
max(_connection_typeid) + 1 if _connection_typeid else 0
)
else:
_connection_typeid.append(_connection_types.index(type))

if type_ == "angles":
snapshot.angles.N = len(connections)
snapshot.angles.M = 3
snapshot.angles.group = connections
snapshot.angles.types = _connection_types
snapshot.angles.typeid = _connection_typeid
elif type_ == "dihedrals":
snapshot.dihedrals.N = len(connections)
snapshot.dihedrals.M = 4
snapshot.dihedrals.group = connections
snapshot.dihedrals.types = _connection_types
snapshot.dihedrals.typeid = _connection_typeid


# The following functions are obtained from gmso/utils/connectivity.py with
# minor modifications.
def _sort_connection_by_name(conn_sites, type_):
if type_ == "angles":
site1, site3 = sorted([conn_sites[0], conn_sites[2]])
return [site1, conn_sites[1], site3]
elif type_ == "dihedrals":
site1, site2, site3, site4 = conn_sites
if site2 > site3 or (site2 == site3 and site1 > site4):
return [site4, site3, site2, site1]
else:
return [site1, site2, site3, site4]


def _find_connections(bonds):
"""Identify all possible connections within a topology."""
compound = nx.Graph()

for b in bonds:
compound.add_edge(b[0], b[1])

compound_line_graph = nx.line_graph(compound)

angle_matches = _detect_connections(compound_line_graph, type_="angle")
dihedral_matches = _detect_connections(
compound_line_graph, type_="dihedral"
)

return {
"angles": angle_matches,
"dihedrals": dihedral_matches,
}


def _detect_connections(compound_line_graph, type_="angle"):
EDGES = {
"angle": ((0, 1),),
"dihedral": ((0, 1), (1, 2)),
}

connection = nx.Graph()
for edge in EDGES[type_]:
assert len(edge) == 2, "Edges should be of length 2"
connection.add_edge(edge[0], edge[1])

matcher = nx.algorithms.isomorphism.GraphMatcher(
compound_line_graph, connection
)

formatter_fns = {
"angle": _format_subgraph_angle,
"dihedral": _format_subgraph_dihedral,
}

conn_matches = IndexedSet()
for m in matcher.subgraph_isomorphisms_iter():
new_connection = formatter_fns[type_](m)
conn_matches.add(new_connection)
if conn_matches:
conn_matches = _trim_duplicates(conn_matches)

# Do more sorting of individual connection
sorted_conn_matches = list()
for match in conn_matches:
if match[0] < match[-1]:
sorted_conn = match
else:
sorted_conn = match[::-1]
sorted_conn_matches.append(list(sorted_conn))

# Final sorting the whole list
if type_ == "angle":
return sorted(
sorted_conn_matches,
key=lambda angle: (
angle[1],
angle[0],
angle[2],
),
)
elif type_ == "dihedral":
return sorted(
sorted_conn_matches,
key=lambda dihedral: (
dihedral[1],
dihedral[2],
dihedral[0],
dihedral[3],
),
)


def _get_sorted_by_n_connections(m):
"""Return sorted by n connections for the matching graph."""
small = nx.Graph()
for k, v in m.items():
small.add_edge(k[0], k[1])
return sorted(small.adj, key=lambda x: len(small[x])), small


def _format_subgraph_angle(m):
"""Format the angle subgraph.
Since we are matching compound line graphs,
back out the actual nodes, not just the edges
Parameters
----------
m : dict
keys are the compound line graph nodes
Values are the sub-graph matches (to the angle, dihedral, or improper)
Returns
-------
connection : list of nodes, in order of bonding
(start, middle, end)
"""
(sort_by_n_connections, _) = _get_sorted_by_n_connections(m)
ends = sorted([sort_by_n_connections[0], sort_by_n_connections[1]])
middle = sort_by_n_connections[2]
return (
ends[0],
middle,
ends[1],
)


def _format_subgraph_dihedral(m):
"""Format the dihedral subgraph.
Since we are matching compound line graphs,
back out the actual nodes, not just the edges
Parameters
----------
m : dict
keys are the compound line graph nodes
Values are the sub-graph matches (to the angle, dihedral, or improper)
top : gmso.Topology
The original Topology
Returns
-------
connection : list of nodes, in order of bonding
(start, mid1, mid2, end)
"""
(sort_by_n_connections, small) = _get_sorted_by_n_connections(m)
start = sort_by_n_connections[0]
if sort_by_n_connections[2] in small.neighbors(start):
mid1 = sort_by_n_connections[2]
mid2 = sort_by_n_connections[3]
else:
mid1 = sort_by_n_connections[3]
mid2 = sort_by_n_connections[2]

end = sort_by_n_connections[1]
return (start, mid1, mid2, end)


def _trim_duplicates(all_matches):
"""Remove redundant sub-graph matches.
Is there a better way to do this? Like when we format the subgraphs,
can we impose an ordering so it's easier to eliminate redundant matches?
"""
trimmed_list = IndexedSet()
for match in all_matches:
if (
match
and match not in trimmed_list
and match[::-1] not in trimmed_list
):
trimmed_list.add(match)
return trimmed_list
Binary file added cmeutils/tests/assets/pekk-cg.gsd
Binary file not shown.
4 changes: 4 additions & 0 deletions cmeutils/tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def p3ht_gsd(self):
def p3ht_cg_gsd(self):
return path.join(asset_dir, "p3ht-cg.gsd")

@pytest.fixture
def pekk_cg_gsd(self):
return path.join(asset_dir, "pekk-cg.gsd")

@pytest.fixture
def mapping(self):
return np.loadtxt(path.join(asset_dir, "mapping.txt"), dtype=int)
Expand Down
Loading

0 comments on commit 3953734

Please sign in to comment.