<cite id="ffb66"></cite><cite id="ffb66"><track id="ffb66"></track></cite>
      <legend id="ffb66"><li id="ffb66"></li></legend>
      色婷婷久,激情色播,久久久无码专区,亚洲中文字幕av,国产成人A片,av无码免费,精品久久国产,99视频精品3
      網易首頁 > 網易號 > 正文 申請入駐

      用 PyTorch 實現 LLM-JEPA:不預測 token,預測嵌入

      0
      分享至

      這篇文章從頭實現 LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures。需要說明的是,這里寫的是一個簡潔的最小化訓練腳本,目標是了解 JEPA 的本質:對同一文本創建兩個視圖,預測被遮蔽片段的嵌入,用表示對齊損失來訓練。

      本文的目標是讓你真正理解這套方法。代碼會逐行講解,每個函數的用途都會解釋清楚,并和論文的核心直覺對應起來。每個代碼塊都會詳細說明,方便你根據自己的實驗需求進行修改。



      代碼

      整個 LLM-JEPA 訓練腳本放在一個文件里:

      它接收原始文本然后創建兩個視圖:context 視圖把某些片段替換成 [MASK],target 視圖保留原始文本但只在被遮蔽位置做監督。Context 編碼器是可訓練的,負責預測 target 編碼器在遮蔽位置的表示。Target 編碼器則是 context 編碼器的 EMA 副本,不參與梯度計算。損失函數用的是預測嵌入和目標嵌入之間的余弦距離。

      運行示例:

      # 小型冒煙測試(無需下載,隨機初始化)
      python llm_jepa_train.py --smoke_test
      # 使用 HF 模型骨干訓練
      python llm_jepa_train.py --model_name distilbert-base-uncased --steps 200 --batch_size 8
      # 在自己的文本文件上訓練
      python llm_jepa_train.py --model_name distilbert-base-uncased --text_file data.txt --steps 2000

      這是一個簡潔的參考實現,不是完整的倉庫代碼。編碼器用的是 Transformers 庫。

      import argparse
      import math
      import os
      import random
      from dataclasses import dataclass
      from typing import List, Tuple, Optional
      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      from torch.utils.data import Dataset, DataLoader
      try:
      from transformers import AutoTokenizer, AutoModel, AutoConfig
      except Exception:
      AutoTokenizer = None
      AutoModel = None
      AutoConfig = None
      # -----------------------------
      # Utilities
      # -----------------------------
      def set_seed(seed: int):
      random.seed(seed)
      torch.manual_seed(seed)
      torch.cuda.manual_seed_all(seed)
      def pick_device(device_str: str) -> torch.device:
      if device_str == "auto":
      return torch.device("cuda" if torch.cuda.is_available() else "cpu")
      return torch.device(device_str)
      # -----------------------------
      # Span masking (simple + effective)
      # -----------------------------
      def sample_span_mask(
      seq_len: int,
      mask_ratio: float,
      mean_span_len: int,
      special_positions: Optional[set] = None,
      ) -> torch.BoolTensor:
      """
      Returns a boolean mask of length seq_len indicating which positions are masked.
      We mask contiguous spans until we reach approximately mask_ratio of tokens.
      """
      if special_positions is None:
      special_positions = set()
      mask = torch.zeros(seq_len, dtype=torch.bool)
      if seq_len <= 0:
      return mask
      target_to_mask = max(1, int(round(seq_len * mask_ratio)))
      masked = 0
      attempts = 0
      max_attempts = seq_len * 4
      while masked < target_to_mask and attempts < max_attempts:
      attempts += 1
      span_len = max(1, int(random.expovariate(1.0 / max(1, mean_span_len))))
      span_len = min(span_len, seq_len)
      start = random.randint(0, seq_len - 1)
      end = min(seq_len, start + span_len)
      span_positions = [i for i in range(start, end) if i not in special_positions]
      if not span_positions:
      continue
      newly = 0
      for i in span_positions:
      if not mask[i]:
      mask[i] = True
      newly += 1
      masked += newly
      return mask
      def apply_mask_to_input_ids(
      input_ids: torch.LongTensor,
      attention_mask: torch.LongTensor,
      tokenizer,
      mask_ratio: float,
      mean_span_len: int,
      ) -> Tuple[torch.LongTensor, torch.BoolTensor]:
      """
      Masks spans inside non-special, non-padding tokens.
      Returns:
      masked_input_ids: input ids with masked tokens replaced by [MASK]
      pred_mask: boolean mask over positions where we apply JEPA loss
      """
      assert input_ids.dim() == 1
      seq_len = int(attention_mask.sum().item())
      # Identify special token positions (CLS, SEP, etc.) in the visible region
      special_positions = set()
      for i in range(seq_len):
      tid = int(input_ids[i].item())
      if tid in {
      tokenizer.cls_token_id,
      tokenizer.sep_token_id,
      tokenizer.pad_token_id,
      }:
      special_positions.add(i)
      pred_mask = sample_span_mask(
      seq_len=seq_len,
      mask_ratio=mask_ratio,
      mean_span_len=mean_span_len,
      special_positions=special_positions,
      )
      masked_input_ids = input_ids.clone()
      mask_token_id = tokenizer.mask_token_id
      if mask_token_id is None:
      raise ValueError("Tokenizer has no mask_token_id. Use a model with [MASK].")
      # Replace masked positions with [MASK]
      masked_input_ids[:seq_len][pred_mask] = mask_token_id
      # pred_mask should be full length (includes pads as False)
      full_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
      full_mask[:seq_len] = pred_mask
      return masked_input_ids, full_mask
      # -----------------------------
      # Dataset
      # -----------------------------
      class TextLinesDataset(Dataset):
      def __init__(self, texts: List[str]):
      self.texts = [t.strip() for t in texts if t.strip()]
      def __len__(self) -> int:
      return len(self.texts)
      def __getitem__(self, idx: int) -> str:
      return self.texts[idx]
      def load_texts_from_file(path: str, max_lines: Optional[int] = None) -> List[str]:
      texts = []
      with open(path, "r", encoding="utf-8") as f:
      for i, line in enumerate(f):
      if max_lines is not None and i >= max_lines:
      break
      texts.append(line.rstrip("\n"))
      return texts
      def default_tiny_corpus() -> List[str]:
      return [
      "The cat sat on the mat and looked at the window.",
      "A quick brown fox jumps over the lazy dog.",
      "Deep learning models can learn useful representations from raw data.",
      "Rocket Learning builds AI tools for education in India.",
      "Transformers use attention to mix information across tokens.",
      "Self-supervised learning can reduce the need for labels.",
      "JEPA trains models to predict embeddings, not tokens.",
      "Bengaluru is a major tech hub in India.",
      "A good system design balances simplicity and scalability.",
      "Reading code carefully helps you understand how an idea is implemented.",
      ]
      @dataclass
      class Batch:
      input_ids: torch.LongTensor # [B, L]
      attention_mask: torch.LongTensor # [B, L]
      masked_input_ids: torch.LongTensor # [B, L]
      pred_mask: torch.BoolTensor # [B, L] positions to compute loss on
      def collate_jepa(
      batch_texts: List[str],
      tokenizer,
      max_length: int,
      mask_ratio: float,
      mean_span_len: int,
      ) -> Batch:
      toks = tokenizer(
      batch_texts,
      padding=True,
      truncation=True,
      max_length=max_length,
      return_tensors="pt",
      )
      input_ids = toks["input_ids"] # [B, L]
      attention_mask = toks["attention_mask"] # [B, L]
      masked_input_ids_list = []
      pred_mask_list = []
      for b in range(input_ids.size(0)):
      mi, pm = apply_mask_to_input_ids(
      input_ids[b],
      attention_mask[b],
      tokenizer,
      mask_ratio=mask_ratio,
      mean_span_len=mean_span_len,
      )
      masked_input_ids_list.append(mi)
      pred_mask_list.append(pm)
      masked_input_ids = torch.stack(masked_input_ids_list, dim=0)
      pred_mask = torch.stack(pred_mask_list, dim=0)
      return Batch(
      input_ids=input_ids,
      attention_mask=attention_mask,
      masked_input_ids=masked_input_ids,
      pred_mask=pred_mask,
      )
      # -----------------------------
      # Model: Encoder + Predictor + EMA target encoder
      # -----------------------------
      class PredictorMLP(nn.Module):
      def __init__(self, dim: int, hidden_mult: int = 4, dropout: float = 0.0):
      super().__init__()
      hidden = dim * hidden_mult
      self.net = nn.Sequential(
      nn.Linear(dim, hidden),
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(hidden, dim),
      )
      def forward(self, x: torch.Tensor) -> torch.Tensor:
      return self.net(x)
      class LLMJEPA(nn.Module):
      def __init__(self, encoder: nn.Module, dim: int, ema_m: float = 0.99, pred_hidden_mult: int = 4):
      super().__init__()
      self.context_encoder = encoder
      self.target_encoder = self._copy_encoder(encoder)
      self.predictor = PredictorMLP(dim=dim, hidden_mult=pred_hidden_mult, dropout=0.0)
      self.ema_m = ema_m
      for p in self.target_encoder.parameters():
      p.requires_grad = False
      @staticmethod
      def _copy_encoder(enc: nn.Module) -> nn.Module:
      import copy
      return copy.deepcopy(enc)
      @torch.no_grad()
      def ema_update(self):
      m = self.ema_m
      for p_ctx, p_tgt in zip(self.context_encoder.parameters(), self.target_encoder.parameters()):
      p_tgt.data.mul_(m).add_(p_ctx.data, alpha=(1.0 - m))
      def forward(
      self,
      masked_input_ids: torch.LongTensor,
      input_ids: torch.LongTensor,
      attention_mask: torch.LongTensor,
      pred_mask: torch.BoolTensor,
      ) -> torch.Tensor:
      """
      Returns JEPA loss (scalar).
      We compute:
      z_ctx = context_encoder(masked_input)
      z_tgt = target_encoder(full input)
      pred = predictor(z_ctx)
      loss over positions in pred_mask
      """
      out_ctx = self.context_encoder(input_ids=masked_input_ids, attention_mask=attention_mask)
      z_ctx = out_ctx.last_hidden_state # [B, L, D]
      with torch.no_grad():
      out_tgt = self.target_encoder(input_ids=input_ids, attention_mask=attention_mask)
      z_tgt = out_tgt.last_hidden_state # [B, L, D]
      pred = self.predictor(z_ctx) # [B, L, D]
      # Select masked positions
      # pred_mask: [B, L] bool
      masked_pred = pred[pred_mask] # [N, D]
      masked_tgt = z_tgt[pred_mask] # [N, D]
      if masked_pred.numel() == 0:
      # Safety: if a batch ends up with no masked tokens, return zero loss
      return pred.sum() * 0.0
      masked_pred = F.normalize(masked_pred, dim=-1)
      masked_tgt = F.normalize(masked_tgt, dim=-1)
      # Cosine distance
      loss = 1.0 - (masked_pred * masked_tgt).sum(dim=-1)
      return loss.mean()
      # -----------------------------
      # Training
      # -----------------------------
      def build_hf_encoder(model_name: str):
      if AutoModel is None:
      raise RuntimeError("transformers is not installed. pip install transformers")
      config = AutoConfig.from_pretrained(model_name)
      encoder = AutoModel.from_pretrained(model_name, config=config)
      dim = int(config.hidden_size)
      return encoder, dim
      def build_random_encoder(vocab_size: int = 30522, dim: int = 256, layers: int = 4, heads: int = 4):
      """
      For smoke tests only: small Transformer encoder (random init).
      Requires a tokenizer with vocab mapping for ids.
      """
      encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)
      transformer = nn.TransformerEncoder(encoder_layer, num_layers=layers)
      class TinyEncoder(nn.Module):
      def __init__(self):
      super().__init__()
      self.emb = nn.Embedding(vocab_size, dim)
      self.pos = nn.Embedding(512, dim)
      self.enc = transformer
      def forward(self, input_ids, attention_mask):
      B, L = input_ids.shape
      pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, L)
      x = self.emb(input_ids) + self.pos(pos_ids)
      # attention_mask: 1 for keep, 0 for pad
      # transformer expects src_key_padding_mask: True for pad
      pad_mask = attention_mask == 0
      h = self.enc(x, src_key_padding_mask=pad_mask)
      return type("Out", (), {"last_hidden_state": h})
      return TinyEncoder(), dim
      def save_checkpoint(path: str, model: LLMJEPA, optimizer: torch.optim.Optimizer, step: int):
      os.makedirs(os.path.dirname(path), exist_ok=True)
      torch.save(
      {
      "step": step,
      "context_encoder": model.context_encoder.state_dict(),
      "target_encoder": model.target_encoder.state_dict(),
      "predictor": model.predictor.state_dict(),
      "optimizer": optimizer.state_dict(),
      },
      path,
      )
      def main():
      parser = argparse.ArgumentParser()
      parser.add_argument("--model_name", type=str, default="distilbert-base-uncased", help="HF encoder backbone")
      parser.add_argument("--text_file", type=str, default="", help="Path to a newline-separated text file")
      parser.add_argument("--max_lines", type=int, default=50000)
      parser.add_argument("--max_length", type=int, default=128)
      parser.add_argument("--mask_ratio", type=float, default=0.3)
      parser.add_argument("--mean_span_len", type=int, default=5)
      parser.add_argument("--ema_m", type=float, default=0.99)
      parser.add_argument("--pred_hidden_mult", type=int, default=4)
      parser.add_argument("--batch_size", type=int, default=8)
      parser.add_argument("--lr", type=float, default=2e-5)
      parser.add_argument("--weight_decay", type=float, default=0.01)
      parser.add_argument("--steps", type=int, default=500)
      parser.add_argument("--warmup_steps", type=int, default=50)
      parser.add_argument("--log_every", type=int, default=25)
      parser.add_argument("--save_every", type=int, default=200)
      parser.add_argument("--save_path", type=str, default="checkpoints/llm_jepa.pt")
      parser.add_argument("--device", type=str, default="auto")
      parser.add_argument("--seed", type=int, default=42)
      parser.add_argument("--smoke_test", action="store_true", help="No downloads, tiny random encoder, tiny corpus")
      args = parser.parse_args()
      set_seed(args.seed)
      device = pick_device(args.device)
      if args.smoke_test:
      if AutoTokenizer is None:
      raise RuntimeError("transformers is required even for smoke_test (for tokenizer).")
      tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
      # Ensure mask token exists
      if tokenizer.mask_token_id is None:
      raise ValueError("Tokenizer must support [MASK]. Use a masked LM tokenizer.")
      texts = default_tiny_corpus()
      ds = TextLinesDataset(texts)
      encoder, dim = build_random_encoder(vocab_size=int(tokenizer.vocab_size), dim=256, layers=4, heads=4)
      model = LLMJEPA(encoder=encoder, dim=dim, ema_m=0.95, pred_hidden_mult=2).to(device)
      lr = 1e-4
      else:
      if AutoTokenizer is None:
      raise RuntimeError("transformers is not installed. pip install transformers")
      tokenizer = AutoTokenizer.from_pretrained(args.model_name)
      if tokenizer.mask_token_id is None:
      raise ValueError(
      "This tokenizer has no [MASK]. Pick a masked-encoder model (BERT/DeBERTa/DistilBERT)."
      )
      if args.text_file:
      texts = load_texts_from_file(args.text_file, max_lines=args.max_lines)
      else:
      texts = default_tiny_corpus()
      ds = TextLinesDataset(texts)
      encoder, dim = build_hf_encoder(args.model_name)
      model = LLMJEPA(encoder=encoder, dim=dim, ema_m=args.ema_m, pred_hidden_mult=args.pred_hidden_mult).to(device)
      lr = args.lr
      # DataLoader
      def _collate(batch_texts):
      return collate_jepa(
      batch_texts=batch_texts,
      tokenizer=tokenizer,
      max_length=args.max_length,
      mask_ratio=args.mask_ratio,
      mean_span_len=args.mean_span_len,
      )
      dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=_collate)
      # Optimizer
      optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)
      # Simple warmup + cosine schedule
      def lr_at(step: int) -> float:
      if step < args.warmup_steps:
      return float(step + 1) / float(max(1, args.warmup_steps))
      progress = (step - args.warmup_steps) / float(max(1, args.steps - args.warmup_steps))
      progress = min(max(progress, 0.0), 1.0)
      return 0.5 * (1.0 + math.cos(math.pi * progress))
      model.train()
      running = 0.0
      step = 0
      data_iter = iter(dl)
      while step < args.steps:
      try:
      batch = next(data_iter)
      except StopIteration:
      data_iter = iter(dl)
      batch = next(data_iter)
      # Move to device
      input_ids = batch.input_ids.to(device)
      attention_mask = batch.attention_mask.to(device)
      masked_input_ids = batch.masked_input_ids.to(device)
      pred_mask = batch.pred_mask.to(device)
      # LR schedule
      scale = lr_at(step)
      for pg in optimizer.param_groups:
      pg["lr"] = lr * scale
      loss = model(
      masked_input_ids=masked_input_ids,
      input_ids=input_ids,
      attention_mask=attention_mask,
      pred_mask=pred_mask,
      )
      optimizer.zero_grad(set_to_none=True)
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      optimizer.step()
      # EMA update after optimizer step
      model.ema_update()
      running += float(loss.item())
      step += 1
      if step % args.log_every == 0:
      avg = running / float(args.log_every)
      running = 0.0
      print(f"step {step:6d} | loss {avg:.4f} | lr {optimizer.param_groups[0]['lr']:.6g}")
      if step % args.save_every == 0:
      save_checkpoint(args.save_path, model, optimizer, step)
      print(f"saved checkpoint to {args.save_path} at step {step}")
      save_checkpoint(args.save_path, model, optimizer, step)
      print(f"training done. final checkpoint: {args.save_path}")
      if __name__ == "__main__":
      main()

      這是一個面向文本的 JEPA 風格表示預測器。

      輸入普通文本行,對每個樣本創建兩個視圖。遮蔽視圖(context view)是同一個句子,但某些 span 被替換成 `[MASK];原始視圖(target view)保持原樣,沒有遮蔽。

      訓練流程是這樣的:遮蔽視圖過一個可訓練的 context 編碼器,原始視圖過一個不可訓練的 target 編碼器,然后訓練一個預測器,讓 context 編碼器的表示能預測 target 編碼器的表示——但只在被遮蔽的位置上計算損失。Target 編碼器通過 EMA 更新來保持穩定。

      這種設計鼓勵模型學習"填補語義"的表示,而不是預測具體的 token。

      set_seed 函數

      def set_seed(seed: int):
      random.seed(seed)
      torch.manual_seed(seed)
      torch.cuda.manual_seed_all(seed)

      這個函數確保運行可復現。random.seed(seed) 固定 Python 的隨機操作(span 遮蔽會用到),torch.manual_seed(seed) 固定 PyTorch 在 CPU 上的隨機性,torch.cuda.manual_seed_all(seed) 固定 CUDA 內核的隨機性。

      span 遮蔽和模型初始化都是隨機的,不設種子的話每次跑結果都不一樣。

      pick_device 函數

      def pick_device(device_str: str) -> torch.device:
      if device_str == "auto":
      return torch.device("cuda" if torch.cuda.is_available() else "cpu")
      return torch.device(device_str)

      返回 PyTorch 設備對象。如果傳 --device auto,有 GPU 就用 GPU,沒有就用 CPU。也可以直接指定 --device cpu 或 --device cuda。

      張量和模型必須在同一設備上,這是基本要求。

      sample_span_mask 函數

      def sample_span_mask(seq_len, mask_ratio, mean_span_len, special_positions=None)

      整個腳本里最重要的函數之一。

      目標是創建一個布爾掩碼,標記序列中哪些位置該被遮蔽。參數包括:seq_len 是真實 token 數量(不含 padding),mask_ratio 是遮蔽比例(比如 0.3),mean_span_len 是連續遮蔽 span 的平均長度,special_positions 是永遠不該遮蔽的位置(CLS、SEP、PAD)。

      內部邏輯是先創建一個全 False 的掩碼,然后計算需要遮蔽多少 token:

      target_to_mask = max(1, int(round(seq_len * mask_ratio)))

      即使序列很短也至少遮蔽 1 個。

      接下來循環采樣 span 直到湊夠數。Span 長度從指數分布采樣:

      span_len = max(1, int(random.expovariate(1.0 / max(1, mean_span_len))))

      這會產出很多短 span 和少量長 span,比較符合自然分布。隨機選一個起始位置,過濾掉特殊 token,把剩下的位置標記為 True。

      遮蔽策略對表示學習質量影響很大。Span 遮蔽能迫使模型從周圍上下文推斷缺失的語義。

      apply_mask_to_input_ids 函數

      def apply_mask_to_input_ids(input_ids, attention_mask, tokenizer, mask_ratio, mean_span_len)

      拿到一個樣本的 token ids,輸出兩個東西:masked_input_ids 是把遮蔽位置換成 [MASK] 后的 ids,pred_mask 是標記哪些位置要算損失的布爾掩碼。

      先算可見序列長度:seq_len = int(attention_mask.sum().item())。attention_mask 里真實 token 是 1,padding 是 0。

      然后識別特殊 token 位置,CLS 和 SEP 不能遮蔽,否則模型容易出問題。調用 sample_span_mask 采樣遮蔽位置,把這些位置替換成 mask_token_id:

      masked_input_ids[:seq_len][pred_mask] = mask_token_id

      返回的 pred_mask 是完整長度的,padding 位置都是 False。只在遮蔽位置算 JEPA 損失,其他位置忽略。

      TextLinesDataset 類

      class TextLinesDataset(Dataset):
      def __init__(self, texts):
      self.texts = [t.strip() for t in texts if t.strip()]

      極簡的數據集實現,存文本行列表,去掉空行和首尾空白。__len__ 返回行數,__getitem__ 返回單條文本。

      load_texts_from_file 逐行讀文件,可限制最大行數,傳 --text_file 時用。default_tiny_corpus 提供內置測試數據集。

      Batch 數據類

      @dataclass
      class Batch:
      input_ids
      attention_mask
      masked_input_ids
      pred_mask

      用 dataclass 比返回元組清晰多了,代碼可讀性好。

      collate_jepa 函數

      DataLoader 創建批次時調用的函數。輸入是原始文本列表,先用 tokenizer 做分詞、padding、截斷:

      toks = tokenizer(batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")

      產出 input_ids 和 attention_mask。然后對每個樣本調 apply_mask_to_input_ids 生成遮蔽版本和 pred_mask,最后堆疊成 [B, L] 張量返回 Batch。

      DataLoader 是逐樣本讀的,但訓練需要批次。批處理和遮蔽都在這里發生。

      PredictorMLP 類

      預測器頭,結構簡單:

      nn.Linear(dim, hidden)
      nn.GELU()
      nn.Dropout()
      nn.Linear(hidden, dim)

      把 context 表示映射到 target 表示空間,相當于一個學習出來的適配器,幫助對齊兩邊的嵌入。

      LLMJEPA 模型類

      主模型包裝器,包含四個核心部件:context_encoder 是可訓練的 Transformer 編碼器,target_encoder 是它的深拷貝但不可訓練,predictor 是 MLP,ema_m 是 EMA 動量因子。

      _copy_encoder 用 copy.deepcopy 確保 target 和 context 初始狀態一致。

      ema_update 緩慢更新 target 編碼器權重:

      p_tgt = m * p_tgt + (1 - m) * p_ctx

      m=0.99 時 target 變化非常慢,這能穩定訓練、降低表示坍塌風險。

      forward 的流程:把遮蔽視圖過 context 編碼器(可訓練),原始視圖過 target 編碼器(無梯度),predictor 處理 context 輸出,然后只取遮蔽位置的向量:

      masked_pred = pred[pred_mask] # [N, D]
      masked_tgt = z_tgt[pred_mask] # [N, D]

      從 [B, L, D] 變成 [N, D],N 是遮蔽 token 總數。歸一化后算余弦距離:

      loss = 1 - (masked_pred * masked_tgt).sum(dim=-1)
      return loss.mean()

      歸一化是因為余弦相似度只看向量方向,不看大小。

      build_hf_encoder 函數

      加載 Hugging Face 編碼器,返回模型和隱藏維度(從 config.hidden_size 讀)。

      build_random_encoder 函數

      冒煙測試專用,從頭建一個小 Transformer 編碼器,包括嵌入層、位置嵌入、編碼器堆棧。注意這不是掩碼語言模型,只是個編碼器架構。返回對象帶 .last_hidden_state 屬性是為了匹配 HF 輸出格式。

      總結

      這個實現刻意追求清晰而非完整,所以沒有自定義注意力掩碼、多視圖數據集或混合目標。但是把它當參考實現用是非常合適的。原始 LLM-JEPA 論文做得更深入,把 JEPA 和 token 預測結合起來,還利用了文本-代碼這樣的自然配對視圖。那些設計對下游任務表現很重要,但也增加了復雜度,容易讓人看不清核心機制。

      論文:

      https://avoid.overfit.cn/post/09eb991a93f64a83a376cdb52ac5c661

      作者:azhar

      特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。

      Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

      相關推薦
      熱點推薦
      伊朗最大“內鬼”被抓?革命衛隊:勾結以色列,指揮官卡尼被拘!

      伊朗最大“內鬼”被抓?革命衛隊:勾結以色列,指揮官卡尼被拘!

      青青子衿
      2026-03-05 11:57:03
      打瘋了!東契奇首節狂轟22+5三分 生涯30次單節20+升歷史第四

      打瘋了!東契奇首節狂轟22+5三分 生涯30次單節20+升歷史第四

      醉臥浮生
      2026-03-07 12:13:33
      伊拉克庫爾德第一夫人宣言:我們不是任人驅使的炮灰!

      伊拉克庫爾德第一夫人宣言:我們不是任人驅使的炮灰!

      勝研集
      2026-03-06 13:44:23
      廣東一女子不愿上班常年坐街邊,因長得好看被路人投喂:又懶又饞

      廣東一女子不愿上班常年坐街邊,因長得好看被路人投喂:又懶又饞

      明智家庭教育
      2026-03-06 17:19:16
      美以伊軍事沖突最大副作用,是斬斷了俄羅斯的“救命稻草”

      美以伊軍事沖突最大副作用,是斬斷了俄羅斯的“救命稻草”

      廖保平
      2026-03-05 12:08:52
      “不想為以色列賣命”:帝國最后的遮羞布,美式民主終成笑話

      “不想為以色列賣命”:帝國最后的遮羞布,美式民主終成笑話

      怪口歷史的K先生
      2026-03-06 15:22:51
      為何關閉霍爾木茲海峽就能掐全球脖子?因為伊朗原油是全世界最好的

      為何關閉霍爾木茲海峽就能掐全球脖子?因為伊朗原油是全世界最好的

      風向觀察
      2026-03-06 21:31:15
      兩會不到3天,5大好消息傳來!老百姓暗暗叫好:希望國家盡快落實

      兩會不到3天,5大好消息傳來!老百姓暗暗叫好:希望國家盡快落實

      談史論天地
      2026-03-07 06:54:29
      1979年,張國燾凍死在養老院,許世友:除了主席,沒人是他的對手

      1979年,張國燾凍死在養老院,許世友:除了主席,沒人是他的對手

      文史季季紅
      2026-03-05 13:35:03
      寫入教科書的一天:F-35在德黑蘭完成全球首次實戰空對空擊殺

      寫入教科書的一天:F-35在德黑蘭完成全球首次實戰空對空擊殺

      斌聞天下
      2026-03-06 07:30:03
      伊方:因美以襲擊喪生的伊朗人三成為青少年

      伊方:因美以襲擊喪生的伊朗人三成為青少年

      環球網資訊
      2026-03-07 06:39:29
      為什么美國的華人華裔地位那么低 網友從各方面分析 真就那樣

      為什么美國的華人華裔地位那么低 網友從各方面分析 真就那樣

      侃神評故事
      2026-03-06 07:10:03
      我包養過一個女大學生,七年花了一千多萬

      我包養過一個女大學生,七年花了一千多萬

      煙火人間故事匯
      2026-03-06 23:05:03
      性壓抑已經變態至此了?

      性壓抑已經變態至此了?

      黯泉
      2026-03-07 11:28:43
      蘿莉島,是進入核心圈層的投名狀,你猜他們為什么都穿紅皮鞋

      蘿莉島,是進入核心圈層的投名狀,你猜他們為什么都穿紅皮鞋

      百曉生談歷史
      2026-03-05 22:00:08
      一份“煮熟的三文魚”火了,原來低認知的家長,真能搞出人命!

      一份“煮熟的三文魚”火了,原來低認知的家長,真能搞出人命!

      妍妍教育日記
      2026-03-07 08:45:06
      伊朗萬萬沒想到,自家王牌武器遭到破解,美軍多了一張底牌

      伊朗萬萬沒想到,自家王牌武器遭到破解,美軍多了一張底牌

      空天力量
      2026-03-06 13:09:18
      上次被發現還是1911年!上海寶山驚現1只,專家:可能是坐船來的

      上次被發現還是1911年!上海寶山驚現1只,專家:可能是坐船來的

      萬象硬核本尊
      2026-03-06 23:54:22
      女子實名舉報某團外賣:不上大額券就讓我變成“凌晨營業”,你們真黑!

      女子實名舉報某團外賣:不上大額券就讓我變成“凌晨營業”,你們真黑!

      回旋鏢
      2026-03-06 21:13:59
      塔圖姆復出15分12板7助攻凱爾特人大勝獨行俠,布朗24分7板7助

      塔圖姆復出15分12板7助攻凱爾特人大勝獨行俠,布朗24分7板7助

      湖人崛起
      2026-03-07 10:25:09
      2026-03-07 13:43:00
      deephub incentive-icons
      deephub
      CV NLP和數據挖掘知識
      1940文章數 1456關注度
      往期回顧 全部

      科技要聞

      OpenClaw爆火,六位"養蝦人"自述與AI共生

      頭條要聞

      特朗普突然放話"先解決伊朗后解決古巴" 梅西聽懵了

      頭條要聞

      特朗普突然放話"先解決伊朗后解決古巴" 梅西聽懵了

      體育要聞

      塔圖姆歸來:凱爾特人的春之綠

      娛樂要聞

      周杰倫田馥甄的“JH戀” 被扒得底朝天

      財經要聞

      針對"不敢休、不讓休"怪圈 國家出手了

      汽車要聞

      逃離ICU,上汽通用“止血”企穩

      態度原創

      教育
      親子
      房產
      公開課
      軍事航空

      教育要聞

      兩會速遞|教育部部長:將實施新一輪學生心理健康促進行動

      親子要聞

      六個月寶寶查出散光,原因竟是父母長期身旁玩手機,媽媽懵了:我一直以為他閉著眼就沒事

      房產要聞

      傳統學區房熄火?2月海口二手房爆火的板塊竟然是…

      公開課

      李玫瑾:為什么性格比能力更重要?

      軍事要聞

      伊朗:使用無人機擊中美軍"林肯"號航母

      無障礙瀏覽 進入關懷版