mini-course 3: convergence analysis of neural …optimization i in practice, sgd always nds good...

Post on 19-Jul-2020

10 Views

Category:

Documents

0 Downloads

Preview:

Click to see full reader

TRANSCRIPT

Mini-Course 3:Convergence Analysis of Neural Network

Yang Yuan

Computer Science DepartmentCornell University

Deep learning is powerful

What is neural network?

one

laye

rA simplified view.Missing: Convolution/BatchNorm.· · · · · · · · ·

· · · · · · · · ·

← Input x = (1, 2,−4)>

← Weight W =

1 0 0 10 1 0 10 0 1 1

← W>x = (1, 2,−4,−1)>

← ReLU(W>x) = (1, 2, 0, 0)>

What is neural network?

one

laye

rA simplified view.Missing: Convolution/BatchNorm.· · · · · · · · ·

· · · · · · · · ·

← Input x = (1, 2,−4)>

← Weight W =

1 0 0 10 1 0 10 0 1 1

← W>x = (1, 2,−4,−1)>

← ReLU(W>x) = (1, 2, 0, 0)>

Three basic types of theory questions

I RepresentationI Can we express any functions with neural networks?

I Yes, very expressive: [Hornik et al., 1989, Cybenko, 1992,Barron, 1993, Eldan and Shamir, 2015,Safran and Shamir, 2016, Lee et al., 2017]

I OptimizationI Efficient methods for finding good parameters (i.e.,

representations)?

I GeneralizationI Training data used for optimization step.I Does it generalize to unseen data (test data)?

I Little is know. Flat minima? [Shirish Keskar et al., 2016,Hochreiter and Schmidhuber, 1995, Chaudhari et al., 2016,Zhang et al., 2016]

I In practice: neural network is doing great in ALL THREE!I What about in theory?

Three basic types of theory questions

I RepresentationI Can we express any functions with neural networks?

I Yes, very expressive: [Hornik et al., 1989, Cybenko, 1992,Barron, 1993, Eldan and Shamir, 2015,Safran and Shamir, 2016, Lee et al., 2017]

I OptimizationI Efficient methods for finding good parameters (i.e.,

representations)?

I GeneralizationI Training data used for optimization step.I Does it generalize to unseen data (test data)?

I Little is know. Flat minima? [Shirish Keskar et al., 2016,Hochreiter and Schmidhuber, 1995, Chaudhari et al., 2016,Zhang et al., 2016]

I In practice: neural network is doing great in ALL THREE!I What about in theory?

Three basic types of theory questions

I RepresentationI Can we express any functions with neural networks?I Yes, very expressive: [Hornik et al., 1989, Cybenko, 1992,

Barron, 1993, Eldan and Shamir, 2015,Safran and Shamir, 2016, Lee et al., 2017]

I OptimizationI Efficient methods for finding good parameters (i.e.,

representations)?

I GeneralizationI Training data used for optimization step.I Does it generalize to unseen data (test data)?

I Little is know. Flat minima? [Shirish Keskar et al., 2016,Hochreiter and Schmidhuber, 1995, Chaudhari et al., 2016,Zhang et al., 2016]

I In practice: neural network is doing great in ALL THREE!I What about in theory?

Three basic types of theory questions

I RepresentationI Can we express any functions with neural networks?I Yes, very expressive: [Hornik et al., 1989, Cybenko, 1992,

Barron, 1993, Eldan and Shamir, 2015,Safran and Shamir, 2016, Lee et al., 2017]

I OptimizationI Efficient methods for finding good parameters (i.e.,

representations)?

I GeneralizationI Training data used for optimization step.I Does it generalize to unseen data (test data)?I Little is know. Flat minima? [Shirish Keskar et al., 2016,

Hochreiter and Schmidhuber, 1995, Chaudhari et al., 2016,Zhang et al., 2016]

I In practice: neural network is doing great in ALL THREE!I What about in theory?

Optimization

I In practice, SGD always finds good local minima.I SGD: stochastic gradient descentI xt+1 = xt − ηgt , E [gt ] = ∇f (xt)

I Some results are negative, saying optimization for neuralnetworks is in general hard.

I [Sıma, 2002, Livni et al., 2014, Shamir, 2016].

I Or positive but with special algorithms (tensor decomposition,half space intersection, etc.)

I [Janzamin et al., 2015, Zhang et al., 2015,Sedghi and Anandkumar, 2015, Goel et al., 2016],

I With strong assumptions on the model (weights are complexnumbers, learning polynomials only, weights are iid random)

I [Andoni et al., 2014, Arora et al., 2014]

Optimization

I In practice, SGD always finds good local minima.I SGD: stochastic gradient descentI xt+1 = xt − ηgt , E [gt ] = ∇f (xt)

I Some results are negative, saying optimization for neuralnetworks is in general hard.

I [Sıma, 2002, Livni et al., 2014, Shamir, 2016].

I Or positive but with special algorithms (tensor decomposition,half space intersection, etc.)

I [Janzamin et al., 2015, Zhang et al., 2015,Sedghi and Anandkumar, 2015, Goel et al., 2016],

I With strong assumptions on the model (weights are complexnumbers, learning polynomials only, weights are iid random)

I [Andoni et al., 2014, Arora et al., 2014]

Optimization

I In practice, SGD always finds good local minima.I SGD: stochastic gradient descentI xt+1 = xt − ηgt , E [gt ] = ∇f (xt)

I Some results are negative, saying optimization for neuralnetworks is in general hard.

I [Sıma, 2002, Livni et al., 2014, Shamir, 2016].

I Or positive but with special algorithms (tensor decomposition,half space intersection, etc.)

I [Janzamin et al., 2015, Zhang et al., 2015,Sedghi and Anandkumar, 2015, Goel et al., 2016],

I With strong assumptions on the model (weights are complexnumbers, learning polynomials only, weights are iid random)

I [Andoni et al., 2014, Arora et al., 2014]

Optimization

I In practice, SGD always finds good local minima.I SGD: stochastic gradient descentI xt+1 = xt − ηgt , E [gt ] = ∇f (xt)

I Some results are negative, saying optimization for neuralnetworks is in general hard.

I [Sıma, 2002, Livni et al., 2014, Shamir, 2016].

I Or positive but with special algorithms (tensor decomposition,half space intersection, etc.)

I [Janzamin et al., 2015, Zhang et al., 2015,Sedghi and Anandkumar, 2015, Goel et al., 2016],

I With strong assumptions on the model (weights are complexnumbers, learning polynomials only, weights are iid random)

I [Andoni et al., 2014, Arora et al., 2014]

Recent work: independent activations

Independent activation assumption

I The outputs of ReLU units are independent of the input x,and independent of each other. [Choromanska et al., 2015,Kawaguchi, 2016, Brutzkus and Globerson, 2017].

← Input x = (1, 2,−4)>

← Weight W =

1 0 0 10 1 0 10 0 1 1

← W>x = (1, 2,−4,−1)>

← ReLU(W>x) = (1, 2, 0, 0)>

Independent!

Independent!

Recent work: independent activations

Independent activation assumption

I The outputs of ReLU units are independent of the input x,and independent of each other. [Choromanska et al., 2015,Kawaguchi, 2016, Brutzkus and Globerson, 2017].

← Input x = (1, 2,−4)>

← Weight W =

1 0 0 10 1 0 10 0 1 1

← W>x = (1, 2,−4,−1)>

← ReLU(W>x) = (1, 2, 0, 0)>

Independent!

Independent!

Recent work: independent activations

Independent activation assumption

I The outputs of ReLU units are independent of the input x,and independent of each other. [Choromanska et al., 2015,Kawaguchi, 2016, Brutzkus and Globerson, 2017].

← Input x = (1, 2,−4)>

← Weight W =

1 0 0 10 1 0 10 0 1 1

← W>x = (1, 2,−4,−1)>

← ReLU(W>x) = (1, 2, 0, 0)>

Independent!

Independent!

CNN Model in [Brutzkus and Globerson, 2017]

Recent work: guarantees of other algorithm

I Tensor Decomposition+ Gradient Descent converges toground truth for one hidden layer network.[Zhong et al., 2017]

I Kernel methods could learn deep neural network witheigenvalue decay assumption. [Goel and Klivans, 2017]

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)I Show C ≥ 0.I So f (W) = f ∗.

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.

I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)I Show C ≥ 0.I So f (W) = f ∗.

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)I Show C ≥ 0.I So f (W) = f ∗.

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)I Show C ≥ 0.I So f (W) = f ∗.

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.

I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)I Show C ≥ 0.I So f (W) = f ∗.

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)

I Show C ≥ 0.I So f (W) = f ∗.

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)I Show C ≥ 0.

I So f (W) = f ∗.

Recent work: deep linear models

Ignore the activation functions.

I [Saxe et al., 2013, Kawaguchi, 2016, Hardt and Ma, 2016]

I Only learn a linear function.I Proof idea (for deep linear residual network)

I Loss: ‖(I + W`) · · · (I + W1)x − (Ax + b)‖2

I Compute ∂f∂Wi

.I Get lower bound for the gradient norm:‖∇f (W)‖2

F ≥ C (f (W)− f ∗)I Show C ≥ 0.I So f (W) = f ∗.

Today’s paper: convergence analysis for two layer withReLU [Li and Yuan, 2017]

Below: our thought process, might be messy.

Our starting model

I A two layer model, not deep.

f (x ,W) = ‖ReLU(W>x)‖1

I Assume there exists a teacher network producing thelabels.

f (x ,W∗) = ‖ReLU(W∗>x)‖1

I Square loss.

L(W) = Ex [(f (x ,W)− f (x ,W∗))2]

I x ∼ N (0, I).I A common assumption [Choromanska et al., 2015,

Tian, 2016, Xie et al., 2017].

input x

W>x

ReLU(W>x)

Take sum

output

Our starting model

I A two layer model, not deep.

f (x ,W) = ‖ReLU(W>x)‖1

I Assume there exists a teacher network producing thelabels.

f (x ,W∗) = ‖ReLU(W∗>x)‖1

I Square loss.

L(W) = Ex [(f (x ,W)− f (x ,W∗))2]

I x ∼ N (0, I).I A common assumption [Choromanska et al., 2015,

Tian, 2016, Xie et al., 2017].

input x

W>x

ReLU(W>x)

Take sum

output

Our starting model

I A two layer model, not deep.

f (x ,W) = ‖ReLU(W>x)‖1

I Assume there exists a teacher network producing thelabels.

f (x ,W∗) = ‖ReLU(W∗>x)‖1

I Square loss.

L(W) = Ex [(f (x ,W)− f (x ,W∗))2]

I x ∼ N (0, I).I A common assumption [Choromanska et al., 2015,

Tian, 2016, Xie et al., 2017].

input x

W>x

ReLU(W>x)

Take sum

output

Our starting model

I A two layer model, not deep.

f (x ,W) = ‖ReLU(W>x)‖1

I Assume there exists a teacher network producing thelabels.

f (x ,W∗) = ‖ReLU(W∗>x)‖1

I Square loss.

L(W) = Ex [(f (x ,W)− f (x ,W∗))2]

I x ∼ N (0, I).I A common assumption [Choromanska et al., 2015,

Tian, 2016, Xie et al., 2017].

input x

W>x

ReLU(W>x)

Take sum

output

Our starting model

I A two layer model, not deep.

f (x ,W) = ‖ReLU(W>x)‖1

I Assume there exists a teacher network producing thelabels.

f (x ,W∗) = ‖ReLU(W∗>x)‖1

I Square loss.

L(W) = Ex [(f (x ,W)− f (x ,W∗))2]

I x ∼ N (0, I).I A common assumption [Choromanska et al., 2015,

Tian, 2016, Xie et al., 2017].

input x

W>x

ReLU(W>x)

Take sum

output

However..

[Tian, 2016] showed, even if

I W initialized symmetrically

I W∗ forms orthonormal basis.

I Gradient descent may stuck atsaddle points.

Complicated surface, hard to analyze.

From [Tian, 2016]

However.. [stuck]

[Tian, 2016] showed, even if

I W initialized symmetrically

I W∗ forms orthonormal basis.

I Gradient descent may stuck atsaddle points.

Complicated surface, hard to analyze.

From [Tian, 2016]

Residual network [He et al., 2016]

I state-of-the-art structure. (3300citations)

I Won a few competitions.

I Very powerful after stacking multipleblocks together.

I Easy to train.

A ResNet Block. From[He et al., 2016]

Residual network [He et al., 2016]

I state-of-the-art structure. (3300citations)

I Won a few competitions.

I Very powerful after stacking multipleblocks together.

I Easy to train.

A ResNet Block. From[He et al., 2016]

Residual network [He et al., 2016]

I state-of-the-art structure. (3300citations)

I Won a few competitions.

I Very powerful after stacking multipleblocks together.

I Easy to train.A ResNet Block. From[He et al., 2016]

Adding residual link?

I We modify the network:

f (x ,W) = ‖ReLU(W>x)‖1

to (adding identity)

f (x ,W) = ‖ReLU((I + W)>x)‖1

I Same for the ground truth f (x ,W∗).

I Essentially: move the weight by I.

input x

W>x

⊕ResidualLink +x

ReLU((I + W)>x)

Take sum

output

Adding residual link?

I We modify the network:

f (x ,W) = ‖ReLU(W>x)‖1

to (adding identity)

f (x ,W) = ‖ReLU((I + W)>x)‖1

I Same for the ground truth f (x ,W∗).

I Essentially: move the weight by I.

input x

W>x

⊕ResidualLink +x

ReLU((I + W)>x)

Take sum

output

Adding residual link?

I We modify the network:

f (x ,W) = ‖ReLU(W>x)‖1

to (adding identity)

f (x ,W) = ‖ReLU((I + W)>x)‖1

I Same for the ground truth f (x ,W∗).

I Essentially: move the weight by I.

input x

W>x

⊕ResidualLink +x

ReLU((I + W)>x)

Take sum

output

Adding residual link?

I We modify the network:

f (x ,W) = ‖ReLU(W>x)‖1

to (adding identity)

f (x ,W) = ‖ReLU((I + W)>x)‖1

I Same for the ground truth f (x ,W∗).

I Essentially: move the weight by I.input x

W>x

⊕ResidualLink +x

ReLU((I + W)>x)

Take sum

output

Ask Simulation: does SGD converge for this model?

Ask Simulation: does SGD converge for this model?

Simulation says: yes.

Illustration of our key observation

O

I

I + W∗

I + W

Easy for SGD

Unknown

Seems hard

Residual link

How to prove this?

How to prove this?

One-point convexity.

A function f (x) is called δ-one point strongly convex in domain Dwith respect to point x∗, if ∀x ∈ D,〈−∇f (x), x∗ − x〉 > δ‖x∗ − x‖2

2.

How to prove this?

One-point convexity.

A function f (x) is called δ-one point strongly convex in domain Dwith respect to point x∗, if ∀x ∈ D,〈−∇f (x), x∗ − x〉 > δ‖x∗ − x‖2

2.

I A weaker condition than convexity.

I if it’s one point convex, we get toW∗ closer after every step, as longas the step size is small.

0

50

100

150

200

0

50

100

150

200

−5

0

5

10

15

How to prove this?

One-point convexity.

A function f (x) is called δ-one point strongly convex in domain Dwith respect to point x∗, if ∀x ∈ D,〈−∇f (x), x∗ − x〉 > δ‖x∗ − x‖2

2.

I A weaker condition than convexity.

I if it’s one point convex, we get toW∗ closer after every step, as longas the step size is small.

0

50

100

150

200

0

50

100

150

200

−5

0

5

10

15

One-point convex: an illustration

W∗

W1

W5

Ask Simulation: is it one point convex?

Ask Simulation: is it one point convex?

Simulation says: yes.

Compute 〈−∇L(W),W∗ −W〉

−∇L(W)j

=d∑

i=1

[π2

(w∗i − wi ) +(π

2− θi∗,j

)(ei + w∗i )−

(π2− θi ,j

)(ei + wi )

+ (‖ei + w∗i ‖2 sin θi∗,j − ‖ei + wi‖2 sin θi ,j)ej + wj

]

Compute 〈−∇L(W),W∗ −W〉

−∇L(W)j

=d∑

i=1

[π2

(w∗i − wi ) +(π

2− θi∗,j

)(ei + w∗i )−

(π2− θi ,j

)(ei + wi )

+ (‖ei + w∗i ‖2 sin θi∗,j − ‖ei + wi‖2 sin θi ,j)ej + wj

]

I ei ,wi ,w∗i are column vectors of I,W,W∗.

I θi ,j∗ : angle between ei + wi and ej + w∗j (Hard)

I sin θi ,j (Hard)

Taylor expansion: tedious calculation

What is θi ,j∗? It is the angle between wi + ei and w∗j + ej . We have

cos(θi ,j∗)

=〈wi + ei ,w

∗j + ej〉

‖wi + ei‖‖w∗j + ej‖=〈wi ,w

∗j 〉+ wi ,j + w∗j ,i

‖wi + ei‖‖w∗j + ej‖≈ (1− wi ,i )(1− w∗j ,j)(〈wi ,w

∗j 〉+ wi ,j + w∗j ,i )

≈ (1− wi ,i − w∗j ,j)(〈wi ,w∗j 〉+ wi ,j + w∗j ,i )

≈ 〈wi ,w∗j 〉+ wi ,j + w∗j ,i − wi ,iwi ,j − wi ,iw

∗j ,i − w∗j ,jwi ,j − w∗j ,jw

∗j ,i

Since we know arccos(x) ≈ π/2− x , we have

θi ,j∗ ≈π

2−〈wi ,w

∗j 〉+ wi ,j + w∗j ,i

‖wi + ei‖‖w∗j + ej‖

However direct Taylor expansion is too loose

I In order to show 〈−∇L(W),W∗ −W〉 ≥ 0, we need toassume γ , max‖W‖2, ‖W∗‖2 ≤ O( 1

d )

I Super local region.

However direct Taylor expansion is too loose

I In order to show 〈−∇L(W),W∗ −W〉 ≥ 0, we need toassume γ , max‖W‖2, ‖W∗‖2 ≤ O( 1

d )

I Super local region.

However direct Taylor expansion is too loose [stuck]

I In order to show 〈−∇L(W),W∗ −W〉 ≥ 0, we need toassume γ , max‖W‖2, ‖W∗‖2 ≤ O( 1

d )

I Super local region.

Ask Simulation: what is the largest γ to satisfy one pointconvexity?

Ask Simulation: what is the largest γ to satisfy one pointconvexity?

Simulation says: Ω(1).

O

I

I + W∗

I + W

Easy for SGD

Unknown

Seems hard

Residual link

Taylor expansion: tedious calculation

What is θi ,j∗? It is the angle between wi + ei and w∗j + ej . We have

cos(θi ,j∗)

=〈wi + ei ,w

∗j + ej〉

‖wi + ei‖‖w∗j + ej‖=〈wi ,w

∗j 〉+ wi ,j + w∗j ,i

‖wi + ei‖‖w∗j + ej‖≈ (1− wi ,i )(1− w∗j ,j)(〈wi ,w

∗j 〉+ wi ,j + w∗j ,i )

≈ (1− wi ,i − w∗j ,j)(〈wi ,w∗j 〉+ wi ,j + w∗j ,i )

≈ 〈wi ,w∗j 〉+ wi ,j + w∗j ,i − wi ,iwi ,j − wi ,iw

∗j ,i − w∗j ,jwi ,j − w∗j ,jw

∗j ,i

Since we know arccos(x) ≈ π/2− x , we have

θi ,j∗ ≈π

2−〈wi ,w

∗j 〉+ wi ,j + w∗j ,i

‖wi + ei‖‖w∗j + ej‖

Use geometry to get tighter bounds!

I Denote ei + w∗i as−→OC , ei + wi as

−→OD, ei + w∗i as

−→OA, ei + wi as

−→OB. Thus, ‖w∗i − wi‖2 = ‖

−→DC‖2.

I Draw−→HB ‖

−→CD, so ‖

−→OH‖2 ≥ ‖

−→OB‖2 = ‖

−→OA‖2.

I Since 4CDO ∼ 4HBO, we have

‖−→CD‖2

‖−→HB‖2

=‖−→OD‖2

‖−→OB‖2

= ‖−→OD‖2 ≥ 1− γ

I So ‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2.

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

Use geometry to get tighter bounds!

I Denote ei + w∗i as−→OC , ei + wi as

−→OD, ei + w∗i as

−→OA, ei + wi as

−→OB. Thus, ‖w∗i − wi‖2 = ‖

−→DC‖2.

I Draw−→HB ‖

−→CD, so ‖

−→OH‖2 ≥ ‖

−→OB‖2 = ‖

−→OA‖2.

I Since 4CDO ∼ 4HBO, we have

‖−→CD‖2

‖−→HB‖2

=‖−→OD‖2

‖−→OB‖2

= ‖−→OD‖2 ≥ 1− γ

I So ‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2.

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

Use geometry to get tighter bounds!

I Denote ei + w∗i as−→OC , ei + wi as

−→OD, ei + w∗i as

−→OA, ei + wi as

−→OB. Thus, ‖w∗i − wi‖2 = ‖

−→DC‖2.

I Draw−→HB ‖

−→CD, so ‖

−→OH‖2 ≥ ‖

−→OB‖2 = ‖

−→OA‖2.

I Since 4CDO ∼ 4HBO, we have

‖−→CD‖2

‖−→HB‖2

=‖−→OD‖2

‖−→OB‖2

= ‖−→OD‖2 ≥ 1− γ

I So ‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2.

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

Use geometry to get tighter bounds!

I Denote ei + w∗i as−→OC , ei + wi as

−→OD, ei + w∗i as

−→OA, ei + wi as

−→OB. Thus, ‖w∗i − wi‖2 = ‖

−→DC‖2.

I Draw−→HB ‖

−→CD, so ‖

−→OH‖2 ≥ ‖

−→OB‖2 = ‖

−→OA‖2.

I Since 4CDO ∼ 4HBO, we have

‖−→CD‖2

‖−→HB‖2

=‖−→OD‖2

‖−→OB‖2

= ‖−→OD‖2 ≥ 1− γ

I So ‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2.

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

Central Lemma: geometric lemma

I 4ABO is a isosceles triangle, so ‖−→AG‖2 = ‖

−→GB‖2.

I ‖−→HB‖2 ≥ ‖

−→AB‖2 = 2‖

−→GB‖2.

I ‖−→GB‖2 ≤ ‖

−→HB‖2

2 ≤ ‖−→CD‖2

2(1−γ) .

I ‖−→GB‖2 ≤ ‖

−→CD‖2

2(1−γ) .

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2

Central Lemma: geometric lemma

I 4ABO is a isosceles triangle, so ‖−→AG‖2 = ‖

−→GB‖2.

I ‖−→HB‖2 ≥ ‖

−→AB‖2 = 2‖

−→GB‖2.

I ‖−→GB‖2 ≤ ‖

−→HB‖2

2 ≤ ‖−→CD‖2

2(1−γ) .

I ‖−→GB‖2 ≤ ‖

−→CD‖2

2(1−γ) .

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2

Central Lemma: geometric lemma

I 4ABO is a isosceles triangle, so ‖−→AG‖2 = ‖

−→GB‖2.

I ‖−→HB‖2 ≥ ‖

−→AB‖2 = 2‖

−→GB‖2.

I ‖−→GB‖2 ≤ ‖

−→HB‖2

2 ≤ ‖−→CD‖2

2(1−γ) .

I ‖−→GB‖2 ≤ ‖

−→CD‖2

2(1−γ) .

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2

Central Lemma: geometric lemma

I 4ABO is a isosceles triangle, so ‖−→AG‖2 = ‖

−→GB‖2.

I ‖−→HB‖2 ≥ ‖

−→AB‖2 = 2‖

−→GB‖2.

I ‖−→GB‖2 ≤ ‖

−→HB‖2

2 ≤ ‖−→CD‖2

2(1−γ) .

I ‖−→GB‖2 ≤ ‖

−→CD‖2

2(1−γ) .

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→CD‖2 ≥ (1− γ)‖

−→HB‖2

Central Lemma: geometric lemma

I 4ABE ∼ 4BGO

I

‖−→AE‖2

‖−→AB‖2

=‖−→OG‖2

‖−→OB‖2

=

√1− ‖

−→GB‖2

2

1

I

‖−→AE‖2

‖−→AB‖2

√1−

‖−→CD‖2

2

4(1− γ)2≥

√1−

1− γ

)2

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→GB‖2 ≤ ‖

−→CD‖2

2(1−γ)

‖wi − w∗i ‖2 ≤ 2γ

Central Lemma: geometric lemma

I 4ABE ∼ 4BGO

I

‖−→AE‖2

‖−→AB‖2

=‖−→OG‖2

‖−→OB‖2

=

√1− ‖

−→GB‖2

2

1

I

‖−→AE‖2

‖−→AB‖2

√1−

‖−→CD‖2

2

4(1− γ)2≥

√1−

1− γ

)2

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→GB‖2 ≤ ‖

−→CD‖2

2(1−γ)

‖wi − w∗i ‖2 ≤ 2γ

Central Lemma: geometric lemma

I 4ABE ∼ 4BGO

I

‖−→AE‖2

‖−→AB‖2

=‖−→OG‖2

‖−→OB‖2

=

√1− ‖

−→GB‖2

2

1

I

‖−→AE‖2

‖−→AB‖2

√1−

‖−→CD‖2

2

4(1− γ)2≥

√1−

1− γ

)2

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→GB‖2 ≤ ‖

−→CD‖2

2(1−γ)

‖wi − w∗i ‖2 ≤ 2γ

Central Lemma: geometric lemma

I ‖−→CD‖2 ≥ ‖

−→CF‖2

I 4CFO ∼ 4AEO

I ‖W∗‖2 ≤ γI

‖−→CD‖2

‖−→AE‖2

≥ ‖−→CF‖2

‖−→AE‖2

=‖−→OC‖2

‖−→OA‖2

= ‖ei + w∗i ‖2 ≥ 1− γ

I ‖−→CD‖2 ≥ (1− γ)‖

−→AE‖2 ≥

(1− γ)

√1−

1−γ

)2‖−→AB‖2

I ‖−→AB‖2 ≤ ‖

−→CD‖2√1−2γ

=‖w∗i −wi‖2√

1−2γ

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→AE‖2

‖−→AB‖2

≥√

1−(

γ1−γ

)2

Central Lemma: geometric lemma

I ‖−→CD‖2 ≥ ‖

−→CF‖2

I 4CFO ∼ 4AEO

I ‖W∗‖2 ≤ γI

‖−→CD‖2

‖−→AE‖2

≥ ‖−→CF‖2

‖−→AE‖2

=‖−→OC‖2

‖−→OA‖2

= ‖ei + w∗i ‖2 ≥ 1− γ

I ‖−→CD‖2 ≥ (1− γ)‖

−→AE‖2 ≥

(1− γ)

√1−

1−γ

)2‖−→AB‖2

I ‖−→AB‖2 ≤ ‖

−→CD‖2√1−2γ

=‖w∗i −wi‖2√

1−2γ

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→AE‖2

‖−→AB‖2

≥√

1−(

γ1−γ

)2

Central Lemma: geometric lemma

I ‖−→CD‖2 ≥ ‖

−→CF‖2

I 4CFO ∼ 4AEO

I ‖W∗‖2 ≤ γ

I

‖−→CD‖2

‖−→AE‖2

≥ ‖−→CF‖2

‖−→AE‖2

=‖−→OC‖2

‖−→OA‖2

= ‖ei + w∗i ‖2 ≥ 1− γ

I ‖−→CD‖2 ≥ (1− γ)‖

−→AE‖2 ≥

(1− γ)

√1−

1−γ

)2‖−→AB‖2

I ‖−→AB‖2 ≤ ‖

−→CD‖2√1−2γ

=‖w∗i −wi‖2√

1−2γ

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→AE‖2

‖−→AB‖2

≥√

1−(

γ1−γ

)2

Central Lemma: geometric lemma

I ‖−→CD‖2 ≥ ‖

−→CF‖2

I 4CFO ∼ 4AEO

I ‖W∗‖2 ≤ γI

‖−→CD‖2

‖−→AE‖2

≥ ‖−→CF‖2

‖−→AE‖2

=‖−→OC‖2

‖−→OA‖2

= ‖ei + w∗i ‖2 ≥ 1− γ

I ‖−→CD‖2 ≥ (1− γ)‖

−→AE‖2 ≥

(1− γ)

√1−

1−γ

)2‖−→AB‖2

I ‖−→AB‖2 ≤ ‖

−→CD‖2√1−2γ

=‖w∗i −wi‖2√

1−2γ

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→AE‖2

‖−→AB‖2

≥√

1−(

γ1−γ

)2

Central Lemma: geometric lemma

I ‖−→CD‖2 ≥ ‖

−→CF‖2

I 4CFO ∼ 4AEO

I ‖W∗‖2 ≤ γI

‖−→CD‖2

‖−→AE‖2

≥ ‖−→CF‖2

‖−→AE‖2

=‖−→OC‖2

‖−→OA‖2

= ‖ei + w∗i ‖2 ≥ 1− γ

I ‖−→CD‖2 ≥ (1− γ)‖

−→AE‖2 ≥

(1− γ)

√1−

1−γ

)2‖−→AB‖2

I ‖−→AB‖2 ≤ ‖

−→CD‖2√1−2γ

=‖w∗i −wi‖2√

1−2γ

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖−→AE‖2

‖−→AB‖2

≥√

1−(

γ1−γ

)2

Central Lemma: geometric lemma

I |〈ei + w∗i − ei + wi , ei + wi 〉| = ‖−→BE‖2

I 4ABE ∼ 4GBO

I

‖−→BE‖2

‖−→AB‖2

=‖−→GB‖2

‖−→BO‖2

=‖−→AB‖2

2

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖wi − w∗i ‖2 ≤ 2γ

‖−→AB‖2 ≤

‖w∗i −wi‖2√1−2γ

Central Lemma: geometric lemma

I |〈ei + w∗i − ei + wi , ei + wi 〉| = ‖−→BE‖2

I 4ABE ∼ 4GBO

I

‖−→BE‖2

‖−→AB‖2

=‖−→GB‖2

‖−→BO‖2

=‖−→AB‖2

2

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖wi − w∗i ‖2 ≤ 2γ

‖−→AB‖2 ≤

‖w∗i −wi‖2√1−2γ

Central Lemma: geometric lemma

I |〈ei + w∗i − ei + wi , ei + wi 〉| = ‖−→BE‖2

I 4ABE ∼ 4GBO

I

‖−→BE‖2

‖−→AB‖2

=‖−→GB‖2

‖−→BO‖2

=‖−→AB‖2

2

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖wi − w∗i ‖2 ≤ 2γ

‖−→AB‖2 ≤

‖w∗i −wi‖2√1−2γ

Central Lemma: geometric lemma

I |〈ei + w∗i − ei + wi , ei + wi 〉| = ‖−→BE‖2

I 4ABE ∼ 4GBO

I

‖−→BE‖2

‖−→AB‖2

=‖−→GB‖2

‖−→BO‖2

=‖−→AB‖2

2

|〈ei + w∗i − ei + wi , ei + wi 〉| =‖−→AB‖2

2

2≤‖w∗i − wi‖2

2

2(1− 2γ)

This is a very tight bound O(γ2).

O

A B

CD

E

FGH

OC : ei + w∗iOA : ei + w∗i

OD : ei + wi

OB : ei + wi

‖wi − w∗i ‖2 ≤ 2γ

‖−→AB‖2 ≤

‖w∗i −wi‖2√1−2γ

Is it enough..?

With tight bounds, we get, if ‖W0‖2, ‖W∗‖2 ≤ γ = Ω(1):

〈−∇L(W),W∗ −W〉 >(

0.084− (1 + γ)g

2(1− 2γ)

)‖W∗ −W‖2

F

I g is a potential function:

g ,d∑

i=1

(‖ei + w∗i ‖2 − ‖ei + wi‖2)

I d is input dimension.

I How does g affect OPC?

Is it enough..?

With tight bounds, we get, if ‖W0‖2, ‖W∗‖2 ≤ γ = Ω(1):

〈−∇L(W),W∗ −W〉 >(

0.084− (1 + γ)g

2(1− 2γ)

)‖W∗ −W‖2

F

I g is a potential function:

g ,d∑

i=1

(‖ei + w∗i ‖2 − ‖ei + wi‖2)

I d is input dimension.

I How does g affect OPC?

Ask Simulation: what if g is large?

Ask Simulation: what if g is large?

Simulation:

Ask Simulation: what if g is large?

Simulation:

Ask Simulation: what if g is large? [stuck]

Simulation:

Ask Simulation: will g always decrease with SGD?

Ask Simulation: will g always decrease with SGD?

Simulation: Yes (for lots of instances)

g controls the dynamics!

PI PII

The actual dynamics

W1

W∗

W6

W10

Phase I: W1 →W6, W may go to the wrong direction. Phase II:W6 →W10, W gets closer to W∗ in every step by one point convexity.

Two phase framework

I Phase II g decreases to a small value.I Technique: analyze the dynamics of SGD

I Phase III One point convex regionI Get closer to W∗ after every stepI g is always smallI Technique: compute inner product

Phase I: g keeps decreasing

Phase I: g keeps decreasing

I g ,∑d

i=1(‖ei + w∗i ‖2 − ‖ei + wi‖2).

Phase I: g keeps decreasing

I g ,∑d

i=1(‖ei + w∗i ‖2 − ‖ei + wi‖2).

What is ∆g here?

ei + wi

O

ei ∆wi

∆‖ei + wi‖2

Phase I: g keeps decreasing

I g ,∑d

i=1(‖ei + w∗i ‖2 − ‖ei + wi‖2).

What is ∆g here?

∆g ≈∑i

〈−∆wi , ei + wi 〉 = 〈η∇L(W), I + W〉

≈〈η∇L(W), I〉 = ηTr(∇L(W))

ei + wi

O

ei ∆wi

∆‖ei + wi‖2

Phase I: g keeps decreasing

I g ,∑d

i=1(‖ei + w∗i ‖2 − ‖ei + wi‖2).

What is ∆g here?

∆g ≈∑i

〈−∆wi , ei + wi 〉 = 〈η∇L(W), I + W〉

≈〈η∇L(W), I〉 = ηTr(∇L(W))

What is Tr(∇L(W))?

Tr(∇L(W)) ≈ O(Tr(W∗ −W)) + O(Tr((W∗ −W)uu>)) + dg ei + wi

O

ei ∆wi

∆‖ei + wi‖2

Phase I: g keeps decreasing

I g ,∑d

i=1(‖ei + w∗i ‖2 − ‖ei + wi‖2).

What is ∆g here?

∆g ≈∑i

〈−∆wi , ei + wi 〉 = 〈η∇L(W), I + W〉

≈〈η∇L(W), I〉 = ηTr(∇L(W))

What is Tr(∇L(W))?

Tr(∇L(W)) ≈ O(Tr(W∗ −W)) + O(Tr((W∗ −W)uu>)) + dg ei + wi

O

ei ∆wi

∆‖ei + wi‖2

Phase I: g keeps decreasing

I g ,∑d

i=1(‖ei + w∗i ‖2 − ‖ei + wi‖2).

What is ∆g here?

∆g ≈∑i

〈−∆wi , ei + wi 〉 = 〈η∇L(W), I + W〉

≈〈η∇L(W), I〉 = ηTr(∇L(W))

What is Tr(∇L(W))?

Tr(∇L(W)) ≈ O(Tr(W∗ −W)) + O(Tr((W∗ −W)uu>)) + dg ei + wi

O

ei ∆wi

∆‖ei + wi‖2

Phase I: g keeps decreasing

Observation:

Tr(W∗ −W) =d∑i

(1 + w∗i,i − 1− wi,i ) ≈d∑i

(‖ei + w∗i ‖2 − ‖ei + wi‖2) = g

Phase I: g keeps decreasing

Observation:

Tr(W∗ −W) =d∑i

(1 + w∗i,i − 1− wi,i ) ≈d∑i

(‖ei + w∗i ‖2 − ‖ei + wi‖2) = g

But Tr((W∗ −W)uu>) is hard to bound. Therefore, we considerthe joint updating rule of s , (W∗ −W)u and g .(Below is for illustration only)

‖st+1‖2 ≈ 0.9‖st‖2 + 10η|gt ||gt+1| ≈ 0.9|gt |+ 10η‖st‖2

When 10η < 0.05, ‖st+1‖2 + |gt+1| ≤ 0.95(‖st‖2 + |gt |). So |gt |will be very small.

Main result

Main Theorem (informal).

If input is from Gaussian distribution, ‖W0‖2, ‖W∗‖2 ≤ γ(constant), step size is small, SGD with mini batch and initial pointW0 will reach W∗ after polynomial number of steps, in two phases.

Main result

Main Theorem (informal).

If input is from Gaussian distribution, ‖W0‖2, ‖W∗‖2 ≤ γ(constant), step size is small, SGD with mini batch and initial pointW0 will reach W∗ after polynomial number of steps, in two phases.

Matches with standard O(1/√d) initialization schemes. (d is

input dimension.)

Open questions

I Multiple layers?

I Other input distributions?

I Convolutional networks?

I Residual link that skips two layers?

I Identify different potential functions for other non-convexproblems?

Andoni, A., Panigrahy, R., Valiant, G., and Zhang, L. (2014).Learning polynomials with neural networks.In ICML, pages 1908–1916.

Arora, S., Bhaskara, A., Ge, R., and Ma, T. (2014).Provable bounds for learning some deep representations.In Proceedings of the 31th International Conference onMachine Learning, ICML 2014, Beijing, China, 21-26 June2014, pages 584–592.

Barron, A. R. (1993).Universal approximation bounds for superpositions of asigmoidal function.IEEE Trans. Information Theory, 39(3):930–945.

Brutzkus, A. and Globerson, A. (2017).Globally optimal gradient descent for a convnet with gaussianinputs.In ICML 2017.

Chaudhari, P., Choromanska, A., Soatto, S., LeCun, Y.,Baldassi, C., Borgs, C., Chayes, J., Sagun, L., and Zecchina,R. (2016).Entropy-SGD: Biasing Gradient Descent Into Wide Valleys.ArXiv e-prints.

Choromanska, A., Henaff, M., Mathieu, M., Arous, G. B., andLeCun, Y. (2015).The loss surfaces of multilayer networks.In AISTATS.

Cybenko, G. (1992).Approximation by superpositions of a sigmoidal function.MCSS, 5(4):455.

Eldan, R. and Shamir, O. (2015).The Power of Depth for Feedforward Neural Networks.ArXiv e-prints.

Goel, S., Kanade, V., Klivans, A. R., and Thaler, J. (2016).Reliably learning the relu in polynomial time.

CoRR, abs/1611.10258.

Goel, S. and Klivans, A. (2017).Eigenvalue decay implies polynomial-time learnability forneural networks.In NIPS 2017.

Hardt, M. and Ma, T. (2016).Identity matters in deep learning.CoRR, abs/1611.04231.

He, K., Zhang, X., Ren, S., and Sun, J. (2016).Deep residual learning for image recognition.In CVPR, pages 770–778.

Hochreiter, S. and Schmidhuber, J. (1995).Simplifying neural nets by discovering flat minima.In Advances in Neural Information Processing Systems 7,pages 529–536. MIT Press.

Hornik, K., Stinchcombe, M. B., and White, H. (1989).Multilayer feedforward networks are universal approximators.

Neural Networks, 2(5):359–366.

Janzamin, M., Sedghi, H., and Anandkumar, A. (2015).Beating the perils of non-convexity: Guaranteed training ofneural networks using tensor methods.arXiv preprint arXiv:1506.08473.

Kawaguchi, K. (2016).Deep learning without poor local minima.In NIPS, pages 586–594.

Lee, H., Ge, R., Risteski, A., Ma, T., and Arora, S. (2017).On the ability of neural nets to express distributions.ArXiv e-prints.

Li, Y. and Yuan, Y. (2017).Convergence analysis of two-layer neural networks with reluactivation.In NIPS 2017.

Livni, R., Shalev-Shwartz, S., and Shamir, O. (2014).On the computational efficiency of training neural networks.

In NIPS, pages 855–863.

Safran, I. and Shamir, O. (2016).Depth-Width Tradeoffs in Approximating Natural Functionswith Neural Networks.ArXiv e-prints.

Saxe, A. M., McClelland, J. L., and Ganguli, S. (2013).Exact solutions to the nonlinear dynamics of learning in deeplinear neural networks.CoRR, abs/1312.6120.

Sedghi, H. and Anandkumar, A. (2015).Provable methods for training neural networks with sparseconnectivity.ICLR.

Shamir, O. (2016).Distribution-specific hardness of learning neural networks.CoRR, abs/1609.01037.

Shirish Keskar, N., Mudigere, D., Nocedal, J., Smelyanskiy,M., and Tang, P. T. P. (2016).On Large-Batch Training for Deep Learning: GeneralizationGap and Sharp Minima.ArXiv e-prints.

Sıma, J. (2002).Training a single sigmoidal neuron is hard.Neural Computation, 14(11):2709–2728.

Tian, Y. (2016).Symmetry-breaking convergence analysis of certain two-layeredneural networks with relu nonlinearity.In Submitted to ICLR 2017.

Xie, B., Liang, Y., and Song, L. (2017).Diversity leads to generalization in neural networks.In AISTATS.

Zhang, C., Bengio, S., Hardt, M., Recht, B., and Vinyals, O.(2016).

Understanding deep learning requires rethinking generalization.

ArXiv e-prints.

Zhang, Y., Lee, J. D., Wainwright, M. J., and Jordan, M. I.(2015).Learning halfspaces and neural networks with randominitialization.CoRR, abs/1511.07948.

Zhong, K., Song, Z., Jain, P., Bartlett, P. L., and Dhillon, I. S.(2017).Recovery guarantees for one-hidden-layer neural networks.In ICML 2017.

top related