當(dāng)前位置:首頁(yè) > 芯聞號(hào) > 充電吧
[導(dǎo)讀]GBDT 回歸的原理和Python 實(shí)現(xiàn)

完整實(shí)現(xiàn)代碼請(qǐng)參考github:

1. 原理篇

我們用人話而不是大段的數(shù)學(xué)公式來(lái)講講GBDT回歸是怎么一回事。

1.1 溫故知新

回歸樹(shù)是GBDT的基礎(chǔ),之前的一篇文章曾經(jīng)講過(guò)回歸樹(shù)的原理和實(shí)現(xiàn)。


1.2 預(yù)測(cè)年齡

仍然以預(yù)測(cè)同事年齡來(lái)舉例,從《回歸樹(shù)》那篇文章中我們可以知道,如果需要通過(guò)一個(gè)常量來(lái)預(yù)測(cè)同事的年齡,平均值是最佳的選擇之一。

1.3 年齡的殘差

我們不妨假設(shè)同事的年齡分別為5歲、6歲、7歲,那么同事的平均年齡就是6歲。所以我們用6歲這個(gè)常量來(lái)預(yù)測(cè)同事的年齡,即[6, 6, 6]。每個(gè)同事年齡的殘差 = 年齡 – 預(yù)測(cè)值 = [5, 6, 7] – [6, 6, 6],所以殘差為[-1, 0, 1]

1.4 預(yù)測(cè)年齡的殘差

為了讓模型更加準(zhǔn)確,其中一個(gè)思路是讓殘差變小。如何減少殘差呢?我們不妨對(duì)殘差建立一顆回歸樹(shù),然后預(yù)測(cè)出準(zhǔn)確的殘差。假設(shè)這棵樹(shù)預(yù)測(cè)的殘差是[-0.9, 0, 0.9],將上一輪的預(yù)測(cè)值和這一輪的預(yù)測(cè)值求和,每個(gè)同事的年齡 = [6, 6, 6] + [-0.9, 0, 0.9] = [5.1, 6, 6.9],顯然與真實(shí)值[5, 6, 7]更加接近了, 年齡的殘差此時(shí)變?yōu)閇-0.1, 0, 0.1]。顯然,預(yù)測(cè)的準(zhǔn)確性得到了提升。

1.5 GBDT

重新整理一下思路,假設(shè)我們的預(yù)測(cè)一共迭代3輪 年齡:[5, 6, 7]

第1輪預(yù)測(cè):[6, 6, 6] (平均值)

第1輪殘差:[-1, 0, 1]

第2輪預(yù)測(cè):[6, 6, 6] (平均值) + [-0.9, 0, 0.9] (第1顆回歸樹(shù)) = [5.1, 6, 6.9]

第2輪殘差:[-0.1, 0, 0.1]

第3輪預(yù)測(cè):[6, 6, 6] (平均值) + [-0.9, 0, 0.9] (第1顆回歸樹(shù)) + [-0.08, 0, 0.07] (第2顆回歸樹(shù)) = [5.02, 6, 6.97]

第3輪殘差:[-0.08, 0, 0.03]

看上去殘差越來(lái)越小,而這種預(yù)測(cè)方式就是GBDT算法。

1.6 公式推導(dǎo)

看到這里,相信您對(duì)GBDT已經(jīng)有了直觀的認(rèn)識(shí)。這么做有什么科學(xué)依據(jù)么,為什么殘差可以越來(lái)越小呢?前方小段數(shù)學(xué)公式低能預(yù)警。

  1. 假設(shè)要做m輪預(yù)測(cè),預(yù)測(cè)函數(shù)為Fm,初始常量或每一輪的回歸樹(shù)為fm,輸入變量為X,有:

  2. 設(shè)要預(yù)測(cè)的變量為y,采用MSE作為損失函數(shù):

  3. 我們知道泰勒公式的一階展開(kāi)式是長(zhǎng)成這個(gè)樣子滴:

  4. 如果:

  5. 那么,根據(jù)式3和式4可以得出:

  6. 根據(jù)式2可以知道,損失函數(shù)的一階偏導(dǎo)數(shù)為:

  7. 根據(jù)式6可以知道,損失函數(shù)的二階偏導(dǎo)數(shù)為:

  8. 蓄力結(jié)束,開(kāi)始放大招。根據(jù)式1,損失函數(shù)的一階導(dǎo)數(shù)為:

  9. 根據(jù)式5,將式8進(jìn)一步展開(kāi)為:

  10. 令式9,即損失函數(shù)的一階偏導(dǎo)數(shù)為0,那么:

  11. 將式6,式7代入式9得到:

因此,我們需要通過(guò)用第m-1輪殘差的均值來(lái)得到函數(shù)fm,進(jìn)而優(yōu)化函數(shù)Fm。而回歸樹(shù)的原理就是通過(guò)最佳劃分區(qū)域的均值來(lái)進(jìn)行預(yù)測(cè)。所以fm可以選用回歸樹(shù)作為基礎(chǔ)模型,將初始值,m-1顆回歸樹(shù)的預(yù)測(cè)值相加便可以預(yù)測(cè)y。

2. 實(shí)現(xiàn)篇

本人用全宇宙最簡(jiǎn)單的編程語(yǔ)言——Python實(shí)現(xiàn)了GBDT回歸算法,沒(méi)有依賴任何第三方庫(kù),便于學(xué)習(xí)和使用。簡(jiǎn)單說(shuō)明一下實(shí)現(xiàn)過(guò)程,更詳細(xì)的注釋請(qǐng)參考本人github上的代碼。

2.1 導(dǎo)入回歸樹(shù)類

回歸樹(shù)是我之前已經(jīng)寫好的一個(gè)類,在之前的文章詳細(xì)介紹過(guò),代碼請(qǐng)參考:

regression_tree.pygithub.com


1

from ..tree.regression_tree import RegressionTree

2.2 創(chuàng)建GradientBoostingBase類

初始化,存儲(chǔ)回歸樹(shù)、學(xué)習(xí)率、初始預(yù)測(cè)值和變換函數(shù)。(注:回歸不需要做變換,因此函數(shù)的返回值等于參數(shù))


1

2

3

4

5

6

class GradientBoostingBase(object):

????def __init__(self):

????????self.trees = None

????????self.lr = None

????????self.init_val = None

????????self.fn = lambda x: x

2.3 計(jì)算初始預(yù)測(cè)值

初始預(yù)測(cè)值即y的平均值。


1

2

def _get_init_val(self, y):

????return sum(y) / len(y)

2.4 計(jì)算殘差


1

2

def _get_residuals(self, y, y_hat):

????return [yi - self.fn(y_hat_i) for yi, y_hat_i in zip(y, y_hat)]

2.5 訓(xùn)練模型

訓(xùn)練模型的時(shí)候需要注意以下幾點(diǎn): 1. 控制樹(shù)的最大深度max_depth; 2. 控制分裂時(shí)最少的樣本量min_samples_split; 3. 訓(xùn)練每一棵回歸樹(shù)的時(shí)候要乘以一個(gè)學(xué)習(xí)率lr,防止模型過(guò)擬合; 4. 對(duì)樣本進(jìn)行抽樣的時(shí)候要采用有放回的抽樣方式。


1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

def fit(self, X, y, n_estimators, lr, max_depth, min_samples_split, subsample=None):

????self.init_val = self._get_init_val(y)

?

????n = len(y)

????y_hat = [self.init_val] * n

????residuals = self._get_residuals(y, y_hat)

?

????self.trees = []

????self.lr = lr

????for _ in range(n_estimators):

????????idx = range(n)

????????if subsample is not None:

????????????k = int(subsample * n)

????????????idx = choices(population=idx, k=k)

????????X_sub = [X[i] for i in idx]

????????residuals_sub = [residuals[i] for i in idx]

????????y_hat_sub = [y_hat[i] for i in idx]

?

????????tree = RegressionTree()

????????tree.fit(X_sub, residuals_sub, max_depth, min_samples_split)

?

????????self._update_score(tree, X_sub, y_hat_sub, residuals_sub)

?

????????y_hat = [y_hat_i + lr * res_hat_i for y_hat_i,

????????????????????res_hat_i in zip(y_hat, tree.predict(X))]

?

????????residuals = self._get_residuals(y, y_hat)

????????self.trees.append(tree)

2.6 預(yù)測(cè)一個(gè)樣本


1

2

def _predict(self, Xi):

????return self.fn(self.init_val + sum(self.lr * tree._predict(Xi) for tree in self.trees))

2.7 預(yù)測(cè)多個(gè)樣本


1

2

def predict(self, X):

????return [self._predict(Xi) for Xi in X]

3 效果評(píng)估

3.1 main函數(shù)

使用著名的波士頓房?jī)r(jià)數(shù)據(jù)集,按照7:3的比例拆分為訓(xùn)練集和測(cè)試集,訓(xùn)練模型,并統(tǒng)計(jì)準(zhǔn)確度。


1

2

3

4

5

6

7

8

9

10

11

12

13

14

@run_time

def main():

????print("Tesing the accuracy of GBDT regressor...")

?

????X, y = load_boston_house_prices()

?

????X_train, X_test, y_train, y_test = train_test_split(

????????X, y, random_state=10)

?

????reg = GradientBoostingRegressor()

????reg.fit(X=X_train, y=y_train, n_estimators=4,

????????????lr=0.5, max_depth=2, min_samples_split=2)

?

????get_r2(reg, X_test, y_test)

3.2 效果展示

最終擬合優(yōu)度0.851,運(yùn)行時(shí)間2.2秒,效果還算不錯(cuò)~

3.3 工具函數(shù)

本人自定義了一些工具函數(shù),可以在github上查看

utils.pygithub.com

1. run_time – 測(cè)試函數(shù)運(yùn)行時(shí)間

2. load_boston_house_prices – 加載波士頓房?jī)r(jià)數(shù)據(jù)

3. train_test_split – 拆分訓(xùn)練集、測(cè)試集

4. get_r2 – 計(jì)算擬合優(yōu)度


本站聲明: 本文章由作者或相關(guān)機(jī)構(gòu)授權(quán)發(fā)布,目的在于傳遞更多信息,并不代表本站贊同其觀點(diǎn),本站亦不保證或承諾內(nèi)容真實(shí)性等。需要轉(zhuǎn)載請(qǐng)聯(lián)系該專欄作者,如若文章內(nèi)容侵犯您的權(quán)益,請(qǐng)及時(shí)聯(lián)系本站刪除。
換一批
延伸閱讀

9月2日消息,不造車的華為或?qū)⒋呱龈蟮莫?dú)角獸公司,隨著阿維塔和賽力斯的入局,華為引望愈發(fā)顯得引人矚目。

關(guān)鍵字: 阿維塔 塞力斯 華為

倫敦2024年8月29日 /美通社/ -- 英國(guó)汽車技術(shù)公司SODA.Auto推出其旗艦產(chǎn)品SODA V,這是全球首款涵蓋汽車工程師從創(chuàng)意到認(rèn)證的所有需求的工具,可用于創(chuàng)建軟件定義汽車。 SODA V工具的開(kāi)發(fā)耗時(shí)1.5...

關(guān)鍵字: 汽車 人工智能 智能驅(qū)動(dòng) BSP

北京2024年8月28日 /美通社/ -- 越來(lái)越多用戶希望企業(yè)業(yè)務(wù)能7×24不間斷運(yùn)行,同時(shí)企業(yè)卻面臨越來(lái)越多業(yè)務(wù)中斷的風(fēng)險(xiǎn),如企業(yè)系統(tǒng)復(fù)雜性的增加,頻繁的功能更新和發(fā)布等。如何確保業(yè)務(wù)連續(xù)性,提升韌性,成...

關(guān)鍵字: 亞馬遜 解密 控制平面 BSP

8月30日消息,據(jù)媒體報(bào)道,騰訊和網(wǎng)易近期正在縮減他們對(duì)日本游戲市場(chǎng)的投資。

關(guān)鍵字: 騰訊 編碼器 CPU

8月28日消息,今天上午,2024中國(guó)國(guó)際大數(shù)據(jù)產(chǎn)業(yè)博覽會(huì)開(kāi)幕式在貴陽(yáng)舉行,華為董事、質(zhì)量流程IT總裁陶景文發(fā)表了演講。

關(guān)鍵字: 華為 12nm EDA 半導(dǎo)體

8月28日消息,在2024中國(guó)國(guó)際大數(shù)據(jù)產(chǎn)業(yè)博覽會(huì)上,華為常務(wù)董事、華為云CEO張平安發(fā)表演講稱,數(shù)字世界的話語(yǔ)權(quán)最終是由生態(tài)的繁榮決定的。

關(guān)鍵字: 華為 12nm 手機(jī) 衛(wèi)星通信

要點(diǎn): 有效應(yīng)對(duì)環(huán)境變化,經(jīng)營(yíng)業(yè)績(jī)穩(wěn)中有升 落實(shí)提質(zhì)增效舉措,毛利潤(rùn)率延續(xù)升勢(shì) 戰(zhàn)略布局成效顯著,戰(zhàn)新業(yè)務(wù)引領(lǐng)增長(zhǎng) 以科技創(chuàng)新為引領(lǐng),提升企業(yè)核心競(jìng)爭(zhēng)力 堅(jiān)持高質(zhì)量發(fā)展策略,塑強(qiáng)核心競(jìng)爭(zhēng)優(yōu)勢(shì)...

關(guān)鍵字: 通信 BSP 電信運(yùn)營(yíng)商 數(shù)字經(jīng)濟(jì)

北京2024年8月27日 /美通社/ -- 8月21日,由中央廣播電視總臺(tái)與中國(guó)電影電視技術(shù)學(xué)會(huì)聯(lián)合牽頭組建的NVI技術(shù)創(chuàng)新聯(lián)盟在BIRTV2024超高清全產(chǎn)業(yè)鏈發(fā)展研討會(huì)上宣布正式成立。 活動(dòng)現(xiàn)場(chǎng) NVI技術(shù)創(chuàng)新聯(lián)...

關(guān)鍵字: VI 傳輸協(xié)議 音頻 BSP

北京2024年8月27日 /美通社/ -- 在8月23日舉辦的2024年長(zhǎng)三角生態(tài)綠色一體化發(fā)展示范區(qū)聯(lián)合招商會(huì)上,軟通動(dòng)力信息技術(shù)(集團(tuán))股份有限公司(以下簡(jiǎn)稱"軟通動(dòng)力")與長(zhǎng)三角投資(上海)有限...

關(guān)鍵字: BSP 信息技術(shù)
關(guān)閉
關(guān)閉