2011/08/03

machine learning: curve-fitting

#coding: utf-8                                                                  
                                                                                
import numpy as np                                                              
from pylab import *                                                             
import sys                                                                      
                                                                                
M = 9                                                                           
def y(x, wlist):                                                                
    ret = wlist[0]                                                              
    for i in range(1, M+1):                                                     
        ret += wlist[i] * (x ** i)                                              
    return ret                                                                  
                                                                                
def estimate(xlist, tlist, lam):                                                
    A = []                                                                      
    for i in range(M+1):                                                        
        for j in range(M+1):                                                    
            temp = (xlist**(i+j)).sum()                                         
            if i == j:                                                          
                temp += lam                                                     
            A.append(temp)                                                      
    A = array(A).reshape(M+1, M+1)                                              
                                                                                
    T = []                                                                      
    for i in range(M+1):                                                        
        T.append(((xlist**i) * tlist).sum())                                    
    T = array(T)                                                                
                                                                                
    wlist = np.linalg.solve(A, T)                                               
    return wlist                                                                
                                                                                
def rms(xlist, tlist, wlist):                                                   
    E_w = 0                                                                     
    N = len(xlist)                                                              
                                                                                
    for n in range(0, N):                                                       
        sum = 0                                                                 
        for j in range(0, M+1):                                                 
            sum += wlist[j] * (xlist[n] ** j)                                   
        E_w += ((sum - tlist[n]) ** 2)/2                                        
                                                                                
    return str(np.sqrt(2 * E_w / N))                                            

def example1():                                                                 
    # number of training data                                                   
    N = 10                                                                      
                                                                                
    # generate training data                                                    
    xlist = np.linspace(0, 1, N) # extract N-points                             
    tlist = np.sin(2 * np.pi * xlist) + np.random.normal(0, 0.2, xlist.size)    
                                                                                
    # estimate parametaer w                                                     
    wlist = estimate(xlist, tlist, np.exp(-18.0))                               
    #print wlist                                                                
    print "E_RMS for training data: %s" %  rms(xlist, tlist, wlist)             
                                                                                
    N = 100                                                                     
    xlist2 = np.linspace(0, 1, N) # extract N-points                            
    tlist2 = np.sin(2 * np.pi * xlist2) + np.random.normal(0, 0.2, xlist2.size) 
    print "E_RMS for test data: %s" % rms(xlist2, tlist2, wlist)                
                                                                                
    # generate original data                                                    
    xs = np.linspace(0, 1, 1000)                                                
                                                                                
    # ideal and model                                                           
    ideal = np.sin(2 * np.pi * xs)                                              
    model = [y(x, wlist) for x in xs]                                           
                                                                                
    # plot training data and original data                                      
    plot(xlist, tlist, 'bo')                                                    
    plot(xlist2, tlist2, 'rd')                                                  
    plot(xs, ideal, 'g-')                                                       
    plot(xs, model, 'r-')                                                       
    xlim(0.0, 1.0)                                                              
    ylim(-1.5, 1.5)                                                             
    show()                                                                      
                                                                                
if __name__ == "__main__":                                                      
    example1()

No comments:

Post a Comment

100