<cite id="ffb66"></cite><cite id="ffb66"><track id="ffb66"></track></cite>
      <legend id="ffb66"><li id="ffb66"></li></legend>
      色婷婷久,激情色播,久久久无码专区,亚洲中文字幕av,国产成人A片,av无码免费,精品久久国产,99视频精品3
      網(wǎng)易首頁 > 網(wǎng)易號(hào) > 正文 申請(qǐng)入駐

      JAX 訓(xùn)練加速指南:8 個(gè)讓 TPU 滿跑的工程實(shí)戰(zhàn)習(xí)慣

      0
      分享至

      TPU 訓(xùn)練的真實(shí)效率往往取決于兩個(gè)核心要素:Shape 的穩(wěn)定性算子的融合度

      很多時(shí)候,JAX 任務(wù)之所以出現(xiàn)嚴(yán)重的性能瓶頸,并非算法本身設(shè)計(jì)有問題,而是忽視了 XLA 編譯器與底層硬件對(duì)“確定性”的極度偏好。基于大量實(shí)戰(zhàn)調(diào)優(yōu)經(jīng)驗(yàn),本文總結(jié)了八條能讓 JAX 訓(xùn)練任務(wù)從“甚至跑不通”蛻變?yōu)椤芭軡M TPU 算力”的工程經(jīng)驗(yàn)。



      1、盡早鎖定 Shape

      TPU 喜歡靜態(tài) Shape,JAX 也是,所以動(dòng)態(tài) Shape 是性能殺手,它會(huì)觸發(fā)重新編譯(Recompile)。一旦發(fā)生重編譯,Step time 和內(nèi)存占用都會(huì)直接炸裂。所以解決方法也很簡(jiǎn)單,選定幾個(gè)規(guī)范的尺寸,剩下的全填(Pad)滿。

      全局 Batch Size要能被 TPU 核心數(shù)整除,然后就是對(duì)于變長(zhǎng)序列,別指望它原本多長(zhǎng)就多長(zhǎng),把它 Pad 到幾個(gè)固定的“桶(Bucket)”里,比如 128、256 或 512,這步工作最好在輸入(Input Pipeline)里就做完。

      Python層面的條件判斷盡量別依賴 Shape,真要分支邏輯,就老老實(shí)實(shí)讓 lax.cond 或 lax.switch 來接管。

      # Example: bucketing & padding (conceptual)
      def pad_to_length(arr, L):
      pad = L - arr.shape[0]
      return jnp.pad(arr, ((0, pad), (0, 0)), mode='constant')
      bucket_sizes = [128, 256, 512]
      def bucket_len(n):
      return next(b for b in bucket_sizes if n <= b)
      def preprocess_batch(batch):
      L = bucket_len(batch["tokens"].shape[1])
      batch["tokens"] = pad_to_length(batch["tokens"], L)
      batch["mask"] = pad_to_length(batch["mask"], L)
      return batch

      每個(gè) Step 喂給 TPU 的 Shape 只要是固定的,XLA 編譯器就不會(huì)找麻煩。

      2、激活值默認(rèn)用 bfloat16,主權(quán)重要 FP32

      在 TPU 上bfloat16 (bf16) 是個(gè)好東西,兼顧了速度、內(nèi)存和數(shù)值穩(wěn)定性。

      工程上的常規(guī)操作是:激活(Activations)和梯度(Gradients)存成 bf16。但是,優(yōu)化器狀態(tài)里的權(quán)重必須保留一份FP32 的“主副本”,不然跑久了數(shù)值就會(huì)漂移。所欲需要在模型邊界做類型轉(zhuǎn)換(Cast)的時(shí)候小心點(diǎn)。

      class MLP(nn.Module):
      features: int
      @nn.compact
      def __call__(self, x):
      x = x.astype(jnp.bfloat16) # fast path on TPUs
      x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)
      x = nn.gelu(x)
      x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)
      return x
      # Optimizer state stays in FP32 (conceptual)
      params_fp32 = params.astype(jnp.float32)
      grads_bf16 = compute_grads_bf16(...)
      updates_fp32 = opt.update(grads_bf16.astype(jnp.float32), opt_state, params_fp32)

      3、pjit和命名網(wǎng)格:切分要明確,別靠猜

      JAX 在 TPU 上最強(qiáng)的一點(diǎn)就是通過 pjit 實(shí)現(xiàn)了GSPMD。你通過 PartitionSpecs 告訴它想要什么切分方式,XLA 負(fù)責(zé)搞定如何在設(shè)備間搬運(yùn)數(shù)據(jù)。

      在 TPU 核心上建個(gè)命名網(wǎng)格(Mesh)。做數(shù)據(jù)并行(Data Parallelism)時(shí),用 PartitionSpec('data', None) 這種模式。如果模型太大需要張量并行(Tensor Model Parallelism),就加個(gè) 'model' 軸。

      import numpy as np
      import jax
      import jax.numpy as jnp
      from jax.sharding import Mesh, PartitionSpec as P
      from jax.experimental import pjit
      devices = np.array(jax.devices()).reshape(1, -1) # 1 x N mesh
      mesh = Mesh(devices, ('data',))
      def loss_fn(params, batch):
      logits = model_apply(params, batch['x'])
      return cross_entropy(logits, batch['y'])
      @pjit.pjit(
      in_shardings=(P(None), P('data')), # params replicated, batch sharded on 'data'
      out_shardings=P(None), # scalar loss replicated
      )
      def step(params, batch):
      grads = jax.grad(loss_fn)(params, batch)
      # aggregate grads across cores
      grads = jax.tree.map(lambda g: jax.lax.pmean(g, axis_name='data'), grads)
      return grads
      with mesh:
      grads = step(params, sharded_batch)

      切分(Sharding)這事必須顯式。如果偷懶依賴自動(dòng)推導(dǎo),等到后期 debug 那些悄無聲息的跨設(shè)備數(shù)據(jù)傳輸時(shí),絕對(duì)會(huì)很痛苦。

      4、jit, vmap, scan 三件套

      TPU 喜歡大塊頭的 Kernel,討厭成千上萬個(gè)細(xì)碎的小算子。訓(xùn)練 Step 和任何中大型計(jì)算邏輯,必須用 jit 包起來。遇到 Python 循環(huán),如果是時(shí)間步邏輯就換成 lax.scan,如果是批次并行就用 vmap。

      把 Loss 計(jì)算、梯度計(jì)算和參數(shù)更新塞進(jìn)同一個(gè) jitted 函數(shù)里,這樣編譯器才有機(jī)會(huì)把它們?nèi)诤铣梢粋€(gè)大算子。

      import optax
      import jax
      optimizer = optax.adamw(3e-4)
      def loss_and_grads(params, batch):
      def _loss(p):
      logits = model_apply(p, batch['x'])
      return cross_entropy(logits, batch['y'])
      loss, grads = jax.value_and_grad(_loss)(params)
      return loss, grads
      @jax.jit
      def train_step(state, batch):
      loss, grads = loss_and_grads(state.params, batch)
      grads = jax.lax.pmean(grads, axis_name='data')
      updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)
      new_params = optax.apply_updates(state.params, updates)
      return state.replace(params=new_params, opt_state=new_opt_state), loss

      5、別讓輸入管道拖后腿

      Host 到 Device 的數(shù)據(jù)傳輸一旦停頓,吞吐量就掉下來了,所以永遠(yuǎn)別讓計(jì)算單元等數(shù)據(jù)。

      用 tf.data 或者高效的 NumPy loader 配合 prefetch。數(shù)據(jù)預(yù)取到設(shè)備(Stage to device) 最好做雙重緩沖。全局 Batch盡量大(當(dāng)然要能被核心數(shù)整除),數(shù)據(jù)增強(qiáng)這種臟活累活在 Host 上一次性做完。

      # tf.data pipeline (conceptual)
      ds = (tf.data.TFRecordDataset(files)
      .map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)
      .batch(global_batch_size, drop_remainder=True)
      .prefetch(tf.data.AUTOTUNE))
      # Convert to NumPy and prefetch onto devices
      from flax.jax_utils import prefetch_to_device
      it = prefetch_to_device(map(npify, ds.as_numpy_iterator()), size=2)
      with mesh:
      for step_i in range(num_steps):
      batch = next(it) # already sharded/prefetched
      state, loss = train_step(state, batch)

      6、PRNG要Fold 進(jìn) Step 和 Device ID

      JAX 的 PRNG 是無狀態(tài)的,這意味如果不小心,很容易在不同 Step 或者不同設(shè)備上用了一樣的隨機(jī)數(shù) Key。

      每個(gè) Step 都要 Split 一次絕對(duì)別復(fù)用。所以說為了保證獨(dú)立性必須把Global StepDevice IndexFold進(jìn)去。數(shù)據(jù)增強(qiáng)/Dropout 的 Key 和參數(shù)初始化的 Key 得分開管理。

      def make_step_rng(rng, step):
      step_key = jax.random.fold_in(rng, step)
      dev_key = jax.random.fold_in(step_key, jax.lax.axis_index('data'))
      return jax.random.split(dev_key, 1)[0]
      @jax.jit
      def train_step(state, batch, base_rng):
      rng = make_step_rng(base_rng, state.step)
      logits = model_apply(state.params, batch['x'], rngs={'dropout': rng})
      ...

      7、Remat,智能 Checkpoint,梯度累積

      TPU 內(nèi)存看著大,模型一跑起來就不夠用。深層網(wǎng)絡(luò)可以直接用 Activation Checkpointing(jax.checkpoint 或 nn.remat),用計(jì)算換顯存。想跑大 Batch 但顯存不夠,就用梯度累積(Gradient Accumulation) 把它切成小的 micro-step。

      存盤的時(shí)候,推薦用 Orbax 做異步、分片(Sharded)的 Checkpoint,穩(wěn)。

      from flax import linen as nn
      class DeepBlock(nn.Module):
      @nn.compact
      def __call__(self, x):
      # recompute on backward to trim activation memory
      f = nn.remat(lambda y: nn.gelu(nn.Dense(x.shape[-1])(y)))
      return f(x)
      # Gradient accumulation (conceptual)
      @jax.jit
      def accum_step(state, batch_slices):
      def body(carry, micro):
      state, grad_sum = carry
      _, grads = loss_and_grads(state.params, micro)
      return (state, jax.tree_util.tree_map(jnp.add, grad_sum, grads)), None
      init_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params)
      (state, grad_sum), _ = jax.lax.scan(body, (state, init_grads), batch_slices)
      grads = jax.tree_map(lambda g: g / len(batch_slices), grad_sum)
      ...

      8、一定要跑 Profiler

      把關(guān)鍵代碼段用 Profiler Annotations 包起來,看 Step Timeline。重點(diǎn)找 Host Waits、Recompiles 和那些沒融合好的細(xì)碎算子(Small op soup)。

      穩(wěn)態(tài)運(yùn)行的時(shí)候,盯著 Tokens/sec 或者Images/sec,還有硬件利用率。

      from jax.experimental import host_callback as hcb
      from jax import profiler
      def tagged(name, fn, *a, **k):
      profiler.annotate_function(name=name)
      return fn(*a, **k)
      @jax.jit
      def train_step(state, batch):
      profiler.annotate_function(name="train_step")
      # do work...
      return state, loss

      一定要在鎖定 Shape 并且 JIT 完熱點(diǎn)路徑之后再做 Profile,不然全是噪音,根本看不到真正的瓶頸。

      極簡(jiǎn) TPU 訓(xùn)練示例

      這基本包含了上面所有的內(nèi)容

      # Pseudo-skeleton (Flax + JAX + TPU)
      mesh = Mesh(np.array(jax.devices()).reshape(1, -1), ('data',))
      @pjit.pjit(in_shardings=(P(None), P('data'), P(None)), out_shardings=(P(None), P(None)))
      def train_step(state, batch, base_rng):
      rng = jax.random.fold_in(base_rng, state.step)
      rng = jax.random.fold_in(rng, jax.lax.axis_index('data'))
      def loss_fn(p):
      logits = model_apply(p, batch['x'].astype(jnp.bfloat16),
      rngs={'dropout': rng})
      return cross_entropy(logits, batch['y'])
      loss, grads = jax.value_and_grad(loss_fn)(state.params)
      grads = jax.tree_map(lambda g: jax.lax.pmean(g, 'data'), grads)
      updates, opt_state = optimizer.update(grads, state.opt_state, state.params)
      params = optax.apply_updates(state.params, updates)
      return state.replace(params=params, opt_state=opt_state, step=state.step+1), loss
      with mesh:
      for step_i, batch in enumerate(prefetched_iterator):
      state, loss = train_step(state, batch, base_rng)
      if step_i % log_every == 0:
      # Pull back just tiny scalars; keep big tensors on device
      host_loss = jax.device_get(loss)
      print(f"[{step_i}] loss={host_loss:.4f}")

      總結(jié)

      TPU 需要的是 一致性:穩(wěn)定的 Shape,融合的 Kernel,目的明確的切分,不掉鏈子的數(shù)據(jù)管道,把上面的這八件事做好,寫 JAX 訓(xùn)練循環(huán)就非常順暢了。

      https://avoid.overfit.cn/post/16b582a493ba4eca8333314859665dd2

      作者:Modexa

      特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。

      Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

      相關(guān)推薦
      熱點(diǎn)推薦
      歐爾班宣布反制措施:在我們耗盡石油之前,烏克蘭人將先耗盡資金

      歐爾班宣布反制措施:在我們耗盡石油之前,烏克蘭人將先耗盡資金

      陳恧侃故事
      2026-03-07 11:31:12
      現(xiàn)場(chǎng)直擊:伊朗防空系統(tǒng)攔截美以目標(biāo)

      現(xiàn)場(chǎng)直擊:伊朗防空系統(tǒng)攔截美以目標(biāo)

      新華社
      2026-03-06 10:54:01
      鄭爽分享美國近況,穿搭不輸當(dāng)紅明星,自曝做醫(yī)美網(wǎng)友直呼認(rèn)不出

      鄭爽分享美國近況,穿搭不輸當(dāng)紅明星,自曝做醫(yī)美網(wǎng)友直呼認(rèn)不出

      萌神木木
      2026-03-06 17:33:33
      河南女孩六年前為救父親性命,稱誰給40萬就嫁給誰,如今過得如何

      河南女孩六年前為救父親性命,稱誰給40萬就嫁給誰,如今過得如何

      牛鍋巴小釩
      2026-03-06 19:45:04
      致敬西虹市首富?切爾西眾人開球前將球圍在中間,解說員啞然失笑

      致敬西虹市首富?切爾西眾人開球前將球圍在中間,解說員啞然失笑

      懂球帝
      2026-03-07 13:08:08
      廣西女子發(fā)現(xiàn)罕見青竹鯉,時(shí)不時(shí)側(cè)身蹭水底,網(wǎng)友:魚生天花板!

      廣西女子發(fā)現(xiàn)罕見青竹鯉,時(shí)不時(shí)側(cè)身蹭水底,網(wǎng)友:魚生天花板!

      貍貓之一的動(dòng)物圈
      2026-03-06 09:38:48
      小學(xué)生實(shí)名投訴極氪 建議取消讓她寫作業(yè)的小桌板!極氪回應(yīng)

      小學(xué)生實(shí)名投訴極氪 建議取消讓她寫作業(yè)的小桌板!極氪回應(yīng)

      快科技
      2026-03-06 23:12:33
      王震堅(jiān)決反對(duì)中顧委副主任排名,薄一波:我是常務(wù),就這么定了

      王震堅(jiān)決反對(duì)中顧委副主任排名,薄一波:我是常務(wù),就這么定了

      芊芊子吟
      2026-03-06 09:45:07
      莫雷加德全家抵達(dá)重慶:對(duì)這座城市印象深刻,會(huì)請(qǐng)樊振東推薦美食

      莫雷加德全家抵達(dá)重慶:對(duì)這座城市印象深刻,會(huì)請(qǐng)樊振東推薦美食

      乒談
      2026-03-07 00:19:01
      中國女籃72-66再勝巴西,不是張子宇王思雨,她17+7成新核

      中國女籃72-66再勝巴西,不是張子宇王思雨,她17+7成新核

      林子說事
      2026-03-07 08:15:14
      速度滑冰世錦賽:寧忠?guī)r收獲短距離全能、男子1000米兩項(xiàng)季軍

      速度滑冰世錦賽:寧忠?guī)r收獲短距離全能、男子1000米兩項(xiàng)季軍

      懂球帝
      2026-03-07 07:21:57
      高市早苗被逼到絕路:派也死,不派也死

      高市早苗被逼到絕路:派也死,不派也死

      鯨探所長(zhǎng)
      2026-03-07 12:02:36
      霍爾木茲海峽船只遭襲4死3重傷!兩萬海員被困,伊朗稱不會(huì)關(guān)閉海峽,但與以美有關(guān)船只不得通行;普京與伊總統(tǒng)通話:通過多種渠道保持聯(lián)系

      霍爾木茲海峽船只遭襲4死3重傷!兩萬海員被困,伊朗稱不會(huì)關(guān)閉海峽,但與以美有關(guān)船只不得通行;普京與伊總統(tǒng)通話:通過多種渠道保持聯(lián)系

      大風(fēng)新聞
      2026-03-07 10:05:06
      霍震霆也沒想到,46歲的霍啟剛,會(huì)在兩會(huì)上憑一個(gè)舉動(dòng)給霍家長(zhǎng)臉

      霍震霆也沒想到,46歲的霍啟剛,會(huì)在兩會(huì)上憑一個(gè)舉動(dòng)給霍家長(zhǎng)臉

      攬星河的筆記
      2026-03-06 23:55:22
      結(jié)束了!整整27年生涯!曝冠軍主帥最后一舞

      結(jié)束了!整整27年生涯!曝冠軍主帥最后一舞

      籃球?qū)崙?zhàn)寶典
      2026-03-06 18:57:43
      晴好周末,出游安排起來 | 天氣早知道

      晴好周末,出游安排起來 | 天氣早知道

      上觀新聞
      2026-03-07 11:57:06
      針對(duì)“不敢休、不讓休”怪圈,國家出手了!

      針對(duì)“不敢休、不讓休”怪圈,國家出手了!

      國是直通車
      2026-03-07 09:12:15
      村里紅白事從不回,男子母親離世,鄰居等著看笑話,結(jié)果長(zhǎng)了見識(shí)

      村里紅白事從不回,男子母親離世,鄰居等著看笑話,結(jié)果長(zhǎng)了見識(shí)

      子芫伴你成長(zhǎng)
      2026-02-23 12:21:40
      重回國乒?塵埃落定,劉國梁發(fā)聲,崗位曝光,布局國乒男隊(duì)發(fā)展

      重回國乒?塵埃落定,劉國梁發(fā)聲,崗位曝光,布局國乒男隊(duì)發(fā)展

      卿子書
      2026-03-06 09:25:27
      比賽還沒開打,上海申花先迎來兩個(gè)壞消息,新賽季斬獲開門紅懸了

      比賽還沒開打,上海申花先迎來兩個(gè)壞消息,新賽季斬獲開門紅懸了

      零度眼看球
      2026-03-07 08:58:12
      2026-03-07 13:40:49
      deephub incentive-icons
      deephub
      CV NLP和數(shù)據(jù)挖掘知識(shí)
      1940文章數(shù) 1456關(guān)注度
      往期回顧 全部

      科技要聞

      OpenClaw爆火,六位"養(yǎng)蝦人"自述與AI共生

      頭條要聞

      特朗普突然放話"先解決伊朗后解決古巴" 梅西聽懵了

      頭條要聞

      特朗普突然放話"先解決伊朗后解決古巴" 梅西聽懵了

      體育要聞

      塔圖姆歸來:凱爾特人的春之綠

      娛樂要聞

      周杰倫田馥甄的“JH戀” 被扒得底朝天

      財(cái)經(jīng)要聞

      針對(duì)"不敢休、不讓休"怪圈 國家出手了

      汽車要聞

      逃離ICU,上汽通用“止血”企穩(wěn)

      態(tài)度原創(chuàng)

      本地
      親子
      藝術(shù)
      公開課
      軍事航空

      本地新聞

      食味印象|一口入魂!康樂烤肉串起千年絲路香

      親子要聞

      六個(gè)月寶寶查出散光,原因竟是父母長(zhǎng)期身旁玩手機(jī),媽媽懵了:我一直以為他閉著眼就沒事

      藝術(shù)要聞

      Mark Grantham | 城市街景

      公開課

      李玫瑾:為什么性格比能力更重要?

      軍事要聞

      伊朗:使用無人機(jī)擊中美軍"林肯"號(hào)航母

      無障礙瀏覽 進(jìn)入關(guān)懷版