![]()
NeRF(Neural Radiance Fields,神經(jīng)輻射場)的核心思路是用一個(gè)全連接網(wǎng)絡(luò)表示三維場景。輸入是5D向量空間坐標(biāo)(x, y, z)加上視角方向(θ, φ),輸出則是該點(diǎn)的顏色和體積密度。訓(xùn)練的數(shù)據(jù)則是同一物體從不同角度拍攝的若干張照片。
通常情況下泛化能力是模型的追求目標(biāo),需要在大量不同樣本上訓(xùn)練以避免過擬合。但NeRF恰恰相反,它只在單一場景的多個(gè)視角上訓(xùn)練,刻意讓網(wǎng)絡(luò)"過擬合"到這個(gè)特定場景,這與傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的訓(xùn)練邏輯完全相反。
這樣NeRF把網(wǎng)絡(luò)訓(xùn)練成了某個(gè)場景的"專家",這個(gè)專家只懂一件事,但懂得很透徹:給它任意一個(gè)新視角,它都能告訴你從那個(gè)方向看場景是什么樣子,存儲(chǔ)的不再是一堆圖片,而是場景本身的隱式表示。

基本概念
把5D輸入向量拆開來看:空間位置(x, y, z)和觀察方向(θ, φ)。
顏色(也就是輻射度)同時(shí)依賴位置和觀察方向,這很好理解,因?yàn)橥粋€(gè)點(diǎn)從不同角度看可能有不同的反光效果。但密度只跟位置有關(guān)與觀察方向無關(guān)。這里的假設(shè)是材質(zhì)本身不會(huì)因?yàn)槟銚Q個(gè)角度看就變透明或變不透明,這個(gè)約束大幅降低了模型復(fù)雜度。
用來表示這個(gè)映射關(guān)系的是一個(gè)多層感知機(jī)(MLP)而且沒有卷積層,這個(gè)MLP被有意過擬合到特定場景。
![]()
渲染流程分三步:沿每條光線采樣生成3D點(diǎn),用網(wǎng)絡(luò)預(yù)測每個(gè)點(diǎn)的顏色和密度,最后用體積渲染把這些顏色累積成二維圖像。
訓(xùn)練時(shí)用梯度下降最小化渲染圖像與真實(shí)圖像之間的差距。不過直接訓(xùn)練效果不好原始5D輸入需要經(jīng)過位置編碼轉(zhuǎn)換才能讓網(wǎng)絡(luò)更好地捕捉高頻細(xì)節(jié)。
傳統(tǒng)體素表示需要顯式存儲(chǔ)整個(gè)場景占用空間巨大。NeRF則把場景信息壓縮在網(wǎng)絡(luò)參數(shù)里,最終模型可以比原始圖片集小很多。這是NeRF的一個(gè)關(guān)鍵優(yōu)勢。
相關(guān)工作
NeRF出現(xiàn)之前,神經(jīng)場景表示一直比不過體素、三角網(wǎng)格這些離散表示方法。
早期也有人用網(wǎng)絡(luò)把位置坐標(biāo)映射到距離函數(shù)或占用場,但只能處理ShapeNet這類合成3D數(shù)據(jù)。
![]()
arxiv:1912.07372 用3D占用場做隱式表示提出了可微渲染公式。arxiv:1906.01618的方法在每個(gè)3D點(diǎn)輸出特征向量和顏色用循環(huán)神經(jīng)網(wǎng)絡(luò)沿光線移動(dòng)來檢測表面,但這些方法生成的表面往往過于平滑。
如果視角采樣足夠密集,光場插值技術(shù)就能生成新視角。但視角稀疏時(shí)必須用表示方法,體積方法能生成真實(shí)感強(qiáng)的圖像但分辨率上不去。
場景表示機(jī)制
![]()
輸入是位置x= (x, y, z) 和觀察方向d= (θ, φ),輸出是顏色 c = (r, g, b) 和密度 σ。整個(gè)5D映射用MLP來近似。
![]()
優(yōu)化目標(biāo)是網(wǎng)絡(luò)權(quán)重 Θ。密度被假設(shè)為多視角一致的,顏色則同時(shí)取決于位置和觀察方向。
網(wǎng)絡(luò)結(jié)構(gòu)上先用8個(gè)全連接層處理空間位置,輸出密度σ和一個(gè)256維特征向量。這個(gè)特征再和觀察方向拼接,再經(jīng)過一個(gè)全連接層得到顏色。
體積渲染
光線參數(shù)化如下:
![]()
密度σ描述的是某點(diǎn)對(duì)光線的阻擋程度,可以理解為吸收概率。更嚴(yán)格地說它是光線在該點(diǎn)終止的微分概率。根據(jù)這個(gè)定義,光線從t傳播到t?的透射概率可以表示為:
![]()
σ和T之間的關(guān)系可以畫圖來理解:
![]()
密度升高時(shí)透射率下降。一旦透射率降到零,后面的東西就完全被遮住了,也就是看不見了。
光線的期望顏色C(r)定義如下,沿光線從近到遠(yuǎn)積分:
![]()
問題在于c和σ都來自神經(jīng)網(wǎng)絡(luò)這個(gè)積分沒有解析解。
實(shí)際計(jì)算時(shí)用數(shù)值積分,采用分層采樣策略——把積分范圍分成N個(gè)區(qū)間,每個(gè)區(qū)間均勻隨機(jī)抽一個(gè)點(diǎn)。
![]()
分層采樣保證MLP在整個(gè)優(yōu)化過程中都能在連續(xù)位置上被評(píng)估。采樣點(diǎn)通過求積公式計(jì)算C(t)這個(gè)公式選擇上考慮了可微性。跟純隨機(jī)采樣比方差更低。
T?是光線存活到第i個(gè)區(qū)間之前的概率。那光線在第i個(gè)區(qū)間內(nèi)終止的概率呢?可以用密度來算:
![]()
σ越大這個(gè)概率越趨近于零,再往下推導(dǎo):
![]()
光線顏色可以寫成:
![]()
其中:
![]()
位置編碼
直接拿5D坐標(biāo)訓(xùn)練MLP,高頻細(xì)節(jié)渲染不出來。因?yàn)樯疃染W(wǎng)絡(luò)天生偏好學(xué)習(xí)低頻信號(hào),解決辦法是用高頻函數(shù)把輸入映射到更高維空間。
![]()
γ對(duì)每個(gè)坐標(biāo)分別應(yīng)用,是個(gè)確定性函數(shù)沒有可學(xué)習(xí)參數(shù)。p歸一化到[-1,+1]。L=4時(shí)的編碼可視化:
![]()
L=4時(shí)的位置編碼示意
編碼用的是不同頻率的正弦函數(shù)。Transformer里也用類似的位置編碼但目的不同——Transformer是為了讓模型感知token順序,NeRF是為了注入高頻信息。
分層采樣
均勻采樣的問題在于大量計(jì)算浪費(fèi)在空曠區(qū)域。分層采樣的思路是訓(xùn)練兩個(gè)網(wǎng)絡(luò),一個(gè)粗糙一個(gè)精細(xì)。
先用粗糙網(wǎng)絡(luò)采樣評(píng)估一批點(diǎn),再根據(jù)結(jié)果用逆變換采樣在重要區(qū)域加密采樣。精細(xì)網(wǎng)絡(luò)用兩組樣本一起計(jì)算最終顏色。粗糙網(wǎng)絡(luò)的顏色可以寫成采樣顏色的加權(quán)和。
實(shí)現(xiàn)
每個(gè)場景單獨(dú)訓(xùn)練一個(gè)網(wǎng)絡(luò),只需要RGB圖像作為訓(xùn)練數(shù)據(jù)。每次迭代從所有像素里采樣一批光線,損失函數(shù)是粗糙和精細(xì)網(wǎng)絡(luò)預(yù)測值與真值之間的均方誤差。
![]()
接下來從零實(shí)現(xiàn)NeRF架構(gòu),在一個(gè)包含藍(lán)色立方體和紅色球體的簡單數(shù)據(jù)集上訓(xùn)練。
數(shù)據(jù)集生成代碼不在本文范圍內(nèi)——只涉及基礎(chǔ)幾何變換沒有NeRF特有的概念。
![]()
數(shù)據(jù)集里的一些渲染圖像。相機(jī)矩陣和坐標(biāo)也存在了JSON文件里。
![]()
先導(dǎo)入必要的庫:
import os, json, math
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
位置編碼函數(shù):
def positional_encoding(x, L):
freqs = (2.0 ** torch.arange(L, device=x.device)) * math.pi # Define the frequencies
xb = x[..., None, :] * freqs[:, None] # Multiply by the frequencies
xb = xb.reshape(*x.shape[:-1], L * 3) # Flatten the (x,y,z) coordinates
return torch.cat([torch.sin(xb), torch.cos(xb)], dim=-1)
根據(jù)相機(jī)參數(shù)生成光線:
def get_rays(H, W, camera_angle_x, c2w, device):
# assume the pinhole camera model
fx = 0.5 * W / math.tan(0.5 * camera_angle_x) # calculate the focal lengths (assume fx=fy)
# principal point of the camera or the optical center of the image.
cx = (W - 1) * 0.5
cy = (H - 1) * 0.5
i, j = torch.meshgrid(torch.arange(W, device=device),
torch.arange(H, device=device), indexing="xy")
i, j = i.float(), j.float()
# convert pixels to normalized camera-plane coordinates
x = (i - cx) / fx
y = -(j - cy) / fx
z = -torch.ones_like(x)
# pack into 3D directions and normalize
dirs = torch.stack([x, y, z], dim=-1)
dirs = dirs / torch.norm(dirs, dim=-1, keepdim=True)
# rotate rays into world coordinates using pose matrix
R, t = c2w[:3, :3], c2w[:3, 3]
rd = dirs @ R.T
ro = t.expand_as(rd)
return ro, rd
NeRF網(wǎng)絡(luò)結(jié)構(gòu):
class NeRF(nn.Module):
def __init__(self, L_pos=10, L_dir=4, hidden=256):
super().__init__()
# original vector is concatented with the fourier features
in_pos = 3 + 2 * L_pos * 3
in_dir = 3 + 2 * L_dir * 3
self.fc1 = nn.Linear(in_pos, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.fc3 = nn.Linear(hidden, hidden)
self.fc4 = nn.Linear(hidden, hidden)
self.fc5 = nn.Linear(hidden + in_pos, hidden)
self.fc6 = nn.Linear(hidden, hidden)
self.fc7 = nn.Linear(hidden, hidden)
self.fc8 = nn.Linear(hidden, hidden)
self.sigma = nn.Linear(hidden, 1)
self.feat = nn.Linear(hidden, hidden)
self.rgb1 = nn.Linear(hidden + in_dir, 128)
self.rgb2 = nn.Linear(128, 3)
self.L_pos, self.L_dir = L_pos, L_dir
def forward(self, x, d):
x_enc = torch.cat([x, positional_encoding(x, self.L_pos)], dim=-1)
d_enc = torch.cat([d, positional_encoding(d, self.L_dir)], dim=-1)
h = F.relu(self.fc1(x_enc))
h = F.relu(self.fc2(h))
h = F.relu(self.fc3(h))
h = F.relu(self.fc4(h))
h = torch.cat([h, x_enc], dim=-1) # skip connection
h = F.relu(self.fc5(h))
h = F.relu(self.fc6(h))
h = F.relu(self.fc7(h))
h = F.relu(self.fc8(h))
sigma = F.relu(self.sigma(h)) # density is calculated using positional information
feat = self.feat(h)
h = torch.cat([feat, d_enc], dim=-1) # add directional information for color
h = F.relu(self.rgb1(h))
rgb = torch.sigmoid(self.rgb2(h))
return rgb, sigma
渲染函數(shù),這個(gè)是整個(gè)流程的核心:
def render_rays(model, ro, rd, near=2.0, far=6.0, N=64):
# sample along the ray
t = torch.linspace(near, far, N, device=ro.device)
pts = ro[:, None, :] + rd[:, None, :] * t[None, :, None] # r = o + td
# attach view directions to each sample
# each point knows where the ray comes from
dirs = rd[:, None, :].expand_as(pts)
# query NeRF at each point and reshape
rgb, sigma = model(pts.reshape(-1,3), dirs.reshape(-1,3))
rgb = rgb.reshape(ro.shape[0], N, 3)
sigma = sigma.reshape(ro.shape[0], N)
# compute the distance between the samples
delta = t[1:] - t[:-1]
delta = torch.cat([delta, torch.tensor([1e10], device=ro.device)])
# convert density into opacity
alpha = 1 - torch.exp(-sigma * delta)
# compute transmittance along the ray
T = torch.cumprod(torch.cat([torch.ones((ro.shape[0],1), device=ro.device),
1 - alpha + 1e-10], dim=-1), dim=-1)[:, :-1]
weights = T * alpha
return (weights[...,None] * rgb).sum(dim=1) # accumulate the colors
訓(xùn)練循環(huán):
device = "cuda" if torch.cuda.is_available() else "cpu"
images, c2ws, H, W, fov = load_dataset("nerf_synth_cube_sphere")
images, c2ws = images.to(device), c2ws.to(device)
model = NeRF().to(device)
opt = torch.optim.Adam(model.parameters(), lr=5e-4)
loss_hist, psnr_hist, iters = [], [], []
for it in range(1, 5001):
idx = torch.randint(0, images.shape[0], (1,)).item()
ro, rd = get_rays(H, W, fov, c2ws[idx], device)
gt = images[idx].reshape(-1,3)
sel = torch.randint(0, ro.numel()//3, (2048,), device=device)
pred = render_rays(model, ro.reshape(-1,3)[sel], rd.reshape(-1,3)[sel])
# for simplicity, we will only implement the coarse sampling.
loss = F.mse_loss(pred, gt[sel])
opt.zero_grad()
loss.backward()
opt.step()
if it % 200 == 0:
psnr = -10 * torch.log10(loss).item()
loss_hist.append(loss.item())
psnr_hist.append(psnr)
iters.append(it)
print(f"Iter {it} | Loss {loss.item():.6f} | PSNR {psnr:.2f} dB")
torch.save(model.state_dict(), "nerf_cube_sphere_coarse.pth")
# ---- Plots ----
plt.figure()
plt.plot(iters, loss_hist, color='red', lw=5)
plt.title("Training Loss")
plt.show()
plt.figure()
plt.plot(iters, psnr_hist, color='black', lw=5)
plt.title("Training PSNR")
plt.show()
迭代次數(shù)與PSNR、損失值的變化曲線:
![]()
模型訓(xùn)練完成下一步是生成新視角。
look_at函數(shù)用于從指定相機(jī)位置構(gòu)建位姿矩陣:
def look_at(eye):
eye = torch.tensor(eye, dtype=torch.float32) # where the camera is
target = torch.tensor([0.0, 0.0, 0.0])
up = torch.tensor([0,1,0], dtype=torch.float32) # which direction is "up" in the world
f = (target - eye); f /= torch.norm(f) # forward direction of the camera
r = torch.cross(f, up); r /= torch.norm(r) # right direction. use cross product between f and up
u = torch.cross(r, f) # true camera up direction
c2w = torch.eye(4)
c2w[:3,0], c2w[:3,1], c2w[:3,2], c2w[:3,3] = r, u, -f, eye
return c2w
推理代碼:
device = "cuda" if torch.cuda.is_available() else "cpu"
with open("nerf_synth_cube_sphere/transforms.json") as f:
meta = json.load(f)
H, W, fov = meta["h"], meta["w"], meta["camera_angle_x"]
model = NeRF().to(device)
model.load_state_dict(torch.load("nerf_cube_sphere_coarse.pth", map_location=device))
model.eval()
os.makedirs("novel_views", exist_ok=True)
for i in range(120):
angle = 2 * math.pi * i / 120
eye = [4 * math.cos(angle), 1.0, 4 * math.sin(angle)]
c2w = look_at(eye).to(device)
with torch.no_grad():
ro, rd = get_rays(H, W, fov, c2w, device)
rgb = render_rays(model, ro.reshape(-1,3), rd.reshape(-1,3))
img = rgb.reshape(H, W, 3).clamp(0,1).cpu().numpy()
Image.fromarray((img*255).astype(np.uint8)).save(f"novel_views/view_{i:03d}.png")
print("Rendered view", i)
新視角渲染結(jié)果(訓(xùn)練集中沒有這些角度):
![]()
圖中的偽影——椒鹽噪聲、條紋、浮動(dòng)的亮點(diǎn)——來自空曠區(qū)域的密度估計(jì)誤差。只用粗糙模型、不做精細(xì)采樣的情況下這些問題會(huì)更明顯。另外場景里大片空白區(qū)域也是個(gè)麻煩,模型不得不花大量計(jì)算去探索這些沒什么內(nèi)容的地方。
再看看深度圖:
![]()
立方體的平面捕捉得相當(dāng)準(zhǔn)確沒有幽靈表面。空曠區(qū)域有些斑點(diǎn)噪聲說明雖然空白區(qū)域整體學(xué)得還行,但稀疏性還是帶來了一些小誤差。
參考文獻(xiàn)
Mildenhall, B., Srinivasan, P. P., Gharbi, M., Tancik, M., Barron, J. T., Simonyan, K., Abbeel, P., & Malik, J. (2020). NeRF: Representing scenes as neural radiance fields for view synthesis.
https://avoid.overfit.cn/post/4a1b21ea7d754b81b875928c95a45856
作者:Kavishka Abeywardana
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.