forked from trekhleb/micrograd-ts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.ts
130 lines (115 loc) · 3.24 KB
/
engine.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
type ValueParams = { op?: string; label?: string; prev?: Value[] }
// Stores a single scalar value and its gradient.
export class Value {
data: number
grad: number = 0
label: string
prev: Value[]
op: string | null
constructor(data: number, params: ValueParams = {}) {
this.data = data
this.op = params?.op ?? null
this.label = params?.label ?? ''
this.prev = params?.prev ?? []
}
private toVal(v: Value | number): Value {
return typeof v === 'number' ? new Value(v) : v
}
private backwardStep() {}
add(v: Value | number): Value {
const other = this.toVal(v)
const out = new Value(this.data + other.data, {
prev: [this, other],
op: '+',
})
out.backwardStep = () => {
this.grad += 1 * out.grad
other.grad += 1 * out.grad
}
return out
}
sub(v: Value | number): Value {
const other = this.toVal(v)
const out = new Value(this.data - other.data, {
prev: [this, other],
op: '-',
})
out.backwardStep = () => {
this.grad += 1 * out.grad
other.grad += -1 * out.grad
}
return out
}
mul(v: Value | number): Value {
const other = this.toVal(v)
const out = new Value(this.data * other.data, {
prev: [this, other],
op: '*',
})
out.backwardStep = () => {
this.grad += other.data * out.grad
other.grad += this.data * out.grad
}
return out
}
div(v: Value | number): Value {
const other = this.toVal(v)
const out = new Value(this.data / other.data, {
prev: [this, other],
op: '/',
})
out.backwardStep = () => {
this.grad += (1 / other.data) * out.grad
other.grad += (-this.data / other.data ** 2) * out.grad
}
return out
}
pow(other: number): Value {
if (typeof other !== 'number')
throw new Error('Only supporting int/float powers')
const out = new Value(this.data ** other, {
prev: [this],
op: '^',
})
out.backwardStep = () => {
this.grad += other * this.data ** (other - 1) * out.grad
}
return out
}
exp(): Value {
const out = new Value(Math.exp(this.data), { prev: [this], op: 'e' })
out.backwardStep = () => {
this.grad += out.data * out.grad
}
return out
}
tanh(): Value {
const out = new Value(Math.tanh(this.data), { prev: [this], op: 'tanh' })
out.backwardStep = () => (this.grad += (1 - out.data ** 2) * out.grad)
return out
}
relu(): Value {
const reluVal = this.data < 0 ? 0 : this.data
const out = new Value(reluVal, { prev: [this], op: 'relu' })
out.backwardStep = () => (this.grad += (out.data > 0 ? 1 : 0) * out.grad)
return out
}
backward(): void {
// Topological order of all the children in the graph.
const topo: Value[] = []
const visited = new Set()
const buildTopo = (v: Value) => {
if (visited.has(v)) return
visited.add(v)
for (const parent of v.prev) buildTopo(parent)
topo.push(v)
}
buildTopo(this)
topo.reverse()
// Go one variable at a time and apply the chain rule to get its gradient.
this.grad = 1
for (const node of topo) node.backwardStep()
}
}
// Shortcut for: new Value(data, params)
export const v = (d: number, p: ValueParams = {}): Value => new Value(d, p)