這篇文章從頭實現 LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures。需要說明的是,這里寫的是一個簡潔的最小化訓練腳本,目標是了解 JEPA 的本質:對同一文本創建兩個視圖,預測被遮蔽片段的嵌入,用表示對齊損失來訓練。
本文的目標是讓你真正理解這套方法。代碼會逐行講解,每個函數的用途都會解釋清楚,并和論文的核心直覺對應起來。每個代碼塊都會詳細說明,方便你根據自己的實驗需求進行修改。
![]()
代碼
整個 LLM-JEPA 訓練腳本放在一個文件里:
它接收原始文本然后創建兩個視圖:context 視圖把某些片段替換成 [MASK],target 視圖保留原始文本但只在被遮蔽位置做監督。Context 編碼器是可訓練的,負責預測 target 編碼器在遮蔽位置的表示。Target 編碼器則是 context 編碼器的 EMA 副本,不參與梯度計算。損失函數用的是預測嵌入和目標嵌入之間的余弦距離。
運行示例:
# 小型冒煙測試(無需下載,隨機初始化)
python llm_jepa_train.py --smoke_test
# 使用 HF 模型骨干訓練
python llm_jepa_train.py --model_name distilbert-base-uncased --steps 200 --batch_size 8
# 在自己的文本文件上訓練
python llm_jepa_train.py --model_name distilbert-base-uncased --text_file data.txt --steps 2000
這是一個簡潔的參考實現,不是完整的倉庫代碼。編碼器用的是 Transformers 庫。
import argparse
import math
import os
import random
from dataclasses import dataclass
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
try:
from transformers import AutoTokenizer, AutoModel, AutoConfig
except Exception:
AutoTokenizer = None
AutoModel = None
AutoConfig = None
# -----------------------------
# Utilities
# -----------------------------
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def pick_device(device_str: str) -> torch.device:
if device_str == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(device_str)
# -----------------------------
# Span masking (simple + effective)
# -----------------------------
def sample_span_mask(
seq_len: int,
mask_ratio: float,
mean_span_len: int,
special_positions: Optional[set] = None,
) -> torch.BoolTensor:
"""
Returns a boolean mask of length seq_len indicating which positions are masked.
We mask contiguous spans until we reach approximately mask_ratio of tokens.
"""
if special_positions is None:
special_positions = set()
mask = torch.zeros(seq_len, dtype=torch.bool)
if seq_len <= 0:
return mask
target_to_mask = max(1, int(round(seq_len * mask_ratio)))
masked = 0
attempts = 0
max_attempts = seq_len * 4
while masked < target_to_mask and attempts < max_attempts:
attempts += 1
span_len = max(1, int(random.expovariate(1.0 / max(1, mean_span_len))))
span_len = min(span_len, seq_len)
start = random.randint(0, seq_len - 1)
end = min(seq_len, start + span_len)
span_positions = [i for i in range(start, end) if i not in special_positions]
if not span_positions:
continue
newly = 0
for i in span_positions:
if not mask[i]:
mask[i] = True
newly += 1
masked += newly
return mask
def apply_mask_to_input_ids(
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
tokenizer,
mask_ratio: float,
mean_span_len: int,
) -> Tuple[torch.LongTensor, torch.BoolTensor]:
"""
Masks spans inside non-special, non-padding tokens.
Returns:
masked_input_ids: input ids with masked tokens replaced by [MASK]
pred_mask: boolean mask over positions where we apply JEPA loss
"""
assert input_ids.dim() == 1
seq_len = int(attention_mask.sum().item())
# Identify special token positions (CLS, SEP, etc.) in the visible region
special_positions = set()
for i in range(seq_len):
tid = int(input_ids[i].item())
if tid in {
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.pad_token_id,
}:
special_positions.add(i)
pred_mask = sample_span_mask(
seq_len=seq_len,
mask_ratio=mask_ratio,
mean_span_len=mean_span_len,
special_positions=special_positions,
)
masked_input_ids = input_ids.clone()
mask_token_id = tokenizer.mask_token_id
if mask_token_id is None:
raise ValueError("Tokenizer has no mask_token_id. Use a model with [MASK].")
# Replace masked positions with [MASK]
masked_input_ids[:seq_len][pred_mask] = mask_token_id
# pred_mask should be full length (includes pads as False)
full_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
full_mask[:seq_len] = pred_mask
return masked_input_ids, full_mask
# -----------------------------
# Dataset
# -----------------------------
class TextLinesDataset(Dataset):
def __init__(self, texts: List[str]):
self.texts = [t.strip() for t in texts if t.strip()]
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> str:
return self.texts[idx]
def load_texts_from_file(path: str, max_lines: Optional[int] = None) -> List[str]:
texts = []
with open(path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
if max_lines is not None and i >= max_lines:
break
texts.append(line.rstrip("\n"))
return texts
def default_tiny_corpus() -> List[str]:
return [
"The cat sat on the mat and looked at the window.",
"A quick brown fox jumps over the lazy dog.",
"Deep learning models can learn useful representations from raw data.",
"Rocket Learning builds AI tools for education in India.",
"Transformers use attention to mix information across tokens.",
"Self-supervised learning can reduce the need for labels.",
"JEPA trains models to predict embeddings, not tokens.",
"Bengaluru is a major tech hub in India.",
"A good system design balances simplicity and scalability.",
"Reading code carefully helps you understand how an idea is implemented.",
]
@dataclass
class Batch:
input_ids: torch.LongTensor # [B, L]
attention_mask: torch.LongTensor # [B, L]
masked_input_ids: torch.LongTensor # [B, L]
pred_mask: torch.BoolTensor # [B, L] positions to compute loss on
def collate_jepa(
batch_texts: List[str],
tokenizer,
max_length: int,
mask_ratio: float,
mean_span_len: int,
) -> Batch:
toks = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
input_ids = toks["input_ids"] # [B, L]
attention_mask = toks["attention_mask"] # [B, L]
masked_input_ids_list = []
pred_mask_list = []
for b in range(input_ids.size(0)):
mi, pm = apply_mask_to_input_ids(
input_ids[b],
attention_mask[b],
tokenizer,
mask_ratio=mask_ratio,
mean_span_len=mean_span_len,
)
masked_input_ids_list.append(mi)
pred_mask_list.append(pm)
masked_input_ids = torch.stack(masked_input_ids_list, dim=0)
pred_mask = torch.stack(pred_mask_list, dim=0)
return Batch(
input_ids=input_ids,
attention_mask=attention_mask,
masked_input_ids=masked_input_ids,
pred_mask=pred_mask,
)
# -----------------------------
# Model: Encoder + Predictor + EMA target encoder
# -----------------------------
class PredictorMLP(nn.Module):
def __init__(self, dim: int, hidden_mult: int = 4, dropout: float = 0.0):
super().__init__()
hidden = dim * hidden_mult
self.net = nn.Sequential(
nn.Linear(dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class LLMJEPA(nn.Module):
def __init__(self, encoder: nn.Module, dim: int, ema_m: float = 0.99, pred_hidden_mult: int = 4):
super().__init__()
self.context_encoder = encoder
self.target_encoder = self._copy_encoder(encoder)
self.predictor = PredictorMLP(dim=dim, hidden_mult=pred_hidden_mult, dropout=0.0)
self.ema_m = ema_m
for p in self.target_encoder.parameters():
p.requires_grad = False
@staticmethod
def _copy_encoder(enc: nn.Module) -> nn.Module:
import copy
return copy.deepcopy(enc)
@torch.no_grad()
def ema_update(self):
m = self.ema_m
for p_ctx, p_tgt in zip(self.context_encoder.parameters(), self.target_encoder.parameters()):
p_tgt.data.mul_(m).add_(p_ctx.data, alpha=(1.0 - m))
def forward(
self,
masked_input_ids: torch.LongTensor,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
pred_mask: torch.BoolTensor,
) -> torch.Tensor:
"""
Returns JEPA loss (scalar).
We compute:
z_ctx = context_encoder(masked_input)
z_tgt = target_encoder(full input)
pred = predictor(z_ctx)
loss over positions in pred_mask
"""
out_ctx = self.context_encoder(input_ids=masked_input_ids, attention_mask=attention_mask)
z_ctx = out_ctx.last_hidden_state # [B, L, D]
with torch.no_grad():
out_tgt = self.target_encoder(input_ids=input_ids, attention_mask=attention_mask)
z_tgt = out_tgt.last_hidden_state # [B, L, D]
pred = self.predictor(z_ctx) # [B, L, D]
# Select masked positions
# pred_mask: [B, L] bool
masked_pred = pred[pred_mask] # [N, D]
masked_tgt = z_tgt[pred_mask] # [N, D]
if masked_pred.numel() == 0:
# Safety: if a batch ends up with no masked tokens, return zero loss
return pred.sum() * 0.0
masked_pred = F.normalize(masked_pred, dim=-1)
masked_tgt = F.normalize(masked_tgt, dim=-1)
# Cosine distance
loss = 1.0 - (masked_pred * masked_tgt).sum(dim=-1)
return loss.mean()
# -----------------------------
# Training
# -----------------------------
def build_hf_encoder(model_name: str):
if AutoModel is None:
raise RuntimeError("transformers is not installed. pip install transformers")
config = AutoConfig.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name, config=config)
dim = int(config.hidden_size)
return encoder, dim
def build_random_encoder(vocab_size: int = 30522, dim: int = 256, layers: int = 4, heads: int = 4):
"""
For smoke tests only: small Transformer encoder (random init).
Requires a tokenizer with vocab mapping for ids.
"""
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)
transformer = nn.TransformerEncoder(encoder_layer, num_layers=layers)
class TinyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(vocab_size, dim)
self.pos = nn.Embedding(512, dim)
self.enc = transformer
def forward(self, input_ids, attention_mask):
B, L = input_ids.shape
pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, L)
x = self.emb(input_ids) + self.pos(pos_ids)
# attention_mask: 1 for keep, 0 for pad
# transformer expects src_key_padding_mask: True for pad
pad_mask = attention_mask == 0
h = self.enc(x, src_key_padding_mask=pad_mask)
return type("Out", (), {"last_hidden_state": h})
return TinyEncoder(), dim
def save_checkpoint(path: str, model: LLMJEPA, optimizer: torch.optim.Optimizer, step: int):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(
{
"step": step,
"context_encoder": model.context_encoder.state_dict(),
"target_encoder": model.target_encoder.state_dict(),
"predictor": model.predictor.state_dict(),
"optimizer": optimizer.state_dict(),
},
path,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="distilbert-base-uncased", help="HF encoder backbone")
parser.add_argument("--text_file", type=str, default="", help="Path to a newline-separated text file")
parser.add_argument("--max_lines", type=int, default=50000)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--mask_ratio", type=float, default=0.3)
parser.add_argument("--mean_span_len", type=int, default=5)
parser.add_argument("--ema_m", type=float, default=0.99)
parser.add_argument("--pred_hidden_mult", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--warmup_steps", type=int, default=50)
parser.add_argument("--log_every", type=int, default=25)
parser.add_argument("--save_every", type=int, default=200)
parser.add_argument("--save_path", type=str, default="checkpoints/llm_jepa.pt")
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--smoke_test", action="store_true", help="No downloads, tiny random encoder, tiny corpus")
args = parser.parse_args()
set_seed(args.seed)
device = pick_device(args.device)
if args.smoke_test:
if AutoTokenizer is None:
raise RuntimeError("transformers is required even for smoke_test (for tokenizer).")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Ensure mask token exists
if tokenizer.mask_token_id is None:
raise ValueError("Tokenizer must support [MASK]. Use a masked LM tokenizer.")
texts = default_tiny_corpus()
ds = TextLinesDataset(texts)
encoder, dim = build_random_encoder(vocab_size=int(tokenizer.vocab_size), dim=256, layers=4, heads=4)
model = LLMJEPA(encoder=encoder, dim=dim, ema_m=0.95, pred_hidden_mult=2).to(device)
lr = 1e-4
else:
if AutoTokenizer is None:
raise RuntimeError("transformers is not installed. pip install transformers")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.mask_token_id is None:
raise ValueError(
"This tokenizer has no [MASK]. Pick a masked-encoder model (BERT/DeBERTa/DistilBERT)."
)
if args.text_file:
texts = load_texts_from_file(args.text_file, max_lines=args.max_lines)
else:
texts = default_tiny_corpus()
ds = TextLinesDataset(texts)
encoder, dim = build_hf_encoder(args.model_name)
model = LLMJEPA(encoder=encoder, dim=dim, ema_m=args.ema_m, pred_hidden_mult=args.pred_hidden_mult).to(device)
lr = args.lr
# DataLoader
def _collate(batch_texts):
return collate_jepa(
batch_texts=batch_texts,
tokenizer=tokenizer,
max_length=args.max_length,
mask_ratio=args.mask_ratio,
mean_span_len=args.mean_span_len,
)
dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=_collate)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)
# Simple warmup + cosine schedule
def lr_at(step: int) -> float:
if step < args.warmup_steps:
return float(step + 1) / float(max(1, args.warmup_steps))
progress = (step - args.warmup_steps) / float(max(1, args.steps - args.warmup_steps))
progress = min(max(progress, 0.0), 1.0)
return 0.5 * (1.0 + math.cos(math.pi * progress))
model.train()
running = 0.0
step = 0
data_iter = iter(dl)
while step < args.steps:
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(dl)
batch = next(data_iter)
# Move to device
input_ids = batch.input_ids.to(device)
attention_mask = batch.attention_mask.to(device)
masked_input_ids = batch.masked_input_ids.to(device)
pred_mask = batch.pred_mask.to(device)
# LR schedule
scale = lr_at(step)
for pg in optimizer.param_groups:
pg["lr"] = lr * scale
loss = model(
masked_input_ids=masked_input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
pred_mask=pred_mask,
)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# EMA update after optimizer step
model.ema_update()
running += float(loss.item())
step += 1
if step % args.log_every == 0:
avg = running / float(args.log_every)
running = 0.0
print(f"step {step:6d} | loss {avg:.4f} | lr {optimizer.param_groups[0]['lr']:.6g}")
if step % args.save_every == 0:
save_checkpoint(args.save_path, model, optimizer, step)
print(f"saved checkpoint to {args.save_path} at step {step}")
save_checkpoint(args.save_path, model, optimizer, step)
print(f"training done. final checkpoint: {args.save_path}")
if __name__ == "__main__":
main()
這是一個面向文本的 JEPA 風格表示預測器。
輸入普通文本行,對每個樣本創建兩個視圖。遮蔽視圖(context view)是同一個句子,但某些 span 被替換成 `[MASK];原始視圖(target view)保持原樣,沒有遮蔽。
訓練流程是這樣的:遮蔽視圖過一個可訓練的 context 編碼器,原始視圖過一個不可訓練的 target 編碼器,然后訓練一個預測器,讓 context 編碼器的表示能預測 target 編碼器的表示——但只在被遮蔽的位置上計算損失。Target 編碼器通過 EMA 更新來保持穩定。
這種設計鼓勵模型學習"填補語義"的表示,而不是預測具體的 token。
set_seed 函數
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
這個函數確保運行可復現。random.seed(seed) 固定 Python 的隨機操作(span 遮蔽會用到),torch.manual_seed(seed) 固定 PyTorch 在 CPU 上的隨機性,torch.cuda.manual_seed_all(seed) 固定 CUDA 內核的隨機性。
span 遮蔽和模型初始化都是隨機的,不設種子的話每次跑結果都不一樣。
pick_device 函數
def pick_device(device_str: str) -> torch.device:
if device_str == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(device_str)
返回 PyTorch 設備對象。如果傳 --device auto,有 GPU 就用 GPU,沒有就用 CPU。也可以直接指定 --device cpu 或 --device cuda。
張量和模型必須在同一設備上,這是基本要求。
sample_span_mask 函數
def sample_span_mask(seq_len, mask_ratio, mean_span_len, special_positions=None)
整個腳本里最重要的函數之一。
目標是創建一個布爾掩碼,標記序列中哪些位置該被遮蔽。參數包括:seq_len 是真實 token 數量(不含 padding),mask_ratio 是遮蔽比例(比如 0.3),mean_span_len 是連續遮蔽 span 的平均長度,special_positions 是永遠不該遮蔽的位置(CLS、SEP、PAD)。
內部邏輯是先創建一個全 False 的掩碼,然后計算需要遮蔽多少 token:
target_to_mask = max(1, int(round(seq_len * mask_ratio)))
即使序列很短也至少遮蔽 1 個。
接下來循環采樣 span 直到湊夠數。Span 長度從指數分布采樣:
span_len = max(1, int(random.expovariate(1.0 / max(1, mean_span_len))))
這會產出很多短 span 和少量長 span,比較符合自然分布。隨機選一個起始位置,過濾掉特殊 token,把剩下的位置標記為 True。
遮蔽策略對表示學習質量影響很大。Span 遮蔽能迫使模型從周圍上下文推斷缺失的語義。
apply_mask_to_input_ids 函數
def apply_mask_to_input_ids(input_ids, attention_mask, tokenizer, mask_ratio, mean_span_len)
拿到一個樣本的 token ids,輸出兩個東西:masked_input_ids 是把遮蔽位置換成 [MASK] 后的 ids,pred_mask 是標記哪些位置要算損失的布爾掩碼。
先算可見序列長度:seq_len = int(attention_mask.sum().item())。attention_mask 里真實 token 是 1,padding 是 0。
然后識別特殊 token 位置,CLS 和 SEP 不能遮蔽,否則模型容易出問題。調用 sample_span_mask 采樣遮蔽位置,把這些位置替換成 mask_token_id:
masked_input_ids[:seq_len][pred_mask] = mask_token_id
返回的 pred_mask 是完整長度的,padding 位置都是 False。只在遮蔽位置算 JEPA 損失,其他位置忽略。
TextLinesDataset 類
class TextLinesDataset(Dataset):
def __init__(self, texts):
self.texts = [t.strip() for t in texts if t.strip()]
極簡的數據集實現,存文本行列表,去掉空行和首尾空白。__len__ 返回行數,__getitem__ 返回單條文本。
load_texts_from_file 逐行讀文件,可限制最大行數,傳 --text_file 時用。default_tiny_corpus 提供內置測試數據集。
Batch 數據類
@dataclass
class Batch:
input_ids
attention_mask
masked_input_ids
pred_mask
用 dataclass 比返回元組清晰多了,代碼可讀性好。
collate_jepa 函數
DataLoader 創建批次時調用的函數。輸入是原始文本列表,先用 tokenizer 做分詞、padding、截斷:
toks = tokenizer(batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
產出 input_ids 和 attention_mask。然后對每個樣本調 apply_mask_to_input_ids 生成遮蔽版本和 pred_mask,最后堆疊成 [B, L] 張量返回 Batch。
DataLoader 是逐樣本讀的,但訓練需要批次。批處理和遮蔽都在這里發生。
PredictorMLP 類
預測器頭,結構簡單:
nn.Linear(dim, hidden)
nn.GELU()
nn.Dropout()
nn.Linear(hidden, dim)
把 context 表示映射到 target 表示空間,相當于一個學習出來的適配器,幫助對齊兩邊的嵌入。
LLMJEPA 模型類
主模型包裝器,包含四個核心部件:context_encoder 是可訓練的 Transformer 編碼器,target_encoder 是它的深拷貝但不可訓練,predictor 是 MLP,ema_m 是 EMA 動量因子。
_copy_encoder 用 copy.deepcopy 確保 target 和 context 初始狀態一致。
ema_update 緩慢更新 target 編碼器權重:
p_tgt = m * p_tgt + (1 - m) * p_ctx
m=0.99 時 target 變化非常慢,這能穩定訓練、降低表示坍塌風險。
forward 的流程:把遮蔽視圖過 context 編碼器(可訓練),原始視圖過 target 編碼器(無梯度),predictor 處理 context 輸出,然后只取遮蔽位置的向量:
masked_pred = pred[pred_mask] # [N, D]
masked_tgt = z_tgt[pred_mask] # [N, D]
從 [B, L, D] 變成 [N, D],N 是遮蔽 token 總數。歸一化后算余弦距離:
loss = 1 - (masked_pred * masked_tgt).sum(dim=-1)
return loss.mean()
歸一化是因為余弦相似度只看向量方向,不看大小。
build_hf_encoder 函數
加載 Hugging Face 編碼器,返回模型和隱藏維度(從 config.hidden_size 讀)。
build_random_encoder 函數
冒煙測試專用,從頭建一個小 Transformer 編碼器,包括嵌入層、位置嵌入、編碼器堆棧。注意這不是掩碼語言模型,只是個編碼器架構。返回對象帶 .last_hidden_state 屬性是為了匹配 HF 輸出格式。
總結
這個實現刻意追求清晰而非完整,所以沒有自定義注意力掩碼、多視圖數據集或混合目標。但是把它當參考實現用是非常合適的。原始 LLM-JEPA 論文做得更深入,把 JEPA 和 token 預測結合起來,還利用了文本-代碼這樣的自然配對視圖。那些設計對下游任務表現很重要,但也增加了復雜度,容易讓人看不清核心機制。
論文:
https://avoid.overfit.cn/post/09eb991a93f64a83a376cdb52ac5c661
作者:azhar
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.