initial commit
Browse files
README.md
CHANGED
@@ -58,6 +58,19 @@
|
|
58 |
- **NeRF-MAE**: The first large-scale pretraining utilizing Neural Radiance Fields (NeRF) as an input modality. We pretrain a single Transformer model on thousands of NeRFs for 3D representation learning.
|
59 |
- **NeRF-MAE Dataset**: A large-scale NeRF pretraining and downstream task finetuning dataset.
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
## Citation
|
62 |
|
63 |
If you find this repository or our dataset useful, please star ⭐ this repository and consider citing 📝:
|
@@ -102,15 +115,68 @@ cd ../../../..
|
|
102 |
|
103 |
## ⛳ Model Usage and Checkpoints
|
104 |
|
|
|
|
|
105 |
NeRF-MAE is structured to provide easy access to pretrained NeRF-MAE models (and reproductions), to facilitate use for various downstream tasks. This is for extracting good visual features from NeRFs if you don't have resources for large-scale pretraining. Our pretraining provides an easy-to-access embedding of any NeRF scene, which can be used for a variety of downstream tasks in a straightforwaed way.
|
106 |
|
107 |
-
We have released
|
|
|
|
|
|
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
```
|
110 |
import torch
|
111 |
# Load data from the specified folder and filename with the given resolution.
|
112 |
res, rgbsigma = load_data(folder_name, filename, resolution=args.resolution)
|
113 |
|
|
|
|
|
114 |
# Build the model using provided arguments.
|
115 |
model = build_model(args)
|
116 |
|
|
|
58 |
- **NeRF-MAE**: The first large-scale pretraining utilizing Neural Radiance Fields (NeRF) as an input modality. We pretrain a single Transformer model on thousands of NeRFs for 3D representation learning.
|
59 |
- **NeRF-MAE Dataset**: A large-scale NeRF pretraining and downstream task finetuning dataset.
|
60 |
|
61 |
+
## 🏷️ TODO 🚀
|
62 |
+
|
63 |
+
- [x] Release large-scale pretraining code 🚀
|
64 |
+
- [x] Release NeRF-MAE dataset comprising radiance and density grids 🚀
|
65 |
+
- [x] Release 3D object detection finetuning and eval code 🚀
|
66 |
+
- [x] Pretrained NeRF-MAE checkpoints and out-of-the-box model usage 🚀
|
67 |
+
|
68 |
+
## NeRF-MAE Model Architecture
|
69 |
+
<p align="center">
|
70 |
+
<img src="demo/nerf-mae_architecture.jpg" width="90%">
|
71 |
+
</p>
|
72 |
+
|
73 |
+
|
74 |
## Citation
|
75 |
|
76 |
If you find this repository or our dataset useful, please star ⭐ this repository and consider citing 📝:
|
|
|
115 |
|
116 |
## ⛳ Model Usage and Checkpoints
|
117 |
|
118 |
+
- [Hugginface repo to download pretrained and finetuned checkpoints](https://huggingface.co/mirshad7/NeRF-MAE)
|
119 |
+
|
120 |
NeRF-MAE is structured to provide easy access to pretrained NeRF-MAE models (and reproductions), to facilitate use for various downstream tasks. This is for extracting good visual features from NeRFs if you don't have resources for large-scale pretraining. Our pretraining provides an easy-to-access embedding of any NeRF scene, which can be used for a variety of downstream tasks in a straightforwaed way.
|
121 |
|
122 |
+
We have released pretrained and finetuned checkpoints to start using our codebase out-of-the-box. There are two types of usages. 1. Most common one is using the features directly in a downstream task such as an FPN head for 3D Object Detection and 2. Reconstruct the original grid for enforcing losses such as masked reconstruction loss. Below is a sample useage of our model with spelled out comments.
|
123 |
+
|
124 |
+
|
125 |
+
1. Get the features to be used in a downstream task
|
126 |
|
127 |
+
```
|
128 |
+
import torch
|
129 |
+
|
130 |
+
# Define Swin Transformer configurations
|
131 |
+
swin_config = {
|
132 |
+
"swin_t": {"embed_dim": 96, "depths": [2, 2, 6, 2], "num_heads": [3, 6, 12, 24]},
|
133 |
+
"swin_s": {"embed_dim": 96, "depths": [2, 2, 18, 2], "num_heads": [3, 6, 12, 24]},
|
134 |
+
"swin_b": {"embed_dim": 128, "depths": [2, 2, 18, 2], "num_heads": [3, 6, 12, 24]},
|
135 |
+
"swin_l": {"embed_dim": 192, "depths": [2, 2, 18, 2], "num_heads": [6, 12, 24, 48]},
|
136 |
+
}
|
137 |
+
|
138 |
+
# Set the desired backbone type
|
139 |
+
backbone_type = "swin_s"
|
140 |
+
config = swin_config[backbone_type]
|
141 |
+
|
142 |
+
# Initialize Swin Transformer model
|
143 |
+
model = SwinTransformer_MAE3D_New(
|
144 |
+
patch_size=[4, 4, 4],
|
145 |
+
embed_dim=config["embed_dim"],
|
146 |
+
depths=config["depths"],
|
147 |
+
num_heads=config["num_heads"],
|
148 |
+
window_size=[4, 4, 4],
|
149 |
+
stochastic_depth_prob=0.1,
|
150 |
+
expand_dim=True,
|
151 |
+
resolution=resolution,
|
152 |
+
)
|
153 |
+
|
154 |
+
# Load checkpoint and remove unused layers
|
155 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
156 |
+
model.load_state_dict(checkpoint["state_dict"])
|
157 |
+
for attr in ["decoder4", "decoder3", "decoder2", "decoder1", "out", "mask_token"]:
|
158 |
+
delattr(model, attr)
|
159 |
+
|
160 |
+
# Extract features using Swin Transformer backbone. input_grid has sample shape torch.randn((1, 4, 160, 160, 160))
|
161 |
+
features = []
|
162 |
+
input_grid = model.patch_partition(input_grid) + model.pos_embed.type_as(input_grid).to(input_grid.device).clone().detach()
|
163 |
+
for stage in model.stages:
|
164 |
+
input_grid = stage(input_grid)
|
165 |
+
features.append(torch.permute(input_grid, [0, 4, 1, 2, 3]).contiguous()) # Format: [N, C, H, W, D]
|
166 |
+
|
167 |
+
#Multi-scale features have shape: [torch.Size([1, 96, 40, 40, 40]), torch.Size([1, 192, 20, 20, 20]), torch.Size([1, 384, 10, 10, 10]), torch.Size([1, 768, 5, 5, 5])]
|
168 |
+
|
169 |
+
# Process features through FPN
|
170 |
+
```
|
171 |
+
|
172 |
+
2. Get the Original Grid Output
|
173 |
```
|
174 |
import torch
|
175 |
# Load data from the specified folder and filename with the given resolution.
|
176 |
res, rgbsigma = load_data(folder_name, filename, resolution=args.resolution)
|
177 |
|
178 |
+
# rgbsigma has sample shape torch.randn((1, 4, 160, 160, 160))
|
179 |
+
|
180 |
# Build the model using provided arguments.
|
181 |
model = build_model(args)
|
182 |
|