synaptic metaplasticity in binarized neural networks

Post on 05-Dec-2021

2 Views

Category:

Documents

0 Downloads

Preview:

Click to see full reader

TRANSCRIPT

Synaptic metaplasticity in binarized neural networks

Axel Laborieux1, Maxence Ernoult 2, Tifenn Hirtzlin1, Damien Querlioz1

Summary

Unlike the brain, artificial neural networks, includ-ing state-of-the-art deep neural networks for com-puter vision, are subject to “catastrophic forget-ting” [1]: they rapidly forget the previous taskwhen trained on a new one. Neuroscience suggeststhat biological synapses avoid this issue throughthe process of synaptic consolidation and meta-plasticity : the plasticity itself changes upon re-peated synaptic events [2, 3]. In this work, weshow that this concept of metaplasticity can betransferred to a particular type of deep neural net-works, binarized neural networks (BNNs) [4], to re-duce catastrophic forgetting. BNNs were initiallydeveloped to allow low-energy consumption imple-mentation of neural networks. In these networks,synaptic weights and activations are constrainedto {−1,+1} and training is performed using hid-den real-valued weights which are discarded at testtime. Our first contribution is to draw a paral-lel between the metaplastic states of [2] and thehidden weights inherent to BNNs. Based on thisinsight, we propose a simple synaptic consolida-tion strategy for the hidden weight. We justifyit using a tractable binary optimization problem,and we show that our strategy performs almost aswell as mainstream machine learning approaches tomitigate catastrophic forgetting, which minimizetask-specific loss functions [5], on the task of learn-ing pixel-permuted versions of the MNIST digitdataset sequentially. Moreover, unlike these tech-niques, our approach does not require task bound-aries, thereby allowing us to explore a new set-ting where the network learns from a stream ofdata. When trained on data streams from Fash-ion MNIST or CIFAR-10, our metaplastic BNNoutperforms a standard BNN and closely matchesthe accuracy of the network trained on the wholedataset. These results suggest that BNNs are morethan a low precision version of full precision net-works and highlight the benefits of the synergy be-tween neuroscience and deep learning [6].

1Centre de Nanosciences et de Nanotechnologies, Uni-versité Paris-Saclay

2Mila, Université de Montréal

Hidden weights as metaplastic statesThe problem of forgetting in artificial neural net-works results from a dilemma: synapses need tobe updated in order to learn new tasks but alsoto be protected against further changes in orderto preserve knowledge. In a foundational neuro-science work, Fusi et al. show than in small Hop-field networks, catastrophic forgetting can be ad-dressed by introducing a hidden metaplastic statethat controls the plasticity of the synapse [2].Synapses can assume only +1 or −1 weight, withthe metaplastic state modulating the difficulty forthe synapse to switch. Therefore, in this scheme,repeated potentiation of a positive-weight synapsewill only affect its metaplastic state and not itsactual weight. Here, we remark that the way thatBNNs are trained is remarkably similar to this situ-ation. In BNNs, synapses can also only assume +1or −1 weight, and they feature a hidden real weight(W h), which is updated by backpropagation. Thesynaptic weight changes between +1 and −1 onlywhen W h changes sign, suggesting that W h can beseen as a metaplastic state modulating the diffi-culty for the actual weight to change sign. How-ever, standard BNNs are as prone to catastrophicforgetting as conventional neural networks. In [2],Fusi et al. showed that the metaplastic changesshould make subsequent affect plasticity exponen-tially to mitigate forgetting, whereas W h affectsweight changes only linearly in BNNs. Therefore,in this work, we propose to adapt the learning pro-cess of BNNs so that the larger the magnitude ofa hidden weight W h, the more difficult to switchits associated binarized weight W b = sign(W h).Denoting UW the update provided by the learningalgorithm, we implement:

W h ← W h − ηUW · fmeta(m,Wh) if UWW

h > 0

W h ← W h − ηUW otherwise.

As in the metaplasticity model of [2] wheresynaptic plasticity decreases exponentially withthe metaplastic state, we choose fmeta(m,W

h) =tanh

′(m ·W h) to produce an exponential decay for

1

arX

iv:2

101.

0759

2v1

[cs

.NE

] 1

9 Ja

n 20

21

large metaplastic states W h, where m is an hyper-parameter that controls the consolidation.

Toy problem studyTo validate the interpretation of hidden weightsas metaplastic states, we first focus on a highlysimplified binary optimization task that we solvein a way analogous to the BNN training process.We want the binarized weights W b to minimizea quadratic loss L, as depicted by the color mapon Fig. 1(a) in two dimensions, with W ∗ as theglobal optimum. We assume that W h is updatedby loss gradients computed with binarized weights,similarly to BNNs:

W ht+1 = W h

t − η∂L∂W

(W bt ).

We can show that if the infinite norm of W ∗ islesser than one, some hidden weights diverge ast → ∞. This is because W h is updated by lossgradients computed at the corners of the square,in contrast with conventional optimization. Moreimportantly, if we define importance of the bina-rized weight as the increase of the loss ∆L whenthe weight is switched to the opposite value, wecan prove that the speed of divergence of the hid-den weight is directly linked to the importance ofthe binarized weight. For instance, in Fig. 1(a),W b

x is more important than W by for optimization.

Finally, we plot ∆L versus |W h| in Fig. 1(b), (c)for higher dimensions and for a BNN trained onMNIST and observe that the correspondence be-tween important weights and hidden weight diver-gence still holds, justifying the fact that consolidat-ing synapses with diverging hidden weights as ourproposal does, is a promising route for mitigatingcatastrophic forgetting.

Experimental resultsContinual learning benchmark. We now ap-ply our consolidation strategy to the permutedMNIST benchmark on two hidden layers percep-trons of varying number of neurons. We show inFig. 1(d),(e) the average test accuracy as a func-tion of the number of tasks learned so far. Weobserve that our technique indeed allows sequen-tial task learning and performs almost as well asElastic Weight Consolidation (EWC) [5] adaptedto BNNs (the importance factor is computed withthe binarized weights) over a wide range of hiddenlayer sizes when learning up to 20 tasks. We choosem = 1.35 and λEWC = 5 · 10−3 for EWC.

Learning from a stream of data. By construc-tion, our approach does not require to update theimportance factor between two consecutive tasks.Building on this asset, we explore a new setting,which we call stream learning, and where a taskis learned by learning sub-parts of the full datasetsequentially, with all classes evenly distributed ineach subset. We choose Fashion MNIST (FM-NIST) and CIFAR-10 for our experiments. Thearchitectures used are a perceptron with two hid-den layers of 1,024 units for FMNIST and a VGG-16 convolutional architecture for CIFAR-10. Weplot on Fig. 1(f), (g) the test accuracy reached bythose networks when metaplasticity is used (red)or not (blue). We see that our approach comescloser to the accuracy reached when the full datasetis learned at once (straight lines) than the non-metaplastic counterpart. Overall, these resultshighlight the benefit of metaplasticity models fromneuroscience when applied to machine learning.

Acknowledgement

This work was supported by European ResearchCouncil Starting Grant NANOINFER (reference:715872).

References[1] French, R. M., Trends in cogn. sci. (1999).[2] Fusi et al. Neuron (2005).[3] Abraham, W. C. Nat. Rev. Neurosci. (2008).[4] Courbariaux et al. arXiv:1602.02830 (2016)[5] Kirkpatrick et al. PNAS (2017).[6] Richards et al. Nat. Neurosci. (2019).[7] Zenke et al. PMLR (2017).

2

Figure 1: (a) Quadratic binarized optimization intwo dimensions. (b-c) Average loss increase whenswitchingW b versus normalized hidden weight, forthe binary quadratic problem (b) and for a BNNon MNIST (c). (d-e) Permuted MNIST benchmarkwith our method (d) and EWC (e), where the xaxis labels the number of learned tasks, with onecolor per network size. Fashion MNIST (f) andCIFAR-10 (g) test accuracy in the stream learningsetting allowed by our approach (red) compared toa standard BNN (blue). Horizontal rules denotefull dataset training baseline.

3

top related