580.691 learning theory reza shadmehr the loss function, the normal equation,
DESCRIPTION
580.691 Learning Theory Reza Shadmehr The loss function, the normal equation, cross validation, online learning, and the LMS algorithm. Review of the linear classification problem - PowerPoint PPT PresentationTRANSCRIPT
580.691 Learning Theory
Reza Shadmehr
The loss function, the normal equation,
cross validation, online learning, and the LMS algorithm
Review of the linear classification problem
• Hypothesis class: we assume that what we are about to approximate is
a function that belongs to some space of functions F. We don’t know
the true function f, but we hypothesize that it too belongs to F :
:
f F
f X Y X y Y
x
1 1ˆ ( )=sign( ) sign( )i i d df w w x w x x x
f̂
f̂ FHypothesis:
• Estimation: we are given a training set of examples and labels and using some adaptation algorithm, we find
(1) (1) (2) (2) ( ) ( )
( ) ( ) ( ) ( )( ) ( ) ( ) ( ) ( )1 1
, , , , , ,
ˆˆ ; =sign sign( )
n n
i i i ii i i i id d
y y y
y f w x w x
x x x
x w w x
f̂ F
• Evaluation: we measure how well our estimate generalizes to novel examples.
?ˆˆ ( )=new new newy f y x
• Whenever our estimate was wrong, change the weight for each “expert”:
( 1) ( ) ( ) ( )i i i iy w w x ( ) ( ) ( )i i iy sign w xwhenever
Trial number
Pixel number
Loss function
• The loss function provides a cost for being wrong. The objective of
adaptation is to minimize the loss function. In the case of image labeling
problem, we might have:
ˆ0,
ˆ,ˆ1,
y yLoss y y
y y
• We might want to minimize the loss over the training set:
( ) ( ) ( ) ( )
1 1
1 1 ˆˆ, , ;n n
i i i i
i i
Loss y y Loss y fn n
x w
• This is a function of the parameters w and we can minimize it directly.
Find w that minimizes: ( ) ( )
1
1 ˆ, ;n
i i
i
Loss y fn x w
Empirical loss function
The training loss based on a few sampled examples and labels serves as a proxy for the test performance measured over the whole population.
Issues about minimization of the loss function
• Why should be minimize the loss over the training set when we are
actually interested in minimizing the loss over the test set?
• We assume that each training and test example-label pair (x,y) is drawn
independently and at random from the same but unknown population of
examples and labels.
• We represent this population as a joint probability distribution p(x,y) so
that each training/test example is a sample from this distribution:
( ) ( ), ,i iy p yx x
( ) ( )
1
( , )
1 ˆ, ;
ˆ, ;
ni i
i
y p
Loss y fn
E Loss y f
x
x w
x w
Empirical loss (training set only)
Expected loss over all the data (training and test together)
Regression
• The goal is to make quantitative (real valued) predictions on the basis of
a (vector of) features or attributes.
• Example: years to onset of Huntington’s disease in genetically at risk
individuals.
Current age
CAGrepeats
MotherHD?
FatherHD?
HD onset
43 37 55 1 051 33 37 0 149 39 43 1 1
We need to– specify the hypothesized class of functions (e.g., linear)– select how to measure prediction loss (the loss function)– solve the resulting minimization problem
Regression: Hypothesized class and the loss function
0 1
0 1 1
: ( ; )
: ( ; )dd d
f R R f x w w x
f R R f w w x w x
w
x w
Univariate regression:
Multivariate regression:
0 1[ , , , ]Tdw w ww Parameters we need to find:
Loss function:
2
2( ) ( ) ( ) ( )
1 1
ˆ ˆ( , ) ( )
1 1ˆ( ) , ;
n ni i i i
ni i
Loss y y y y
J Loss y y y fn n
w x wEmpirical loss:
Squared error
Mean squared error
Regression: estimation of parameters
• We have to minimize the empirical loss function:
2( ) ( )
1
2( ) ( )0 1
1
1( ) ;
1
ni i
ni
ni i
i
J y fn
y w w xn
w x w
• This function is quadratic in terms of w. It has a minimum at some point in the w space. We find the w that minimizes the loss function by finding the conditions where the derivative of the loss function is zero.
1
0
( ) 0
( ) 0
n
n
dJ
dw
dJ
dw
w
w
2( ) ( )0 1
1
2( ) ( )0 1
1 1 1
2( ) ( )0 1
11
( ) ( ) ( )0 1
1
( ) ( )0 1
0 1
1( )
1( ) 0
10
2( ) 0
2( ) ( 1) 0
ni i
ni
ni i
ni
ni i
i
ni i i
i
ni i
ni
J y w w xn
d dJ y w w x
dw n dw
dy w w x
n dw
y w w x xn
dJ y w w x
dw n
w
w
w
Optimality conditions: Finding conditions that minimize the empirical loss function
2( )
2
p f u u
dp dp du duu
dw du dw dw
Chain rule reminder
Expected behavior of the model errors (residuals)
( ) ( ) ( )0 1
( ) ( ) ( ) ( )
1 1 1
( ) ( )
0 1 1
2( ) ( ) 0 0
2( ) ( 1) 0 0
i i i
n ni i i i
ni i
n ni i
ni i
y y w w x
dJ y x y x
dw n
dJ y y
dw n
w
w
Error in the prediction (model residual)
• The prediction error should be mean zero and not contain any linear trends, i.e., be uncorrelated with any linear function of the inputs.
x
y
x
But there may exist some non-linear function of inputs that can account for the residuals.
( ) ( ) ( ) ( ) ( )0 1 0 1
1 1 1
0n n n
i i i i i
i i i
y k k x k y k y x
Loss function: matrix notation
(1) (1)
0
1( ) ( )
1
ˆ
1n n
y xw
X Xw
y x
y w y w
2(1) (1)
2 0( ) ( )0 1
11 ( ) ( )
2
11 1
( )
1
1
1
1
ni i
ni n n
T
T
y xw
J y w w xwn n
y x
Xn
X Xn
n
w
y w
y w y w
y y
2 2 2a
b a b c
c
“L2” norm
Optimality condition: minimize mean squared error
1 ( )
1( ) 0
1 1( ) ( ) 0
1 1( ) 0
1 10
1 10
20
Tn
Tn
TTT
TT T T T T
TT T T T T
T T T T
T T
J X Xn
d dJ X X
d d n
X X X Xn n
X X X X Xn n
X X X X X Xn n
X X X X X Xn n
X X Xn
w y w y w
w y w y ww w
y w y w
y w y w
y w y w
y w y w
y w
1
T T
T T
X X X
X X X
w y
w ythe “normal” equation
The pseudo-inverse
1* *
* *
1
known: , want:
ˆ ˆ
ˆ ˆ ˆfind so that it minimizes the sum of squared errors:
pseudoinverse:
ˆ
ˆ
T
T T
T T
X
X
X
X X
X X X X X X I
X X X
X X X
y w
y w
y w
w y w y w
w y
w y
Review of regression
0 1
0 1 1
: ( ; )
: ( ; )dd d
f R R f x w w x
f R R f w w x w x
w
x w
Univariate regression:
Multivariate regression:
0 1[ , , , ]Tdw w ww Parameters we need to find:
Loss function:
2
2( ) ( ) ( ) ( )
1 1
ˆ ˆ( , ) ( )
1 1ˆ ˆ( ) , ;
n ni i i i
ni i
Loss y y y y
J Loss y y y fn n
w x wEmpirical loss:
1 1 ( ) T T
nJ X Xn n
w y w y w ε ε
1
T T
T T
X X X
X X X
w y
w y
2 ( ) 0T T
nd
J X X Xd n
w y ww
Regression with polynomials
• univariate regression with m-th order polynomials:
20 1 2: ( ; ) m
mf R R f x w w x w x w x w
1T TX X X
w y
2(1) (1) (1)
0(1)2(2) (2) (2)
1
( )
2( ) ( ) ( )
1
1
1
m
m
n
mmn n n
x x xw
ywx x x
X
y wx x x
y w
0 1 ( ; )f x w w x w 2 30 1 2 3 ( ; )f x w w x w x w x w
2 50 1 2 5( ; )f x w w x w x w x w 2 10
0 1 2 10( ; )f x w w x w x w x w
-1.5 -1 -0.5 0 0.5 1 1.5
2
4
6
8
-1.5 -1 -0.5 0 0.5 1 1.5
2
4
6
8
-1.5 -1 -0.5 0 0.5 1 1.5
0
2
4
6
8
Regression with polynomials: fit improves with increased order
Over-fitting
• We want to fit the training set, but as model complexity increases, we
run the risk of over-fitting.
2( ) ( )
1
1ˆ; 0
ni i
i
y fn
x w
-1.5 -1 -0.5 0 0.5 1 1.5
0
2
4
6
8
Train set
-1.5 -1 -0.5 0 0.5 1 1.50
2
4
6
8
Leave out
When the model order is over-fitting, leaving a single data point out of the training set can drastically change the fit.
Cross validation
• We want to fit the training set, but we want to also generalize correctly.
To measure generalization, we leave out a data point (named the test
point), fit the data, and then measure error on the test point. The average
error over all possible test points is the cross validation error.
2( ) ( ) (! )
1
1;
ni i i
i
CV y fn
x w
(! )iwWeights estimated from a training set that does not include the i-th data point
2 4 6 8 10
0.4
0.5
0.6
0.7
0.8
0.9
Model order
Mea
n-sq
uare
d er
ror
(tra
inin
g se
t)
2 4 6 8 100
25
50
75
100
125
150
175
1 2 3 4 5
1
1.5
2
2.5
3
Model order
Cro
ss-v
alid
atio
n er
ror
Cro
ss-v
alid
atio
n er
ror
Cross validation
• Cross validation error will often increase when the model structure is
over-fitting the data.
Model order
(actual data was generated with a 2nd order polynomial process)
Batch vs. online learning algorithms
• In “batch” learning, we don’t have to make any predictions until we see
all of the data. At that point, we make a model to fit all the data.
• In “online” learning, data points are given to us one at a time. We use
each example pair to update our model.
(1) (1) (2) (2) ( ) ( ), , , , , ,n nD y y y x x x
1
ˆ
T T
X
X X X
y w
w y
1 1 1 1(1) (1)(1) (1)0 1 1
(1) (1) (1)
(2) (1)
ˆ
ˆ
Tm my w w w
y y y
x x w x
w w
We are given an x and with our current model we predict a y
The teacher tells us our error
We modify our model
Online learning: the LMS algorithm
• Assume we have the model:
( )nx
( )nw
( )
( )
n
n
y
x
( )
( )
ˆ n
n
y
x
( ) ( ) ( ) ( )( ) ( ) ( )1 1 2 1ˆ n n n nn n T ny w x w x w x
When we project w onto x, we get a scalar p:
( ) ( ) ( ) ( )
( ) ( )
( ) ( )
( ) ( ) ( )
( ) ( )
cos
cos
ˆ
n T n n n
n T n
n n
n T n n
n n
yp
w x w x
w x
w x
w x
x x
( ) cosnp w
What we want is to change w so that when we project onto x we get:
Anywhere along the dash line is the solution we’re looking for.
( )
( )
n
n
y
x
( )nx
( )nw
( )
( )
n
n
y
x
( )
( )
ˆ n
n
y
x( )n
( 1)nw
( ) ( ) ( )( )
( ) ( ) ( )
( ) ( ) ( )2( )
( 1) ( ) ( ) ( ) ( )2( )
ˆ
1ˆ
1ˆ
n n nn
n n n
n n n
n
n n n n n
n
y y
y y
y y
x
x x x
xx
w w xx
The LMS algorithm
Unit vector along x
“step size”
w changes along a vector parallel to the input x in that trial with a magnitude proportional to the prediction error in that trial.
With this step size, we change w to completely account for the error in that trial.
LMS algorithm attempts to minimize a squared error loss function by approximating the gradient of the loss function
2( ) ( )
1
( )( ) ( )
1
( )( ) ( )
1
( )( 1) ( ) ( ) ( )
1
1 ( )
2
1
1
1
Nn T n
n
Nnn T n
ii n
Nnn T n
n
Nnt t n T n
n
J yN
dJy x
dw N
dJy
d N
yN
w w x
w x
w x xw
w w w x x
Average error over all data points
Steepest descent algorithm
( )
( 1) ( ) ( ) ( )( ) ( )
t
t t t T tt T t
y x
w w w xx x
LMS: local error as a rough estimate of average error
( 1)nx( 1)n
( )nx
( )nw( )n
Iterating over two data points
( )nx
( )nw( )n
Iterating over three data points
( 1)nx( 1)n
Equilibrium point
( 2)nx
( 2)n
With 3 data points, solution will not move to a single point and stay put. It converges to a small region of the parameter space but will bounce around, as long as > 0.
Convergence of LMS-algorithm
( 1) ( ) ( ) ( ) ( ) n n n T n ny w w w x x
It is difficult to prove “convergence” of LMS because the weights keep bouncing around. But we can prove convergence for the steepest decent algorithm and then use the fact that LMS is a stochastic approximation to it.
Convergence of LMS-algorithm
0 ( )(1) (0) ( ) ( )
1
(0) (0)
(0)
(2) (1)
(0)
2 1 2 (0)
1
1( ) (0)
1
NT nn n
n
T
T T
T T
T T T T
iT T T
i
iT T T
i
y
X X
I X X X
I X X X
I X X I X X X X
I X X X I X X
I X X X I X X
w w w x x
w y w
w y
w w y
w y y
y w
w y w
a geometric series
Convergence of a geometric series of scalars
1 0 1 1 1 1
1
1 2
1
1 1
1 1
1
1
1
11
lim if 11
ni n n
ni
ni n
ni
n ni i n
n ni i
n
n
n n
a a a a a a s
a a as a a a
a a a s as a
as
a
s aa
Convergence of a geometric series of matrices
1 1 1
1
1 1
1 1
1
1
1
1
lim 0 if 1
lim if 1
ni n
ni
n ni i n
n ni i
nn
n n
nn i
n n i
A I A A S
A A A S AS I A
S I A I A
A QLQ
A QL Q
A i
S I A i
See homework for this:
Convergence of steepest descent algorithm
1
1
1( ) (0)
1
1 1
1
1 if 1 , where is an eigen value of
if 1
1
is positive definite, therefore
ii
i
iT T T
i
iT T Ti
i
T Ti i
T Ti
A i AI A
I X X X I X X
I X X I I X X I X X i
I X X X X
X X X X
w y w
1 1( )
0
21 1 0< 2 0<
0 if 1 , (use eigen vector decomposition)
if 0< 2
T Ti i T
i
T Ti
T T Ti
i
X X i X X iX X
I X X I X X i
X X X X X i
w y
See homework for proof of this.
We have shown convergence of the steepest decent algorithm to the solution of the normal equations. The LMS is a stochastic approximation to steepest decent, thus it “converges” as well, but will jump around stochastically, as long as the learning rate is greater than zero. Convergence can be reached when the learning rate is systematically made smaller on each step.
We will call changes of the learning rate “adaptive learning” and will see a principled approach to this problem when we consider Bayesian approaches to learning.
Summary: Linear Regression
0 1
0 1 1
: ( ; )
: ( ; )dd d
f R R f x w w x
f R R f w w x w x
w
x w
Univariate regression:
Multivariate regression:
0 1[ , , , ]Tdw w ww Parameters we need to find:
Loss function:
2
2( ) ( ) ( ) ( )
1 1
ˆ ˆ( , ) ( )
1 1ˆ ˆ( ) , ;
n ni i i i
ni i
Loss y y y y
J Loss y y y fn n
w x wEmpirical loss:
1 1 ( ) T T
nJ X Xn n
w y w y w ε ε
1
T T
T T
X X X
X X X
w y
w y
2 ( ) 0T T
nd
J X X Xd n
w y ww
Summary: Iterative learning
• Increased model complexity reduces error over the training data but can
increase the leave-one-out cross validation error. We want a model that
fits the trained data and generalizes correctly.
• LMS algorithm:
• w changes along a vector parallel to the input x in that trial with a
magnitude proportional to the error in that trial.
• Steepest descent algorithm:
2( ) ( )
1
( )( 1) ( ) ( ) ( )
1
1 ( )
Nn T n
n
Nnt t n T n
n
J yN
y
w w x
w w w x x
( ) ( ) ( )
( 1) ( ) ( ) ( ) ( )2( )
1( 1) ( ) ( ) ( ) ( )max
ˆ
1ˆ
ˆ 0 2
n n T n
n n n n n
n
n n n n n T
y
y y
y y X X
w x
w w xx
w w x