@vzhou842/

A RNN from scratch

Python

A Recurrent Neural Network implemented from scratch (using only numpy) in Python. https://github.com/vzhou842/rnn-from-scratch

fork
loading
Files
  • main.py
  • data.py
  • requirements.txt
  • rnn.py

This Plugin Crashed!

Error: Error: must not create an existing file {"type":"CREATE_FILE","wid":"0.5746015779402525","path":"main.py","file":{"path":"main.py","content":{"asEncoding":{"base64":"IyAtLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tCiMgR2l0aHViOiBodHRwczovL2dpdGh1Yi5jb20vdnpob3U4NDIvcm5uLWZyb20tc2NyYXRjaAojIFJlYWQgdGhlIGJsb2cgcG9zdDogaHR0cHM6Ly92aWN0b3J6aG91LmNvbS9ibG9nL2ludHJvLXRvLXJubnMvCiMgLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLQppbXBvcnQgbnVtcHkgYXMgbnAKaW1wb3J0IHJhbmRvbQoKZnJvbSBybm4gaW1wb3J0IFJOTgpmcm9tIGRhdGEgaW1wb3J0IHRyYWluX2RhdGEsIHRlc3RfZGF0YQoKIyBDcmVhdGUgdGhlIHZvY2FidWxhcnkuCnZvY2FiID0gbGlzdChzZXQoW3cgZm9yIHRleHQgaW4gdHJhaW5fZGF0YS5rZXlzKCkgZm9yIHcgaW4gdGV4dC5zcGxpdCgnICcpXSkpCnZvY2FiX3NpemUgPSBsZW4odm9jYWIpCnByaW50KCclZCB1bmlxdWUgd29yZHMgZm91bmQnICUgdm9jYWJfc2l6ZSkKCiMgQXNzaWduIGluZGljZXMgdG8gZWFjaCB3b3JkLgp3b3JkX3RvX2lkeCA9IHsgdzogaSBmb3IgaSwgdyBpbiBlbnVtZXJhdGUodm9jYWIpIH0KaWR4X3RvX3dvcmQgPSB7IGk6IHcgZm9yIGksIHcgaW4gZW51bWVyYXRlKHZvY2FiKSB9CiMgcHJpbnQod29yZF90b19pZHhbJ2dvb2QnXSkKIyBwcmludChpZHhfdG9fd29yZFswXSkKCmRlZiBjcmVhdGVJbnB1dHModGV4dCk6CiAgJycnCiAgUmV0dXJucyBhbiBhcnJheSBvZiBvbmUtaG90IHZlY3RvcnMgcmVwcmVzZW50aW5nIHRoZSB3b3JkcyBpbiB0aGUgaW5wdXQgdGV4dCBzdHJpbmcuCiAgLSB0ZXh0IGlzIGEgc3RyaW5nCiAgLSBFYWNoIG9uZS1ob3QgdmVjdG9yIGhhcyBzaGFwZSAodm9jYWJfc2l6ZSwgMSkKICAnJycKICBpbnB1dHMgPSBbXQogIGZvciB3IGluIHRleHQuc3BsaXQoJyAnKToKICAgIHYgPSBucC56ZXJvcygodm9jYWJfc2l6ZSwgMSkpCiAgICB2W3dvcmRfdG9faWR4W3ddXSA9IDEKICAgIGlucHV0cy5hcHBlbmQodikKICByZXR1cm4gaW5wdXRzCgpkZWYgc29mdG1heCh4cyk6CiAgIyBBcHBsaWVzIHRoZSBTb2Z0bWF4IEZ1bmN0aW9uIHRvIHRoZSBpbnB1dCBhcnJheS4KICByZXR1cm4gbnAuZXhwKHhzKSAvIHN1bShucC5leHAoeHMpKQoKIyBJbml0aWFsaXplIG91ciBSTk4hCnJubiA9IFJOTih2b2NhYl9zaXplLCAyKQoKZGVmIHByb2Nlc3NEYXRhKGRhdGEsIGJhY2twcm9wPVRydWUpOgogICcnJwogIFJldHVybnMgdGhlIFJOTidzIGxvc3MgYW5kIGFjY3VyYWN5IGZvciB0aGUgZ2l2ZW4gZGF0YS4KICAtIGRhdGEgaXMgYSBkaWN0aW9uYXJ5IG1hcHBpbmcgdGV4dCB0byBUcnVlIG9yIEZhbHNlLgogIC0gYmFja3Byb3AgZGV0ZXJtaW5lcyBpZiB0aGUgYmFja3dhcmQgcGhhc2Ugc2hvdWxkIGJlIHJ1bi4KICAnJycKICBpdGVtcyA9IGxpc3QoZGF0YS5pdGVtcygpKQogIHJhbmRvbS5zaHVmZmxlKGl0ZW1zKQoKICBsb3NzID0gMAogIG51bV9jb3JyZWN0ID0gMAoKICBmb3IgeCwgeSBpbiBpdGVtczoKICAgIGlucHV0cyA9IGNyZWF0ZUlucHV0cyh4KQogICAgdGFyZ2V0ID0gaW50KHkpCgogICAgIyBGb3J3YXJkCiAgICBvdXQsIF8gPSBybm4uZm9yd2FyZChpbnB1dHMpCiAgICBwcm9icyA9IHNvZnRtYXgob3V0KQoKICAgICMgQ2FsY3VsYXRlIGxvc3MgLyBhY2N1cmFjeQogICAgbG9zcyAtPSBucC5sb2cocHJvYnNbdGFyZ2V0XSkKICAgIG51bV9jb3JyZWN0ICs9IGludChucC5hcmdtYXgocHJvYnMpID09IHRhcmdldCkKCiAgICBpZiBiYWNrcHJvcDoKICAgICAgIyBCdWlsZCBkTC9keQogICAgICBkX0xfZF95ID0gcHJvYnMKICAgICAgZF9MX2RfeVt0YXJnZXRdIC09IDEKCiAgICAgICMgQmFja3dhcmQKICAgICAgcm5uLmJhY2twcm9wKGRfTF9kX3kpCgogIHJldHVybiBsb3NzIC8gbGVuKGRhdGEpLCBudW1fY29ycmVjdCAvIGxlbihkYXRhKQoKIyBUcmFpbmluZyBsb29wCmZvciBlcG9jaCBpbiByYW5nZSgxMDAwKToKICB0cmFpbl9sb3NzLCB0cmFpbl9hY2MgPSBwcm9jZXNzRGF0YSh0cmFpbl9kYXRhKQoKICBpZiBlcG9jaCAlIDEwMCA9PSA5OToKICAgIHByaW50KCctLS0gRXBvY2ggJWQnICUgKGVwb2NoICsgMSkpCiAgICBwcmludCgnVHJhaW46XHRMb3NzICUuM2YgfCBBY2N1cmFjeTogJS4zZicgJSAodHJhaW5fbG9zcywgdHJhaW5fYWNjKSkKCiAgICB0ZXN0X2xvc3MsIHRlc3RfYWNjID0gcHJvY2Vzc0RhdGEodGVzdF9kYXRhLCBiYWNrcHJvcD1GYWxzZSkKICAgIHByaW50KCdUZXN0Olx0TG9zcyAlLjNmIHwgQWNjdXJhY3k6ICUuM2YnICUgKHRlc3RfbG9zcywgdGVzdF9hY2MpKQo="},"asBuffer":null},"loaded":true}}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# -------------------------------
# Github: https://github.com/vzhou842/rnn-from-scratch
# Read the blog post: https://victorzhou.com/blog/intro-to-rnns/
# -------------------------------
import numpy as np
import random

from rnn import RNN
from data import train_data, test_data

# Create the vocabulary.
vocab = list(set([w for text in train_data.keys() for w in text.split(' ')]))
vocab_size = len(vocab)
print('%d unique words found' % vocab_size)

# Assign indices to each word.
word_to_idx = { w: i for i, w in enumerate(vocab) }
idx_to_word = { i: w for i, w in enumerate(vocab) }
# print(word_to_idx['good'])
# print(idx_to_word[0])

def createInputs(text):
  '''
  Returns an array of one-hot vectors representing the words in the input text string.
  - text is a string
  - Each one-hot vector has shape (vocab_size, 1)
  '''
  inputs = []
  for w in text.split(' '):
    v = np.zeros((vocab_size, 1))
    v[word_to_idx[w]] = 1
    inputs.append(v)
  return inputs

def softmax(xs):
  # Applies the Softmax Function to the input array.
  return np.exp(xs) / sum(np.exp(xs))

# Initialize our RNN!
rnn = RNN(vocab_size, 2)

def processData(data, backprop=True):
  '''
  Returns the RNN's loss and accuracy for the given data.
  - data is a dictionary mapping text to True or False.
  - backprop determines if the backward phase should be run.
  '''
  items = list(data.items())
  random.shuffle(items)

  loss = 0
  num_correct = 0

  for x, y in items:
    inputs = createInputs(x)
    target = int(y)

    # Forward
    out, _ = rnn.forward(inputs)
    probs = softmax(out)

    # Calculate loss / accuracy
    loss -= np.log(probs[target])
    num_correct += int(np.argmax(probs) == target)

    if backprop:
      # Build dL/dy
      d_L_d_y = probs
      d_L_d_y[target] -= 1

      # Backward
      rnn.backprop(d_L_d_y)

  return loss / len(data), num_correct / len(data)

# Training loop
for epoch in range(1000):
  train_loss, train_acc = processData(train_data)

  if epoch % 100 == 99:
    print('--- Epoch %d' % (epoch + 1))
    print('Train:\tLoss %.3f | Accuracy: %.3f' % (train_loss, train_acc))

    test_loss, test_acc = processData(test_data, backprop=False)
    print('Test:\tLoss %.3f | Accuracy: %.3f' % (test_loss, test_acc))