momentum-based variance reduction in non-convex sgd...johnson and zhang [12] , zhang et al. [30] ,...

10
Momentum-Based Variance Reduction in Non-Convex SGD Ashok Cutkosky Google Research Mountain View, CA, USA Francesco Orabona Boston University Boston, MA, USA Abstract Variance reduction has emerged in recent years as a strong competitor to stochastic gradient descent in non-convex problems, providing the first algorithms to improve upon the converge rate of stochastic gradient descent for finding first-order critical points. However, variance reduction techniques typically require carefully tuned learning rates and willingness to use excessively large “mega-batches” in order to achieve their improved results. We present a new algorithm, STORM, that does not require any batches and makes use of adaptive learning rates, enabling simpler implementation and less hyperparameter tuning. Our technique for removing the batches uses a variant of momentum to achieve variance reduction in non-convex optimization. On smooth losses F ,STORM finds a point x with E[F (x)] O(1/ T + σ 1/3 /T 1/3 ) in T iterations with σ 2 variance in the gradients, matching the optimal rate and without requiring knowledge of σ. 1 Introduction This paper addresses the classic stochastic optimization problem, in which we are given a function F : R d R, and wish to find x R d such that F (x) is as small as possible. Unfortunately, our access to F is limited to a stochastic function oracle: we can obtain sample functions f (·) where ξ represents some sample variable (e.g. a minibatch index) such that E[f (·)] = F (·). Stochastic optimization problems are found throughout machine learning. For example, in supervised learning, x represents the parameters of a model (say the weights of a neural network), ξ represents an example, f (x) represents the loss on an example, and F represents the training loss of the model. We do not assume convexity, so in general the problem of finding a true minimum of F may be NP-hard. Hence, we relax the problem to finding a critical point of F – that is a point such that F (x)=0. Also, we assume access only to stochastic gradients evaluated on arbitrary points, rather than Hessians or other information. In this setting, the standard algorithm is stochastic gradient descent (SGD). SGD produces a sequence of iterates x 1 ,..., x T using the recursion x t+1 = x t - η t g t , (1) where g t = f (x t t ), f (·1 ),...,f (·T ) are i.i.d. samples from a distribution D, and η 1 ,...η T R are a sequence of learning rates that must be carefully tuned to ensure good perfor- mance. Assuming the η t are selected properly, SGD guarantees that a randomly selected iterate x t satisfies E[F (x t )] O(1/T 1/4 ) [9]. Recently, variance reduction has emerged as an improved technique for finding critical points in non- convex optimization problems. Stochastic variance-reduced gradient (SVRG) algorithms also produce iterates x 1 ,...,x T according to the update formula (1), but now g t is a variance reduced estimate of F (x t ). Over the last few years, SVRG algorithms have improved the convergence rate to critical points of non-convex SGD from O(1/T 1/4 ) to O(1/T 3/10 ) [2, 21] to O(1/T 1/3 ) [8, 31]. Despite 33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.

Upload: others

Post on 10-Feb-2021

1 views

Category:

Documents


0 download

TRANSCRIPT

  • Momentum-Based Variance Reduction inNon-Convex SGD

    Ashok CutkoskyGoogle Research

    Mountain View, CA, [email protected]

    Francesco OrabonaBoston UniversityBoston, MA, USA

    [email protected]

    Abstract

    Variance reduction has emerged in recent years as a strong competitor to stochasticgradient descent in non-convex problems, providing the first algorithms to improveupon the converge rate of stochastic gradient descent for finding first-order criticalpoints. However, variance reduction techniques typically require carefully tunedlearning rates and willingness to use excessively large “mega-batches” in order toachieve their improved results. We present a new algorithm, STORM, that doesnot require any batches and makes use of adaptive learning rates, enabling simplerimplementation and less hyperparameter tuning. Our technique for removing thebatches uses a variant of momentum to achieve variance reduction in non-convexoptimization. On smooth losses F , STORM finds a point x with E[‖∇F (x)‖] ≤O(1/

    √T +σ1/3/T 1/3) in T iterations with σ2 variance in the gradients, matching

    the optimal rate and without requiring knowledge of σ.

    1 Introduction

    This paper addresses the classic stochastic optimization problem, in which we are given a functionF : Rd → R, and wish to find x ∈ Rd such that F (x) is as small as possible. Unfortunately, ouraccess to F is limited to a stochastic function oracle: we can obtain sample functions f(·, ξ) whereξ represents some sample variable (e.g. a minibatch index) such that E[f(·, ξ)] = F (·). Stochasticoptimization problems are found throughout machine learning. For example, in supervised learning,x represents the parameters of a model (say the weights of a neural network), ξ represents an example,f(x, ξ) represents the loss on an example, and F represents the training loss of the model.

    We do not assume convexity, so in general the problem of finding a true minimum of F may beNP-hard. Hence, we relax the problem to finding a critical point of F – that is a point such that∇F (x) = 0. Also, we assume access only to stochastic gradients evaluated on arbitrary points,rather than Hessians or other information. In this setting, the standard algorithm is stochastic gradientdescent (SGD). SGD produces a sequence of iterates x1, . . . ,xT using the recursion

    xt+1 = xt − ηtgt, (1)where gt = ∇f(xt, ξt), f(·, ξ1), . . . , f(·, ξT ) are i.i.d. samples from a distribution D, andη1, . . . ηT ∈ R are a sequence of learning rates that must be carefully tuned to ensure good perfor-mance. Assuming the ηt are selected properly, SGD guarantees that a randomly selected iterate xtsatisfies E[‖∇F (xt)‖] ≤ O(1/T 1/4) [9].Recently, variance reduction has emerged as an improved technique for finding critical points in non-convex optimization problems. Stochastic variance-reduced gradient (SVRG) algorithms also produceiterates x1, . . . , xT according to the update formula (1), but now gt is a variance reduced estimate of∇F (xt). Over the last few years, SVRG algorithms have improved the convergence rate to criticalpoints of non-convex SGD from O(1/T 1/4) to O(1/T 3/10) [2, 21] to O(1/T 1/3) [8, 31]. Despite

    33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.

  • this improvement, SVRG has not seen as much success in practice in non-convex machine learningproblems [5]. Many reasons may contribute to this phenomenon, but two potential issues we addresshere are SVRG’s use of non-adaptive learning rates and reliance on giant batch sizes to constructvariance reduced gradients through the use of low-noise gradients calculated at a “checkpoint”. Inparticular, for non-convex losses SVRG analyses typically involve carefully selecting learning rates,the number of samples to construct the gradient on the checkpoint points, and the frequency of updateof the checkpoint points. The optimal settings balance various unknown problem parameters exactlyin order to obtain improved performance, making it especially important, and especially difficult, totune them.

    In this paper, we address both of these issues. We present a new algorithm called STOchasticRecursive Momentum (STORM) that achieves variance reduction through the use of a variant of themomentum term, similar to the popular RMSProp or Adam momentum heuristics [24, 13]. Hence,our algorithm does not require a gigantic batch to compute checkpoint gradients – in fact, ouralgorithm does not require any batches at all because it never needs to compute a checkpoint gradient.STORM achieves the optimal convergence rate of O(1/T 1/3) [3], and it uses an adaptive learningrate schedule that will automatically adjust to the variance values of∇f(xt, ξt). Overall, we considerour algorithm a significant qualitative departure from the usual paradigm for variance reduction, andwe hope our analysis may provide insight into the value of momentum in non-convex optimization.

    The rest of the paper is organized as follows. The next section discusses the related work on variancereduction and adaptive learning rates in non-convex SGD. Section 3 formally introduces our notationand assumptions. We present our basic update rule and its connection to SGD with momentum inSection 4, and our algorithm in Section 5. Finally, we present some empirical results in Section 6 andconcludes with a discussion in Section 7.

    2 Related Work

    Variance-reduction methods were proposed independently by three groups at the same conference:Johnson and Zhang [12], Zhang et al. [30], Mahdavi et al. [17], and Wang et al. [27]. The firstapplication of variance-reduction method to non-convex SGD is due to Allen-Zhu and Hazan [2].Using variance reduction methods, Fang et al. [8], Zhou et al. [31] have obtained much betterconvergence rates for critical points in non-convex SGD. These methods are very different from ourapproach because they require the calculation of gradients at checkpoints. In fact, in order to computethe variance reduced gradient estimates gt, the algorithm must periodically stop producing iteratesxt and instead generate a very large “mega-batch” of samples ξ1, . . . , ξN which is used to computea checkpoint gradient 1N

    ∑Ni=1∇f(v, ξi) for an appropriate checkpoint point v. Depending on the

    algorithm, N may be as large as O(T ), and typically no smaller than O(T 2/3). The only exceptionswe are aware of are SARAH [18, 19] and iSARAH [20]. However, their guarantees do not improveover the ones of plain SGD, and they still require at least one checkpoint gradient. Independently andsimultaneously with this work, [25] have proposed a new algorithm that does improve over SGD tomatch our same convergence rate, although it does still require one checkpoint gradient. Interestingly,their update formula is very similar to ours, although the analysis is rather different. We are not awareof prior works for non-convex optimization with reduced variance methods that completely avoidusing giant batches.

    On the other hand, adaptive learning-rate schemes, that choose the values ηt in some data-dependentway so as to reduce the need for tuning the values of ηt manually, have been introduced by Duchiet al. [7] and popularized by the heuristic methods like RMSProp and Adam [24, 13]. In the non-convex setting, adaptive learning rates can be shown to improve the convergence rate of SGD toO(1/

    √T + (σ2/T )1/4), where σ2 is a bound on the variance of ∇f(xt) [16, 28, 22]. Hence, these

    adaptive algorithms obtain much better convergence guarantees when the problem is “easy”, andhave become extremely popular in practice. In contrast, the only variance-reduced algorithm we areaware of that uses adaptive learning rates is [4], but their techniques apply only to convex losses.

    3 Notation and Assumptions

    In the following, we will write vectors with bold letters and we will denote the inner product betweenvectors a and b by a · b.

    2

  • Throughout the paper we will make the following assumptions. We assume access to a stream ofindependent random variables ξ1, . . . , ξT ∈ Ξ and a function f such that for all t and for all x,E[f(x, ξt)|x] = F (x). Note that we access two gradients on the same ξt on two different points ineach update, like in standard variance-reduced methods. In practice, ξt may denote an i.i.d. trainingexample, or an index into a training set while f(x, ξt) indicates the loss on the training exampleusing the model parameter x. We assume there is some σ2 that upper bounds the noise on gradients:E[‖∇f(x, ξt)−∇F (x)‖2] ≤ σ2.We define F ? = infx F (x) and we will assume that F ? > −∞. We will also need some assumptionson the functions f(x, ξt). Define a differentiable function f : Rd → R to be G-Lipschitz iff‖∇f(x)‖ ≤ G for all x, and f to be L-smooth iff ‖∇f(x)−∇f(y)‖ ≤ L‖x− y‖ for all x and y.We assume that f(x, ξt) is differentiable, and L-smooth as a function of x with probability 1. Wewill also assume that f(x, ξt) is G-Lipschitz for our adaptive analysis. We show in appendix B thatthis assumption can be lifted at the expense of adaptivity to σ.

    4 Momentum and Variance Reduction

    Before describing our algorithm in details, we briefly explore the connection between SGD withmomentum and variance reduction.

    The stochastic gradient descent with momentum algorithm is typically implemented as

    dt = (1− a)dt−1 + a∇f(xt, ξt)xt+1 = xt − ηdt,

    where a is small, i.e. a = 0.1. In words, instead of using the current gradient∇F (xt) in the updateof xt, we use an exponential average of the past observed gradients.

    While SGD with momentum and its variants have been successfully used in many machine learningapplications [13], it is well known that the presence of noise in the stochastic gradients can nullifythe theoretical gain of the momentum term [e.g. 29]. As a result, it is unclear how and why usingmomentum can be better than plain SGD. Although recent works have proved that a variant of SGDwith momentum improves the non-dominant terms in the convergence rate on convex stochastic leastsquare problems [6, 11], it is still unclear if the actual convergence rate can be improved.

    Here, we take a different route. Instead of showing that momentum in SGD works in the same way asin the noiseless case, i.e. giving accelerated rates, we show that a variant of momentum can provablyreduce the variance of the gradients. In its simplest form, the variant we propose is:

    dt = (1− a)dt−1 + a∇f(xt, ξt) + (1− a)(∇f(xt, ξt)−∇f(xt−1, ξt)) (2)xt+1 = xt − ηdt . (3)

    The only difference is the that we add the term (1− a)(∇f(xt, ξt)−∇f(xt−1, ξt)) to the update.As in standard variance-reduced methods, we use two gradients in each step. However, we do notneed to use the gradient calculated at any checkpoint points. Note that if xt ≈ xt−1, then our updatebecomes approximately the momentum one. These two terms will be similar as long as the algorithmis actually converging to some point, and so we can expect the algorithm to behave exactly like theclassic momentum SGD towards the end of the optimization process.

    To understand why the above updates delivers a variance reduction, consider the “error in dt” whichwe denote as �t:

    �t := dt −∇F (xt) .This term measures the error we incur by using dt as update direction instead of the correct butunknown direction,∇F (xt). The equivalent term in SGD would be E[‖∇f(xt, ξt)−∇F (xt)‖2] ≤σ2. So, if E[‖�t‖2] decreases over time, we have realized a variance reduction effect. Our technicalresult that we use to show this decrease is provided in Lemma 2, but let us take a moment here toappreciate why this should be expected intuitively. Considering the update written in (2), we canobtain a recursive expression for �t by subtracting∇F (xt) from both sides:

    �t = (1− a)�t−1 + a(∇f(xt, ξt)−∇F (xt))+ (1− a)(∇f(xt, ξt)−∇f(xt−1, ξt)− (∇F (xt)−∇F (xt−1))) .

    3

  • Algorithm 1 STORM: STOchastic Recursive Momentum1: Input: Parameters k, w, c, initial point x12: Sample ξ13: G1 ← ‖∇f(x1, ξ1)‖4: d1 ← ∇f(x1, ξ1)5: η0 ← kw1/36: for t = 1 to T do7: ηt ← k(w+∑ti=1G2t )1/38: xt+1 ← xt − ηtdt9: at+1 ← cη2t

    10: Sample ξt+111: Gt+1 ← ‖∇f(xt+1, ξt+1)‖12: dt+1 ← ∇f(xt+1, ξt+1) + (1− at+1)(dt −∇f(xt, ξt+1))13: end for14: Choose x̂ uniformly at random from x1, . . . ,xT . (In practice, set x̂ = xT ).15: return x̂

    Now, notice that there is good reason to expect the second and third terms of the RHS above to besmall: we can control a(∇f(xt, ξt) − ∇F (xt)) simply by choosing small enough values a, andfrom smoothness we expect (∇f(xt, ξt) − ∇f(xt−1, ξt) − (∇F (xt) − ∇F (xt−1)) to be of theorder of O(‖xt − xt−1‖) = O(ηdt−1). Therefore, by choosing small enough η and a, we obtain‖�t‖ = (1− a)‖�t−1‖+Z where Z is some small value. Thus, intuitively ‖�t‖ will decrease until itreaches Z/a. This highlights a trade-off in setting η and a in order to decrease the numerator of Z/awhile keeping the denominator sufficiently large. Our central challenge is showing that it is possibleto achieve a favorable trade-off in which Z/a is very small, resulting in small error �t.

    5 STORM: STOchastic Recursive Momentum

    We now describe our stochastic optimization algorithm, which we call STOchastic Recursive Mo-mentum (STORM). The pseudocode is in Algorithm 1. As described in the previous section, itsbasic update is of the form of (2) and (3). However, in order to achieve adaptivity to the noise inthe gradients, both the stepsize and the momentum term will depend on the past gradients, à laAdaGrad [7].

    The convergence guarantee of STORM is presented in Theorem 1 below.

    Theorem 1. Under the assumptions in Section 3, for any b > 0, we write k = bG23

    L . Set

    c = 28L2 + G2/(7Lk3) = L2(28 + 1/(7b3)) and w = max(

    (4Lk)3, 2G2,(ck4L

    )3)=

    G2 max((4b)3, 2, (28b+ 17b2 )

    3/64). Then, STORM satisfies

    E [‖∇F (x̂)‖] = E[

    1

    T

    T∑

    t=1

    ‖∇F (xt)‖]≤ w

    1/6√

    2M + 2M3/4√T

    +2σ1/3

    T 1/3,

    where M = 8k (F (x1)− F ?) + w1/3σ2

    4L2k2 +k2c2

    2L2 ln(T + 2).

    In words, Theorem 1 guarantees that STORM will make the norm of the gradients converge to 0at a rate of O( lnT√

    T) if there is no noise, and in expectation at a rate of 2σ

    1/3

    T 1/3in the stochastic case.

    We remark that we achieve both rates automatically, without the need to know the noise level northe need to tune stepsizes. Note that the rate when σ 6= 0 matches the optimal rate [3], which waspreviously only obtained by SVRG-based algorithms that require a “mega-batch” [8, 31].

    The dependence on G in this bound deserves some discussion - at first blush it appears that if G→ 0,the bound will go to infinity because the denominator in M goes to zero. Fortunately, this is notso: the resolution is to observe that F (x1)− F ? = O(G) and σ = O(G), so that the numerators ofM actually go to zero at least as fast as the denominator. The dependence on L may be similarlynon-intuitive: as L→ 0, M →∞. In this case this is actually to be expected: if L = 0, then there

    4

  • are no critical points (because the gradients are all the same!) and so we cannot actually find one.In general, M should be regarded as an O(log(T )) term where the constant indicates some inherenthardness level in the problem.

    Finally, note that here we assumed that each f(x, ξ) is G-Lipschitz in x. Prior variance reductionresults (e.g. [18, 8, 25]) do not make use of this assumption. However, we we show in AppendixB that simply replacing all instances of G or Gt in the parameters of STORM with an oracle-tunedvalue of σ allows us to dispense with this assumption while still avoiding all checkpoint gradients.

    Also note that, as in similar work on stochastic minimization of non-convex functions, Theorem 1only bounds the gradient of a randomly selected iterate [9]. However, in practical implementationswe expect the last iterate to perform equally well.

    Our analysis formalizes the intuition developed in the previous section through a Lyapunov potentialfunction. Our Lyapunov function is somewhat non-standard: for smooth non-convex functions,the Lyapunov function is typically of the form Φt = F (xt), but we propose to use the functionΦt = F (xt) + zt‖�t‖2 for a time-varying zt ∝ η−1t−1, where �t is the error in the update introducedin the previous section. The use of time-varying zt appears to be critical for us to avoid usingany checkpoints: with constant zt it seems that one always needs at least one checkpoint gradient.Potential functions of this form have been used to analyze momentum algorithms in order to proveasymptotic guarantees, see, e.g., Ruszczynski and Syski [23]. However, as far as we know, this use ofa potential is somewhat different than most variance reduction analyses, and so may provide avenuesfor further development. We now proceed to the proof of Theorem 1.

    5.1 Proof of Theorem 1

    First, we consider a generic SGD-style analysis. Most SGD analyses assume that the gradientestimates used by the algorithm are unbiased of ∇F (xt), but unfortunately dt biased. As a result,we need the following slightly different analysis. For lack of space, the proof of this Lemma and thenext one are in the Appendix.Lemma 1. Suppose ηt ≤ 14L for all t. Then

    E[F (xt+1)− F (xt)] ≤ E[−ηt/4‖∇F (xt)‖2 + 3ηt/4‖�t‖2

    ].

    The following technical observation is key to our analysis of STORM: it provides a recurrence thatenables us to bound the variance of the estimates dt.Lemma 2. With the notation in Algorithm 1, we have

    E[‖�t‖2/ηt−1

    ]≤ E

    [2c2η3t−1G

    2t + (1− at)2(1 + 4L2η2t−1)‖�t−1‖2/ηt−1

    +4(1− at)2L2ηt−1‖∇F (xt−1)‖2].

    Lemma 2 exhibits a somewhat involved algebraic identity, so let us try to build some intuition forwhat it means and how it can help us. First, multiply both sides by ηt−1. Technically the expectationsmake this a forbidden operation, but we ignore this detail for now. Next, observe that

    ∑Tt=1G

    2t

    is roughly Θ(T ) (since the the variance prevents ‖gt‖2 from going to zero even when ‖∇F (xt)‖does). Therefore ηt is roughly O(1/t1/3), and at is roughly O(1/t2/3). Discarding all constants, andobserving that (1− at)2 ≤ (1− at), the above Lemma is then saying that

    E[‖�t‖2] ≤ E[η4t−1 + (1− at)‖�t−1‖2 + η2t−1‖∇F (xt−1)‖2

    ]

    = E[t−4/3 +

    (1− t−2/3

    )‖�t−1‖2 + t−1/3‖∇F (xt−1)‖2

    ].

    We can use this recurrence to compute a kind of “equilibrium value” for E[‖�t‖2]: set E[‖�t‖2] =E[‖�t−1‖2] and solve to obtain ‖�t‖2 isO(1/t2/3+‖∇F (xt)‖2). This in turn suggests that, whenever‖∇F (xt)‖2 is greater than 1/t2/3, the gradient estimate dt = ∇F (xt) + �t will be a very goodapproximation of ∇F (xt) so that gradient descent should make very fast progress. Therefore, weexpect the “equilibrium value” for ‖∇F (xt)‖2 to be O(1/T 2/3), since this is the point at which theestimate dt becomes dominated by the error.

    We formalize this intuition using a Lyapunov function of the form Φt = F (xt) + zt‖�t‖2 in theproof of Theorem 1 below.

    5

  • Proof of Theorem 1. Consider the potential Φt = F (xt) + 132L2ηt−1 ‖�t‖2. We will upper bound

    Φt+1 − Φt for each t, which will allow us to bound ΦT in terms of Φ1 by summing over t. First,observe that since w ≥ (4Lk)3, we have ηt ≤ 14L . Further, since at+1 = cη2t , we have at+1 ≤ck

    4Lw1/3≤ 1 for all t. Then, we first consider η−1t ‖�t+1‖2 − η−1t−1‖�t‖2. Using Lemma 2, we obtain

    E[η−1t ‖�t+1‖2 − η−1t−1‖�t‖2

    ]

    ≤ E[2c2η3tG

    2t+1 +

    (1− at+1)2(1 + 4L2η2t )‖�t‖2ηt

    + 4(1− at+1)2L2ηt‖∇F (xt)‖2 −‖�t‖2ηt−1

    ]

    ≤ E

    2c2η3tG2t+1︸ ︷︷ ︸

    At

    +(η−1t (1− at+1)(1 + 4L2η2t )− η−1t−1

    )‖�t‖2︸ ︷︷ ︸

    Bt

    + 4L2ηt‖∇F (xt)‖2︸ ︷︷ ︸Ct

    .

    Let us focus on the terms of this expression individually. For the first term, At, observe thatw ≥ 2G2 ≥ G2 +G2t+1 to obtain:T∑

    t=1

    At =T∑

    t=1

    2c2η3tG2t+1 =

    T∑

    t=1

    2k3c2G2t+1

    w +∑ti=1G

    2i

    ≤T∑

    t=1

    2k3c2G2t+1

    G2 +∑t+1i=1G

    2i

    ≤ 2k3c2 ln(

    1 +T+1∑

    t=1

    G2tG2

    )

    ≤ 2k3c2 ln (T + 2) ,where in the second to last inequality we used Lemma 4 in the Appendix.

    For the second term Bt, we have

    Bt ≤ (η−1t − η−1t−1 + η−1t (4L2η2t − at+1))‖�t‖2 =(η−1t − η−1t−1 + ηt(4L2 − c)

    )‖�t‖2 .

    Let us focus on 1ηt −1

    ηt−1for a minute. Using the concavity of x1/3, we have (x + y)1/3 ≤

    x1/3 + yx−2/3/3. Therefore:

    1

    ηt− 1ηt−1

    =1

    k

    (w +

    t∑

    i=1

    G2i

    )1/3−(w +

    t−1∑

    i=1

    G2i

    )1/3 ≤ G

    2t

    3k(w +∑t−1i=1 G

    2i )

    2/3

    ≤ G2t

    3k(w −G2 +∑ti=1G2i )2/3≤ G

    2t

    3k(w/2 +∑ti=1G

    2i )

    2/3

    ≤ 22/3G2t

    3k(w +∑ti=1G

    2i )

    2/3≤ 2

    2/3G2

    3k3η2t ≤

    22/3G2

    12Lk3ηt ≤

    G2

    7Lk3ηt,

    where we have used that that w ≥ (4Lk)3 to have ηt ≤ 14L .Further, since c = 28L2 +G2/(7Lk3), we have

    ηt(4L2 − c) ≤ −24L2ηt −G2ηt/(7Lk3) .

    Thus, we obtain Bt ≤ −24L2ηt‖�t‖2. Putting all this together yields:

    1

    32L2

    T∑

    t=1

    (‖�t+1‖2ηt

    − ‖�t‖2

    ηt−1

    )≤ k

    3c2

    16L2ln (T + 2) +

    T∑

    t=1

    [ηt8‖∇F (xt)‖2 −

    3ηt4‖�t‖2

    ]. (4)

    Now, we are ready to analyze the potential Φt. Since ηt ≤ 14L , we can use Lemma 1 to obtain

    E[Φt+1 − Φt] ≤ E[−ηt

    4‖∇F (xt)‖2 +

    3ηt4‖�t‖2 +

    1

    32L2ηt‖�t+1‖2 −

    1

    32L2ηt−1‖�t‖2

    ].

    Summing over t and using (4), we obtain

    E[ΦT+1 − Φ1] ≤T∑

    t=1

    E[−ηt

    4‖∇F (xt)‖2 +

    3ηt4‖�t‖2 +

    1

    32L2ηt‖�t+1‖2 −

    1

    32L2ηt−1‖�t‖2

    ]

    ≤ E[k3c2

    16L2ln (T + 2)−

    T∑

    t=1

    ηt8‖∇F (xt)‖2

    ].

    6

  • Reordering the terms, we have

    E

    [T∑

    t=1

    ηt‖∇F (xt)‖2]≤ E

    [8(Φ1 − ΦT+1) + k3c2/(2L2) ln (T + 2)

    ]

    ≤ 8(F (x1)− F ?) + E[‖�1‖2]/(4L2η0) + k3c2/(2L2) ln(T + 2)≤ 8(F (x1)− F ?) + w1/3σ2/(4L2k) + k3c2/(2L2) ln(T + 2),

    where the last inequality is given by the definition of d1 and η0 in the algorithm.

    Now, we relate E[∑T

    t=1 ηt‖∇F (xt)‖2]

    to E[∑T

    t=1 ‖∇F (xt)‖2]. First, since ηt is decreasing,

    E

    [T∑

    t=1

    ηt‖∇F (xt)‖2]≥ E

    [ηT

    T∑

    t=1

    ‖∇F (xt)‖2].

    Now, from Cauchy-Schwarz inequality, for any random variables A and B we have E[A2]E[B2] ≥E[AB]2. Hence, setting A =

    √ηT∑T−1t=1 ‖∇F (xt)‖2 and B =

    √1/ηT , we obtain

    E[1/ηT ]E

    [ηT

    T∑

    t=1

    ‖∇F (xt)‖2]≥ E

    √√√√

    T∑

    t=1

    ‖∇F (xt)‖22

    .

    Therefore, if we set M = 1k[8(F (x1)− F ?) + w

    1/3σ2

    4L2k +k3c2

    2L2 ln(T + 2)], to get

    E

    √√√√

    T∑

    t=1

    ‖∇F (xt)‖22

    ≤ E[

    8(F (x1)− F ?) + w1/3σ2

    4L2k +k3c2

    2L2 ln(T + 2)

    ηT

    ]

    = E[kM

    ηT

    ]≤ E

    M

    (w +

    T∑

    t=1

    G2t

    )1/3 .

    Define ζt = ∇f(xt, ξt)−∇F (xt), so that E[‖ζt‖2] ≤ σ2. Then, we have G2t = ‖∇F (xt)+ ζt‖2 ≤2‖∇F (xt)‖2 + 2‖ζt‖2. Plugging this in and using (a+ b)1/3 ≤ a1/3 + b1/3 we obtain:

    E

    √√√√

    T∑

    t=1

    ‖∇F (xt)‖22

    ≤ E

    M

    (w + 2

    T∑

    t=1

    ‖ζt‖2)1/3

    +M21/3

    (T∑

    t=1

    ‖∇F (xt)‖2)1/3

    ≤M(w + 2Tσ2)1/3 + E

    21/3M

    √√√√

    T∑

    t=1

    ‖∇F (xt)‖2

    2/3

    ≤M(w + 2Tσ2)1/3 + 21/3M

    E

    √√√√

    T∑

    t=1

    ‖∇F (xt)‖2

    2/3

    ,

    where we have used the concavity of x 7→ xa for all a ≤ 1 to move expectations inside the exponents.Now, define X =

    √∑Tt=1 ‖∇F (xt)‖2. Then the above can be rewritten as:(E[X])2 ≤M(w + 2Tσ2)1/3 + 21/3M(E[X])2/3 .

    Note that this implies that either (E[X])2 ≤ 2M(w + Tσ2)1/3, or (E[X])2 ≤ 2 · 21/3M(E[X])2/3.Solving for E[X] in these two cases, we obtain

    E[X] ≤√

    2M(w + 2Tσ2)1/6 + 2M3/4 .

    Finally, observe that by Cauchy-Schwarz we have∑Tt=1 ‖∇F (xt)‖/T ≤ X/

    √T so that

    E

    [T∑

    t=1

    ‖∇F (xt)‖T

    ]≤√

    2M(w + 2Tσ2)1/6 + 2M3/4√T

    ≤ w1/6√

    2M + 2M3/4√T

    +2σ1/3

    T 1/3,

    where we used (a+ b)1/3 ≤ a1/3 + b1/3 in the last inequality.

    7

  • 0 1 2 3Steps 1e5

    10−4

    10−3

    10−2

    10−1

    100

    Trai

    n lo

    ss (l

    og sc

    ale)

    AdamAdagradStorm

    (a) Train Loss vs Iterations

    0 1 2 3Steps 1e5

    0.80

    0.85

    0.90

    0.95

    1.00

    Trai

    n ac

    cura

    cy

    AdamAdagradStorm

    (b) Train Accuracy vs Iterations

    0 2 4 6Steps 1e5

    0.4

    0.5

    0.6

    0.7

    0.8

    0.9

    Test

    acc

    urac

    y

    AdamAdagradStorm

    (c) Test Accuracy vs Iterations

    Figure 1: Experiments on CIFAR-10 with ResNet-32 Network.

    6 Empirical Validation

    In order to confirm that our advances do indeed yield an algorithm that performs well and requireslittle tuning, we implemented STORM in TensorFlow [1] and tested its performance on the CIFAR-10image recognition benchmark [14] using a ResNet model [10], as implemented by the Tensor2Tensorpackage [26]1. We compare STORM to AdaGrad and Adam, which are both very popular andsuccessful optimization algorithms. The learning rates for AdaGrad and Adam were swept overa logarithmically spaced grid. For STORM, we set w = k = 0.1 as a default2 and swept c overa logarithmically spaced grid, so that all algorithms involved only one parameter to tune. Noregularization was employed. We record train loss (cross-entropy), and accuracy on both the trainand test sets (see Figure 1).

    These results show that, while STORM is only marginally better than AdaGrad on test accuracy, onboth training loss and accuracy STORM appears to be somewhat faster in terms of number of iterations.We note that the convergence proof we provide actually only applies to the training loss (since we aremaking multiple passes over the dataset). We leave for the future whether appropriate regularizationcan trade-off STORM’s better training loss performance to obtain better test performance.

    7 Conclusion

    We have introduced a new variance-reduction-based algorithm, STORM, that finds critical points instochastic, smooth, non-convex problems. Our algorithm improves upon prior algorithms by virtueof removing the need for checkpoint gradients, and incorporating adaptive learning rates. Theseimprovements mean that STORM is substantially easier to tune: it does not require choosing the sizeof the checkpoints, nor how often to compute the checkpoints (because there are no checkpoints),and by using adaptive learning rates the algorithm enjoys the same robustness to learning rate tuningas popular algorithms like AdaGrad or Adam. STORM obtains the optimal convergence guarantee,adapting to the level of noise in the problem without knowledge of this parameter. We verified thaton CIFAR-10 with a ResNet architecture, STORM indeed seems to be optimizing the objective infewer iterations than baseline algorithms.

    Additionally, we point out that STORM’s update formula is strikingly similar to the standard SGDwith momentum heuristic employed in practice. To our knowledge, no theoretical result actuallyestablishes an advantage of adding momentum to SGD in stochastic problems, creating an intriguingmystery. While our algorithm is not precisely the same as the SGD with momentum, we feel that itprovides strong intuitive evidence that momentum is performing some kind of variance reduction.We therefore hope that some of the analysis techniques used in this paper may provide a path towardsexplaining the advantages of momentum.

    1https://github.com/google-research/google-research/tree/master/storm_optimizer2We picked these defaults by tuning over a logarithmic grid on the much-simpler MNIST dataset [15]. w and

    k were not tuned on CIFAR10.

    8

  • References[1] M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, G. S. Corrado, A. Davis,

    J. Dean, M. Devin, S. Ghemawat, I. Goodfellow, A. Harp, G. Irving, M. Isard, Y. Jia, R. Joze-fowicz, L. Kaiser, M. Kudlur, J. Levenberg, D. Mané, R. Monga, S. Moore, D. Murray, C. Olah,M. Schuster, J. Shlens, B. Steiner, I. Sutskever, K. Talwar, P. Tucker, V. Vanhoucke, V. Va-sudevan, F. Viégas, O. Vinyals, P. Warden, M. Wattenberg, M. Wicke, Y. Yu, and X. Zheng.TensorFlow: Large-scale machine learning on heterogeneous systems, 2015.

    [2] Z. Allen-Zhu and E. Hazan. Variance reduction for faster non-convex optimization. In Interna-tional conference on machine learning, pages 699–707, 2016.

    [3] Y. Arjevani, Y. Carmon, J. C. Duchi, D. J. Foster, N. Srebro, and B. Woodworth. Lower boundsfor non-convex stochastic optimization. arXiv preprint arXiv:1912.02365, 2019.

    [4] A. Cutkosky and R. Busa-Fekete. Distributed stochastic optimization via adaptive SGD. InAdvances in Neural Information Processing Systems, pages 1910–1919, 2018.

    [5] A. Defazio and L. Bottou. On the ineffectiveness of variance reduced optimization for deeplearning. arXiv preprint arXiv:1812.04529, 2018.

    [6] A. Dieuleveut, N. Flammarion, and F. Bach. Harder, better, faster, stronger convergence ratesfor least-squares regression. J. Mach. Learn. Res., 18(1):3520–3570, January 2017.

    [7] J. C. Duchi, E. Hazan, and Y. Singer. Adaptive subgradient methods for online learning andstochastic optimization. Journal of Machine Learning Research, 12:2121–2159, 2011.

    [8] C. Fang, C. J. Li, Z. Lin, and T. Zhang. Spider: Near-optimal non-convex optimization viastochastic path-integrated differential estimator. In Advances in Neural Information ProcessingSystems, pages 689–699, 2018.

    [9] S. Ghadimi and G. Lan. Stochastic first-and zeroth-order methods for nonconvex stochasticprogramming. SIAM Journal on Optimization, 23(4):2341–2368, 2013.

    [10] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In Proc. ofthe IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.

    [11] P. Jain, S. M. Kakade, R. Kidambi, P. Netrapalli, and A. Sidford. Accelerating stochasticgradient descent for least squares regression. In S. Bubeck, V. Perchet, and P. Rigollet, editors,Proceedings of the 31st Conference On Learning Theory, volume 75 of Proceedings of MachineLearning Research, pages 545–604. PMLR, 06–09 Jul 2018.

    [12] R. Johnson and T. Zhang. Accelerating stochastic gradient descent using predictive variancereduction. In Advances in neural information processing systems, pages 315–323, 2013.

    [13] D. P. Kingma and J. Ba. Adam: A method for stochastic optimization. In InternationalConference on Learning Representations (ICLR), 2015.

    [14] A. Krizhevsky. Learning multiple layers of features from tiny images. Technical report,University of Toronto, 2009.

    [15] Y. Lecun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to documentrecognition. Proceedings of the IEEE, 86(11):2278–2324, November 1998.

    [16] X. Li and F. Orabona. On the convergence of stochastic gradient descent with adaptive stepsizes.In Proc. of the 22nd International Conference on Artificial Intelligence and Statistics, AISTATS,2019.

    [17] M. Mahdavi, L. Zhang, and R. Jin. Mixed optimization for smooth functions. In Advances inneural information processing systems, pages 674–682, 2013.

    [18] L. M. Nguyen, J. Liu, K. Scheinberg, and M. Takáč. SARAH: A novel method for machinelearning problems using stochastic recursive gradient. In Proc. of the 34th InternationalConference on Machine Learning-Volume 70, pages 2613–2621. JMLR. org, 2017.

    [19] L. M. Nguyen, J. Liu, K. Scheinberg, and M. Takáč. Stochastic recursive gradient algorithm fornonconvex optimization. arXiv preprint arXiv:1705.07261, 2017.

    [20] L. M. Nguyen, K. Scheinberg, and M. Takáč. Inexact SARAH algorithm for stochasticoptimization. arXiv preprint arXiv:1811.10105, 2018.

    9

  • [21] S. J. Reddi, A. Hefny, S. Sra, B. Poczos, and A. Smola. Stochastic variance reduction fornonconvex optimization. In International conference on machine learning, pages 314–323,2016.

    [22] S. J. Reddi, S. Kale, and S. Kumar. On the convergence of Adam and beyond. In InternationalConference on Learning Representations, 2018.

    [23] A. Ruszczynski and W. Syski. Stochastic approximation method with gradient averagingfor unconstrained problems. IEEE Transactions on Automatic Control, 28(12):1097–1105,December 1983.

    [24] T. Tieleman and G. Hinton. Lecture 6.5-rmsprop: Divide the gradient by a running average ofits recent magnitude. COURSERA: Neural Networks for Machine Learning, 2012.

    [25] Quoc Tran-Dinh, Nhan H Pham, Dzung T Phan, and Lam M Nguyen. Hybrid stochastic gradientdescent algorithms for stochastic nonconvex optimization. arXiv preprint arXiv:1905.05920,2019.

    [26] A. Vaswani, S. Bengio, E. Brevdo, F. Chollet, A. N. Gomez, S. Gouws, L. Jones, Ł. Kaiser,N. Kalchbrenner, N. Parmar, R. Sepassi, N. Shazeer, and J. Uszkoreit. Tensor2tensor for neuralmachine translation. CoRR, abs/1803.07416, 2018.

    [27] C. Wang, X. Chen, A. J. Smola, and E. P. Xing. Variance reduction for stochastic gradientoptimization. In C. J. C. Burges, L. Bottou, M. Welling, Z. Ghahramani, and K. Q. Wein-berger, editors, Advances in Neural Information Processing Systems 26, pages 181–189. CurranAssociates, Inc., 2013.

    [28] R. Ward, X. Wu, and L. Bottou. AdaGrad stepsizes: Sharp convergence over nonconvexlandscapes, from any initialization. arXiv preprint arXiv:1806.01811, 2018.

    [29] K. Yuan, B. Ying, and A. H. Sayed. On the influence of momentum acceleration on onlinelearning. Journal of Machine Learning Research, 17(192):1–66, 2016.

    [30] L. Zhang, M. Mahdavi, and R. Jin. Linear convergence with condition number independentaccess of full gradients. In C. J. C. Burges, L. Bottou, M. Welling, Z. Ghahramani, and K. Q.Weinberger, editors, Advances in Neural Information Processing Systems 26, pages 980–988.Curran Associates, Inc., 2013.

    [31] D. Zhou, P. Xu, and Q. Gu. Stochastic nested variance reduced gradient descent for nonconvexoptimization. In Advances in Neural Information Processing Systems, pages 3921–3932, 2018.

    10