580.691 learning theory reza shadmehr the loss function, the normal equation,

31
580.691 Learning Theory Reza Shadmehr The loss function, the normal equation, cross validation, online learning, and the LMS algorithm

Upload: miranda-roach

Post on 03-Jan-2016

19 views

Category:

Documents


0 download

DESCRIPTION

580.691 Learning Theory Reza Shadmehr The loss function, the normal equation, cross validation, online learning, and the LMS algorithm. Review of the linear classification problem - PowerPoint PPT Presentation

TRANSCRIPT

Page 1: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

580.691 Learning Theory

Reza Shadmehr

The loss function, the normal equation,

cross validation, online learning, and the LMS algorithm

Page 2: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Review of the linear classification problem

• Hypothesis class: we assume that what we are about to approximate is

a function that belongs to some space of functions F. We don’t know

the true function f, but we hypothesize that it too belongs to F :

:

f F

f X Y X y Y

x

1 1ˆ ( )=sign( ) sign( )i i d df w w x w x x x

f̂ FHypothesis:

Page 3: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

• Estimation: we are given a training set of examples and labels and using some adaptation algorithm, we find

(1) (1) (2) (2) ( ) ( )

( ) ( ) ( ) ( )( ) ( ) ( ) ( ) ( )1 1

, , , , , ,

ˆˆ ; =sign sign( )

n n

i i i ii i i i id d

y y y

y f w x w x

x x x

x w w x

f̂ F

• Evaluation: we measure how well our estimate generalizes to novel examples.

?ˆˆ ( )=new new newy f y x

• Whenever our estimate was wrong, change the weight for each “expert”:

( 1) ( ) ( ) ( )i i i iy w w x ( ) ( ) ( )i i iy sign w xwhenever

Trial number

Pixel number

Page 4: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Loss function

• The loss function provides a cost for being wrong. The objective of

adaptation is to minimize the loss function. In the case of image labeling

problem, we might have:

ˆ0,

ˆ,ˆ1,

y yLoss y y

y y

• We might want to minimize the loss over the training set:

( ) ( ) ( ) ( )

1 1

1 1 ˆˆ, , ;n n

i i i i

i i

Loss y y Loss y fn n

x w

• This is a function of the parameters w and we can minimize it directly.

Find w that minimizes: ( ) ( )

1

1 ˆ, ;n

i i

i

Loss y fn x w

Empirical loss function

Page 5: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

The training loss based on a few sampled examples and labels serves as a proxy for the test performance measured over the whole population.

Issues about minimization of the loss function

• Why should be minimize the loss over the training set when we are

actually interested in minimizing the loss over the test set?

• We assume that each training and test example-label pair (x,y) is drawn

independently and at random from the same but unknown population of

examples and labels.

• We represent this population as a joint probability distribution p(x,y) so

that each training/test example is a sample from this distribution:

( ) ( ), ,i iy p yx x

( ) ( )

1

( , )

1 ˆ, ;

ˆ, ;

ni i

i

y p

Loss y fn

E Loss y f

x

x w

x w

Empirical loss (training set only)

Expected loss over all the data (training and test together)

Page 6: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Regression

• The goal is to make quantitative (real valued) predictions on the basis of

a (vector of) features or attributes.

• Example: years to onset of Huntington’s disease in genetically at risk

individuals.

Current age

CAGrepeats

MotherHD?

FatherHD?

HD onset

43 37 55 1 051 33 37 0 149 39 43 1 1

We need to– specify the hypothesized class of functions (e.g., linear)– select how to measure prediction loss (the loss function)– solve the resulting minimization problem

Page 7: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Regression: Hypothesized class and the loss function

0 1

0 1 1

: ( ; )

: ( ; )dd d

f R R f x w w x

f R R f w w x w x

w

x w

Univariate regression:

Multivariate regression:

0 1[ , , , ]Tdw w ww Parameters we need to find:

Loss function:

2

2( ) ( ) ( ) ( )

1 1

ˆ ˆ( , ) ( )

1 1ˆ( ) , ;

n ni i i i

ni i

Loss y y y y

J Loss y y y fn n

w x wEmpirical loss:

Squared error

Mean squared error

Page 8: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Regression: estimation of parameters

• We have to minimize the empirical loss function:

2( ) ( )

1

2( ) ( )0 1

1

1( ) ;

1

ni i

ni

ni i

i

J y fn

y w w xn

w x w

• This function is quadratic in terms of w. It has a minimum at some point in the w space. We find the w that minimizes the loss function by finding the conditions where the derivative of the loss function is zero.

1

0

( ) 0

( ) 0

n

n

dJ

dw

dJ

dw

w

w

Page 9: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

2( ) ( )0 1

1

2( ) ( )0 1

1 1 1

2( ) ( )0 1

11

( ) ( ) ( )0 1

1

( ) ( )0 1

0 1

1( )

1( ) 0

10

2( ) 0

2( ) ( 1) 0

ni i

ni

ni i

ni

ni i

i

ni i i

i

ni i

ni

J y w w xn

d dJ y w w x

dw n dw

dy w w x

n dw

y w w x xn

dJ y w w x

dw n

w

w

w

Optimality conditions: Finding conditions that minimize the empirical loss function

2( )

2

p f u u

dp dp du duu

dw du dw dw

Chain rule reminder

Page 10: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Expected behavior of the model errors (residuals)

( ) ( ) ( )0 1

( ) ( ) ( ) ( )

1 1 1

( ) ( )

0 1 1

2( ) ( ) 0 0

2( ) ( 1) 0 0

i i i

n ni i i i

ni i

n ni i

ni i

y y w w x

dJ y x y x

dw n

dJ y y

dw n

w

w

Error in the prediction (model residual)

• The prediction error should be mean zero and not contain any linear trends, i.e., be uncorrelated with any linear function of the inputs.

x

y

x

But there may exist some non-linear function of inputs that can account for the residuals.

( ) ( ) ( ) ( ) ( )0 1 0 1

1 1 1

0n n n

i i i i i

i i i

y k k x k y k y x

Page 11: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Loss function: matrix notation

(1) (1)

0

1( ) ( )

1

ˆ

1n n

y xw

X Xw

y x

y w y w

2(1) (1)

2 0( ) ( )0 1

11 ( ) ( )

2

11 1

( )

1

1

1

1

ni i

ni n n

T

T

y xw

J y w w xwn n

y x

Xn

X Xn

n

w

y w

y w y w

y y

2 2 2a

b a b c

c

“L2” norm

Page 12: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Optimality condition: minimize mean squared error

1 ( )

1( ) 0

1 1( ) ( ) 0

1 1( ) 0

1 10

1 10

20

Tn

Tn

TTT

TT T T T T

TT T T T T

T T T T

T T

J X Xn

d dJ X X

d d n

X X X Xn n

X X X X Xn n

X X X X X Xn n

X X X X X Xn n

X X Xn

w y w y w

w y w y ww w

y w y w

y w y w

y w y w

y w y w

y w

1

T T

T T

X X X

X X X

w y

w ythe “normal” equation

Page 13: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

The pseudo-inverse

1* *

* *

1

known: , want:

ˆ ˆ

ˆ ˆ ˆfind so that it minimizes the sum of squared errors:

pseudoinverse:

ˆ

ˆ

T

T T

T T

X

X

X

X X

X X X X X X I

X X X

X X X

y w

y w

y w

w y w y w

w y

w y

Page 14: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Review of regression

0 1

0 1 1

: ( ; )

: ( ; )dd d

f R R f x w w x

f R R f w w x w x

w

x w

Univariate regression:

Multivariate regression:

0 1[ , , , ]Tdw w ww Parameters we need to find:

Loss function:

2

2( ) ( ) ( ) ( )

1 1

ˆ ˆ( , ) ( )

1 1ˆ ˆ( ) , ;

n ni i i i

ni i

Loss y y y y

J Loss y y y fn n

w x wEmpirical loss:

1 1 ( ) T T

nJ X Xn n

w y w y w ε ε

1

T T

T T

X X X

X X X

w y

w y

2 ( ) 0T T

nd

J X X Xd n

w y ww

Page 15: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Regression with polynomials

• univariate regression with m-th order polynomials:

20 1 2: ( ; ) m

mf R R f x w w x w x w x w

1T TX X X

w y

2(1) (1) (1)

0(1)2(2) (2) (2)

1

( )

2( ) ( ) ( )

1

1

1

m

m

n

mmn n n

x x xw

ywx x x

X

y wx x x

y w

Page 16: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

0 1 ( ; )f x w w x w 2 30 1 2 3 ( ; )f x w w x w x w x w

2 50 1 2 5( ; )f x w w x w x w x w 2 10

0 1 2 10( ; )f x w w x w x w x w

-1.5 -1 -0.5 0 0.5 1 1.5

2

4

6

8

-1.5 -1 -0.5 0 0.5 1 1.5

2

4

6

8

-1.5 -1 -0.5 0 0.5 1 1.5

0

2

4

6

8

Regression with polynomials: fit improves with increased order

Page 17: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Over-fitting

• We want to fit the training set, but as model complexity increases, we

run the risk of over-fitting.

2( ) ( )

1

1ˆ; 0

ni i

i

y fn

x w

-1.5 -1 -0.5 0 0.5 1 1.5

0

2

4

6

8

Train set

-1.5 -1 -0.5 0 0.5 1 1.50

2

4

6

8

Leave out

When the model order is over-fitting, leaving a single data point out of the training set can drastically change the fit.

Page 18: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Cross validation

• We want to fit the training set, but we want to also generalize correctly.

To measure generalization, we leave out a data point (named the test

point), fit the data, and then measure error on the test point. The average

error over all possible test points is the cross validation error.

2( ) ( ) (! )

1

1;

ni i i

i

CV y fn

x w

(! )iwWeights estimated from a training set that does not include the i-th data point

Page 19: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

2 4 6 8 10

0.4

0.5

0.6

0.7

0.8

0.9

Model order

Mea

n-sq

uare

d er

ror

(tra

inin

g se

t)

2 4 6 8 100

25

50

75

100

125

150

175

1 2 3 4 5

1

1.5

2

2.5

3

Model order

Cro

ss-v

alid

atio

n er

ror

Cro

ss-v

alid

atio

n er

ror

Cross validation

• Cross validation error will often increase when the model structure is

over-fitting the data.

Model order

(actual data was generated with a 2nd order polynomial process)

Page 20: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Batch vs. online learning algorithms

• In “batch” learning, we don’t have to make any predictions until we see

all of the data. At that point, we make a model to fit all the data.

• In “online” learning, data points are given to us one at a time. We use

each example pair to update our model.

(1) (1) (2) (2) ( ) ( ), , , , , ,n nD y y y x x x

1

ˆ

T T

X

X X X

y w

w y

1 1 1 1(1) (1)(1) (1)0 1 1

(1) (1) (1)

(2) (1)

ˆ

ˆ

Tm my w w w

y y y

x x w x

w w

We are given an x and with our current model we predict a y

The teacher tells us our error

We modify our model

Page 21: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Online learning: the LMS algorithm

• Assume we have the model:

( )nx

( )nw

( )

( )

n

n

y

x

( )

( )

ˆ n

n

y

x

( ) ( ) ( ) ( )( ) ( ) ( )1 1 2 1ˆ n n n nn n T ny w x w x w x

When we project w onto x, we get a scalar p:

( ) ( ) ( ) ( )

( ) ( )

( ) ( )

( ) ( ) ( )

( ) ( )

cos

cos

ˆ

n T n n n

n T n

n n

n T n n

n n

yp

w x w x

w x

w x

w x

x x

( ) cosnp w

What we want is to change w so that when we project onto x we get:

Anywhere along the dash line is the solution we’re looking for.

( )

( )

n

n

y

x

Page 22: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

( )nx

( )nw

( )

( )

n

n

y

x

( )

( )

ˆ n

n

y

x( )n

( 1)nw

( ) ( ) ( )( )

( ) ( ) ( )

( ) ( ) ( )2( )

( 1) ( ) ( ) ( ) ( )2( )

ˆ

n n nn

n n n

n n n

n

n n n n n

n

y y

y y

y y

x

x x x

xx

w w xx

The LMS algorithm

Unit vector along x

“step size”

w changes along a vector parallel to the input x in that trial with a magnitude proportional to the prediction error in that trial.

With this step size, we change w to completely account for the error in that trial.

Page 23: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

LMS algorithm attempts to minimize a squared error loss function by approximating the gradient of the loss function

2( ) ( )

1

( )( ) ( )

1

( )( ) ( )

1

( )( 1) ( ) ( ) ( )

1

1 ( )

2

1

1

1

Nn T n

n

Nnn T n

ii n

Nnn T n

n

Nnt t n T n

n

J yN

dJy x

dw N

dJy

d N

yN

w w x

w x

w x xw

w w w x x

Average error over all data points

Steepest descent algorithm

( )

( 1) ( ) ( ) ( )( ) ( )

t

t t t T tt T t

y x

w w w xx x

LMS: local error as a rough estimate of average error

Page 24: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

( 1)nx( 1)n

( )nx

( )nw( )n

Iterating over two data points

( )nx

( )nw( )n

Iterating over three data points

( 1)nx( 1)n

Equilibrium point

( 2)nx

( 2)n

With 3 data points, solution will not move to a single point and stay put. It converges to a small region of the parameter space but will bounce around, as long as > 0.

Convergence of LMS-algorithm

( 1) ( ) ( ) ( ) ( ) n n n T n ny w w w x x

Page 25: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

It is difficult to prove “convergence” of LMS because the weights keep bouncing around. But we can prove convergence for the steepest decent algorithm and then use the fact that LMS is a stochastic approximation to it.

Convergence of LMS-algorithm

0 ( )(1) (0) ( ) ( )

1

(0) (0)

(0)

(2) (1)

(0)

2 1 2 (0)

1

1( ) (0)

1

NT nn n

n

T

T T

T T

T T T T

iT T T

i

iT T T

i

y

X X

I X X X

I X X X

I X X I X X X X

I X X X I X X

I X X X I X X

w w w x x

w y w

w y

w w y

w y y

y w

w y w

a geometric series

Page 26: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Convergence of a geometric series of scalars

1 0 1 1 1 1

1

1 2

1

1 1

1 1

1

1

1

11

lim if 11

ni n n

ni

ni n

ni

n ni i n

n ni i

n

n

n n

a a a a a a s

a a as a a a

a a a s as a

as

a

s aa

Page 27: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Convergence of a geometric series of matrices

1 1 1

1

1 1

1 1

1

1

1

1

lim 0 if 1

lim if 1

ni n

ni

n ni i n

n ni i

nn

n n

nn i

n n i

A I A A S

A A A S AS I A

S I A I A

A QLQ

A QL Q

A i

S I A i

See homework for this:

Page 28: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Convergence of steepest descent algorithm

1

1

1( ) (0)

1

1 1

1

1 if 1 , where is an eigen value of

if 1

1

is positive definite, therefore

ii

i

iT T T

i

iT T Ti

i

T Ti i

T Ti

A i AI A

I X X X I X X

I X X I I X X I X X i

I X X X X

X X X X

w y w

1 1( )

0

21 1 0< 2 0<

0 if 1 , (use eigen vector decomposition)

if 0< 2

T Ti i T

i

T Ti

T T Ti

i

X X i X X iX X

I X X I X X i

X X X X X i

w y

See homework for proof of this.

Page 29: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

We have shown convergence of the steepest decent algorithm to the solution of the normal equations. The LMS is a stochastic approximation to steepest decent, thus it “converges” as well, but will jump around stochastically, as long as the learning rate is greater than zero. Convergence can be reached when the learning rate is systematically made smaller on each step.

We will call changes of the learning rate “adaptive learning” and will see a principled approach to this problem when we consider Bayesian approaches to learning.

Page 30: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Summary: Linear Regression

0 1

0 1 1

: ( ; )

: ( ; )dd d

f R R f x w w x

f R R f w w x w x

w

x w

Univariate regression:

Multivariate regression:

0 1[ , , , ]Tdw w ww Parameters we need to find:

Loss function:

2

2( ) ( ) ( ) ( )

1 1

ˆ ˆ( , ) ( )

1 1ˆ ˆ( ) , ;

n ni i i i

ni i

Loss y y y y

J Loss y y y fn n

w x wEmpirical loss:

1 1 ( ) T T

nJ X Xn n

w y w y w ε ε

1

T T

T T

X X X

X X X

w y

w y

2 ( ) 0T T

nd

J X X Xd n

w y ww

Page 31: 580.691  Learning Theory Reza Shadmehr The loss function, the normal equation,

Summary: Iterative learning

• Increased model complexity reduces error over the training data but can

increase the leave-one-out cross validation error. We want a model that

fits the trained data and generalizes correctly.

• LMS algorithm:

• w changes along a vector parallel to the input x in that trial with a

magnitude proportional to the error in that trial.

• Steepest descent algorithm:

2( ) ( )

1

( )( 1) ( ) ( ) ( )

1

1 ( )

Nn T n

n

Nnt t n T n

n

J yN

y

w w x

w w w x x

( ) ( ) ( )

( 1) ( ) ( ) ( ) ( )2( )

1( 1) ( ) ( ) ( ) ( )max

ˆ

ˆ 0 2

n n T n

n n n n n

n

n n n n n T

y

y y

y y X X

w x

w w xx

w w x