-
Notifications
You must be signed in to change notification settings - Fork 2
/
adjacency.py
151 lines (122 loc) · 5.86 KB
/
adjacency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# This file is part of connectome-manipulator.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024 Blue Brain Project/EPFL
"""Module for comparing connectomes based on adjacency matrices:
Structural comparison of two connectomes in terms of adjacency matrices for selected pathways
(including synapse counts per connection), as specified by the config. For each connectome,
the underlying adjacency/count matrices are computed by the :func:`compute` function and will be saved
to a data file first. The individual adjacency/count matrices, together with a difference map
between the two connectomes, are then plotted by means of the :func:`plot` function.
"""
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csc_matrix
from connectome_manipulator.access_functions import (
get_edges_population,
get_node_ids,
get_connections,
)
def compute(circuit, sel_src=None, sel_dest=None, edges_popul_name=None, **_):
"""Extracts adjacency and count matrices from a given circuit's connectome.
Args:
circuit (bluepysnap.Circuit): Input circuit
sel_src (str/list-like/dict): Source (pre-synaptic) neuron selection
sel_dest (str/list-like/dict): Target (post-synaptic) neuron selection
edges_popul_name (str): Name of SONATA egdes population to extract data from
Returns:
dict: Dictionary containing the extracted data elements; see Notes
Note:
The returned dictionary contains the following data elements that can be selected for plotting through the structural comparison configuration file, together with a common dictionary containing additional information. Each data element is a dictionary with "data" (scipy.sparse.csc_matrix), "name" (str), and "unit" (str) items.
* "adj": Adjacency matrix
* "adj_cnt": Synaptome matrix, containing the numbers of synapses per connection
"""
# Select edge population
edges = get_edges_population(circuit, edges_popul_name)
# Select corresponding source/target nodes populations
src_nodes = edges.source
tgt_nodes = edges.target
src_node_ids = get_node_ids(src_nodes, sel_src)
tgt_node_ids = get_node_ids(tgt_nodes, sel_dest)
assert (
len(src_node_ids) > 0 and len(tgt_node_ids) > 0
), "ERROR: Empty src/tgt node selection(s)!"
# Map source/target node ids to continuous range of indices for plotting
src_gid_min = min(src_nodes.ids())
src_gid_max = max(src_nodes.ids())
tgt_gid_min = min(tgt_nodes.ids())
tgt_gid_max = max(tgt_nodes.ids())
src_plot_ids = np.full(src_gid_max - src_gid_min + 1, -1).astype(int)
src_gid_offset = src_gid_min
src_plot_ids[src_node_ids - src_gid_offset] = np.arange(len(src_node_ids))
def src_gid_to_idx(gids):
return src_plot_ids[gids - src_gid_offset]
tgt_plot_ids = np.full(tgt_gid_max - tgt_gid_min + 1, -1).astype(int)
tgt_gid_offset = tgt_gid_min
tgt_plot_ids[tgt_node_ids - tgt_gid_offset] = np.arange(len(tgt_node_ids))
def tgt_gid_to_idx(gids):
return tgt_plot_ids[gids - tgt_gid_offset]
print(
f"INFO: Creating {len(src_node_ids)}x{len(tgt_node_ids)} adjacency matrix (sel_src={sel_src}, sel_dest={sel_dest})",
flush=True,
)
conns = get_connections(edges, src_node_ids, tgt_node_ids, with_nsyn=True)
if len(conns) == 0: # No connections, creating empty matrix
count_matrix = csc_matrix((len(src_node_ids), len(tgt_node_ids)), dtype=int)
else:
count_matrix = csc_matrix(
(conns[:, 2], (src_gid_to_idx(conns[:, 0]), tgt_gid_to_idx(conns[:, 1]))),
shape=(len(src_node_ids), len(tgt_node_ids)),
dtype=int,
)
adj_matrix = count_matrix > 0
return {
"adj": {"data": adj_matrix, "name": "Adjacency", "unit": None},
"adj_cnt": {"data": count_matrix, "name": "Adjacency count", "unit": "Synapse count"},
"common": {"src_gids": src_node_ids, "tgt_gids": tgt_node_ids},
}
def plot(
res_dict, _common_dict, fig_title=None, vmin=None, vmax=None, isdiff=False, **_
): # pragma:no cover
"""Plots an adjacency/count matrix or a difference matrix.
Args:
res_dict (dict): Results dictionary, containing selected data for plotting; must contain a "data" item with a sparse matrix of type scipy.sparse.csc_matrix, as well as "name" and "unit" items containing strings.
_common_dict (dict): Common dictionary, containing additional information - Not used
fig_title (str): Optional figure title
vmin (float): Minimum plot range
vmax (float): Maximum plot range
isdiff (bool): Flag indicating that ``res_dict`` contains a difference matrix; in this case, a symmetric plot range is required and a divergent colormap will be used
"""
if isdiff: # Difference plot
assert -1 * vmin == vmax, "ERROR: Symmetric plot range required!"
cmap = "PiYG" # Symmetric (diverging) colormap
else: # Regular plot
assert vmin == 0, "ERROR: Plot range including 0 required!"
cmap = "hot_r" # Regular colormap [color at 0 should be white (not actually drawn), to match figure background!]
mat = res_dict["data"].tocoo() # Convert to COO, for easy access to row/col and data!!
col_idx = mat.data
plt.scatter(
mat.col,
mat.row,
marker=",",
s=0.1,
edgecolors="none",
alpha=0.5,
c=col_idx,
cmap=cmap,
vmin=vmin,
vmax=vmax,
)
if not res_dict["data"].dtype == bool:
cb = plt.colorbar()
cb.set_label(res_dict["unit"])
if fig_title is None:
plt.title(res_dict["name"])
else:
plt.title(fig_title)
plt.xlabel("Postsynaptic neurons")
plt.ylabel("Presynaptic neurons")
plt.axis("image")
plt.xlim((-0.5, res_dict["data"].shape[1] - 0.5))
plt.ylim((-0.5, res_dict["data"].shape[0] - 0.5))
plt.gca().invert_yaxis()