Skip to content

Commit

Permalink
Fixed the inconsistency between numpy and openPMD and added Exception…
Browse files Browse the repository at this point in the history
… for trying to use OpenPMD with the wrong number of snapshots
  • Loading branch information
RandomDefaultUser committed Nov 21, 2024
1 parent a99d5a6 commit 4697216
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions mala/datahandling/data_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def __shuffle_numpy(
# if the number of new snapshots is not a divisor of the grid size
# then we have to trim the original snapshots to size
# the indicies to be removed are selected at random
if self.data_points_to_remove is not None:
if (
self.data_points_to_remove is not None
and np.sum(self.data_points_to_remove) > 0
):
if self.parameters.shuffling_seed is not None:
np.random.seed(idx * self.parameters.shuffling_seed)
ngrid = (
Expand Down Expand Up @@ -548,27 +551,44 @@ def shuffle_snapshots(
self.data_points_to_remove = None
if number_of_shuffled_snapshots is None:
number_of_shuffled_snapshots = self.nr_snapshots
number_of_new_snapshots = number_of_shuffled_snapshots

shuffled_gridsizes = snapshot_size_list // number_of_new_snapshots
# Currently, the openPMD interface is not feature-complete.
if np.any(
np.array(
[
snapshot.grid_dimension[0] % number_of_shuffled_snapshots
for snapshot in self.parameters.snapshot_directories_list
]
)
!= 0
):
raise ValueError(
"Shuffling from OpenPMD files currently only "
"supported if first dimension of all snapshots "
"can evenly be divided by number of snapshots. "
"Please select a different number of shuffled "
"snapshots or use the numpy interface. "
)

shuffled_gridsizes = snapshot_size_list // number_of_shuffled_snapshots

if np.any(
np.array(snapshot_size_list)
- (
(np.array(snapshot_size_list) // number_of_new_snapshots)
* number_of_new_snapshots
(np.array(snapshot_size_list) // number_of_shuffled_snapshots)
* number_of_shuffled_snapshots
)
> 0
):
number_of_data_points = int(
np.sum(shuffled_gridsizes) * number_of_new_snapshots
np.sum(shuffled_gridsizes) * number_of_shuffled_snapshots
)

self.data_points_to_remove = []
for i in range(0, self.nr_snapshots):
self.data_points_to_remove.append(
snapshot_size_list[i]
- shuffled_gridsizes[i] * number_of_new_snapshots
- shuffled_gridsizes[i] * number_of_shuffled_snapshots
)
tot_points_missing = sum(self.data_points_to_remove)

Expand All @@ -581,22 +601,22 @@ def shuffle_snapshots(
)

shuffle_dimensions = [
int(number_of_data_points / number_of_new_snapshots),
int(number_of_data_points / number_of_shuffled_snapshots),
1,
1,
]

printout(
"Data shuffler will generate",
number_of_new_snapshots,
number_of_shuffled_snapshots,
"new snapshots.",
)
printout("Shuffled snapshot dimension will be ", shuffle_dimensions)

# Prepare permutations.
permutations = []
seeds = []
for i in range(0, number_of_new_snapshots):
for i in range(0, number_of_shuffled_snapshots):
# This makes the shuffling deterministic, if specified by the user.
if self.parameters.shuffling_seed is not None:
np.random.seed(i * self.parameters.shuffling_seed)
Expand All @@ -606,7 +626,7 @@ def shuffle_snapshots(

if snapshot_type == "numpy":
self.__shuffle_numpy(
number_of_new_snapshots,
number_of_shuffled_snapshots,
shuffle_dimensions,
descriptor_save_path,
save_name,
Expand All @@ -625,7 +645,7 @@ def shuffle_snapshots(
)
self.__shuffle_openpmd(
descriptor,
number_of_new_snapshots,
number_of_shuffled_snapshots,
shuffle_dimensions,
save_name,
permutations,
Expand All @@ -641,7 +661,7 @@ def shuffle_snapshots(
)
self.__shuffle_openpmd(
target,
number_of_new_snapshots,
number_of_shuffled_snapshots,
shuffle_dimensions,
save_name,
permutations,
Expand Down

0 comments on commit 4697216

Please sign in to comment.