codestella
commited on
Commit
·
97ec1af
1
Parent(s):
3c9f729
code change
Browse files- .gitattributes +0 -17
- LICENSE +0 -0
- __init__.py +0 -0
- assets/lego-nerf.gif +0 -0
- configs/blender.yaml +0 -0
- configs/demo.yaml +0 -0
- configs/diet_nerf_tpu_vm_4shot.yaml +2 -1
- configs/diet_nerf_tpu_vm_few_shot.yaml +2 -1
- configs/diet_nerf_tpu_vm_test.yaml +3 -2
- configs/eval_diet_nerf_tpu_vm_few_shot.yaml +0 -0
- configs/nerf_tpu_vm_4shot.yaml +0 -0
- configs/nerf_tpu_vm_few_shot.yaml +0 -0
- configs/orig_nerf_tpu_vm_full.yaml +0 -0
- configs/orig_nerf_tpu_vm_test.yaml +0 -0
- eval.py +18 -9
- eval.sh +0 -0
- example_data/imgs/r_0.png +0 -0
- example_data/transforms_test.json +0 -0
- example_data/transforms_train.json +0 -0
- fork-of-first-touch-of-nerf-in-jax.ipynb +0 -0
- nerf/__init__.py +0 -0
- nerf/__pycache__/__init__.cpython-37.pyc +0 -0
- nerf/__pycache__/clip_utils.cpython-37.pyc +0 -0
- nerf/__pycache__/datasets.cpython-37.pyc +0 -0
- nerf/__pycache__/model_utils.cpython-37.pyc +0 -0
- nerf/__pycache__/models.cpython-37.pyc +0 -0
- nerf/__pycache__/utils.cpython-37.pyc +0 -0
- nerf/clip_utils.py +17 -23
- nerf/datasets.py +15 -9
- nerf/model_utils.py +0 -0
- nerf/models.py +2 -3
- nerf/utils.py +4 -2
- requirements.txt +0 -0
- run.sh +0 -0
- train.py +9 -21
- train.sh +0 -0
.gitattributes
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
CHANGED
|
File without changes
|
__init__.py
CHANGED
|
File without changes
|
assets/lego-nerf.gif
DELETED
|
Binary file (519 kB)
|
|
|
configs/blender.yaml
CHANGED
|
File without changes
|
configs/demo.yaml
CHANGED
|
File without changes
|
configs/diet_nerf_tpu_vm_4shot.yaml
CHANGED
|
@@ -8,8 +8,9 @@ white_bkgd: true
|
|
| 8 |
batch_size: 1024
|
| 9 |
randomized: true
|
| 10 |
max_steps: 200000
|
|
|
|
| 11 |
print_every: 100
|
| 12 |
-
render_every:
|
| 13 |
save_every: 5000
|
| 14 |
use_semantic_loss: true
|
| 15 |
clip_model_name: openai/clip-vit-base-patch32
|
|
|
|
| 8 |
batch_size: 1024
|
| 9 |
randomized: true
|
| 10 |
max_steps: 200000
|
| 11 |
+
stop_sc_loss: 160000
|
| 12 |
print_every: 100
|
| 13 |
+
render_every: 1000
|
| 14 |
save_every: 5000
|
| 15 |
use_semantic_loss: true
|
| 16 |
clip_model_name: openai/clip-vit-base-patch32
|
configs/diet_nerf_tpu_vm_few_shot.yaml
CHANGED
|
@@ -8,8 +8,9 @@ white_bkgd: true
|
|
| 8 |
batch_size: 1024
|
| 9 |
randomized: true
|
| 10 |
max_steps: 200000
|
|
|
|
| 11 |
print_every: 100
|
| 12 |
-
render_every:
|
| 13 |
save_every: 5000
|
| 14 |
use_semantic_loss: true
|
| 15 |
clip_model_name: openai/clip-vit-base-patch32
|
|
|
|
| 8 |
batch_size: 1024
|
| 9 |
randomized: true
|
| 10 |
max_steps: 200000
|
| 11 |
+
stop_sc_loss: 160000
|
| 12 |
print_every: 100
|
| 13 |
+
render_every: 1000
|
| 14 |
save_every: 5000
|
| 15 |
use_semantic_loss: true
|
| 16 |
clip_model_name: openai/clip-vit-base-patch32
|
configs/diet_nerf_tpu_vm_test.yaml
CHANGED
|
@@ -2,12 +2,13 @@ dataset: blender
|
|
| 2 |
batching: single_image
|
| 3 |
factor: 0
|
| 4 |
num_coarse_samples: 64
|
| 5 |
-
num_fine_samples:
|
| 6 |
use_viewdirs: true
|
| 7 |
white_bkgd: true
|
| 8 |
-
batch_size:
|
| 9 |
randomized: true
|
| 10 |
max_steps: 200000
|
|
|
|
| 11 |
print_every: 100
|
| 12 |
render_every: 1000
|
| 13 |
save_every: 5000
|
|
|
|
| 2 |
batching: single_image
|
| 3 |
factor: 0
|
| 4 |
num_coarse_samples: 64
|
| 5 |
+
num_fine_samples: 128
|
| 6 |
use_viewdirs: true
|
| 7 |
white_bkgd: true
|
| 8 |
+
batch_size: 1024
|
| 9 |
randomized: true
|
| 10 |
max_steps: 200000
|
| 11 |
+
stop_sc_loss: 160000
|
| 12 |
print_every: 100
|
| 13 |
render_every: 1000
|
| 14 |
save_every: 5000
|
configs/eval_diet_nerf_tpu_vm_few_shot.yaml
CHANGED
|
File without changes
|
configs/nerf_tpu_vm_4shot.yaml
CHANGED
|
File without changes
|
configs/nerf_tpu_vm_few_shot.yaml
CHANGED
|
File without changes
|
configs/orig_nerf_tpu_vm_full.yaml
CHANGED
|
File without changes
|
configs/orig_nerf_tpu_vm_test.yaml
CHANGED
|
File without changes
|
eval.py
CHANGED
|
@@ -112,30 +112,39 @@ def main(unused_argv):
|
|
| 112 |
summary_writer = tensorboard.SummaryWriter(
|
| 113 |
path.join(FLAGS.train_dir, "eval"))
|
| 114 |
|
| 115 |
-
def generate_spinning_gif(radius, phi,
|
| 116 |
_rng = random.PRNGKey(0)
|
| 117 |
partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
|
| 118 |
gif_images = []
|
|
|
|
| 119 |
for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
|
| 120 |
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
|
| 121 |
rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
|
| 122 |
_rng, key0, key1 = random.split(_rng, 3)
|
| 123 |
-
color,
|
| 124 |
_rng, False, chunk=4096)
|
| 125 |
image = predict_to_image(color)
|
|
|
|
| 126 |
gif_images.append(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
gif_images[0].save(gif_fn, save_all=True,
|
| 128 |
append_images=gif_images,
|
| 129 |
duration=100, loop=0)
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
if FLAGS.generate_gif_only:
|
| 133 |
print('generate GIF file only')
|
| 134 |
_radius = 4.
|
| 135 |
_phi = (30 * math.pi) / 180
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
print(f'GIF file for spinning views written: {_gif_fn}')
|
| 139 |
return
|
| 140 |
else:
|
| 141 |
print('generate GIF file AND evaluate model performance')
|
|
@@ -149,6 +158,7 @@ def main(unused_argv):
|
|
| 149 |
utils.makedirs(out_dir)
|
| 150 |
psnr_values = []
|
| 151 |
ssim_values = []
|
|
|
|
| 152 |
#lpips_values = []
|
| 153 |
if not FLAGS.eval_once:
|
| 154 |
showcase_index = np.random.randint(0, dataset.size)
|
|
@@ -225,9 +235,8 @@ def main(unused_argv):
|
|
| 225 |
if not is_gif_written:
|
| 226 |
_radius = 4.
|
| 227 |
_phi = (30 * math.pi) / 180
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
print(f'GIF file for spinning views written: {_gif_fn}')
|
| 231 |
is_gif_written = True
|
| 232 |
|
| 233 |
if FLAGS.eval_once:
|
|
|
|
| 112 |
summary_writer = tensorboard.SummaryWriter(
|
| 113 |
path.join(FLAGS.train_dir, "eval"))
|
| 114 |
|
| 115 |
+
def generate_spinning_gif(radius, phi, output_dir, frame_n):
|
| 116 |
_rng = random.PRNGKey(0)
|
| 117 |
partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
|
| 118 |
gif_images = []
|
| 119 |
+
gif_images2 = []
|
| 120 |
for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
|
| 121 |
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
|
| 122 |
rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
|
| 123 |
_rng, key0, key1 = random.split(_rng, 3)
|
| 124 |
+
color, disp, _ = utils.render_image(partial_render_fn, rays,
|
| 125 |
_rng, False, chunk=4096)
|
| 126 |
image = predict_to_image(color)
|
| 127 |
+
image2 = predict_to_image(disp[Ellipsis, 0])
|
| 128 |
gif_images.append(image)
|
| 129 |
+
gif_images2.append(image2)
|
| 130 |
+
|
| 131 |
+
gif_fn = os.path.join(output_dir, 'rgb_spinning.gif')
|
| 132 |
+
gif_fn2 = os.path.join(output_dir, 'disp_spinning.gif')
|
| 133 |
gif_images[0].save(gif_fn, save_all=True,
|
| 134 |
append_images=gif_images,
|
| 135 |
duration=100, loop=0)
|
| 136 |
+
gif_images2[0].save(gif_fn2, save_all=True,
|
| 137 |
+
append_images=gif_images2,
|
| 138 |
+
duration=100, loop=0)
|
| 139 |
+
|
| 140 |
+
#return gif_images, gif_images2
|
| 141 |
|
| 142 |
if FLAGS.generate_gif_only:
|
| 143 |
print('generate GIF file only')
|
| 144 |
_radius = 4.
|
| 145 |
_phi = (30 * math.pi) / 180
|
| 146 |
+
generate_spinning_gif(_radius, _phi, out_dir, frame_n=30)
|
| 147 |
+
print('GIF file for spinning views written)')
|
|
|
|
| 148 |
return
|
| 149 |
else:
|
| 150 |
print('generate GIF file AND evaluate model performance')
|
|
|
|
| 158 |
utils.makedirs(out_dir)
|
| 159 |
psnr_values = []
|
| 160 |
ssim_values = []
|
| 161 |
+
|
| 162 |
#lpips_values = []
|
| 163 |
if not FLAGS.eval_once:
|
| 164 |
showcase_index = np.random.randint(0, dataset.size)
|
|
|
|
| 235 |
if not is_gif_written:
|
| 236 |
_radius = 4.
|
| 237 |
_phi = (30 * math.pi) / 180
|
| 238 |
+
generate_spinning_gif(_radius, _phi, out_dir, frame_n=30)
|
| 239 |
+
print(f'GIF file for spinning views written')
|
|
|
|
| 240 |
is_gif_written = True
|
| 241 |
|
| 242 |
if FLAGS.eval_once:
|
eval.sh
CHANGED
|
File without changes
|
example_data/imgs/r_0.png
CHANGED
|
|
example_data/transforms_test.json
CHANGED
|
File without changes
|
example_data/transforms_train.json
CHANGED
|
File without changes
|
fork-of-first-touch-of-nerf-in-jax.ipynb
CHANGED
|
File without changes
|
nerf/__init__.py
CHANGED
|
File without changes
|
nerf/__pycache__/__init__.cpython-37.pyc
DELETED
|
Binary file (137 Bytes)
|
|
|
nerf/__pycache__/clip_utils.cpython-37.pyc
DELETED
|
Binary file (5.16 kB)
|
|
|
nerf/__pycache__/datasets.cpython-37.pyc
DELETED
|
Binary file (18.3 kB)
|
|
|
nerf/__pycache__/model_utils.cpython-37.pyc
DELETED
|
Binary file (10 kB)
|
|
|
nerf/__pycache__/models.cpython-37.pyc
DELETED
|
Binary file (5.08 kB)
|
|
|
nerf/__pycache__/utils.cpython-37.pyc
DELETED
|
Binary file (15.8 kB)
|
|
|
nerf/clip_utils.py
CHANGED
|
@@ -15,50 +15,44 @@ FLAGS = flags.FLAGS
|
|
| 15 |
|
| 16 |
@partial(jax.jit, static_argnums=[0])
|
| 17 |
def semantic_loss(clip_model, src_image, target_embedding):
|
| 18 |
-
|
| 19 |
-
f_image = utils.unshard(src_image[
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
#c_image = c_image.reshape([w, w, 3])
|
| 23 |
f_image = f_image.reshape([w, w, 3])
|
| 24 |
-
|
| 25 |
-
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.
|
| 26 |
-
#src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
|
| 27 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
| 28 |
-
sc_loss =
|
| 29 |
return sc_loss, f_image
|
| 30 |
|
| 31 |
def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
|
| 32 |
-
random_rays =
|
| 33 |
-
target_embedding = batch["embedding"]
|
| 34 |
rng, key_0, key_1 = random.split(rng,3)
|
| 35 |
-
|
| 36 |
def loss_fn(variables):
|
| 37 |
-
|
| 38 |
-
sc_loss, f_image = semantic_loss(clip_model,
|
| 39 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
| 40 |
(sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
| 41 |
return sc_loss, grad, src_image
|
| 42 |
|
| 43 |
@partial(jax.jit, static_argnums=[0, 1])
|
| 44 |
def semantic_step_single(model, clip_model, rng, state, batch, lr):
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
random_rays = batch["random_rays"]
|
| 48 |
rng, key_0, key_1 = random.split(rng,3)
|
| 49 |
|
| 50 |
def semantic_loss(variables):
|
| 51 |
c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
|
| 52 |
-
# reshape flat pixel to an image (assume 3 channels & square shape)
|
| 53 |
w = int(math.sqrt(f_image.shape[0]))
|
| 54 |
-
|
| 55 |
f_image = f_image.reshape([w, w, 3])
|
| 56 |
|
| 57 |
-
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.
|
| 58 |
-
# src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
|
| 59 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
| 60 |
-
|
| 61 |
-
sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding)**2)
|
| 62 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
| 63 |
(sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
| 64 |
return sc_loss, grad, src_image
|
|
|
|
| 15 |
|
| 16 |
@partial(jax.jit, static_argnums=[0])
|
| 17 |
def semantic_loss(clip_model, src_image, target_embedding):
|
| 18 |
+
c_image = utils.unshard(src_image[0])
|
| 19 |
+
f_image = utils.unshard(src_image[1])
|
| 20 |
+
w = int(math.sqrt(f_image.shape[0]))
|
| 21 |
+
c_image = c_image.reshape([w, w, 3])
|
|
|
|
| 22 |
f_image = f_image.reshape([w, w, 3])
|
| 23 |
+
|
| 24 |
+
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
|
|
|
|
| 25 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
| 26 |
+
sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
|
| 27 |
return sc_loss, f_image
|
| 28 |
|
| 29 |
def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
|
| 30 |
+
random_rays = batch["random_rays"]
|
| 31 |
+
target_embedding = batch["embedding"]
|
| 32 |
rng, key_0, key_1 = random.split(rng,3)
|
| 33 |
+
|
| 34 |
def loss_fn(variables):
|
| 35 |
+
images = render_pfn(variables, key_0, key_1, random_rays)
|
| 36 |
+
sc_loss, f_image = semantic_loss(clip_model, images, target_embedding)
|
| 37 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
| 38 |
(sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
| 39 |
return sc_loss, grad, src_image
|
| 40 |
|
| 41 |
@partial(jax.jit, static_argnums=[0, 1])
|
| 42 |
def semantic_step_single(model, clip_model, rng, state, batch, lr):
|
| 43 |
+
random_rays = jax.tree_map(lambda x: x.reshape(-1,3), batch["random_rays"])
|
| 44 |
+
target_embedding = batch["embedding"]
|
|
|
|
| 45 |
rng, key_0, key_1 = random.split(rng,3)
|
| 46 |
|
| 47 |
def semantic_loss(variables):
|
| 48 |
c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
|
|
|
|
| 49 |
w = int(math.sqrt(f_image.shape[0]))
|
| 50 |
+
c_image = c_image.reshape([w, w, 3])
|
| 51 |
f_image = f_image.reshape([w, w, 3])
|
| 52 |
|
| 53 |
+
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
|
|
|
|
| 54 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
| 55 |
+
sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
|
|
|
|
| 56 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
| 57 |
(sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
| 58 |
return sc_loss, grad, src_image
|
nerf/datasets.py
CHANGED
|
@@ -236,6 +236,7 @@ class Blender(Dataset):
|
|
| 236 |
camera_angle_x = float(meta["camera_angle_x"])
|
| 237 |
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
|
| 238 |
self.n_examples = self.images.shape[0]
|
|
|
|
| 239 |
|
| 240 |
if flags.use_semantic_loss and clip_model is not None:
|
| 241 |
embs = []
|
|
@@ -258,8 +259,8 @@ class Blender(Dataset):
|
|
| 258 |
|
| 259 |
frames = np.arange(len(meta["frames"]))
|
| 260 |
if few_shot > 0 and split == 'train':
|
| 261 |
-
np.random.seed(0)
|
| 262 |
-
np.random.shuffle(frames)
|
| 263 |
frames = frames[:few_shot]
|
| 264 |
|
| 265 |
# if split == 'train':
|
|
@@ -308,16 +309,21 @@ class Blender(Dataset):
|
|
| 308 |
src_seed = int(time.time())
|
| 309 |
src_rng = jax.random.PRNGKey(src_seed)
|
| 310 |
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
|
| 311 |
-
|
| 312 |
-
cx = np.random.randint(
|
| 313 |
-
cy = np.random.randint(
|
| 314 |
-
d =
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
| 316 |
w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
|
| 317 |
random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
|
| 318 |
-
batch_dict["random_rays"] = random_rays
|
|
|
|
|
|
|
| 319 |
return batch_dict
|
| 320 |
-
|
| 321 |
class LLFF(Dataset):
|
| 322 |
"""LLFF Dataset."""
|
| 323 |
|
|
|
|
| 236 |
camera_angle_x = float(meta["camera_angle_x"])
|
| 237 |
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
|
| 238 |
self.n_examples = self.images.shape[0]
|
| 239 |
+
self.dtype = flags.clip_output_dtype
|
| 240 |
|
| 241 |
if flags.use_semantic_loss and clip_model is not None:
|
| 242 |
embs = []
|
|
|
|
| 259 |
|
| 260 |
frames = np.arange(len(meta["frames"]))
|
| 261 |
if few_shot > 0 and split == 'train':
|
| 262 |
+
# np.random.seed(0)
|
| 263 |
+
# np.random.shuffle(frames)
|
| 264 |
frames = frames[:few_shot]
|
| 265 |
|
| 266 |
# if split == 'train':
|
|
|
|
| 309 |
src_seed = int(time.time())
|
| 310 |
src_rng = jax.random.PRNGKey(src_seed)
|
| 311 |
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
|
| 312 |
+
|
| 313 |
+
cx = np.random.randint(320, 480)
|
| 314 |
+
cy = np.random.randint(320, 480)
|
| 315 |
+
d = 140
|
| 316 |
+
|
| 317 |
+
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 1)
|
| 318 |
+
random_rays = jax.tree_map(lambda x: x[cy-d:cy+d:4,cx-d:cx+d:4], random_rays)
|
| 319 |
+
|
| 320 |
w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
|
| 321 |
random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
|
| 322 |
+
batch_dict["random_rays"] = utils.shard(random_rays)
|
| 323 |
+
if self.dtype == 'float16':
|
| 324 |
+
batch_dict = jax.tree_map(lambda x: x.astype(np.float16), batch_dict)
|
| 325 |
return batch_dict
|
| 326 |
+
|
| 327 |
class LLFF(Dataset):
|
| 328 |
"""LLFF Dataset."""
|
| 329 |
|
nerf/model_utils.py
CHANGED
|
File without changes
|
nerf/models.py
CHANGED
|
@@ -136,7 +136,7 @@ class NerfModel(nn.Module):
|
|
| 136 |
(comp_rgb, disp, acc),
|
| 137 |
]
|
| 138 |
|
| 139 |
-
if self.num_fine_samples > 0
|
| 140 |
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
| 141 |
key, rng_1 = random.split(rng_1)
|
| 142 |
|
|
@@ -191,8 +191,7 @@ class NerfModel(nn.Module):
|
|
| 191 |
)
|
| 192 |
ret.append((comp_rgb, disp, acc))
|
| 193 |
if rgb_only:
|
| 194 |
-
|
| 195 |
-
return [None, ret[0][0]]
|
| 196 |
return ret
|
| 197 |
|
| 198 |
def construct_nerf(key, example_batch, args):
|
|
|
|
| 136 |
(comp_rgb, disp, acc),
|
| 137 |
]
|
| 138 |
|
| 139 |
+
if self.num_fine_samples > 0:
|
| 140 |
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
| 141 |
key, rng_1 = random.split(rng_1)
|
| 142 |
|
|
|
|
| 191 |
)
|
| 192 |
ret.append((comp_rgb, disp, acc))
|
| 193 |
if rgb_only:
|
| 194 |
+
return [ret[0][0], ret[1][0]]
|
|
|
|
| 195 |
return ret
|
| 196 |
|
| 197 |
def construct_nerf(key, example_batch, args):
|
nerf/utils.py
CHANGED
|
@@ -66,11 +66,11 @@ def define_flags():
|
|
| 66 |
flags.DEFINE_bool("use_semantic_loss", True,
|
| 67 |
"whether use semantic loss or not")
|
| 68 |
flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
|
| 69 |
-
flags.DEFINE_string("clip_output_dtype", "
|
| 70 |
"float32/ float16 (float16 for memory saving)")
|
| 71 |
flags.DEFINE_integer("sc_loss_every", 16,
|
| 72 |
"no. of steps to take before performing semantic loss evaluation")
|
| 73 |
-
flags.DEFINE_float("sc_loss_mult", 1e-
|
| 74 |
"weighting for semantic loss from CLIP")
|
| 75 |
|
| 76 |
# Dataset Flags
|
|
@@ -166,6 +166,8 @@ def define_flags():
|
|
| 166 |
|
| 167 |
flags.DEFINE_integer("max_steps", 1000000,
|
| 168 |
"the number of optimization steps.")
|
|
|
|
|
|
|
| 169 |
flags.DEFINE_integer("save_every", 10000,
|
| 170 |
"the number of steps to save a checkpoint.")
|
| 171 |
flags.DEFINE_integer("print_every", 100,
|
|
|
|
| 66 |
flags.DEFINE_bool("use_semantic_loss", True,
|
| 67 |
"whether use semantic loss or not")
|
| 68 |
flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
|
| 69 |
+
flags.DEFINE_string("clip_output_dtype", "float16",
|
| 70 |
"float32/ float16 (float16 for memory saving)")
|
| 71 |
flags.DEFINE_integer("sc_loss_every", 16,
|
| 72 |
"no. of steps to take before performing semantic loss evaluation")
|
| 73 |
+
flags.DEFINE_float("sc_loss_mult", 1e-2,
|
| 74 |
"weighting for semantic loss from CLIP")
|
| 75 |
|
| 76 |
# Dataset Flags
|
|
|
|
| 166 |
|
| 167 |
flags.DEFINE_integer("max_steps", 1000000,
|
| 168 |
"the number of optimization steps.")
|
| 169 |
+
flags.DEFINE_integer("stop_sc_loss", 1000000,
|
| 170 |
+
"the number of sc_loss optimization steps")
|
| 171 |
flags.DEFINE_integer("save_every", 10000,
|
| 172 |
"the number of steps to save a checkpoint.")
|
| 173 |
flags.DEFINE_integer("print_every", 100,
|
requirements.txt
CHANGED
|
File without changes
|
run.sh
CHANGED
|
File without changes
|
train.py
CHANGED
|
@@ -50,7 +50,6 @@ print(f"detected device: {jax.local_devices()}")
|
|
| 50 |
|
| 51 |
|
| 52 |
def train_step(model, clip_model, rng, state, batch, lr, step, K,):
|
| 53 |
-
# TODO make clip_grad input enable
|
| 54 |
"""One optimization step.
|
| 55 |
|
| 56 |
Args:
|
|
@@ -102,7 +101,6 @@ def train_step(model, clip_model, rng, state, batch, lr, step, K,):
|
|
| 102 |
|
| 103 |
(_, stats), grad = (
|
| 104 |
jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
|
| 105 |
-
#grad = jax.lax.pmean(grad, axis_name="batch")
|
| 106 |
stats = jax.lax.pmean(stats, axis_name="batch")
|
| 107 |
|
| 108 |
# Clip the gradient by value.
|
|
@@ -238,26 +236,16 @@ def main(unused_argv):
|
|
| 238 |
|
| 239 |
grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
|
| 240 |
|
| 241 |
-
if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
|
| 242 |
sc_batch = dataset.get_clip_data()
|
| 243 |
if jax.local_device_count() > 1:
|
| 244 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
|
| 245 |
else:
|
| 246 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
|
| 247 |
|
| 248 |
-
if jax.host_id() == 0 and step%FLAGS.print_every:
|
| 249 |
-
for mlp_k, mlp in grad['params'].items():
|
| 250 |
-
for layer_k, layer_g in mlp.items():
|
| 251 |
-
summary_writer.scalar("%s/%s/kernel_grad"%(mlp_k, layer_k), jnp.linalg.norm(jnp.mean(layer_g['kernel'],0)), step)
|
| 252 |
-
for mlp_k, mlp in sc_grad['params'].items():
|
| 253 |
-
for layer_k, layer_g in mlp.items():
|
| 254 |
-
summary_writer.scalar("%s/%s/kernel_sc_grad"%(mlp_k, layer_k), jnp.linalg.norm(layer_g['kernel']), step)
|
| 255 |
-
|
| 256 |
leaves, treedef = jax.tree_flatten(grad)
|
| 257 |
sc_leaves, _ = jax.tree_flatten(sc_grad)
|
| 258 |
grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
|
| 259 |
-
|
| 260 |
-
|
| 261 |
|
| 262 |
state = update_pstep(state, grad, lr)
|
| 263 |
|
|
@@ -276,24 +264,26 @@ def main(unused_argv):
|
|
| 276 |
summary_writer.scalar("psnr/train", stats.psnr[0], step)
|
| 277 |
summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
|
| 278 |
summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
|
| 279 |
-
|
| 280 |
avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
|
| 281 |
avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
|
| 282 |
stats_trace = []
|
| 283 |
summary_writer.scalar("train_avg/loss", avg_loss, step)
|
| 284 |
summary_writer.scalar("train_avg/psnr", avg_psnr, step)
|
| 285 |
-
|
| 286 |
steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
|
| 287 |
reset_timer = True
|
| 288 |
rays_per_sec = FLAGS.batch_size * steps_per_sec
|
| 289 |
-
summary_writer.scalar("
|
| 290 |
-
summary_writer.scalar("
|
|
|
|
|
|
|
| 291 |
precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
|
| 292 |
print(("{:" + "{:d}".format(precision) + "d}").format(step) +
|
| 293 |
f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
|
| 294 |
f"avg_loss={avg_loss:0.4f}, " +
|
| 295 |
f"weight_l2={stats.weight_l2[0]:0.2e}, " +
|
| 296 |
-
|
| 297 |
f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
|
| 298 |
if step % FLAGS.save_every == 0:
|
| 299 |
state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
|
@@ -324,12 +314,10 @@ def main(unused_argv):
|
|
| 324 |
eval_time = time.time() - t_eval_start
|
| 325 |
num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
|
| 326 |
rays_per_sec = num_rays / eval_time
|
| 327 |
-
summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
|
| 328 |
print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
|
| 329 |
summary_writer.scalar("psnr/test", psnr, step)
|
| 330 |
-
summary_writer.scalar("test_psnr", psnr, step)
|
| 331 |
summary_writer.scalar("ssim/ssim", ssim, step)
|
| 332 |
-
summary_writer.scalar("test_ssim", ssim, step)
|
| 333 |
if sc_image is not None:
|
| 334 |
summary_writer .image("random_ray_image", sc_image, step)
|
| 335 |
summary_writer.image("test_pred_color", pred_color, step)
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def train_step(model, clip_model, rng, state, batch, lr, step, K,):
|
|
|
|
| 53 |
"""One optimization step.
|
| 54 |
|
| 55 |
Args:
|
|
|
|
| 101 |
|
| 102 |
(_, stats), grad = (
|
| 103 |
jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
|
|
|
|
| 104 |
stats = jax.lax.pmean(stats, axis_name="batch")
|
| 105 |
|
| 106 |
# Clip the gradient by value.
|
|
|
|
| 236 |
|
| 237 |
grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
|
| 238 |
|
| 239 |
+
if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss and step < FLAGS.stop_sc_loss:
|
| 240 |
sc_batch = dataset.get_clip_data()
|
| 241 |
if jax.local_device_count() > 1:
|
| 242 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
|
| 243 |
else:
|
| 244 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
leaves, treedef = jax.tree_flatten(grad)
|
| 247 |
sc_leaves, _ = jax.tree_flatten(sc_grad)
|
| 248 |
grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
|
|
|
|
|
|
|
| 249 |
|
| 250 |
state = update_pstep(state, grad, lr)
|
| 251 |
|
|
|
|
| 264 |
summary_writer.scalar("psnr/train", stats.psnr[0], step)
|
| 265 |
summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
|
| 266 |
summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
|
| 267 |
+
|
| 268 |
avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
|
| 269 |
avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
|
| 270 |
stats_trace = []
|
| 271 |
summary_writer.scalar("train_avg/loss", avg_loss, step)
|
| 272 |
summary_writer.scalar("train_avg/psnr", avg_psnr, step)
|
| 273 |
+
|
| 274 |
steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
|
| 275 |
reset_timer = True
|
| 276 |
rays_per_sec = FLAGS.batch_size * steps_per_sec
|
| 277 |
+
summary_writer.scalar("stats/weight_l2", stats.weight_l2[0], step)
|
| 278 |
+
summary_writer.scalar("stats/learning_rate", lr, step)
|
| 279 |
+
summary_writer.scalar("iter_speed/train_steps_per_sec", steps_per_sec, step)
|
| 280 |
+
summary_writer.scalar("iter_speed/train_rays_per_sec", rays_per_sec, step)
|
| 281 |
precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
|
| 282 |
print(("{:" + "{:d}".format(precision) + "d}").format(step) +
|
| 283 |
f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
|
| 284 |
f"avg_loss={avg_loss:0.4f}, " +
|
| 285 |
f"weight_l2={stats.weight_l2[0]:0.2e}, " +
|
| 286 |
+
f"sc_loss={sc_loss:0.4f}, " +
|
| 287 |
f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
|
| 288 |
if step % FLAGS.save_every == 0:
|
| 289 |
state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
|
|
|
| 314 |
eval_time = time.time() - t_eval_start
|
| 315 |
num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
|
| 316 |
rays_per_sec = num_rays / eval_time
|
| 317 |
+
summary_writer.scalar("iter_speed/test_rays_per_sec", rays_per_sec, step)
|
| 318 |
print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
|
| 319 |
summary_writer.scalar("psnr/test", psnr, step)
|
|
|
|
| 320 |
summary_writer.scalar("ssim/ssim", ssim, step)
|
|
|
|
| 321 |
if sc_image is not None:
|
| 322 |
summary_writer .image("random_ray_image", sc_image, step)
|
| 323 |
summary_writer.image("test_pred_color", pred_color, step)
|
train.sh
CHANGED
|
File without changes
|