TPU 訓(xùn)練的真實(shí)效率往往取決于兩個(gè)核心要素:Shape 的穩(wěn)定性與算子的融合度。
很多時(shí)候,JAX 任務(wù)之所以出現(xiàn)嚴(yán)重的性能瓶頸,并非算法本身設(shè)計(jì)有問題,而是忽視了 XLA 編譯器與底層硬件對(duì)“確定性”的極度偏好。基于大量實(shí)戰(zhàn)調(diào)優(yōu)經(jīng)驗(yàn),本文總結(jié)了八條能讓 JAX 訓(xùn)練任務(wù)從“甚至跑不通”蛻變?yōu)椤芭軡M TPU 算力”的工程經(jīng)驗(yàn)。
![]()
1、盡早鎖定 Shape
TPU 喜歡靜態(tài) Shape,JAX 也是,所以動(dòng)態(tài) Shape 是性能殺手,它會(huì)觸發(fā)重新編譯(Recompile)。一旦發(fā)生重編譯,Step time 和內(nèi)存占用都會(huì)直接炸裂。所以解決方法也很簡(jiǎn)單,選定幾個(gè)規(guī)范的尺寸,剩下的全填(Pad)滿。
全局 Batch Size要能被 TPU 核心數(shù)整除,然后就是對(duì)于變長(zhǎng)序列,別指望它原本多長(zhǎng)就多長(zhǎng),把它 Pad 到幾個(gè)固定的“桶(Bucket)”里,比如 128、256 或 512,這步工作最好在輸入(Input Pipeline)里就做完。
Python層面的條件判斷盡量別依賴 Shape,真要分支邏輯,就老老實(shí)實(shí)讓 lax.cond 或 lax.switch 來接管。
# Example: bucketing & padding (conceptual)
def pad_to_length(arr, L):
pad = L - arr.shape[0]
return jnp.pad(arr, ((0, pad), (0, 0)), mode='constant')
bucket_sizes = [128, 256, 512]
def bucket_len(n):
return next(b for b in bucket_sizes if n <= b)
def preprocess_batch(batch):
L = bucket_len(batch["tokens"].shape[1])
batch["tokens"] = pad_to_length(batch["tokens"], L)
batch["mask"] = pad_to_length(batch["mask"], L)
return batch
每個(gè) Step 喂給 TPU 的 Shape 只要是固定的,XLA 編譯器就不會(huì)找麻煩。
2、激活值默認(rèn)用 bfloat16,主權(quán)重要 FP32
在 TPU 上bfloat16 (bf16) 是個(gè)好東西,兼顧了速度、內(nèi)存和數(shù)值穩(wěn)定性。
工程上的常規(guī)操作是:激活(Activations)和梯度(Gradients)存成 bf16。但是,優(yōu)化器狀態(tài)里的權(quán)重必須保留一份FP32 的“主副本”,不然跑久了數(shù)值就會(huì)漂移。所欲需要在模型邊界做類型轉(zhuǎn)換(Cast)的時(shí)候小心點(diǎn)。
class MLP(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = x.astype(jnp.bfloat16) # fast path on TPUs
x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)
x = nn.gelu(x)
x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)
return x
# Optimizer state stays in FP32 (conceptual)
params_fp32 = params.astype(jnp.float32)
grads_bf16 = compute_grads_bf16(...)
updates_fp32 = opt.update(grads_bf16.astype(jnp.float32), opt_state, params_fp32)
3、pjit和命名網(wǎng)格:切分要明確,別靠猜
JAX 在 TPU 上最強(qiáng)的一點(diǎn)就是通過 pjit 實(shí)現(xiàn)了GSPMD。你通過 PartitionSpecs 告訴它想要什么切分方式,XLA 負(fù)責(zé)搞定如何在設(shè)備間搬運(yùn)數(shù)據(jù)。
在 TPU 核心上建個(gè)命名網(wǎng)格(Mesh)。做數(shù)據(jù)并行(Data Parallelism)時(shí),用 PartitionSpec('data', None) 這種模式。如果模型太大需要張量并行(Tensor Model Parallelism),就加個(gè) 'model' 軸。
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import pjit
devices = np.array(jax.devices()).reshape(1, -1) # 1 x N mesh
mesh = Mesh(devices, ('data',))
def loss_fn(params, batch):
logits = model_apply(params, batch['x'])
return cross_entropy(logits, batch['y'])
@pjit.pjit(
in_shardings=(P(None), P('data')), # params replicated, batch sharded on 'data'
out_shardings=P(None), # scalar loss replicated
)
def step(params, batch):
grads = jax.grad(loss_fn)(params, batch)
# aggregate grads across cores
grads = jax.tree.map(lambda g: jax.lax.pmean(g, axis_name='data'), grads)
return grads
with mesh:
grads = step(params, sharded_batch)
切分(Sharding)這事必須顯式。如果偷懶依賴自動(dòng)推導(dǎo),等到后期 debug 那些悄無聲息的跨設(shè)備數(shù)據(jù)傳輸時(shí),絕對(duì)會(huì)很痛苦。
4、jit, vmap, scan 三件套
TPU 喜歡大塊頭的 Kernel,討厭成千上萬個(gè)細(xì)碎的小算子。訓(xùn)練 Step 和任何中大型計(jì)算邏輯,必須用 jit 包起來。遇到 Python 循環(huán),如果是時(shí)間步邏輯就換成 lax.scan,如果是批次并行就用 vmap。
把 Loss 計(jì)算、梯度計(jì)算和參數(shù)更新塞進(jìn)同一個(gè) jitted 函數(shù)里,這樣編譯器才有機(jī)會(huì)把它們?nèi)诤铣梢粋€(gè)大算子。
import optax
import jax
optimizer = optax.adamw(3e-4)
def loss_and_grads(params, batch):
def _loss(p):
logits = model_apply(p, batch['x'])
return cross_entropy(logits, batch['y'])
loss, grads = jax.value_and_grad(_loss)(params)
return loss, grads
@jax.jit
def train_step(state, batch):
loss, grads = loss_and_grads(state.params, batch)
grads = jax.lax.pmean(grads, axis_name='data')
updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
return state.replace(params=new_params, opt_state=new_opt_state), loss
5、別讓輸入管道拖后腿
Host 到 Device 的數(shù)據(jù)傳輸一旦停頓,吞吐量就掉下來了,所以永遠(yuǎn)別讓計(jì)算單元等數(shù)據(jù)。
用 tf.data 或者高效的 NumPy loader 配合 prefetch。數(shù)據(jù)預(yù)取到設(shè)備(Stage to device) 最好做雙重緩沖。全局 Batch盡量大(當(dāng)然要能被核心數(shù)整除),數(shù)據(jù)增強(qiáng)這種臟活累活在 Host 上一次性做完。
# tf.data pipeline (conceptual)
ds = (tf.data.TFRecordDataset(files)
.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)
.batch(global_batch_size, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
# Convert to NumPy and prefetch onto devices
from flax.jax_utils import prefetch_to_device
it = prefetch_to_device(map(npify, ds.as_numpy_iterator()), size=2)
with mesh:
for step_i in range(num_steps):
batch = next(it) # already sharded/prefetched
state, loss = train_step(state, batch)
6、PRNG要Fold 進(jìn) Step 和 Device ID
JAX 的 PRNG 是無狀態(tài)的,這意味如果不小心,很容易在不同 Step 或者不同設(shè)備上用了一樣的隨機(jī)數(shù) Key。
每個(gè) Step 都要 Split 一次絕對(duì)別復(fù)用。所以說為了保證獨(dú)立性必須把Global Step和Device Index都Fold進(jìn)去。數(shù)據(jù)增強(qiáng)/Dropout 的 Key 和參數(shù)初始化的 Key 得分開管理。
def make_step_rng(rng, step):
step_key = jax.random.fold_in(rng, step)
dev_key = jax.random.fold_in(step_key, jax.lax.axis_index('data'))
return jax.random.split(dev_key, 1)[0]
@jax.jit
def train_step(state, batch, base_rng):
rng = make_step_rng(base_rng, state.step)
logits = model_apply(state.params, batch['x'], rngs={'dropout': rng})
...
7、Remat,智能 Checkpoint,梯度累積
TPU 內(nèi)存看著大,模型一跑起來就不夠用。深層網(wǎng)絡(luò)可以直接用 Activation Checkpointing(jax.checkpoint 或 nn.remat),用計(jì)算換顯存。想跑大 Batch 但顯存不夠,就用梯度累積(Gradient Accumulation) 把它切成小的 micro-step。
存盤的時(shí)候,推薦用 Orbax 做異步、分片(Sharded)的 Checkpoint,穩(wěn)。
from flax import linen as nn
class DeepBlock(nn.Module):
@nn.compact
def __call__(self, x):
# recompute on backward to trim activation memory
f = nn.remat(lambda y: nn.gelu(nn.Dense(x.shape[-1])(y)))
return f(x)
# Gradient accumulation (conceptual)
@jax.jit
def accum_step(state, batch_slices):
def body(carry, micro):
state, grad_sum = carry
_, grads = loss_and_grads(state.params, micro)
return (state, jax.tree_util.tree_map(jnp.add, grad_sum, grads)), None
init_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params)
(state, grad_sum), _ = jax.lax.scan(body, (state, init_grads), batch_slices)
grads = jax.tree_map(lambda g: g / len(batch_slices), grad_sum)
...
8、一定要跑 Profiler
把關(guān)鍵代碼段用 Profiler Annotations 包起來,看 Step Timeline。重點(diǎn)找 Host Waits、Recompiles 和那些沒融合好的細(xì)碎算子(Small op soup)。
穩(wěn)態(tài)運(yùn)行的時(shí)候,盯著 Tokens/sec 或者Images/sec,還有硬件利用率。
from jax.experimental import host_callback as hcb
from jax import profiler
def tagged(name, fn, *a, **k):
profiler.annotate_function(name=name)
return fn(*a, **k)
@jax.jit
def train_step(state, batch):
profiler.annotate_function(name="train_step")
# do work...
return state, loss
一定要在鎖定 Shape 并且 JIT 完熱點(diǎn)路徑之后再做 Profile,不然全是噪音,根本看不到真正的瓶頸。
極簡(jiǎn) TPU 訓(xùn)練示例
這基本包含了上面所有的內(nèi)容
# Pseudo-skeleton (Flax + JAX + TPU)
mesh = Mesh(np.array(jax.devices()).reshape(1, -1), ('data',))
@pjit.pjit(in_shardings=(P(None), P('data'), P(None)), out_shardings=(P(None), P(None)))
def train_step(state, batch, base_rng):
rng = jax.random.fold_in(base_rng, state.step)
rng = jax.random.fold_in(rng, jax.lax.axis_index('data'))
def loss_fn(p):
logits = model_apply(p, batch['x'].astype(jnp.bfloat16),
rngs={'dropout': rng})
return cross_entropy(logits, batch['y'])
loss, grads = jax.value_and_grad(loss_fn)(state.params)
grads = jax.tree_map(lambda g: jax.lax.pmean(g, 'data'), grads)
updates, opt_state = optimizer.update(grads, state.opt_state, state.params)
params = optax.apply_updates(state.params, updates)
return state.replace(params=params, opt_state=opt_state, step=state.step+1), loss
with mesh:
for step_i, batch in enumerate(prefetched_iterator):
state, loss = train_step(state, batch, base_rng)
if step_i % log_every == 0:
# Pull back just tiny scalars; keep big tensors on device
host_loss = jax.device_get(loss)
print(f"[{step_i}] loss={host_loss:.4f}")
總結(jié)
TPU 需要的是 一致性:穩(wěn)定的 Shape,融合的 Kernel,目的明確的切分,不掉鏈子的數(shù)據(jù)管道,把上面的這八件事做好,寫 JAX 訓(xùn)練循環(huán)就非常順暢了。
https://avoid.overfit.cn/post/16b582a493ba4eca8333314859665dd2
作者:Modexa
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.