<cite id="ffb66"></cite><cite id="ffb66"><track id="ffb66"></track></cite>
      <legend id="ffb66"><li id="ffb66"></li></legend>
      色婷婷久,激情色播,久久久无码专区,亚洲中文字幕av,国产成人A片,av无码免费,精品久久国产,99视频精品3
      網(wǎng)易首頁 > 網(wǎng)易號 > 正文 申請入駐

      JAX性能優(yōu)化實戰(zhàn):7個變換讓TPU/GPU吃滿算力

      0
      分享至

      JAX跑得快的技巧其實很簡單:通過組合變換讓XLA能看到大塊連續(xù)的計算,比如說批處理、融合、分片,讓每一步在單設(shè)備或多設(shè)備同步時都像一個干凈的kernel。

      我們今天就來總結(jié)7個能夠提高運行速度的JAX變換組合



      1、 jit 優(yōu)先,形狀穩(wěn)定

      jit對函數(shù)做一次追蹤后XLA負責融合算子,形狀穩(wěn)定、無副作用時,Python處理的開銷就被分攤掉,可以提高運行速度。

      形狀創(chuàng)建和靜態(tài)參數(shù)要么挪到step外部,要么顯式標記為static。donate_argnums能讓JAX復(fù)用緩沖區(qū),省掉不必要的內(nèi)存拷貝。step之間保持dtype和shape一致,trace結(jié)果才能被緩存下來。

      import jax, jax.numpy as jnp
      @jax.jit(donate_argnums=(0,))
      def sgd_step(params, batch, lr):
      x, y = batch
      def loss_fn(p):
      preds = model_apply(p, x) # pure function
      return jnp.mean((preds - y) ** 2)
      grads = jax.grad(loss_fn)(params)
      return jax.tree_map(lambda p, g: p - lr * g, params, grads)

      每個(shape, dtype, static-arg)組合只追蹤一次。頻繁retrace多半是輸入shape在變,或者Python邏輯泄漏進了計算圖。

      2、vmap替換Python循環(huán)

      vmap在leading axis上做向量化,XLA直接把batch融進kernel。for循環(huán)沒了設(shè)備launch就少了,內(nèi)存訪問也更連續(xù)。

      # per-example loss
      def example_loss(params, x, y):
      pred = model_apply(params, x)
      return jnp.mean((pred - y) ** 2)
      # batch it without writing loops
      batched_loss = jax.vmap(example_loss, in_axes=(None, 0, 0)) # params broadcasted

      嵌套vmap可以搞2D batch,比如time × batch,只要別超HBM容量。vmap適合做內(nèi)層微批處理,比如ensemble或MC sampling這類場景,外層維度留給分片。

      3、長循環(huán)的融合利器Scan

      RNN、展開解碼、迭代求解器,這些場景用scan比Python循環(huán)快。scan只編譯一次循環(huán)體跑在XLA的while-loop里,Python開銷基本為0,融合和內(nèi)存復(fù)用也更激進。

      from jax import lax
      def rnn_cell(carry, x):
      h = carry
      h = jnp.tanh(W_hh @ h + W_xh @ x + b)
      y = W_hy @ h
      return h, y # (carry, output)
      def rnn_forward(h0, xs):
      hT, ys = lax.scan(rnn_cell, h0, xs) # xs: [T, B, D]
      return hT, ys

      循環(huán)狀態(tài)用carry傳遞,body保持小而純凈,要注意保持形狀不要變,比如:序列模型、diffusion step循環(huán)、定點迭代、beam解碼(形狀穩(wěn)定時)都適用。

      4、remat可以用計算換內(nèi)存

      批次大了TPU/GPU的FLOP利用率往往更高。remat(也叫checkpoint)會丟掉部分中間激活,反向時重算這樣峰值顯存下來batch就能開的更大。

      from jax import remat
      def block(params, x):
      x = jax.nn.gelu(x @ params['w1'])
      x = x @ params['w2']
      return x
      fast_block = remat(block) # checkpointed
      @jax.jit
      def forward(params, x):
      for _ in range(6):
      x = x + fast_block(params, x)
      return x

      只包最重的子塊就行,比如attention加MLP那幾層。同時配合vmap或分片,全局batch能再往上拉。不過需要一些額外FLOPs,但如果換來1.3到2倍的batch increase,wall-clock往往更短。

      5、pmap單機多卡數(shù)據(jù)并行

      pmap把函數(shù)復(fù)制到單主機的多個設(shè)備上(8卡工作站、單節(jié)點8核TPU),梯度可以自動all-reduce,并且每設(shè)備只編譯一次。

      from jax import pmap, lax
      @pmap(axis_name='d')
      def train_step(params, batch, lr):
      x, y = batch # each device sees [local_B, ...]
      def loss_fn(p):
      pred = model_apply(p, x)
      loss = jnp.mean((pred - y) ** 2)
      return loss
      loss, grads = jax.value_and_grad(loss_fn)(params)
      loss = lax.pmean(loss, axis_name='d')
      grads = lax.pmean(grads, axis_name='d')
      params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
      return params, loss

      batch在leading axis分片,lax.pmean聚合loss和grads。單機場景下pmap簡單可靠。跨主機擴展或者想做張量級細粒度分片可以成換pjit。

      6、pjit+ 命名分片:SPMD并行

      pjit編譯出單一SPMD程序可以跨設(shè)備跨主機運行。用mesh和PartitionSpec描述數(shù)組怎么切,JAX處理collective通信,這樣數(shù)據(jù)并行、張量并行、混合并行都能做。

      import jax
      from jax.sharding import Mesh, PartitionSpec as P
      import numpy as np
      devices = np.array(jax.devices()).reshape(2, 4) # 2 × 4 mesh (dp × mp)
      mesh = Mesh(devices, ('dp', 'mp'))
      @jax.jit # jit is optional when using pjit; shown when composing
      def model_apply_sharded(params, x):
      return model_apply(params, x)
      from jax.experimental.pjit import pjit
      with mesh:
      in_shard = (P('mp',), P('dp',)) # example; tailor to your shapes
      out_shard = P('dp',) # e.g., shard batch across dp
      step = pjit(model_apply_sharded,
      in_shardings=(P('mp',), P('dp',)),
      out_shardings=out_shard)
      y = step(params_sharded, x_sharded)

      一般都是batch軸走dp,大矩陣維度(hidden size、heads)走mp。分片數(shù)需要跟設(shè)備拓撲對齊,跨主機流量才少。

      7、value_and_grad的正確堆疊方式

      規(guī)范寫法是jit(value_and_grad(loss, has_aux=True)),外面可以再套一層pmap或pjit。這樣forward只跑一遍metrics留在aux里帶出來。

      def loss_with_aux(params, batch):
      x, y = batch
      pred = model_apply(params, x)
      loss = jnp.mean((pred - y) ** 2)
      aux = {'mse': loss, 'mean_pred': jnp.mean(pred)}
      return loss, aux
      @jax.jit
      def train_step(params, opt_state, batch, lr):
      (loss, aux), grads = jax.value_and_grad(loss_with_aux, has_aux=True)(params, batch)
      updates, opt_state = optimizer_update(grads, opt_state, params, lr)
      params = optax_apply(updates, params)
      return params, opt_state, loss, aux

      value_and_grad放jit里面,JAX會把forward和backward一起stage。返回(loss, aux)日志指標不用再跑一遍forward。

      這套組合很靈活:vmap做微批次,scan跑時序循環(huán),外面套pmap或pjit,donate_argnums標上buffer。

      總結(jié)

      變長序列pad加mask,shape穩(wěn)定是前提條件。traced代碼里不要添加Python隨機性,比如PRNG key要在外面split好。矩陣乘用bfloat16,這樣數(shù)值穩(wěn)定性也夠用,吞吐量在TPU/GPU上表現(xiàn)的也很好。性能profile要重點看warm-up之后的tokens/sec或samples/sec。日志只看標量aux metrics就行,每step把大數(shù)組傳回host是性能殺手。

      JAX的性能不是黑盒:jit + shape可以穩(wěn)定打底,vmap做batch,scan融合循環(huán),remat回收顯存,pmap或pjit做擴展,value_and_grad(..., has_aux=True)讓每一步只跑一次forward一次backward。

      https://avoid.overfit.cn/post/84e4e28e3ca8473488a0e9248d1ec51b

      作者:Nexumo

      特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺“網(wǎng)易號”用戶上傳并發(fā)布,本平臺僅提供信息存儲服務(wù)。

      Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

      相關(guān)推薦
      熱點推薦
      伊朗最大“內(nèi)鬼”被抓?革命衛(wèi)隊:勾結(jié)以色列,指揮官卡尼被拘!

      伊朗最大“內(nèi)鬼”被抓?革命衛(wèi)隊:勾結(jié)以色列,指揮官卡尼被拘!

      青青子衿
      2026-03-05 11:57:03
      打瘋了!東契奇首節(jié)狂轟22+5三分 生涯30次單節(jié)20+升歷史第四

      打瘋了!東契奇首節(jié)狂轟22+5三分 生涯30次單節(jié)20+升歷史第四

      醉臥浮生
      2026-03-07 12:13:33
      伊拉克庫爾德第一夫人宣言:我們不是任人驅(qū)使的炮灰!

      伊拉克庫爾德第一夫人宣言:我們不是任人驅(qū)使的炮灰!

      勝研集
      2026-03-06 13:44:23
      廣東一女子不愿上班常年坐街邊,因長得好看被路人投喂:又懶又饞

      廣東一女子不愿上班常年坐街邊,因長得好看被路人投喂:又懶又饞

      明智家庭教育
      2026-03-06 17:19:16
      美以伊軍事沖突最大副作用,是斬斷了俄羅斯的“救命稻草”

      美以伊軍事沖突最大副作用,是斬斷了俄羅斯的“救命稻草”

      廖保平
      2026-03-05 12:08:52
      “不想為以色列賣命”:帝國最后的遮羞布,美式民主終成笑話

      “不想為以色列賣命”:帝國最后的遮羞布,美式民主終成笑話

      怪口歷史的K先生
      2026-03-06 15:22:51
      為何關(guān)閉霍爾木茲海峽就能掐全球脖子?因為伊朗原油是全世界最好的

      為何關(guān)閉霍爾木茲海峽就能掐全球脖子?因為伊朗原油是全世界最好的

      風向觀察
      2026-03-06 21:31:15
      兩會不到3天,5大好消息傳來!老百姓暗暗叫好:希望國家盡快落實

      兩會不到3天,5大好消息傳來!老百姓暗暗叫好:希望國家盡快落實

      談史論天地
      2026-03-07 06:54:29
      1979年,張國燾凍死在養(yǎng)老院,許世友:除了主席,沒人是他的對手

      1979年,張國燾凍死在養(yǎng)老院,許世友:除了主席,沒人是他的對手

      文史季季紅
      2026-03-05 13:35:03
      寫入教科書的一天:F-35在德黑蘭完成全球首次實戰(zhàn)空對空擊殺

      寫入教科書的一天:F-35在德黑蘭完成全球首次實戰(zhàn)空對空擊殺

      斌聞天下
      2026-03-06 07:30:03
      伊方:因美以襲擊喪生的伊朗人三成為青少年

      伊方:因美以襲擊喪生的伊朗人三成為青少年

      環(huán)球網(wǎng)資訊
      2026-03-07 06:39:29
      為什么美國的華人華裔地位那么低 網(wǎng)友從各方面分析 真就那樣

      為什么美國的華人華裔地位那么低 網(wǎng)友從各方面分析 真就那樣

      侃神評故事
      2026-03-06 07:10:03
      我包養(yǎng)過一個女大學生,七年花了一千多萬

      我包養(yǎng)過一個女大學生,七年花了一千多萬

      煙火人間故事匯
      2026-03-06 23:05:03
      性壓抑已經(jīng)變態(tài)至此了?

      性壓抑已經(jīng)變態(tài)至此了?

      黯泉
      2026-03-07 11:28:43
      蘿莉島,是進入核心圈層的投名狀,你猜他們?yōu)槭裁炊即┘t皮鞋

      蘿莉島,是進入核心圈層的投名狀,你猜他們?yōu)槭裁炊即┘t皮鞋

      百曉生談歷史
      2026-03-05 22:00:08
      一份“煮熟的三文魚”火了,原來低認知的家長,真能搞出人命!

      一份“煮熟的三文魚”火了,原來低認知的家長,真能搞出人命!

      妍妍教育日記
      2026-03-07 08:45:06
      伊朗萬萬沒想到,自家王牌武器遭到破解,美軍多了一張底牌

      伊朗萬萬沒想到,自家王牌武器遭到破解,美軍多了一張底牌

      空天力量
      2026-03-06 13:09:18
      上次被發(fā)現(xiàn)還是1911年!上海寶山驚現(xiàn)1只,專家:可能是坐船來的

      上次被發(fā)現(xiàn)還是1911年!上海寶山驚現(xiàn)1只,專家:可能是坐船來的

      萬象硬核本尊
      2026-03-06 23:54:22
      女子實名舉報某團外賣:不上大額券就讓我變成“凌晨營業(yè)”,你們真黑!

      女子實名舉報某團外賣:不上大額券就讓我變成“凌晨營業(yè)”,你們真黑!

      回旋鏢
      2026-03-06 21:13:59
      塔圖姆復(fù)出15分12板7助攻凱爾特人大勝獨行俠,布朗24分7板7助

      塔圖姆復(fù)出15分12板7助攻凱爾特人大勝獨行俠,布朗24分7板7助

      湖人崛起
      2026-03-07 10:25:09
      2026-03-07 13:43:00
      deephub incentive-icons
      deephub
      CV NLP和數(shù)據(jù)挖掘知識
      1940文章數(shù) 1456關(guān)注度
      往期回顧 全部

      科技要聞

      OpenClaw爆火,六位"養(yǎng)蝦人"自述與AI共生

      頭條要聞

      特朗普突然放話"先解決伊朗后解決古巴" 梅西聽懵了

      頭條要聞

      特朗普突然放話"先解決伊朗后解決古巴" 梅西聽懵了

      體育要聞

      塔圖姆歸來:凱爾特人的春之綠

      娛樂要聞

      周杰倫田馥甄的“JH戀” 被扒得底朝天

      財經(jīng)要聞

      針對"不敢休、不讓休"怪圈 國家出手了

      汽車要聞

      逃離ICU,上汽通用“止血”企穩(wěn)

      態(tài)度原創(chuàng)

      時尚
      教育
      本地
      健康
      公開課

      這些才是適合普通人的穿搭!搭配腰帶、多穿牛仔褲,簡單舒適

      教育要聞

      兩會速遞|教育部部長:將實施新一輪學生心理健康促進行動

      本地新聞

      食味印象|一口入魂!康樂烤肉串起千年絲路香

      轉(zhuǎn)頭就暈的耳石癥,能開車上班嗎?

      公開課

      李玫瑾:為什么性格比能力更重要?

      無障礙瀏覽 進入關(guān)懷版