Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add optional data_path parameter to generate_spectra_splits #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions spectrae/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def cross_split_overlap(self, train, test):
"""
pass

def construct_spectra_graph(self, force_reconstruct = False):
def construct_spectra_graph(self, save_path, force_reconstruct = False):
if self.SPG is not None:
return self.SPG
elif os.path.exists(f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") and not force_reconstruct:
elif os.path.exists(f"{save_path}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") and not force_reconstruct:
print("Loading spectral property graph")
self.SPG = nx.read_gexf(f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf")
self.SPG = nx.read_gexf(f"{save_path}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf")
self.return_spectra_graph_stats()
return self.SPG
else:
Expand Down Expand Up @@ -85,10 +85,10 @@ def construct_spectra_graph(self, force_reconstruct = False):
if all_fully_connected:
raise Exception("All SPG sub components are fully connected, cannot run SPECTRA, all samples are similar to each other")

if not os.path.exists(f"{self.dataset.name}_spectral_property_graphs"):
os.makedirs(f"{self.dataset.name}_spectral_property_graphs")
if not os.path.exists(f"{save_path}_spectral_property_graphs"):
os.makedirs(f"{save_path}_spectral_property_graphs")

nx.write_gexf( self.SPG, f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf")
nx.write_gexf( self.SPG, f"{save_path}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf")

return self.SPG

Expand Down Expand Up @@ -161,34 +161,36 @@ def generate_spectra_splits(self,
number_repeats,
random_seed,
test_size = 0.2,
force_reconstruct = False):
force_reconstruct = False,
data_path = ""):

#Random seed is a list of random seeds for each number
name = self.dataset.name
self.construct_spectra_graph(force_reconstruct = force_reconstruct)
save_path = f"{data_path}{name}"
self.construct_spectra_graph(save_path = save_path, force_reconstruct = force_reconstruct)
if self.binary:
if nx.density(self.SPG) >= 0.4:
raise Exception("Density of SPG is greater than 0.4, SPECTRA will not work as your dataset is too similar to itself. Please check your dataset and SPECTRA properties.")

if not os.path.exists(f"{name}_SPECTRA_splits"):
os.makedirs(f"{name}_SPECTRA_splits")
if not os.path.exists(f"{name}_spectral_property_graphs"):
os.makedirs(f"{name}_spectral_property_graphs")
if not os.path.exists(f"{save_path}_SPECTRA_splits"):
os.makedirs(f"{save_path}_SPECTRA_splits")
if not os.path.exists(f"{save_path}_spectral_property_graphs"):
os.makedirs(f"{save_path}_spectral_property_graphs")

splits = []
for spectral_parameter in spectral_parameters:
for i in range(number_repeats):
if os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}") and not force_reconstruct:
if os.path.exists(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}") and not force_reconstruct:
print(f"Folder SP_{spectral_parameter}_{i} already exists. Skipping")
elif force_reconstruct or not os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}"):
elif force_reconstruct or not os.path.exists(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}"):
train, test, stats = self.generate_spectra_split(float(spectral_parameter), random_seed[i], test_size)
if train is not None:
if not os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}"):
os.makedirs(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}")
if not os.path.exists(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}"):
os.makedirs(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}")

pickle.dump(train, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/train.pkl", "wb"))
pickle.dump(test, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/test.pkl", "wb"))
pickle.dump(stats, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/stats.pkl", "wb"))
pickle.dump(train, open(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}/train.pkl", "wb"))
pickle.dump(test, open(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}/test.pkl", "wb"))
pickle.dump(stats, open(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}/stats.pkl", "wb"))
else:
print(f"Split for SP_{spectral_parameter}_{i} could not be generated since independent set only has one sample")

Expand Down