Bayesian Deep Learning via Subnetwork Inference
貝葉斯深度學習中的子網絡推斷
https://proceedings.mlr.press/v139/daxberger21a/daxberger21a.pdf
![]()
摘要
貝葉斯范式有望解決深度神經網絡的核心問題,如校準性能差和數據利用效率低。然而,將貝葉斯推斷擴展至大規模參數空間通常需引入強約束性近似。本文指出:僅對模型權重的一小部分子集進行推斷,即可獲得準確的預測后驗分布;其余權重則保持為點估計。該子網絡推斷(subnetwork inference)框架使我們得以在子集上使用表達能力更強、原本難以處理的后驗近似方法。具體而言,我們實現了一種簡潔、可擴展的貝葉斯深度學習方法——子網絡線性化拉普拉斯近似(subnetwork linearized Laplace):首先獲得全網絡權重的最大后驗(MAP)估計,隨后基于線性化拉普拉斯近似,在選定子網絡上推斷一個全協方差高斯后驗分布。我們提出一種子網絡選擇策略,旨在最大程度保留模型的預測不確定性。實驗表明,該方法在性能上優于集成方法(ensembles)及對全網絡采用表達能力較弱后驗近似的其他方法。
引言
深度神經網絡(NNs)的一個關鍵缺陷是:其預測往往校準不良且過度自信——尤其當訓練與測試數據分布存在偏移時(Nguyen et al., 2015; Guo et al., 2017)。為支持可靠決策,神經網絡需穩健地量化其預測不確定性(Bhatt et al., 2020),這對醫療、自動駕駛等安全攸關應用尤為重要(Amodei et al., 2016)。
貝葉斯建模(Bishop, 2006; Ghahramani, 2015)通過模型參數的后驗分布,為不確定性量化提供了原則性途徑。不幸的是,神經網絡中精確后驗推斷不可行。盡管貝葉斯深度學習領域近年取得進展(Osawa et al., 2019; Maddox et al., 2019; Dusenberry et al., 2020),現有方法為適配大規模網絡,仍不得不采用不切實際的假設,嚴重限制了后驗分布的表達能力,進而損害不確定性估計質量(Ovadia et al., 2019; Fort et al., 2019; Foong et al., 2019a)。
或許,這些不切實際的推斷近似可被避免。鑒于神經網絡高度過參數化,其精度可由一個小型子網絡良好保持(Cheng et al., 2017);且在低維權重子空間中進行推斷,即可實現準確的不確定性量化(Izmailov et al., 2019)。這引出如下問題:一個完整神經網絡的模型不確定性能否被小型子網絡充分保留? 本文證明:全網絡的后驗預測分布可由子網絡的后驗預測分布良好近似。具體貢獻如下:
提出 子網絡推斷 ——一種可擴展的貝葉斯深度學習通用框架:僅對神經網絡權重的一小部分子集進行推斷,其余權重保持為確定性點估計;由此允許使用原本在大規模網絡中難以處理的高表達力后驗近似方法。我們給出該框架的一種具體實現:先擬合全網絡的最大后驗(MAP)估計,再以線性化拉普拉斯近似在子網絡上推斷全協方差高斯后驗(見圖1)。
提出一種基于 全網絡近似后驗 與 子網絡近似后驗 之間Wasserstein距離的子網絡選擇策略。為提升可擴展性,子網絡選擇階段采用對角近似;選定小規模子網絡后,即可推斷權重間的協方差。實驗發現: 在子網絡選擇階段做近似,對后驗預測的影響遠小于在推斷階段做近似 。
我們在一系列不確定性校準與分布偏移魯棒性基準上評估該方法。實驗表明:高表達力的子網絡推斷方法,其性能優于對全網絡進行低表達力推斷的主流貝葉斯深度學習方法,也優于深度集成(deep ensembles)。
子網絡后驗近似
設 ∈ ?? 為所有神經網絡權重的 D 維向量(即所有層權重矩陣的拼接與展平)。貝葉斯神經網絡(BNNs)旨在捕捉模型不確定性,即由于訓練數據 = {, } 存在多種合理解釋而產生的關于權重 選擇的不確定性。其中, ∈ ?? 為輸出變量(例如分類標簽), ∈ ???? 為特征矩陣。首先,需在 BNN 的權重 上指定一個先驗分布 ()。隨后,我們希望推斷其完整的后驗分布。
這種后驗預測分布將權重中的不確定性轉化為預測中的不確定性。遺憾的是,由于神經網絡(NNs)的非線性特性,推斷精確的后驗分布 (|) 是不可行的;又因權重 的高維度,即使要忠實地近似后驗分布也面臨巨大的計算挑戰。因此,通常采用粗略的后驗近似方法,例如完全因子化近似,即 (|) ≈ ∏?_{d=1} (_d),其中 _d 是權重向量 中的第 d 個權重(Hernández-Lobato & Adams, 2015; Blundell et al., 2015; Khan et al., 2018; Osawa et al., 2019)。然而,已有研究表明,此類近似存在嚴重缺陷(Foong et al., 2019a,b)。
在本工作中,我們質疑廣泛存在的隱含假設——即一個表達能力強的后驗近似必須包含全部 D 個模型權重。相反,我們嘗試僅對權重的一個小規模子集 ? 進行推斷。以下論證支持這一方法:
- 過參數化:Maddox 等人(2020)表明,在局部最優解附近,存在許多方向不會改變神經網絡的預測結果。此外,神經網絡可被大量剪枝而不犧牲測試集精度(Frankle & Carbin, 2019)。這表明,神經網絡的大部分預測能力可集中于一個小規模子網絡中。
- 子模型上的推斷:先前研究1 已提供證據表明,即使推斷未在完整參數空間上進行,仍可有效。例如,Izmailov 等人(2019)和 Snoek 等人(2015)分別在權重的低維投影空間和神經網絡的最后一層上執行推斷。
因此,我們將上述兩個想法結合起來,對公式 (1) 中的后驗分布做出如下兩步近似:
![]()
![]()
與權重剪枝方法的關系。注意,(4)中的后驗近似可以被視為將權重的方差修剪為零。這與權重剪枝方法(Cheng et al., 2017)形成對比,后者將權重本身設置為零。即,權重剪枝方法可以被視為移除權重以保留預測均值(即保持與完整模型接近的準確性)。相比之下,子網絡推斷可以被視為僅移除某些權重的方差——同時保持它們的均值——以保留預測不確定性(例如,保持與完整模型接近的校準)。因此,它們是互補的方法。重要的是,通過不剪枝權重,子網絡推斷保留了完整神經網絡的全部預測能力以保持其預測準確性。
背景:線性化拉普拉斯近似
在本工作中,我們通過使用線性化拉普拉斯近似(MacKay, 1992)對權重上的后驗分布進行近似,從而滿足公式(4)。這是一種可處理的推斷技術,近期已被證明表現優異(Foong 等,2019b;Immer 等,2020),并可事后應用于預訓練模型。下面我們將在一般設定下對其進行描述。
我們將神經網絡函數記為 : ?? → ??。首先,我們定義一個關于神經網絡權重的先驗分布,我們選擇其為完全因子化的高斯分布 () = (; , )。接著,我們尋找后驗分布的一個局部最優解,也稱為權重的最大后驗(MAP)估計:
![]()
隨后,利用在 MAP 估計點處的二階泰勒展開對后驗分布進行近似:
![]()
有趣的是,當采用高斯似然時,以廣義高斯-牛頓(GGN)精度矩陣定義的高斯分布,恰好對應于將神經網絡在處進行一階泰勒展開線性化后的 真實后驗分布 (Khan et al., 2019;Immer et al., 2020)。該局部線性化函數為:
![]()
![]()
![]()
這些閉式表達式頗具吸引力,因其所得預測均值與分類決策邊界 與 MAP 估計所得神經網絡完全一致 。
然而,存儲現代神經網絡(即參數維度 D 極大)權重空間上完整的 D × D
協方差矩陣在計算上是不可行的。盡管已有研究致力于開發更廉價的近似方案(例如僅存儲對角元(Denker & LeCun, 1990)或塊對角元(Ritter et al., 2018; Immer et al., 2020)),但這些近似均以降低預測性能為代價。
線性化拉普拉斯子網絡推斷
我們概述以下程序,用于在子網絡推斷框架內將線性化拉普拉斯近似擴展至大規模神經網絡模型。
![]()
![]()
![]()
![]()
![]()
![]()
![]()
![]()
子網絡選擇
理想情況下,我們希望所選子網絡誘導出的預測后驗分布盡可能接近對全網絡進行推斷所得的預測后驗分布(式11)。這種隨機過程之間的差異通常通過函數空間的 KL 散度(functional Kullback–Leibler divergence)來量化(Sun 等,2019;Burt 等,2020):
![]()
![]()
在權重空間中,我們的目標是最小化全網絡精確后驗分布(式1)與子網絡近似后驗分布(式4)之間的差異。這帶來了兩個挑戰:首先,計算精確后驗分布仍是不可行的;其次,常見的差異度量(如 KL 散度或 Hellinger 距離)對于式(4)中出現的狄拉克δ分布并未良好定義。
為解決第一個問題,我們再次借助第3節中引入的局部線性化方法。線性化模型的真實后驗分布是高斯分布或近似高斯分布2:
我們通過選用平方 2-Wasserstein 距離來解決第二個問題,該度量對于支撐集不相交的分布仍有良好定義。對于全協方差高斯分布(式21)與一個全協方差高斯分布和若干狄拉克δ函數的乘積(式16)的情形,該度量具有如下形式:
![]()
![]()
![]()
![]()
表面上看,我們似乎又回到了最初試圖避免的性能較差的對角假設(Ovadia 等,2019;Foong 等,2019a;Ashukha 等,2020)。然而,這里存在一個關鍵區別:我們是在 子網絡選擇階段 做出對角假設,而非在 推斷階段 ;我們在子網絡 上執行的是 全協方差推斷 。在第6節中,我們將提供證據表明,在子網絡選擇階段采用對角假設是合理的,原因如下:1)相較于在推斷階段做相同假設,它對預測性能的損害要小得多;2)它優于隨機子網絡選擇。
實驗
我們通過實驗評估子網絡推斷的有效性,并將其與以下方法進行比較:(1)對全網絡采用表達能力較弱的推斷方法;(2)深度學習中當前最先進的不確定性量化方法。我們考慮三類基準設置:
1)小規模玩具回歸任務;
2)中等規模表格數據回歸任務;
3)基于 ResNet-18 的圖像分類任務。
更多實驗結果與設置細節分別見附錄 A 與附錄 D。
6.1 子網絡推斷如何保留后驗預測不確定性?
我們首先定性評估:在選定子網絡上采用全協方差高斯后驗所得預測分布,與以下方法所得預測分布的對比情況:
1)全網絡上的全協方差高斯后驗(Full Cov);
2)全網絡上的因子分解高斯后驗(Diag);
3)僅在網絡最后一層上采用全協方差高斯后驗(Final layer)(Snoek 等,2015);
4)點估計(MAP)。
對于子網絡推斷,我們同時考慮兩種子網絡選擇策略:第5節所述的Wasserstein策略(Wass)與均勻隨機選擇策略(Rand),以構建僅包含模型參數總量50%、3%和1%的子網絡。在此玩具實驗中,精確計算后驗邊際方差以指導子網絡選擇尚屬可行。
我們的神經網絡包含2個ReLU隱藏層,每層50個隱藏單元。采用同方差高斯似然函數,其噪聲方差通過最大似然估計優化。我們在網絡權重(不含偏置)上采用GGN拉普拉斯推斷,并結合式(18)中的線性化預測分布。因此,所考察的所有方法共享相同的預測均值,便于更公平地比較其不確定性估計。
我們將全網絡先驗精度設為? = 3?(經驗上表現良好),子網絡先驗精度設為?= ? S/D。
我們采用 Antorán 等人(2020)提出的合成一維回歸任務——輸入數據形成兩個分離的簇,從而可檢驗模型對“簇間區域”的不確定性響應(Foong 等,2019b)。結果如圖2所示:
子網絡推斷在 推斷更少權重 的同時,比對角高斯或僅最后一層推斷更能保留全網絡推斷的不確定性;
通過捕捉權重間的相關性,子網絡推斷可在數據簇之間維持較高不確定性;
該特性在隨機與Wasserstein子網絡選擇下均成立,但后者在子網絡更小時能保留更多不確定性;
相較于對角拉普拉斯,其顯著優勢表明: 在子網絡選擇階段采用對角假設,但在推斷階段轉而使用全協方差高斯后驗(即本文做法),顯著優于直接對推斷后驗采用對角假設 (參見第5節)。
綜上,結果表明:在精心選擇的子網絡上進行高表達力推斷,相較對全網絡采用粗糙近似,能更好地保留預測不確定性。
6.2 大型模型中的子網絡推斷 vs 小型模型中的全網絡推斷
![]()
我們首先獲得每個神經網絡權重的最大后驗(MAP)估計及其同方差似然函數的噪聲方差。隨后,對每個網絡執行全網絡 GGN-Laplace 推斷。我們還使用所提出的 Wasserstein 規則修剪每個網絡的權重方差,使剩余方差數量匹配每一個較小網絡的規模。我們采用對角拉普拉斯近似來廉價地估算用于子網絡選擇的后驗邊際方差。我們利用式(12)和(18)中的線性化方法計算預測分布。因此,具有相同權重數量的神經網絡會產生相同的預測均值;增加所考慮的權重方差數量只會提升預測不確定性。
我們選用三個規模遞增的表格數據集(輸入維度、樣本點數):wine(11維,1439點)、kin8nm(8維,7373點)和 protein(9維,41157點)。我們采用其標準訓練-測試劃分(Hernández-Lobato & Adams, 2015),以及專為測試分布外不確定性的變體劃分(Foong 等, 2019b)。具體細節見附錄 D.4。對于每個劃分,我們將訓練數據的15%留作驗證集,用于在尋找MAP估計及選擇權重先驗精度時進行早停。所有模型和數據集保持其他超參數固定。結果如圖3所示。
![]()
我們呈現平均測試對數似然(LL)值,因其同時考慮了準確率與不確定性。當結合全網絡推斷時,規模更大的模型通常表現最佳,盡管 Wine-gap 和 Protein-gap 是例外。有趣的是,即使我們僅在與小型模型同等規模的子網絡上進行推斷,這些大型模型的表現依然最優。我們推測,這源于權重后驗神經網絡模型中存在大量退化方向(即冗余權重)(Maddox 等,2020)。小型模型的全網絡推斷會同時捕獲有用與無用權重的信息;而在大型模型中,我們的子網絡選擇策略使我們能將更多計算資源用于建模信息豐富的權重方差與協方差。在6個數據集中有3個,我們發現:隨著推斷所涉及權重數量的增加,LL 值出現驟升,隨后進入平臺期。這種平臺現象可能是因為大部分信息豐富的權重方差已被納入模型。考慮到計算 GGN 的成本遠高于神經網絡訓練成本,這些結果表明: 在相同計算量下,對大型模型執行子網絡推斷比對小型模型執行全網絡推斷更優 。
6.3 分布偏移下的圖像分類
我們現在評估采用子網絡推斷的大型卷積神經網絡在圖像分類任務中對分布偏移的魯棒性,并與以下基線方法進行比較:
點估計網絡(MAP);
對全網絡采用表達能力較弱推斷的貝葉斯深度學習方法:MC Dropout(Gal & Ghahramani, 2016)、對角拉普拉斯、VOGN(Osawa 等,2019)——三者均假設權重后驗完全因子化;以及 SWAG(Maddox 等,2019)——假設后驗為“對角+低秩”結構;
深度集成(deep ensembles)(Lakshminarayanan 等,2017)——目前被公認為深度學習不確定性量化的最先進方法(Ovadia 等,2019;Ashukha 等,2020)。
我們采用5個網絡構成的集成(據 Ovadia 等建議),并對 MC Dropout、對角拉普拉斯與 SWAG 均采樣16次。Dropout 概率設為 0.1;對角拉普拉斯的先驗精度通過網格搜索確定為 = 4 × 10?。所有方法均應用于 ResNet-18(He 等,2016):包含1個輸入卷積塊、8個殘差塊和1個線性層,共計 11,168,000 個參數。
對于子網絡推斷,我們采用式(19)中的線性化預測分布;并使用 Wasserstein 子網絡選擇策略,僅保留 0.38% 的權重,得到一個僅含 42,438 個權重的子網絡——這是當前計算條件下可處理全協方差矩陣的最大規模(其大小為)。我們采用對角 SWAG(Maddox 等,2019)估算子網絡選擇所需的邊際權重方差。我們曾嘗試對角拉普拉斯,但發現其所選權重對應于在訓練點上神經網絡雅可比恒為零的位置(即“死亡 ReLU”);此類權重的后驗方差雖大(近似先驗),但對網絡輸出幾乎無影響。SWAG 不受此問題困擾,因其忽略了訓練梯度為零的權重。子網絡推斷的先驗精度經網格搜索設為 = 500。
為評估原則性子網絡選擇的重要性,我們另設一基線:均勻隨機選擇子網絡(記為 Ours (Rand))。我們開展以下兩個實驗,結果見圖4:
旋轉 MNIST:參照(Ovadia 等,2019;Antorán 等,2020),所有方法在 MNIST 上訓練,并在逐步增大的數字旋轉角度下評估其預測分布。盡管所有方法在原始 MNIST 測試集上表現良好,但當旋轉角度超過 30 度時,準確率迅速下降。就對數似然(LL)而言,集成在基線方法中表現最優;而子網絡推斷的 LL 顯著高于幾乎所有基線(包括集成),唯一例外是 VOGN(表現略優)。值得注意的是,Ovadia 等(2019)亦觀察到:平均場變分推斷(VOGN 屬于此類)在 MNIST 上表現極強,但在更大規模數據集上性能顯著下降。子網絡推斷在分布內能做出準確預測,同時對分布外樣本賦予比基線更高的不確定性。
損壞 CIFAR:同樣參照(Ovadia 等,2019;Antorán 等,2020),所有方法在 CIFAR10 上訓練,并在經 16 類不同損壞(每類5個強度等級)的數據上評估(Hendrycks & Dietterich, 2019)。由于局部線性化使預測均值與 MAP 一致,子網絡推斷在預測誤差上與 MAP 網絡相當;集成與 SWAG 準確率最高。然而,子網絡推斷的獨特優勢在于過自信程度最低——在所有損壞強度等級下,其對數似然均優于所有基線方法。此時 VOGN 表現較差;但這與其在 MNIST 上的優異表現看似矛盾——實則再次印證了 Ovadia 等(2019)的發現:平均場變分推斷在 MNIST 上表現良好,但在更大數據集上性能下降。
此外,在兩項基準測試中,隨機選擇子網絡的表現顯著劣于我們提出的 Wasserstein 選擇策略,凸顯了子網絡選擇方式的重要性。
綜上,這些結果表明:子網絡推斷在不確定性校準與分布偏移魯棒性方面,優于其他主流不確定性量化方法。
![]()
![]()
適用范圍與局限性
多輸出模型中的雅可比矩陣計算仍具挑戰性。在當前主流深度學習框架中,由于采用反向模式自動微分,其計算所需反向傳播次數等于模型輸出數量。這使得線性化拉普拉斯方法難以應用于語義分割(Liu 等,2019)或類別數極多的分類任務(Deng 等,2009)。需注意,該問題僅限于線性化拉普拉斯方法本身;其他無此限制的推斷方法仍可納入本框架使用。
先驗精度 的選擇在很大程度上決定了拉普拉斯近似的性能。我們提出的子網絡先驗精度更新方案依賴于對全網絡已有合理參數設定。然而,由于全網絡推斷常不可行,目前選擇 的最佳方式是直接在子網絡近似上進行交叉驗證。
海森矩陣的存儲需求限制了子網絡權重的最大規模。例如,存儲 4 萬個權重對應的海森矩陣約需 6.4 GB 內存。對于現代 Transformer 等超大規模模型,可計算的子網絡僅占總權重極小比例。盡管我們已證明優異性能未必依賴大型子網絡(見圖5),但探索更優的子網絡選擇策略仍是未來研究的關鍵方向。

相關工作
貝葉斯深度學習:針對神經網絡權重后驗分布 p ( w ∣ D ) 的刻畫已有大量研究。迄今為止,哈密頓蒙特卡洛(Hamiltonian Monte Carlo, HMC;Neal, 1995)仍是貝葉斯神經網絡(BNNs)中近似推斷的黃金標準。盡管其在漸近意義上無偏,但基于采樣的方法難以擴展至大規模數據集(Betancourt, 2015)。因此,近年來更流行的做法是在某一近似分布族(通常為高斯分布)中尋找最優代理后驗。其中最早的是 MacKay(1992)提出的拉普拉斯近似,他也同時建議使用線性化模型的后驗來近似預測后驗(Khan 等,2019;Immer 等,2020)。隨著更大規模神經網絡的普及,能捕捉權重間相關性的代理分布因計算不可行而受限;因此,絕大多數現代方法轉而采用平均場假設(Blundell 等,2015;Hernández-Lobato & Adams,2015;Gal & Ghahramani,2016;Mishkin 等,2018;Osawa 等,2019),但這犧牲了模型表達能力(Foong 等,2019a)并導致實證性能下降(Ovadia 等,2019;Antorán 等,2020)。Farquhar 等(2020)曾提出:在更深網絡中,平均場假設或許并不構成限制;但我們的實證結果似乎與該觀點相悖。我們發現,通過降低權重空間維度來擴展那些能考慮權重相關性的近似方法(如 MacKay,1992;Louizos & Welling,2016;Maddox 等,2019;Ritter 等,2018),其性能優于對角近似。由此我們認為,該方向仍需進一步深入研究。
神經線性方法(Neural Linear Methods):此類方法可視為廣義線性模型,其基函數由神經網絡前 l ? 1 層定義;即僅對神經網絡最后一層進行推斷,其余層保持固定(Snoek 等,2015;Riquelme 等,2018;Ovadia 等,2019;Ober & Rasmussen,2019;Pinsler 等,2019;Kristiadi 等,2020)。它們也可被視作子網絡推斷的特例——其中子網絡被簡單地定義為網絡的最后一層。
子空間推斷:神經網絡剪枝這一子領域旨在通過識別實現準確預測所需的最小權重子集來提升計算效率(例如 Frankle & Carbin,2019;Wang 等,2020)。我們的工作與其不同:我們保留全部網絡權重,但目標是找到一個用于概率推理的小型權重子集。與我們更密切相關的是 Izmailov 等(2019)的工作,他們提出在低維權重子空間(例如由 SGD 軌跡主成分構造的子空間)上進行推斷。此外,若干近期方法在變分推斷框架下采用低秩參數化來近似后驗(Rossi 等,2019;Swiatkowski 等,2020;Dusenberry 等,2020),這亦可視為在權重空間的某種隱式子空間上進行推斷。相比之下,我們提出了一種顯式識別與預測不確定性相關權重子集的技術——即尋找坐標軸對齊的子空間(axis-aligned subspaces)。
結論
本研究得出三項主要結論:
1)在神經網絡中建模權重相關性對獲得可靠的預測后驗至關重要;
2)在考慮此類相關性的前提下,單峰后驗近似即可與多峰近似(如深度集成)相媲美;
3)為獲得可靠的預測后驗,無需對全部權重進行推斷。
基于上述洞見,我們構建了一種將貝葉斯推斷擴展至大規模神經網絡的框架:僅對權重子集進行后驗近似,其余權重保持為確定性點估計。該框架將計算成本與總參數量解耦,從而可靈活權衡計算開銷與近似質量,并得以采用更具表達力的后驗近似(如全協方差高斯分布)。
線性化拉普拉斯子網絡推斷方法可事后應用于任意預訓練模型,極具實用價值。實證分析表明,該方法:
1)相較于全網絡采用粗糙近似的方法,表達能力更強,能保留更多不確定性;
2)允許我們使用容量更大、函數擬合能力更廣的神經網絡,而不犧牲不確定性估計質量;
3)性能與當前最先進的不確定性量化方法(如深度集成)相當。
我們期待未來進一步探索:將子網絡推斷與不同近似推斷方法結合、開發更優的子網絡選擇策略,并深入研究子網絡對預測分布特性的影響。
原文鏈接:https://proceedings.mlr.press/v139/daxberger21a/daxberger21a.pdf
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.