Skip to content

Commit

Permalink
fix charge array in np.concatenate call
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjonesBSU committed Oct 17, 2024
1 parent e9de935 commit 75289bb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
10 changes: 4 additions & 6 deletions gmso/external/convert_hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,14 @@ def _parse_particle_information(
for site in top.sites
]
masses = np.zeros(top.n_sites)
# charges = np.zeros(top.n_sites)
charges = list()
charges = np.zeros(top.n_sites)
for idx, site in enumerate(top.sites):
masses[idx] = (
site.mass.to_value(base_units["mass"])
if site.mass
else 1 * base_units["mass"]
)
# charges[idx] = site.charge if site.charge else 0 * u.elementary_charge
charges.append(site.charge if site.charge else 0 * u.elementary_charge)
charges[idx] = site.charge if site.charge else 0 * u.elementary_charge

unique_types = sorted(list(set(types)))
typeids = np.array([unique_types.index(t) for t in types])
Expand All @@ -313,7 +311,7 @@ def _parse_particle_information(
typeids = np.concatenate((np.array([0] * n_rigid), typeids + 1))
# Update mass list and position list of Frame
for idx, _id in enumerate(rigid_ids_set):
group_indices = np.where(rigid_ids == _id)[0]
group_indices = np.where(np.array(rigid_ids) == _id)[0]
group_positions = xyz[group_indices]
group_masses = masses[group_indices]
com_xyz = np.sum(group_positions.T * group_masses, axis=1) / sum(
Expand All @@ -324,7 +322,7 @@ def _parse_particle_information(
# Append rigid center mass and xyz to front
masses = np.concatenate((rigid_masses, masses))
xyz = np.concatenate((rigid_xyz, xyz))
charges = np.concatenate((np.zeros(n_rigid), np.array(charges)))
charges = np.concatenate((np.zeros(n_rigid), charges))
rigid_id_tags = np.concatenate((np.arange(n_rigid), np.array(rigid_ids)))
else:
n_rigid = 0
Expand Down
2 changes: 1 addition & 1 deletion gmso/tests/test_hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_rigid_bodies(self):
assert np.array_equal(
snapshot.particles.body[10:], np.array([1] * ethane.n_particles)
)
assert snapshot.particles.mass[0] == ethane.mass
assert np.allclose(snapshot.particles.mass[0], ethane.mass, atol=1e-2)

@pytest.mark.skipif(
int(hoomd_version[0]) < 4, reason="Unsupported features in HOOMD 3"
Expand Down

0 comments on commit 75289bb

Please sign in to comment.