<cite id="ffb66"></cite><cite id="ffb66"><track id="ffb66"></track></cite>
      <legend id="ffb66"><li id="ffb66"></li></legend>
      色婷婷久,激情色播,久久久无码专区,亚洲中文字幕av,国产成人A片,av无码免费,精品久久国产,99视频精品3
      網易首頁 > 網易號 > 正文 申請入駐

      從零開始用自定義 Triton 內核編寫 FlashAttention-2

      0
      分享至


      本文實現 FlashAttention-2 的前向傳播,具體包括:為 Q、K、V 設計分塊策略;流式處理 K 和 V 塊而非物化完整注意力矩陣;實現在線 softmax 算法保證數值穩定性;支持因果和非因果兩種注意力模式;用 Triton autotuner 自動調優內核配置;最后用 PyTorch 驗證正確性。



      FlashAttention vs. standard attention vs torch2.2 (spda flashattn) TFLOP/s benchmarks

      標準注意力為什么是內存受限的

      標準注意力的瓶頸不在浮點運算量而在內存帶寬。普通注意力計算 S = QK? 之后,要把完整的 N × N 矩陣寫入 HBM再讀回來算 softmax 并存儲然后再讀一次乘以 V,每個元素被訪問 2-4 次每次都走 HBM。

      序列長度 16K 時,這個矩陣包含 16,3842 ≈ 2.56 億個元素。

      反復在 HBM 和計算單元之間搬運這幾億個值,而HBM 是 GPU 上容量最大的內存也是最慢的。A100 上從 HBM 讀數據比從片上 SRAM 讀大約慢 15 倍。大張量和模型權重都放在這里,所以寫內核的首要目標就是減少 HBM 流量把高頻訪問的數據留在寄存器或共享內存里。

      核心方案——讓注意力具備 IO 感知能力

      FlashAttention 的核心思想是讓注意力變得 IO 感知。所謂 IO 感知就是真正理解并利用一個這個定義:片上 SRAM 比 HBM 快幾個數量級。NVIDIA A100 有 40-80GB HBM(也就是那個讓你頻繁遭遇 CUDA OOM 的全局內存)帶寬 1.5-2.0 TB/s;每個 SM 有 192KB SRAM,共 108 個 SM,帶寬估計 19TB/s 左右。

      GPU 硬件有個黃金法則:

      把數據搬到內存層次的上層然后留在那里。除非萬不得已別回 HBM。

      標準注意力完全無視這條規則,把 HBM 讀寫當成零成本操作。FlashAttention 計算的結果和標準縮放點積注意力完全一樣:

      S = QK? ∈ ????,P = softmax(S) ∈ ????,O = PV ∈ ????

      區別在于計算的調度方式。FlashAttention 不在 HBM 里存儲那個巨大的 N × N 注意力矩陣然后再讀回來算 softmax而是重新組織計算:分塊處理序列從全局內存流式讀取 K 和 V 塊,用在線 softmax 增量計算每個塊的部分結果,逐步構建輸出矩陣 O反向傳播時還可以選擇重算而非存儲。

      具體操作是這樣的:拿一塊查詢 Q_block,然后分塊迭代 K 和 V 序列,邊迭代邊做在線 softmax 同時追蹤必要的統計量,累積輸出塊并在片上歸一化,只把最終結果寫回 HBM。

      這樣注意力的內存復雜度就從 O(N2) 降到了 O(N)。

      最難的部分——Softmax

      分塊矩陣乘法不難,而分塊 softmax 才是麻煩事。注意力中 token i 對其他 token 的關注程度,是對該行所有注意力分數做 softmax 得到的:



      普通注意力里這很簡單,因為一個 token 的全部注意力分數已經物化在內存中,一步就能算完最大值、歸一化、softmax。

      而FlashAttention 里情況不一樣,鍵和值是分塊流式進來的內核迭代 K 和 V 時只能看到部分分數塊,永遠看不到完整的分數集,就沒法一步算完 softmax。

      解決方案是在線 softmax 公式。不一步算完,而是維護三個逐查詢的狀態:運行最大值 m?(保證數值穩定),運行歸一化項 l?,運行輸出累加器 O?。每來一個新的注意力分數塊,就更新這些值,最后恢復的結果和對整個序列做完整 softmax 一模一樣。



      完整代碼分解

      從高層看,實現結構如下:

      for each (batch, head):
      for each Q_block:
      initialize m_i, l_i, O_block
      for each K/V block:
      compute partial scores
      update online softmax state
      accumulate output
      write O_block to memory

      所有邏輯融合在內核里,中間狀態全部駐留在片上快速內存。下面逐步講解這個結構如何映射到 Triton 程序和 GPU 執行。

      Host 包裝器和內核啟動

      Python 包裝器負責準備輸入并啟動 Triton 內核,做三件事:驗證和提取輸入張量的形狀與步幅,構建內核執行網格,啟動前向注意力內核。包裝器本身不含注意力邏輯,只定義工作如何在 GPU 上調度。

      # Host wrapper that prepares our inputs and parameters and runs the triton kernel
      class TritonFlashAttention(torch.autograd.Function):
      @staticmethod
      def flash_attention(Q, K, V, causal):
      assert Q.is_cuda
      assert K.is_cuda
      assert V.is_cuda
      B, H, Lq, D = Q.shape
      B, H, Lk, D = K.shape
      B, H, Lk, D = V.shape
      # create the output buffer
      O = torch.empty_like(Q)
      # we set block_sizes manually for now. We will autotune this later
      #BLOCK_SIZE_Q = 128
      #BLOCK_SIZE_KV = 32
      stage = 3 if causal else 1
      grid = lambda x: (triton.cdiv(Lq, x["BLOCK_SIZE_Q"]),
      B * H, 1)
      M = torch.empty((B, H, Lq), device=Q.device, dtype=torch.float32)
      scaling_factor = 1 / math.sqrt(D)
      fwd_flash_attn_kernel[grid](Q, K, V, O, M, scaling_factor,
      Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
      K.stride(0), K.stride(1), K.stride(2), K.stride(3),
      V.stride(0), V.stride(1), V.stride(2), V.stride(3),
      O.stride(0), O.stride(1), O.stride(2), O.stride(3),
      B, NUM_HEADS=H, SEQ_LEN=Lq, HEAD_DIM=D, STAGE=stage,)
      #ctx.save_for_backward
      return O

      程序網格和并行化策略

      host 包裝器里定義了一個 2D 執行網格,決定 GPU 如何分配工作,也就是并行啟動多少個 Triton 程序實例。

      grid = lambda x: (triton.cdiv(Lq, x["BLOCK_SIZE_Q"]), B * H, 1)

      第一維 program_id(0) 標識程序實例處理的查詢序列塊,第二維 program_id(1) 標識對應的 (batch, head) 對。

      維度 0 把查詢序列分成 BLOCK_SIZE_Q 大小的塊,Lq 是查詢序列長度,每個程序實例負責計算輸出矩陣的一個水平"條帶"。維度 1 跨所有 batch 和 head 并行,每個程序實例對應一個 (batch, head) 對。給每個注意力頭分配獨立程序可以最大化占用率。內核內部用 tl.program_id 配合手動步幅算術(qb_stride、qh_stride)把每個 worker 指向它的內存切片。

      每個程序實例負責計算:

      Q[batch, head, q_block : q_block + BLOCK_SIZE_Q]

      這種網格設計提供了序列維度并行、batch 和 head 并行,而且程序間不需要同步。每個程序在緊湊獨立的工作集上運行,tl.program_id 結合顯式步幅算術把每個實例映射到對應內存切片。

      內核分解

      前向傳播分成兩個內核。fwd_flash_attn_kernel 協調執行,加載查詢塊、處理因果邏輯、寫輸出。_attn_fwd_inner 實現核心 FlashAttention-2 計算,流式處理 K/V 塊并執行在線 softmax 更新。每個 Triton 程序實例計算一個查詢塊 × 一個注意力頭 × 一個 batch 元素。

      這種分解把控制邏輯和流式計算分開內核更容易理解和優化。

      前向內核

      這個內核本身不直接實現注意力算法,負責的是把 GPU 程序實例映射到輸入張量的對應塊,協調流式注意力計算,處理因果邏輯,把最終輸出寫回內存。

      @triton.jit
      def fwd_flash_attn_kernel(q_ptr, k_ptr, v_ptr, o_ptr, m_ptr, scale,
      qb_stride, qh_stride, qn_stride, qd_stride,
      kb_stride, kh_stride, kn_stride, kd_stride,
      vb_stride, vh_stride, vn_stride, vd_stride,
      ob_stride, oh_stride, on_stride, od_stride,
      BATCH_SIZE, NUM_HEADS:tl.constexpr, SEQ_LEN:tl.constexpr, HEAD_DIM:tl.constexpr,
      BLOCK_SIZE_Q:tl.constexpr, BLOCK_SIZE_KV:tl.constexpr, STAGE:tl.constexpr):
      # get the id of this program instance
      block_index_q = tl.program_id(0) # Which chunk of sequence this program is responsible for
      index_batch_head = tl.program_id(1) # what batch-head to process. zooms out
      # get exact batch
      index_batch = index_batch_head // NUM_HEADS
      # get exact head
      index_head = index_batch_head % NUM_HEADS
      # create offsets to get the index of sequences we are going to process
      qkv_offset = index_batch * qb_stride + index_head * qh_stride # i.e move from the first to the correct batch then move to the correct head within that batch
      qkv_offset_K = index_batch * kb_stride + index_head * kh_stride
      qkv_offset_V = index_batch * vb_stride + index_head * vh_stride
      qkv_offset_O = index_batch * ob_stride + index_head * oh_stride
      off_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) # same as off_q (in this head what q block do we need to read )
      off_kv = tl.arange(0, BLOCK_SIZE_KV)
      off_head = tl.arange(0, HEAD_DIM)
      # create blocks of pointers to get the address of where the index lives
      Q_block_ptr = q_ptr + qkv_offset + off_q[:, None] * qn_stride + off_head[None, :] * qd_stride
      O_block_ptr = o_ptr + qkv_offset_O + off_q[:, None] * on_stride + off_head[None, :] * od_stride
      m_i = tl.zeros((BLOCK_SIZE_Q,), dtype= tl.float32) - float("inf")
      l_i = tl.zeros((BLOCK_SIZE_Q,), dtype=tl.float32) + 1.0
      O_block = tl.zeros((BLOCK_SIZE_Q, HEAD_DIM), dtype=tl.float32)
      Q_block = tl.load(Q_block_ptr) # add a mask
      # stage 1: Blocks before the diagonal
      # stage 2: diagonal block itself
      # stage 3: for non-causal no masking is needed. For causal mask all the blocks here.
      # runs if causal is True i.e we mask out the future tokens from contributing
      # this if statement executes for non-causal attention (no masking) or for the blocks to the left of the diagonal in the causal attention
      # Stage = 3 if causal else 1
      if STAGE == 1 or STAGE == 3:
      O_block, l_i, m_i = _attn_fwd_inner(
      O_block,
      l_i,
      m_i,
      Q_block,
      block_index_q,
      scale,
      BLOCK_SIZE_Q,
      BLOCK_SIZE_KV,
      4 - STAGE,
      off_kv,
      off_q,
      off_head,
      kn_stride,
      kd_stride,
      vd_stride,
      vn_stride,
      k_ptr,
      v_ptr,
      qkv_offset_K,
      qkv_offset_V,
      SEQ_LEN,
      HEAD_DIM
      )
      # this executes for blocks to the right of the diagonal in the causal attention
      if STAGE == 3:
      O_block, l_i, m_i = _attn_fwd_inner(
      O_block,
      l_i,
      m_i,
      Q_block,
      block_index_q,
      scale,
      BLOCK_SIZE_Q,
      BLOCK_SIZE_KV,
      2,
      off_kv,
      off_q,
      off_head,
      kn_stride,
      kd_stride,
      vd_stride,
      vn_stride,
      k_ptr,
      v_ptr,
      qkv_offset_K,
      qkv_offset_V,
      SEQ_LEN,
      HEAD_DIM
      )
      m_i += tl.math.log(l_i)
      O_block = O_block / l_i[:, None]
      m_ptrs = m_ptr + index_batch_head * SEQ_LEN + off_q
      tl.store(m_ptrs, m_i)
      tl.store(O_block_ptr, O_block.to(tl.float16))

      網格映射

      回顧 Python 包裝器里的網格:

      grid = (
      ceil_div(Lq, BLOCK_SIZE_Q),
      B * H
      )

      這個 2D 網格映射提供序列維度并行和 batch/head 并行。

      內核內部:

      block_index_q = tl.program_id(0)
      index_batch_head = tl.program_id(1)

      解碼第二維:

      index_batch = index_batch_head // NUM_HEADS
      index_head = index_batch_head % NUM_HEADS

      這幾個變量唯一標識當前程序實例負責哪個 batch 元素、哪個注意力頭、哪個查詢塊。

      指針算術和張量布局

      PyTorch 或 numpy 里用多維語法索引張量,比如 Q[batch, head, seq_pos, dim]。而Triton 內核里沒有多維張量,只有指向輸入第一個元素的裸指針 q_ptr必須用指針算術手動重構索引。

      查詢張量 Q 形狀是 [BATCH, HEADS, SEQ_LEN, HEAD_DIM],硬件層面是扁平一維數組存儲。沿每個維度移動用步幅:qb_stride 跳一個 batch,qh_stride 跳一個 head,qn_stride 跳一個 token,qd_stride 跳一個特征。

      選擇 batch 和 head

      每個程序實例先選定自己負責的 batch 和 head 切片:

      qkv_offset = index_batch * qb_stride + index_head * qh_stride

      這個偏移之后,指針指向 Q[batch, head, 0, :]。K、V、O 同理,用各自的步幅。然后構建當前塊的索引范圍:

      off_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
      off_head = tl.arange(0, HEAD_DIM)

      用這些偏移加廣播,構建指向查詢塊的指針:

      Q_block_ptr = q_ptr + qkv_offset \
      + off_q[:, None] * qn_stride \
      + off_head[None, :] * qd_stride

      輸出 O_block_ptr 也類似:

      O_block_ptr = o_ptr + qkv_offset_O \
      + off_q[:, None] * on_stride \
      + off_head[None, :] * od_stride

      完全用指針算術重現了 4D 索引 Q[batch, head, q_positions, head_dim]。

      這種顯式指針構建很關鍵,確保只加載每個程序實例需要的 Q 塊并送到 SRAM,避免碰不相關的內存,實現合并訪問,最大化緩存復用。

      初始化每塊狀態

      加載查詢塊后,內核初始化在線 softmax 所需的每塊狀態并分派流式計算。流式邏輯和因果階段的細節在 _attn_fwd_inner 里,后面分析。先理解這個每塊狀態為什么存在、代表什么。

      為了在迭代 K 和 V 塊時正確增量計算 softmax,需要追蹤三個量:運行最大值 m_i、運行 softmax 分母 l_i、未歸一化加權和 O_block。

      這三個變量構成在線 softmax 算法的狀態。FlashAttention 分塊處理鍵值,內核永遠無法一次訪問所有注意力分數。要得到和完整 softmax 一樣的結果,必須維護數值穩定用的運行最大值 m_i、運行歸一化因子 l_i、累積加權輸出 O_block。這些狀態共同作用,精確重建 softmax(QK?) @ V,不需要物化注意力矩陣。

      運行最大值 m_i 和運行歸一化器

      Softmax 涉及指數運算,FP16/BF16 下容易數值不穩定。為了把指數保持在合理范圍,每個查詢行追蹤一個運行最大值 m_i。處理新的 K 和 V 塊時,這個運行最大值可能增大。一旦增大,之前用舊最大值計算的累積貢獻就不在同一尺度上了。

      糾正辦法是用一個因子重新縮放累積的分母:



      the numerator



      the scaling factor



      the normalizing denominator

      這種重新縮放確保分母里所有項都相對同一個最大值。流式處理鍵值塊時反復應用這個更新就能恢復精確的 softmax 歸一化因子,不需要物化完整的注意力分數集。

      內核里是這樣寫:

      alpha = exp(m_old - m_new)
      l_i = l_i * alpha + l_ij

      累積輸出 O_block

      注意力輸出定義為:



      Final attention output

      標準實現里可以直接算,因為完整的 softmax 歸一化系數事先就知道。FlashAttention 里鍵值分塊流式進來,最終歸一化因子要等所有 K 和 V 塊處理完才能確定。

      所以只能累積一個未歸一化的加權和,最后再歸一化。

      每次迭代,計算相對于當前運行最大值的塊級 softmax 概率:



      維護一個未歸一化輸出累加器:



      unnormalized softmax output

      處理新 K/V 塊時運行最大值可能變,之前累積的輸出必須重新縮放以匹配新最大值。



      逐塊更新輸出累加器:

      O_block = O_block * alpha[:, None]
      O_block = P_block @ V_block + O_block

      所有 K/V 塊處理完后,把累積的未歸一化輸出除以累積的 softmax 分母 li 得到最終注意力輸出:



      final normalization

      結果和標準 softmax 注意力完全一樣,但永遠不會在內存里物化完整注意力矩陣或 softmax 概率。

      每個程序實例為每個查詢塊初始化這三個狀態一次:

      m_i = tl.zeros((BLOCK_SIZE_Q,), dtype=tl.float32) - inf
      l_i = tl.zeros((BLOCK_SIZE_Q,), dtype=tl.float32) + 1
      O_block =tl.zeros((BLOCK_SIZE_Q, HEAD_DIM), dtype=tl.float32)

      流式注意力內核 _attn_fwd_inner

      _attn_fwd_inner 實現 FlashAttention-2 算法核心,由 fwd_flash_attn_kernel 調用,一次處理一個查詢塊。

      @triton.jit
      def _attn_fwd_inner(O_block, l_i,m_i, Q_block, block_index_q,
      scale: tl.constexpr,
      BLOCK_SIZE_Q: tl.constexpr,
      BLOCK_SIZE_KV: tl.constexpr,
      STAGE: tl.constexpr,
      off_kv: tl.constexpr,
      off_q: tl.constexpr,
      off_head: tl.constexpr,
      kn_stride: tl.constexpr,
      kd_stride: tl.constexpr,
      vd_stride: tl.constexpr,
      vn_stride: tl.constexpr,
      k_ptr,
      v_ptr,
      qkv_offset_K: tl.constexpr,
      qkv_offset_V: tl.constexpr,
      SEQ_LEN:tl.constexpr,
      HEAD_DIM: tl.constexpr):

      其中 Q_block 形狀 [BLOCK_SIZE_Q, HEAD_DIM],O_block 是累積輸出,m_i 是每查詢行的運行最大值,l_i 是運行 softmax 歸一化。

      因果塊范圍選擇

      FA 內核支持因果(只看過去和當前 token)和非因果注意力(雙向,可以看未來)。用一個階段機制實現:

      if STAGE == 1:
      lo, hi = 0, block_index_q * BLOCK_SIZE_Q
      elif STAGE == 2:
      lo, hi = block_index_q * BLOCK_SIZE_Q, (block_index_q + 1) * BLOCK_SIZE_Q
      else:
      lo, hi = 0, SEQ_LEN

      這個邏輯決定當前內核處理哪些 K/V 塊。Stage 1 是對角線左側的塊,K 和 V 范圍僅限于此。Stage 2 是對角線塊本身。Stage 3 是非因果邏輯,K 和 V 關注所有 Q。這樣避免計算因果注意力中肯定會被 mask 掉的分數,減少不必要的 masking 工作。

      K 和 V 塊的流式循環

      查詢雖然分區到各程序實例,但每個查詢塊必須關注所有鍵值——這是全注意力的定義決定的。完整 K 和 V 矩陣從不一次性加載到 SRAM,而是以 BLOCK_SIZE_KV 大小的塊流式處理:

      for start_kv in range(lo, hi, BLOCK_SIZE_KV):

      加載 BLOCK_SIZE_KV 個鍵值,計算部分注意力分數,更新在線 softmax 狀態,丟棄該塊,處理下一個。內存復雜度維持 O(N)。

      每個程序實例只加載一個查詢塊,對應序列中一小部分 token。但這些 token 要正確計算注意力輸出,必須關注序列里所有鍵值。這是自注意力定義決定的:每個查詢都要和每個鍵比較。FlashAttention 沒改這個算法要求,只改計算調度方式。鍵值逐塊流式進來,累積到輸出,立刻丟棄,內存占用小,結果精確。一些新的注意力變體(局部注意力、稀疏注意力、滑動窗口注意力)不會關注所有 token。

      為 K 和 V 構建塊指針

      和 Q_block 一樣,計算當前塊的 token 索引:

      kv_positions = start_kv + off_kv

      然后構建指針:

      K_block_ptr = (
      k_ptr + qkv_offset_K
      + off_head[:, None] * kd_stride
      + kv_positions[None, :] * kn_stride
      )
      V_block_ptr = (
      v_ptr + qkv_offset_V
      + kv_positions[:, None] * vn_stride
      + off_head[None, :] * vd_stride
      )

      得到形狀 [HEAD_DIM, BLOCK_SIZE_KV] 的 K 和 V 指針。邊界 mask 邏輯防止最后一個塊越界訪問:

      mask_k = kv_positions[None, :] < SEQ_LEN
      mask_v = kv_positions[:, None] < SEQ_LEN

      從 HBM 加載 K 和 V 到片上 SRAM:

      K_block = tl.load(K_block_ptr, mask=mask_k, other=0.0)
      V_block = tl.load(V_block_ptr, mask=mask_v, other=0.0)

      部分分數計算和在線更新

      計算分塊點積:

      QK_block = tl.dot(Q_block, K_block)

      應用縮放和 mask(如果是因果的),更新運行最大值:

      mask = off_q[:, None] >= (start_kv + off_kv[None, :])
      QK_block = QK_block * scale + tl.where(mask, 0, -1e6)
      m_ij = tl.maximum(m_i, tl.max(QK_block, 1))
      QK_block -= m_ij[:, None]
      m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * scale)
      QK_block = QK_block * scale - m_ij[:, None]

      更新在線 softmax 狀態:

      P_block = exp(QK_block)
      l_ij = sum(P_block, axis=1)
      alpha = exp(m_i - m_ij)
      l_i = l_i * alpha + l_ij

      更新輸出累加器:

      O_block = O_block * alpha[:, None]
      O_block = dot(P_block, V_block, O_block)

      用當前迭代找到的新最大值更新運行最大值:

      m_i = m_ij

      更新后的狀態返回給外層內核 fwd_flash_attn_kernel。

      最終歸一化和寫回

      所有 K/V 塊處理完后,前向內核完成輸出:

      O_block = O_block / l_i[:, None]

      用累積的分母因子歸一化注意力輸出。當前查詢塊的注意力輸出就算完了。

      性能和基準測試

      前向傳播實現完畢并驗證后,可以看看性能和標準注意力實現比較一下。



      FlashAttention vs. standard attention vs torch2.2 (spda flashattn) TFLOP/s benchmarks



      所有序列長度上標準注意力在 3-4 TFLOPs/sec 左右就到頂了。理論計算量雖然按 O(N2) 增長,但標準注意力被 HBM 流量主導。GPU 大部分時間在搬運 N × N 注意力矩陣,不是在做有用計算。序列變長并不能提高計算單元利用率,只是內存壓力變大。

      Triton FlashAttention 內核則隨序列長度增加激進擴展。512 token 時性能一般,超過 2K token 后吞吐量快速上升。16K token 時維持在約 190 TFLOPs/sec。這正是 FlashAttention 設計要達到的效果:阻止注意力矩陣物化,中間數據駐留 SRAM,內存加載得以攤銷。序列越長,內核越趨向計算受限,GPU 接近有效峰值吞吐量——和標準注意力恰好相反,標準注意力序列越長越內存受限。

      第二張圖在 Nvidia A100 上通過 sdpa API 比較了 Triton FlashAttention 和 PyTorch 官方 FlashAttention 實現。序列較短時 PyTorch 實現有競爭力,序列長度 ≥4k 后,自定義 Triton 內核追平并略微超過 PyTorch 性能。16k token 時,兩者都收斂到約 180-190 TFLOPs/sec。

      所有結果在同一 GPU(Nvidia A100 SXM)相同條件下獲得。吞吐量以 TFLOPs/sec 報告,由縮放點積注意力的理論 FLOP 數除以實測內核運行時間得出。序列長度變化,batch 大小、頭數、頭維度固定。

      這些基準驗證了三件事:標準注意力從根本上內存受限;FlashAttention 把瓶頸從內存轉到計算;Triton 提供了足夠的數據移動和 GPU 內存底層控制,能達到接近最優性能。

      關鍵是性能增益隨序列長度增長。這正是 FlashAttention 在實踐中最重要的地方。

      總結

      現代 GPU 上性能由內存行為主導,不是 FLOPs;內核融合和 SRAM 駐留比數學技巧更重要;在線 softmax 是 IO 感知注意力的關鍵;Triton 暴露了足夠的硬件細節來寫可讀又快的內核;仔細分塊加自動調優,自定義內核能和廠商實現打平。

      FlashAttention 不是因為改了算法才更快,是因為它尊重 GPU 實際的工作方式。

      本文只實現了前向傳播。擴展到完整的訓練級 FlashAttention(反向傳播、dropout、各種 mask 變體)留待后續工作。

      本文源代碼:

      https://avoid.overfit.cn/post/0ae6fbc34b7f4c1788f6399a7a1fc431

      by Katherine Oluwadarasimi Olowookere

      特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。

      Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

      相關推薦
      熱點推薦
      爭奪霍爾木茲海峽,都拼了,但……

      爭奪霍爾木茲海峽,都拼了,但……

      新民周刊
      2026-03-05 09:10:56
      “俄羅斯向伊朗分享美軍坐標,又有核武大國進場”

      “俄羅斯向伊朗分享美軍坐標,又有核武大國進場”

      觀察者網
      2026-03-07 08:36:07
      中國向全世界披露:美國4400顆衛星,包圍中國空間站,這是要做啥

      中國向全世界披露:美國4400顆衛星,包圍中國空間站,這是要做啥

      丁丁鯉史紀
      2026-03-06 17:20:34
      馬克龍就伊朗局勢表態:法國不會在中東“打仗”

      馬克龍就伊朗局勢表態:法國不會在中東“打仗”

      參考消息
      2026-03-06 12:58:11
      44+9+5,三節填滿數據欄,湖人這波太輕松了

      44+9+5,三節填滿數據欄,湖人這波太輕松了

      體育新角度
      2026-03-07 16:56:44
      中國駐法國使館發言人就中方對日本出口管制措施答記者問

      中國駐法國使館發言人就中方對日本出口管制措施答記者問

      環球網資訊
      2026-03-07 06:58:05
      伊朗的第一個盟友,下場了!

      伊朗的第一個盟友,下場了!

      深度知局
      2026-03-06 23:02:41
      中國古代歷史上“最牛”的地方割據勢力,傳承29世,割據724年!

      中國古代歷史上“最牛”的地方割據勢力,傳承29世,割據724年!

      小豫講故事
      2026-03-07 06:00:06
      美參院決議川普打伊不必再請示,川普稱古巴是下一個,果真如此?

      美參院決議川普打伊不必再請示,川普稱古巴是下一個,果真如此?

      邵旭峰域
      2026-03-06 16:32:04
      比亞迪再扔王炸,DM6.0橫空出世,燃油車這次真要涼了?

      比亞迪再扔王炸,DM6.0橫空出世,燃油車這次真要涼了?

      老特有話說
      2026-03-06 16:03:17
      7天之后,臺灣怎么辦?

      7天之后,臺灣怎么辦?

      人生就是要簡單
      2026-03-07 07:41:23
      斯普利特:拼盡全力沒能贏比賽有點失望,克林根今晚攻框很棒

      斯普利特:拼盡全力沒能贏比賽有點失望,克林根今晚攻框很棒

      懂球帝
      2026-03-07 12:57:45
      為什么中國不下場支援伊朗?背后有哪些原因

      為什么中國不下場支援伊朗?背后有哪些原因

      楓冷慕詩
      2026-03-06 15:10:27
      曾經走紅,如今卻“淪為笑柄”的4種數碼產品,還是別再買了

      曾經走紅,如今卻“淪為笑柄”的4種數碼產品,還是別再買了

      美家指南
      2026-03-06 10:31:36
      冷知識:真的不建議大家買超大藍莓

      冷知識:真的不建議大家買超大藍莓

      大象新聞
      2026-03-05 20:15:04
      澤連斯基:美國與俄羅斯竟給出相同的勸降——想停戰就放棄頓巴斯

      澤連斯基:美國與俄羅斯竟給出相同的勸降——想停戰就放棄頓巴斯

      老馬拉車莫少裝
      2026-02-22 12:25:15
      國家發改委主任:新建、改擴建1000所普通高中,增加學位200萬個以上,支持雙一流高校本科擴招10萬人以上

      國家發改委主任:新建、改擴建1000所普通高中,增加學位200萬個以上,支持雙一流高校本科擴招10萬人以上

      極目新聞
      2026-03-06 18:28:25
      全國政協委員楊建德 : 建議將春節連續9天假期固定下來,順應民生期盼、保障休假權益、激發內需活力、疏解春運壓力

      全國政協委員楊建德 : 建議將春節連續9天假期固定下來,順應民生期盼、保障休假權益、激發內需活力、疏解春運壓力

      每日經濟新聞
      2026-03-07 16:12:40
      文旅部部長:7名外國游客到上海旅游,買了40箱貨;“成為中國人”成了熱詞

      文旅部部長:7名外國游客到上海旅游,買了40箱貨;“成為中國人”成了熱詞

      上觀新聞
      2026-03-07 12:47:05
      爆笑女友經典糗事笑話,去年五一放假帶女友回家由于是第一次來我家,飯桌上她不好意思放開量地吃!

      爆笑女友經典糗事笑話,去年五一放假帶女友回家由于是第一次來我家,飯桌上她不好意思放開量地吃!

      天天明星
      2026-03-06 15:05:05
      2026-03-07 17:35:00
      deephub incentive-icons
      deephub
      CV NLP和數據挖掘知識
      1940文章數 1456關注度
      往期回顧 全部

      科技要聞

      OpenClaw爆火,六位"養蝦人"自述與AI共生

      頭條要聞

      伊朗總統:絕不可能無條件投降 向鄰國表示歉意

      頭條要聞

      伊朗總統:絕不可能無條件投降 向鄰國表示歉意

      體育要聞

      塔圖姆298天走完這段路 只用27分鐘征服這座城

      娛樂要聞

      周杰倫田馥甄的“JH戀” 被扒得底朝天

      財經要聞

      針對"不敢休、不讓休"怪圈 國家出手了

      汽車要聞

      逃離ICU,上汽通用“止血”企穩

      態度原創

      房產
      旅游
      數碼
      健康
      軍事航空

      房產要聞

      傳統學區房熄火?2月海口二手房爆火的板塊竟然是…

      旅游要聞

      警報聲中的歸途:一個義烏老板娘的中東“驚魂”之旅

      數碼要聞

      AI存儲需求進一步增長,三星NAND閃存被曝Q2將繼續漲價

      轉頭就暈的耳石癥,能開車上班嗎?

      軍事要聞

      美第三個航母打擊群據稱準備部署至中東

      無障礙瀏覽 進入關懷版