AI 和 LLM 的進步通常歸因于三個方面的持續(xù)改進:模型、數(shù)據(jù)、計算。三者互相關(guān)聯(lián)。要跑起那些參數(shù)量龐大的模型,就需要足夠的計算資源來支撐。Llama 3 最大的模型超過 4000 億參數(shù)在 16000 塊 GPU 上訓練了數(shù)周乃至數(shù)月,優(yōu)化計算意味著在更低的成本下訓練更大的模型。
本文將介紹 GPU 的核心特性,并據(jù)此討論如何設(shè)計更快的算法。
GPU 與 CPU 的區(qū)別
CPU 的優(yōu)化目標是單任務延遲,盡可能快地完成一個任務然后轉(zhuǎn)向下一個,這對通用計算是非常合理的。但是GPU 則不同,它優(yōu)化的是吞吐量追求的是同時完成多個并行任務。打個比方:CPU 像一個能力極強的工人,GPU 像一群普通工人同時干活。在 LLM 訓練這種大規(guī)模并行處理場景下GPU 的架構(gòu)天然占優(yōu)。
繼續(xù)用工廠來打比方。GPU 可以看作一個龐大的工廠城鎮(zhèn)。城鎮(zhèn)中有多個"工廠集群"(技術(shù)上叫流式多處理器,SM),每個集群包含多個工廠(流式處理器,SP)和一個小倉庫(共享內(nèi)存)。整個城鎮(zhèn)里還有一個全局倉庫(DRAM),離各集群更遠但容量大得多。
![]()
類比雖然簡化但說明了 GPU 中一條核心:集群內(nèi)的小倉庫訪問速度遠快于全局倉庫,代價是容量小得多。
全局倉庫的運輸通道到底有多慢?過去 20 年間,硬件浮點運算能力(對應工廠車間的加工速度)提升了 60000 倍,DRAM 帶寬只提升了 100 倍,互連帶寬更是只有 30 倍。
![]()
過去的瓶頸在計算,但是現(xiàn)在的瓶頸在內(nèi)存帶寬。既然數(shù)據(jù)搬運才是真正的瓶頸,減少搬運次數(shù)和搬運量就是讓 GPU 跑得更快的關(guān)鍵。以下五個技巧,都圍繞這一思路展開,來自 CS336 課程。
技巧 1:低精度計算
矩陣乘法中,數(shù)字精度是可以選擇的。精度越高,存儲一個數(shù)字所需的字節(jié)越多:9.327595 比 9.33 占的空間大。用低精度數(shù)字意味著搬運的"貨物"更少,在擁堵的"道路"上花費的時間也更短。
這意味著用 fp16 代替 fp32,但并非訓練的所有階段都需要低精度,只需在數(shù)據(jù)搬運階段降到 fp16 即可。具體做法是:輸入以 fp16 格式傳入,矩陣乘法在 32 位精度下完成(計算并非瓶頸,而且高精度可以防止舍入誤差的逐步累積),輸出再降回 fp16 用于傳輸。
![]()
回到工廠類比:道路擁堵(下圖紅線),所以進出工廠的箱子越小越好。工廠內(nèi)部空間充裕,可以在大空間中完成加工,加工完成后再打包成小箱子運出。
技巧 2:算子融合
假設(shè)工廠有三步操作:正方形變圓形,圓形變?nèi)切危切巫冃切巍H绻客瓿梢徊骄桶寻氤善匪突貍}庫再取回來做下一步,那來回搬運的次數(shù)非常多。
![]()
算子融合的做法是把多步操作在工廠內(nèi)一次性完成,省去中間產(chǎn)品的反復搬運。
![]()
實現(xiàn)方式有兩種:手寫低級代碼控制融合細節(jié),或者直接用 torch.compile 自動完成優(yōu)化。
技巧 3:重計算
這個場景稍微復雜一些,假設(shè)工廠從倉庫取了一個正方形,依次加工為圓形、三角形、星形。星形被送回倉庫供后續(xù)使用。但到了最終步驟,四種形狀全部要用——正方形、圓形、三角形、星形。工廠內(nèi)部存不下東西,所有存儲必須依賴倉庫。
安排生產(chǎn)線有兩條路:
- 選項 1:加工過程中把圓形和三角形也送回倉庫保管。需要的時候直接取回。
- 選項 2:不保存中間形狀,丟掉就丟掉。需要的時候從正方形重新加工一輪。
選項 1 省了重新加工的電力,但倉庫搬運量增大。選項 2 搬運量小,但要額外消耗算力。這是一個內(nèi)存與計算之間的權(quán)衡。
既然瓶頸在道路擁堵而非車間產(chǎn)能,重計算(選項 2)是更合理的選擇:重新加工成本低,但從倉庫搬運的成本可能高出幾個數(shù)量級。用算力換內(nèi)存帶寬,劃算。
技巧 4:內(nèi)存合并訪問
倉庫有個特點:貨物按板條箱整箱發(fā)出。工廠請求任何一件物品,倉庫都會把整個板條箱送過來。優(yōu)化的要點在于:把需要的物品盡量集中在同一個箱子里。
假設(shè)每箱 4 件,工廠需要 8 件。如果這 8 件集中在 2 個箱子里,取 2 箱就夠了。如果散落在 8 個箱子中,就得搬 8 箱——搬運成本翻了四倍。
技術(shù)上,DRAM 以"突發(fā)模式"讀取,每次讀取返回一段連續(xù)字節(jié)。即使處理器只需要其中一個地址的數(shù)據(jù),整個突發(fā)段也會被送過來。當所有線程的訪問地址落在同一個突發(fā)段內(nèi)時,只需一次 DRAM 請求,這種情況稱為完全合并訪問。
一個直接的推論:把維度(比如詞匯表大小)對齊到 64 的倍數(shù)會帶來可觀的速度提升。
![]()
原因很簡單:分塊操作(見下一個技巧)需要沿突發(fā)段的邊界讀取數(shù)據(jù),如果分塊邊界與突發(fā)段不對齊,讀取次數(shù)會急劇增加。
技巧 5:分塊
分塊的核心思想:把大矩陣切成小塊,加載到共享內(nèi)存(集群內(nèi)的小倉庫)中,避免反復訪問全局內(nèi)存。
以兩個 4x4 矩陣 A 和 B 的乘法為例,結(jié)果是 4x4 矩陣 C。計算 C 的某幾個元素時,需要在 A 和 B 矩陣上多次跨行/跨列讀取,每次讀取都要訪問全局內(nèi)存。
![]()
分塊的做法是將 A 和 B 各切成四塊。小塊可以整塊加載到共享內(nèi)存中。先加載紅色塊,計算部分和(圖中橙色部分):
![]()
接著加載下一組塊,繼續(xù)累加部分和。總計算量不變,但每一步都在共享內(nèi)存中完成而非反復訪問全局內(nèi)存,節(jié)省的時間相當可觀。
FlashAttention
有了上面五個技巧做鋪墊,可以來看 FlashAttention 了。
先簡要回顧注意力機制。權(quán)重矩陣將隱藏向量投影為 Q、K、V,然后對每個詞的 q 和 k 向量求點積(等價于 Q × K.T 的矩陣乘法),得到原始注意力分數(shù)——即每個查詢詞對各個鍵詞的關(guān)注程度。對原始分數(shù)做 softmax 歸一化,使其加和為 1。
數(shù)值穩(wěn)定性方面,取指數(shù)之前先減去最大值。e12 已經(jīng)是 162,755,超出 fp16 的上限 65,504,直接計算會溢出。減去最大值不改變 softmax 結(jié)果,但規(guī)避了溢出(詳見附錄)。
![]()
歸一化后的 softmax 分數(shù)與每個詞的"值"向量相乘、求和,得到最終的注意力輸出。
![]()
回到 FlashAttention。Q 和 K 相乘產(chǎn)生一個 N × N 矩陣(N 為序列長度)。當上下文窗口很大時,這個矩陣無法整個放入共享內(nèi)存。
解決方案是沿 N 維度分塊。比如上下文窗口 1028,按 64 切塊,每塊可以載入共享內(nèi)存。這樣仍有完整的點積結(jié)果(無需計算部分和),只是逐塊填充結(jié)果矩陣。
![]()
分塊本身是標準操作,棘手的部分在于 softmax 和后續(xù)的值向量加權(quán)求和。計算 softmax 通常需要整行數(shù)據(jù)來做歸一化,而訪問整行意味著要回全局內(nèi)存取數(shù)據(jù)。FlashAttention 的突破在于"在線 softmax"——softmax 計算和值向量加權(quán)求和可以在塊內(nèi)一次性完成,無需看到全行數(shù)據(jù)。關(guān)鍵條件是最終操作是加權(quán)求和,這給了逐塊修正的數(shù)學余地。
下面用一個例子來說明。假設(shè) QK 矩陣乘法產(chǎn)生了六個原始分數(shù),表示某個查詢詞對六個其他詞的關(guān)注度。常規(guī)做法是一次性對六個分數(shù)做 softmax 再與六個值向量加權(quán)求和,得到 A:
![]()
但遍歷整個長度 N 的序列在塊內(nèi)放不下。于是按"在線"方式進行:將六個分數(shù)分為三個塊(每塊 2 個元素),逐塊處理。第一個塊中只有兩個原始分數(shù),先基于這兩個值做計算:
![]()
這一步不做歸一化。雖然可以用當前的和(1+0.0082)歸一化,但后續(xù)塊會改變總和,到頭來還得修正。所以更好的做法是記錄歸一化分母的累積值最后一步統(tǒng)一歸一化。
進入第二個塊。目標是得到與一次性看到所有四個值相同的結(jié)果。四個值的全局最大值是 12,第一個塊需要把自己的最大值傳遞過來。累積的加權(quán)和與歸一化分母也要一并傳遞。
![]()
到目前為止,如果只有四個值,取 A_(1+2) 除以歸一化總和 1.3098 就能得到最終結(jié)果。
最后一個邊界情況是:新塊出現(xiàn)了更大的最大值。第三個塊的最大值從 12 變成了 13,但之前的 A_(1+2) 是按 max=12 算的。要讓結(jié)果與一次性看到全部六個值一致,就需要修正之前的計算——將所有舊指數(shù)乘以 e^(-1)(即 e^(12-13)),補償最大值的變化。
不需要逐個回去修正每個指數(shù)值,只需將 A_(1+2) 和歸一化分母整體乘以 e^(-1) 即可:
![]()
最后用累積分子 A_(1+2+3) 除以更新后的分母 1.4955,得到結(jié)果。整個過程從未回訪之前的塊:只要跟蹤最大值和歸一化分母,就能逐塊完成 softmax。這些操作都在共享內(nèi)存中進行,不必頻繁訪問全局內(nèi)存。
效果如何?FlashAttention 原始論文顯示,在 GPT-2 上注意力計算的耗時減少了數(shù)倍。
![]()
處理大規(guī)模模型時,內(nèi)存放置策略,比如盡量在共享內(nèi)存中完成計算對整體性能的影響遠超我們的想象。
并行計算簡介
以上討論都局限在單 GPU 上,小模型沒問題,但現(xiàn)代大型 LLM 根本裝不進一塊 GPU。Llama 3 用了 16K 塊 GPU,核心問題變成了:如何將訓練計算分配到多臺機器上,再將結(jié)果匯總起來。
在展開不同的并行策略之前,先回顧訓練流程。以一個 2 層神經(jīng)網(wǎng)絡(luò)為例,batch size 16,使用 Adam 優(yōu)化器(為每個參數(shù)維護一階和二階矩估計)。
![]()
數(shù)據(jù)并行
拆分計算的第一種方式是拆分數(shù)據(jù)。假設(shè)有效 batch size 為 16,但每塊 GPU 內(nèi)存只夠放 4 個樣本。單 GPU 下需要跑 4 輪前向傳播來累積梯度,再做一次反向傳播,即梯度累積。
![]()
數(shù)據(jù)并行的做法:把 16 個樣本分給 4 塊 GPU,每塊拿 4 個樣本,各自并行執(zhí)行前向傳播。問題在于如何聚合梯度。
一種方式是匯集所有激活值來計算平均 loss 再求梯度,但更聰明的做法是在每塊 GPU 上各自計算 4 個樣本的梯度再求和——搬運的數(shù)據(jù)量更小,數(shù)學上完全等價。梯度求和后傳回各機器,分別更新本地模型。
![]()
這個操作的技術(shù)術(shù)語是 all-reduce:每臺機器貢獻各自的梯度,合并后每臺機器都拿到結(jié)果。雖然圖示中畫了一個"聚合器"(灰色方框),實際的 all-reduce 實現(xiàn)通常是環(huán)形傳遞——GPU 之間互相傳梯度,最終全部拿到平均值。
4 塊 GPU 并行,有效 batch size 仍然是 16,速度卻快了很多。但有一個效率問題:每塊 GPU 都在做完整模型的更新,要維護所有參數(shù)的 Adam 狀態(tài)(一階矩和二階矩)。
內(nèi)存充裕時這不成問題。但實際上每塊 GPU 里復制了完整的模型參數(shù)、梯度、主權(quán)重以及 Adam 優(yōu)化器狀態(tài)。Adam 的狀態(tài)量是模型參數(shù)量的兩倍,內(nèi)存占用很大。
對于大模型,內(nèi)存成為硬瓶頸。ZeRO(Zero Redundancy Optimizer)針對的就是這個問題:一組內(nèi)存優(yōu)化技術(shù),在保持數(shù)據(jù)并行的前提下大幅減少每塊 GPU 的內(nèi)存占用。
ZeRO Stage 1
核心思想是讓每塊 GPU 只負責更新一部分參數(shù)。比如將每層參數(shù)分成四份,GPU 1 負責 Part 1,GPU 2 負責 Part 2,以此類推。
走一遍流程:數(shù)據(jù)仍然拆分到四塊 GPU 上,每塊 GPU 基于自己看到的 4 個樣本計算完整的梯度——到這里還是標準的數(shù)據(jù)并行。但梯度匯總后不再發(fā)回給所有人,而是按參數(shù)分片發(fā)送:每塊 GPU 只收到自己負責那部分參數(shù)的梯度。術(shù)語上叫 reduce-scatter——每人只拿到合并結(jié)果的一個切片。
各 GPU 只更新自己負責的那部分參數(shù),也只需要保留該部分的優(yōu)化器狀態(tài)。更新完成后,各 GPU 把自己的參數(shù)切片分享出去,拼接成完整模型。術(shù)語上叫 all-gather——每人貢獻一個切片,每人拿到完整拼接結(jié)果。
![]()
ZeRO Stage 1
整個過程可以概括為兩階段:第一階段按數(shù)據(jù)維度拆分,各 GPU 算全參數(shù)梯度再匯總;第二階段按參數(shù)維度拆分,各 GPU 只更新自己負責的參數(shù)切片,最后拼接出完整模型。
效果是每塊 GPU 只保留一小部分優(yōu)化器狀態(tài),內(nèi)存節(jié)省很可觀。計算量方面,reduce-scatter 加 all-gather 的總通信量與樸素數(shù)據(jù)并行中的 all-reduce 等價,沒有額外開銷。
ZeRO Stage 2
ZeRO Stage 2 更進一步——不僅優(yōu)化器狀態(tài)分片,梯度本身也要分片。
關(guān)鍵在于反向傳播是逐層進行的。每一層的梯度算完后,立刻將不屬于自己管轄的部分發(fā)送給對應 GPU 并丟棄。不需要在任何時刻存儲全部層的完整梯度。
在 ZeRO Stage 1 的流程中,要改變的是這一部分:
![]()
改為逐層處理梯度,紅框中的部分變成如下流程:
![]()
第 2 層的梯度算完,把不負責的部分發(fā)出去、丟棄,然后處理第 1 層,重復同樣的步驟。層數(shù)多的 LLM 從中獲益明顯——不需要同時存儲所有層的梯度。代價是逐層通信帶來少量額外開銷。
ZeRO Stage 3,也稱為完全分片數(shù)據(jù)并行(FSDP)
ZeRO Stage 3 把分片推到了極致——連模型權(quán)重都只存各自負責的那部分。這意味著前向傳播也會受到影響。
流程同樣是逐層進行的。到第 1 層時,執(zhí)行 all-gather,各 GPU 各出自己的權(quán)重切片,拼出完整的第 1 層。每塊 GPU 用完整的第 1 層權(quán)重和各自的數(shù)據(jù)計算激活值,算完后立刻丟棄不屬于自己的權(quán)重切片。第 2 層同理。
反向傳播與 ZeRO Stage 2 類似,但多了一步:每層計算梯度前要先 all-gather 把完整權(quán)重拼出來(因為本地沒有完整權(quán)重),算完后再丟棄非本地切片。
本質(zhì)上是按需逐層從各 GPU 拼出模型,任何時候都沒有一塊 GPU 持有全部權(quán)重。通信開銷增加了,但內(nèi)存節(jié)省巨大。對于給定的 GPU 配置,ZeRO Stage 3 能訓練的模型規(guī)模遠超前兩個階段。
CS336 課程給出的數(shù)據(jù):8 塊 A100 80GB GPU 上,不同策略可訓練的最大模型尺寸差異很大。
![]()
同樣的硬件配置下,ZeRO Stage 3 能訓練的模型大了很多。
不過數(shù)據(jù)并行有一個約束條件:batch size。batch size 不能小于 GPU 數(shù)量:沒法給一臺機器半個樣本。而 batch size 越大收益越低:大 batch 降低數(shù)據(jù)噪聲方差,但超過一定閾值后邊際收益接近于零。batch size 的"自然上限"直接限制了數(shù)據(jù)并行的擴展規(guī)模。
模型并行
除了按數(shù)據(jù)維度拆分,還可以按模型維度切分,即模型并行。這里介紹兩種形式:流水線并行和張量并行。
模型并行:流水線并行
流水線并行沿深度方向切分模型,一層分配給一塊 GPU。問題在于前向和反向傳播都是逐層串行的——每層需要前一層的輸出才能開始計算,GPU 在等待輸入時空閑,形成"氣泡"。
![]()
縮小氣泡的方法是引入 mini-batch 級別的流水線:第二塊 GPU 處理某個 mini-batch 的第二層時,第一塊 GPU 可以開始處理下一個 mini-batch 的第一層。
![]()
流水線并行的優(yōu)勢在于內(nèi)存節(jié)省,每個設(shè)備只存一層的參數(shù)以及通信模式簡單,只需將激活值從一層傳到下一層。這種簡單的通信特性使它適合部署在跨集群等帶寬較低的網(wǎng)絡(luò)鏈路上。
張量并行
張量并行沿寬度方向切分,把單層內(nèi)的矩陣乘法分配到多塊 GPU 上并行執(zhí)行,各自得到部分結(jié)果后再跨 GPU 求和。概念上類似于分塊運算,區(qū)別在于分塊是串行處理各塊,張量并行是并發(fā)處理。
通信量很大——每層都要同步激活值。節(jié)點內(nèi)部 NVLink 帶寬在 600-900 GB/s,跨節(jié)點互連慢 10-20 倍。實踐經(jīng)驗表明:張量并行擴展到 8 塊 GPU 以上時,收益會急劇衰減。所以通常將張量并行限制在單個節(jié)點(最多 8 塊 GPU)內(nèi)。
張量并行有一個獨特優(yōu)勢:不依賴 batch size。batch size 是數(shù)據(jù)并行和流水線并行共享的約束資源,張量并行與之正交,可以疊加使用而不消耗這項資源。
組合不同形式的并行
幾種并行策略分別沿不同維度拆分計算:數(shù)據(jù)維度、模型深度維度、模型寬度維度。實際訓練中通常是多種策略的組合。
經(jīng)驗法則很簡單:先解決內(nèi)存問題,確保模型能裝進 GPU。裝不下就用流水線并行、張量并行、ZeRO Stage 3 等節(jié)省內(nèi)存的技術(shù)。模型能裝下之后,再用數(shù)據(jù)并行等手段堆算力,加快每個 batch 的處理速度。
附錄:Softmax 解釋
softmax 將一組原始分數(shù)變換為加和為 1 的概率分布:對每個分數(shù)取指數(shù),然后除以所有指數(shù)之和。
![]()
以三個分數(shù)(12, 7.2, 9.1)為例:
![]()
問題在于指數(shù)值增長極快。e12 已經(jīng)是 162,755,超過 fp16 的最大值 65,504。理論值雖然正確,但計算過程中會溢出。解決辦法是將分子和分母同時除以 e^(max),等價于從所有原始分數(shù)中減去最大值:
![]()
數(shù)學上結(jié)果完全一致,但避免了溢出。可能出現(xiàn)下溢(值太接近零),不過下溢時 0 已經(jīng)是足夠好的近似。這一數(shù)學技巧被幾乎所有 LLM 的 softmax 實現(xiàn)采用。
總結(jié)
這篇文章從 GPU 架構(gòu)講到并行策略,涉及的是把模型從玩具規(guī)模拉到生產(chǎn)規(guī)模所必須面對的工程問題。在專業(yè)團隊中,訓練一個無法放入單塊 GPU 的 LLM 是常態(tài),優(yōu)化訓練成本也是日常工作的一部分。理解底層硬件和并行機制,是做好這些工作的前提。
https://avoid.overfit.cn/post/8b2888b82d7c40c3b60e7e8847dafc9f
by Joseph
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺“網(wǎng)易號”用戶上傳并發(fā)布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.