NeRFとGaussian Splatting~NeRFの実装編
先日、NeRFと3D Gaussian Splattingについてアルゴリズム概要を書きました
今回はNeRFについてもう少し実装面を取り上げてみます
論文
NeRFの親戚はたくさんありますが、提案論文のこちらを見ていきます
コード
論文のサポートGitHubレポジトリにtiny_nerfという、Google Colabで動く軽量なNeRFが公開されています
論文の理解には都合が良いので、こちらを読んでいきましょう
モデルの初期化 - init_model
8層の全結合層+Reluで構成されています
NeRFの提案論文はもともとそんなに凝ったモデルは使っていないので、このような感じです
入力はカメラに入射した光線が通った各点の座標です
出力は各点が持つ色(RGB三次元)と吸収率で計四次元のベクトル(をpositoinal encodingしたもの)です
def init_model(D=8, W=256):
'''
8層のパーセプトロン、出力は4値
'''
relu = tf.keras.layers.ReLU()
dense = lambda W=W, act=relu : tf.keras.layers.Dense(W, activation=act)
inputs = tf.keras.Input(shape=(3 + 3*2*L_embed))
outputs = inputs
for i in range(D):
outputs = dense()(outputs)
if i%4==0 and i>0:
outputs = tf.concat([outputs, inputs], -1)
outputs = dense(4, act=None)(outputs) #\vec{c}, \sigma
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
光線の取得 - get_rays
論文中にo+tdという形式で書かれている光線ベクトルの定義ですね
初見だと何を言っているかわからないところもありますが、透視投影モデルであらわしたカメラ姿勢の式がベースになっています
こちらの教科書が参考になると思います
def get_rays(H, W, focal, c2w):
'''
H, W: 画像サイズ
c2w: pose, camera2
rays_o : coordinates in image plane
rays_d : viewing direction
'''
i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy')
dirs = tf.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -tf.ones_like(i)], -1) #カメラ全体の回転をc2wが持っている 、xk_x + yk_yと基本発想は同じ
rays_d = tf.reduce_sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
rays_o = tf.broadcast_to(c2w[:3,-1], tf.shape(rays_d))
return rays_o, rays_d
ボリュームレンダリング - render_rays
ざっくりいうと
1. 光線を計算
2. 学習するネットワークで推論
3. ボリュームレンダリング
を繰り返しています
def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, rand=False):
def batchify(fn, chunk=1024*32):
return lambda inputs : tf.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
# Compute 3D query points
z_vals = tf.linspace(near, far, N_samples)
if rand:
z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
# Run network
pts_flat = tf.reshape(pts, [-1,3])
pts_flat = embed_fn(pts_flat)
raw = batchify(network_fn)(pts_flat)
raw = tf.reshape(raw, list(pts.shape[:-1]) + [4])
# Compute opacities and colors
sigma_a = tf.nn.relu(raw[...,3])
rgb = tf.math.sigmoid(raw[...,:3])
# Do volume rendering
dists = tf.concat([z_vals[..., 1:] - z_vals[..., :-1], tf.broadcast_to([1e10], z_vals[...,:1].shape)], -1)
alpha = 1.-tf.exp(-sigma_a * dists)
weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
rgb_map = tf.reduce_sum(weights[...,None] * rgb, -2)
depth_map = tf.reduce_sum(weights * z_vals, -1)
acc_map = tf.reduce_sum(weights, -1)
return rgb_map, depth_map, acc_map
モデルのtrain
これらをもとにモデルを学習します
次の良な感じですね
1. カメラが取得した画像とそれに対応づくカメラの姿勢を取ってくる
2. カメラの姿勢に紐づくカメラに入射した光線をシミュレート
3. その光線をもとにボリュームレンダリングする
4. 2~3に対して勾配を計算し、ネットワークを最適化
model = init_model()
optimizer = tf.keras.optimizers.Adam(5e-4)
N_samples = 64
N_iters = 1000
psnrs = []
iternums = []
i_plot = 25
import time
t = time.time()
for i in range(N_iters+1):
img_i = np.random.randint(images.shape[0])
target = images[img_i]
pose = poses[img_i]
rays_o, rays_d = get_rays(H, W, focal, pose) # \vec{o]+ t\vec{d}
with tf.GradientTape() as tape:
rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples, rand=True)
loss = tf.reduce_mean(tf.square(rgb - target))
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
if i%i_plot==0:
print(i, (time.time() - t) / i_plot, 'secs per iter')
t = time.time()
# Render the holdout view for logging
rays_o, rays_d = get_rays(H, W, focal, testpose)
rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
loss = tf.reduce_mean(tf.square(rgb - testimg))
psnr = -10. * tf.math.log(loss) / tf.math.log(10.)
psnrs.append(psnr.numpy())
iternums.append(i)
plt.figure(figsize=(10,4))
plt.subplot(121)
plt.imshow(rgb)
plt.title(f'Iteration: {i}')
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title('PSNR')
plt.show()
print('Done')
最終的にはこんな出力が出てくるはずです:
(補足) positional encodingについて
前回説明しなかったので、positional encodingについて説明します
画像を出力するようなニューラルネットワークでは特に、ニューラルネットワークを最適化しても出力像の鮮明さが失われるということが良く起こります
出力像 yの座標(i, j)の画素にはx(i, j)だけでなくx(i-5, j-5)など近傍の画素からも情報が入り込んでしまうためですね
これを回避する手段はいくつかあるのですが、その一つにpositional encodingを挙げられます
positional encodingでは出力の(i, j)の画素に周囲からの情報が混ざらないよう、画素ごとに周波数の異なる正弦波を掛け算しておきます
周波数の異なる正弦波は互いに打ち消しあいますので、先ほどの例に即していうと出力画像のy(i, j)の画素に対して、入力画像のx(i, j)の画素が強く影響するように調整できる…というトリックです
今回は光線ベクトルをネットワークに入れる直前にpositional encodingをしています
上記例とは若干状況が異なりますが、同様の効果が得られると期待できます
(光線の進行方向に隣接している画素間の情報の混ざりあいが抑制できるので、結果的に出力モデルの先鋭化に寄与するはずです)
この記事が気に入ったらサポートをしてみませんか?