Chrome Pointer

2021年12月5日 星期日

Python 梯度下降法範例教學 Gradient descent example

 

import numpy as np
import matplotlib.pyplot as plt

# 目標函數:y=x^2
def func(x): 
  return np.square(x)

# 目標函數一階導數:dy/dx=2*x  微分之後出來的值是 切線斜率, x越大 斜率越就越大, x越小 斜率就越小, 所以越往下斜率越接近0
def dfunc(x): 
  return 2 * x

def GD(x_startdfepochslr):    
    """  梯度下降法。給定起始點與目標函數的一階導函數,求在epochs次反覆運算中x的更新值
        :param x_start: x的起始點    
        :param df: 目標函數的一階導函數    
        :param epochs: 反覆運算週期    
        :param lr: 學習率    
        :return: x在每次反覆運算後的位置(包括起始點),長度為epochs+1    
     """    
    xs = np.zeros(epochs+1)    
    x = x_start    
    xs[0] = x    
    for i in range(epochs):         
        dx = df(x)       
        # v表示x要改變的幅度        
        v = - dx * lr        
        x += v        
        xs[i+1] = x    

        print ("df(x)微分 : ",dx,"\t v改變的幅度 : ",v) 
    return xs

# Main
# 起始權重
x_start = 5    
# 執行週期數
epochs = 5
# 學習率   
lr = 0.1
# 梯度下降法 
x = GD(x_start, dfunc, epochs, lr=lr) 
print (x)
# 輸出:[-5.     -2.     -0.8    -0.32   -0.128  -0.0512]

color = 'r'    
#plt.plot(line_x, line_y, c='b')    
from numpy import arange
t = arange(-6.06.00.01)
plt.plot(t, func(t), c='b')
plt.plot(x, func(x), c=color, label='lr={}'.format(lr))    
plt.scatter(x, func(x), c=color, )    
plt.legend()

plt.show()



沒有留言:

張貼留言

喜歡我的文章嗎? 喜歡的話可以留言回應我喔! ^^