Skip to content

Commit

Permalink
Updated plots
Browse files Browse the repository at this point in the history
  • Loading branch information
Hjorthmedh committed Nov 29, 2024
1 parent 7d1922c commit f4bebcf
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions snudda/plotting/plot_spike_raster_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,15 @@ def calculate_period_synchrony(self, period, neuron_id=None, time_range=None):
return vs

def plot_spike_histogram_type(self, neuron_type, time_range=None, bin_size=50e-3, fig_size=None,
fig_file=None, label_text=None, show_figure=True, n_core=None, linestyle="-",
legend_loc="best", ax=None):
fig_file=None, label_text=None, show_figure=True, n_core=None,
linestyle="-", line_colours=None, linewidth=3,
legend_loc="best", bbox_anchor=None, ax=None):

self.make_figures_directory()

plt.rcParams.update({'font.size': 24,
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.loc': legend_loc})
'ytick.labelsize': 20})

assert type(neuron_type) == list, "neuron_type should be a list of neuron types"

Expand Down Expand Up @@ -337,12 +337,21 @@ def plot_spike_histogram_type(self, neuron_type, time_range=None, bin_size=50e-3
fig = plt.figure(figsize=fig_size)
ax = fig.add_subplot()

ax.hist(x=all_spikes.values(), bins=bins, weights=weights, linewidth=3, linestyle=linestyle,
histtype="step", color=[self.get_colours(x) for x in all_spikes.keys()],
label=[f"{label_text}{x}" for x in all_spikes.keys()])
if len(all_spikes.keys()) > 1:
all_labels = [f"{label_text}{x}" for x in all_spikes.keys()]
else:
all_labels = [label_text]

if line_colours is None:
line_colours = [self.get_colours(x) for x in all_spikes.keys()]

ax.hist(x=all_spikes.values(), bins=bins, weights=weights, linewidth=linewidth, linestyle=linestyle,
histtype="step", color=line_colours,
label=all_labels)

plt.xlabel("Time (s)", fontsize=20)
plt.ylabel("Frequency (Hz)", fontsize=20)
ax.legend()
ax.legend(loc=legend_loc, bbox_to_anchor=bbox_anchor)

if fig_file:
fig_path = os.path.join(self.figure_path, fig_file)
Expand All @@ -360,7 +369,7 @@ def plot_spike_histogram_type(self, neuron_type, time_range=None, bin_size=50e-3
def plot_spike_histogram(self, population_id=None, neuron_type=None,
skip_time=0, end_time=None, fig_size=None, bin_size=50e-3,
fig_file=None, ax=None, label_text=None, show_figure=True, save_figure=True, colour=None,
linestyle="-", legend_loc="best", title=None):
linestyle="-", legend_loc="best", title=None, bbox_anchor=None):

if population_id is None:
population_id = self.snudda_load.get_neuron_population_units(return_set=True)
Expand All @@ -382,8 +391,8 @@ def plot_spike_histogram(self, population_id=None, neuron_type=None,

plt.rcParams.update({'font.size': 24,
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.loc': legend_loc})
'ytick.labelsize': 20},
)

if ax is None:
fig = plt.figure(figsize=fig_size)
Expand Down Expand Up @@ -424,7 +433,7 @@ def plot_spike_histogram(self, population_id=None, neuron_type=None,

if label_text is None:
label_text = ""

N, bins, patches = ax.hist(x=pop_spikes.values(), bins=bins, weights=weights, linewidth=3, linestyle=linestyle,
histtype="step", color=colour,
label=[f"{label_text}{x}" for x in pop_spikes.keys()])
Expand All @@ -435,14 +444,14 @@ def plot_spike_histogram(self, population_id=None, neuron_type=None,

plt.xlabel("Time (s)", fontsize=20)
plt.ylabel("Frequency (Hz)", fontsize=20)
ax.legend()
ax.legend(loc=legend_loc, bbox_to_anchor=bbox_anchor)

if title:
plt.title(title)

if fig_file is None:
fig_file = os.path.join(self.figure_path,
f"spike-frequency-pop-units{'-'.join([f'{x}' for x in pop_members.keys()])}.pdf")
f"spike-frequency-pop-units{'-'.join([f'{x}' for x in pop_members.keys()])}.png")
else:
fig_file = os.path.join(self.figure_path, fig_file)

Expand Down Expand Up @@ -654,6 +663,10 @@ def plot_firing_frequency_distribution(self, time_range=None, figure_name=None,
/ (time_range[1] - time_range[0])
for s in spikes.values()]

if not isinstance(bins, int):
if np.max(freq) > max(bins):
raise ValueError(f"Max frequency {np.max(freq)} larger than bin range specified {max(bins)}.")

colour = SnuddaPlotSpikeRaster2.get_colours(nt)
count, bin = np.histogram(freq, bins=bins)
plt.stairs(count, bin, label=nt, color=colour, linewidth=3)
Expand All @@ -665,14 +678,16 @@ def plot_firing_frequency_distribution(self, time_range=None, figure_name=None,
if title is not None:
plt.title(title)

plt.ion()
plt.show()
plt.tight_layout()

if figure_name is not None:
self.make_figures_directory()

plt.savefig(os.path.join(self.figure_path, figure_name))

plt.ion()
plt.show()

def plot_population_frequency(self, population_id, time_ranges=None):

raise NotImplementedError()
Expand Down

0 comments on commit f4bebcf

Please sign in to comment.