diff --git a/spectrae/spectra.py b/spectrae/spectra.py index 45bbae6..d19a6d2 100644 --- a/spectrae/spectra.py +++ b/spectrae/spectra.py @@ -38,12 +38,12 @@ def cross_split_overlap(self, train, test): """ pass - def construct_spectra_graph(self, force_reconstruct = False): + def construct_spectra_graph(self, save_path, force_reconstruct = False): if self.SPG is not None: return self.SPG - elif os.path.exists(f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") and not force_reconstruct: + elif os.path.exists(f"{save_path}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") and not force_reconstruct: print("Loading spectral property graph") - self.SPG = nx.read_gexf(f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") + self.SPG = nx.read_gexf(f"{save_path}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") self.return_spectra_graph_stats() return self.SPG else: @@ -85,10 +85,10 @@ def construct_spectra_graph(self, force_reconstruct = False): if all_fully_connected: raise Exception("All SPG sub components are fully connected, cannot run SPECTRA, all samples are similar to each other") - if not os.path.exists(f"{self.dataset.name}_spectral_property_graphs"): - os.makedirs(f"{self.dataset.name}_spectral_property_graphs") + if not os.path.exists(f"{save_path}_spectral_property_graphs"): + os.makedirs(f"{save_path}_spectral_property_graphs") - nx.write_gexf( self.SPG, f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") + nx.write_gexf( self.SPG, f"{save_path}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") return self.SPG @@ -161,34 +161,36 @@ def generate_spectra_splits(self, number_repeats, random_seed, test_size = 0.2, - force_reconstruct = False): + force_reconstruct = False, + data_path = ""): #Random seed is a list of random seeds for each number name = self.dataset.name - self.construct_spectra_graph(force_reconstruct = force_reconstruct) + save_path = f"{data_path}{name}" + self.construct_spectra_graph(save_path = save_path, force_reconstruct = force_reconstruct) if self.binary: if nx.density(self.SPG) >= 0.4: raise Exception("Density of SPG is greater than 0.4, SPECTRA will not work as your dataset is too similar to itself. Please check your dataset and SPECTRA properties.") - if not os.path.exists(f"{name}_SPECTRA_splits"): - os.makedirs(f"{name}_SPECTRA_splits") - if not os.path.exists(f"{name}_spectral_property_graphs"): - os.makedirs(f"{name}_spectral_property_graphs") + if not os.path.exists(f"{save_path}_SPECTRA_splits"): + os.makedirs(f"{save_path}_SPECTRA_splits") + if not os.path.exists(f"{save_path}_spectral_property_graphs"): + os.makedirs(f"{save_path}_spectral_property_graphs") splits = [] for spectral_parameter in spectral_parameters: for i in range(number_repeats): - if os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}") and not force_reconstruct: + if os.path.exists(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}") and not force_reconstruct: print(f"Folder SP_{spectral_parameter}_{i} already exists. Skipping") - elif force_reconstruct or not os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}"): + elif force_reconstruct or not os.path.exists(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}"): train, test, stats = self.generate_spectra_split(float(spectral_parameter), random_seed[i], test_size) if train is not None: - if not os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}"): - os.makedirs(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}") + if not os.path.exists(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}"): + os.makedirs(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}") - pickle.dump(train, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/train.pkl", "wb")) - pickle.dump(test, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/test.pkl", "wb")) - pickle.dump(stats, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/stats.pkl", "wb")) + pickle.dump(train, open(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}/train.pkl", "wb")) + pickle.dump(test, open(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}/test.pkl", "wb")) + pickle.dump(stats, open(f"{save_path}_SPECTRA_splits/SP_{spectral_parameter}_{i}/stats.pkl", "wb")) else: print(f"Split for SP_{spectral_parameter}_{i} could not be generated since independent set only has one sample")