-
Notifications
You must be signed in to change notification settings - Fork 897
/
datasets.py
117 lines (88 loc) · 3.44 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import tarfile
import mlx.core as mx
import numpy as np
import requests
import scipy.sparse as sparse
"""
Preprocessing follows the same implementation as in:
https://github.com/tkipf/gcn
https://github.com/senadkurtisi/pytorch-GCN/tree/main
"""
def download_cora():
"""Downloads the cora dataset into a local cora folder."""
url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz"
extract_to = "."
if os.path.exists(os.path.join(extract_to, "cora")):
return
response = requests.get(url, stream=True)
if response.status_code == 200:
file_path = os.path.join(extract_to, url.split("/")[-1])
# Write the file to local disk
with open(file_path, "wb") as file:
file.write(response.raw.read())
# Extract the .tgz file
with tarfile.open(file_path, "r:gz") as tar:
tar.extractall(path=extract_to)
print(f"Cora dataset extracted to {extract_to}")
os.remove(file_path)
def train_val_test_mask():
"""Splits the loaded dataset into train/validation/test sets."""
train_set = mx.arange(140)
validation_set = mx.arange(200, 500)
test_set = mx.arange(500, 1500)
return train_set, validation_set, test_set
def enumerate_labels(labels):
"""Converts the labels from the original
string form to the integer [0:MaxLabels-1]
"""
label_map = {v: e for e, v in enumerate(set(labels))}
labels = np.array([label_map[label] for label in labels])
return labels
def normalize_adjacency(adj):
"""Normalizes the adjacency matrix according to the
paper by Kipf et al.
https://arxiv.org/abs/1609.02907
"""
adj = adj + sparse.eye(adj.shape[0])
node_degrees = np.array(adj.sum(1))
node_degrees = np.power(node_degrees, -0.5).flatten()
node_degrees[np.isinf(node_degrees)] = 0.0
node_degrees[np.isnan(node_degrees)] = 0.0
degree_matrix = sparse.diags(node_degrees, dtype=np.float32)
adj = degree_matrix @ adj @ degree_matrix
return adj
def load_data(config):
"""Loads the Cora graph data into MLX array format."""
print("Loading Cora dataset...")
# Download dataset files
download_cora()
# Graph nodes
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
raw_node_ids = raw_nodes_data[:, 0].astype(
"int32"
) # unique identifier of each node
raw_node_labels = raw_nodes_data[:, -1]
labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers
node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32")
# Edges
ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)}
raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32")
edges_ordered = np.array(
list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32"
).reshape(raw_edges_data.shape)
# Adjacency matrix
adj = sparse.coo_matrix(
(np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])),
shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]),
dtype=np.float32,
)
# Make the adjacency matrix symmetric
adj = adj + adj.T.multiply(adj.T > adj)
adj = normalize_adjacency(adj)
# Convert to mlx array
features = mx.array(node_features.toarray(), mx.float32)
labels = mx.array(labels_enumerated, mx.int32)
adj = mx.array(adj.toarray())
print("Dataset loaded.")
return features, labels, adj