coupled oscillatory recurrent neural network (cornn): an
TRANSCRIPT
Coupled Oscillatory Recurrent Neural Network
(coRNN): An accurate and (gradient) stable
architecture for learning long time dependencies
T. Konstantin Rusch Siddhartha Mishra
Seminar for Applied Mathematics (SAM)Department of Mathematics
ETH Zurich
Recurrent neural networks
• RNNs achieved tremendous success in time series problems, e.g. inspeech recognition, computer vision and natural language processing
• Recurrent Neural Network imposes temporal structure on network:
hn = F (hn−1, xn,Θ), ∀n = 1, . . . ,N
• Training: Back-propagation through time
• Long term memory challenge
F F F F F
hn h1 h2 h3 hN
xn x1 x2 x3 xN
Difficulty of capturing long time dependencies
• Compute gradient of loss E:
∂E
∂θ=
1
N
N∑n=1
∑1≤k≤n
∂En
∂hn
∂hn
∂hk
∂+hk
∂θ
• Possible exponential growth or decay of
∂hn
∂hk=∏
k<i≤n
∂hi
∂hi−1
• Exploding/vanishing gradient problem (EVGP) (Pascanu et al, 2013)
F F F F F
hn h1 h2 h3 hN
xn x1 x2 x3 xN
Established solutions for learning LTDs
Gating mechanism:
• LSTM, GRU
• Additive gradient structure
• Gradients might stillexplode
Constraining structure of hidden connection:
• unitary/orthogonal RNNs
• Enables RNNs to capture long time dependencies
• Structure constraints ⇒ might lower expressivity
Coupled oscillators: A neurobiological inspiration
• Biological neurons viewed as oscillators
• Functional circuits of the brain interpreted by networks of oscillatoryneurons
(Butler and Paulson, 2015)
• Abstract essence → RNNs based on coupled oscillators
• Stability of oscillatory dynamics: Expect EVGP is mitigated
• Coupled oscillators access very rich set of output states → highexpressivity of system
Coupled oscillatory RNN (coRNN)
• Second-order system of ODEs:
y ′′ = σ(Wy + Wy ′ + Vu + b
)− γy − εy ′,
hidden state y , input u.
• First order equivalent:
y ′ = z , z ′ = σ (Wy + Wz + Vu + b)− γy − εz .
• Discretize to obtain coupled oscillatory RNN (coRNN):
yn = yn−1 + ∆tzn,zn = zn−1 + ∆tσ (Wyn−1 + Wzn−1 + Vun + b)−∆tγyn−1 −∆tεzn.
Exploding/vanishing gradient for coRNN
• Gradients for coRNN:
∂E
∂θ=
1
N
N∑n=1
∂En
∂θ=
1
N
N∑n=1
∑1≤k≤n
∂En
∂Xn
∂Xn
∂Xk
∂+Xk
∂θ︸ ︷︷ ︸∂E
(k)n∂θ
• Assumption:
max
{∆t(1 + ‖W ‖∞)
1 + ∆t,
∆t‖W‖∞1 + ∆t
}≤ ∆tr ,
1
2≤ r ≤ 1.
Assumption easily met
Mitigation of EVGP for coRNN
• Exploding gradient problem for coRNN:∣∣∣∣∂E∂θ∣∣∣∣ ≤ 3
2
(m + Y
√m).
• Vanishing gradient problem for coRNN:
∂E(k)n
∂θ= O
(cδ∆t
32
)+O
(cδ(1 + δ)∆t
52
)+O(∆t3).
• Orders of estimate independent of k .
Experiment: Adding problem
• Input:
X =
[u1 . . . ui . . . uT/2 . . . uj . . . uT0 . . . 1 . . . 0 . . . 1 . . . 0
], ut ∼ U([0, 1])
• Output: Y = ui + uj• Goal: Beat test MSE of 0.167 (variance of baseline output 1)
0 100 200 300 400 500
Training steps (hundreds)
0.000
0.025
0.050
0.075
0.100
0.125
0.150
0.175
0.200
MS
E
T = 500
0 100 200 300 400 500
Training steps (hundreds)
0.000
0.025
0.050
0.075
0.100
0.125
0.150
0.175
0.200
MS
E
T = 2000
FastRNN
anti. RNN
tanh RNN
0 100 200 300 400 500
Training steps (hundreds)
0.000
0.025
0.050
0.075
0.100
0.125
0.150
0.175
0.200
MS
E
T = 5000
coRNN
expRNN
Baseline
Experiment: (Permuted) Sequential MNIST
• sMNIST: Flatten MNISTimages along rows
• psMNIST: Randomlypermute sMNISTsequences
Length T=784
Model sMNIST psMNIST # units # params
LSTM 98.9% 92.9% 256 270kGRU 99.1% 94.1% 256 200kanti. RNN 98.0% 95.8% 128 10kexpRNN 98.7% 96.6% 512 137kFastGRNN 98.7% 94.8% 128 18kcoRNN 99.3% 96.6% 128 34kcoRNN 99.4% 97.3% 256 134k
Experiment: Human activity recognition
• Collection of 6 tracked human activities,measured by accelerometer and gyroscopeon smartphone
• HAR-2: Activities binarized to {Sitting,Laying, Walking Upstairs} and {Standing,Walking, Walking Downstairs}
Model test accuracy # units # params
GRU 93.6% 75 19kLSTM 93.7% 64 16kFastGRNN 95.6% 80 7kanti.sym. RNN 93.2% 120 8kincremental RNN 96.3% 64 4kcoRNN 97.2% 64 9ktiny coRNN 96.5% 20 1k
Experiment: IMDB sentiment analysis
• Collection of 50k movie reviews
• Sentiment analysis
• Tests expressivity of RNN
Model test accuracy # units # params
LSTM 86.8% 128 220kSkip LSTM 86.6% 128 220kGRU 86.2% 128 164kSkip GRU 86.6% 128 164kReLU GRU 84.8% 128 99kexpRNN 84.3% 256 58kcoRNN 87.4% 128 46k
Conclusion
• Neurobiologically inspired RNN
• Exploding and vanishing gradient problem mitigated
• SOTA results on LTD tasks
• coRNN as a first attempt → many possible extensions
• Plan to test coRNN on more oscillatory data
[email protected] | github.com/tk-rusch/coRNN | @tk rusch