-
Notifications
You must be signed in to change notification settings - Fork 5
/
rnn_test.go
108 lines (91 loc) · 1.86 KB
/
rnn_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package rnn
import (
"github.com/gonum/matrix/mat64"
"io/ioutil"
"reflect"
"strings"
"testing"
)
func TestNetwork(t *testing.T) {
r := NewRNN(strings.Repeat("Mary had a little lamb.", 10), "test.tmp")
expected := 14
if r.VocabSize != expected {
t.Fatalf("VocabSize expected: %d, got: %d", expected, r.VocabSize)
}
//r.Run()
}
// testing my understanding of how to multiply a mat64 matrix by a scalar
func TestMultiply(t *testing.T) {
a := mat64.NewDense(2, 2, []float64{0, 1, 2, 3})
a.Scale(2, a)
expectedRows := [][]float64{
{0, 2},
{4, 6},
}
r, c := a.Dims()
if r != 2 || c != 2 {
t.Fatalf("rows/cols changed: %d, %d", r, c)
}
for i, row := range expectedRows {
expected := row
actual := a.RawRowView(i)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("for row %d, expected: %+v, got: %+v", i, expected, actual)
}
}
}
func TestMapInput(t *testing.T) {
input := "Mary had a little lamb."
charToIndex, indexToChar := mapInput(input)
expectedCTI := map[rune]int{
'M': 0,
'a': 1,
'r': 2,
'y': 3,
' ': 4,
'h': 5,
'd': 6,
'l': 7,
'i': 8,
't': 9,
'e': 10,
'm': 11,
'b': 12,
'.': 13,
}
actualCTI := charToIndex
if !reflect.DeepEqual(expectedCTI, actualCTI) {
t.Fatalf("expected: %+v, got: %+v", expectedCTI, actualCTI)
}
actualITC := indexToChar
expectedITC := map[int]rune{
0: 'M',
1: 'a',
2: 'r',
3: 'y',
4: ' ',
5: 'h',
6: 'd',
7: 'l',
8: 'i',
9: 't',
10: 'e',
11: 'm',
12: 'b',
13: '.',
}
if !reflect.DeepEqual(expectedITC, actualITC) {
t.Fatalf("expected: %+v, got: %+v", expectedITC, actualITC)
}
}
func BenchmarkRun(b *testing.B) {
inputBytes, err := ioutil.ReadFile("input.txt")
if err != nil {
b.Fatalf("error reading input training file: %s", err)
}
input := string(inputBytes)
rnn := NewRNN(input, "")
for i := 0; i < b.N; i++ {
rnn.Run(1000)
}
}