mohammed-aljafry commited on
Commit
82742f8
·
verified ·
1 Parent(s): 012ca41

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +65 -0
  2. config.json +16 -0
  3. model.py +1069 -0
  4. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: pytorch
6
+ tags:
7
+ - autonomous-driving
8
+ - multi-modal
9
+ - transformer
10
+ - interfuser
11
+ ---
12
+
13
+ # InterFuser: Multi-modal Transformer for Autonomous Driving
14
+
15
+ This is a Hugging Face repository for the InterFuser model, designed for end-to-end autonomous driving tasks. It processes multi-view camera images and LiDAR data to predict driving waypoints and perceive the surrounding traffic environment.
16
+
17
+ This model was trained by [Your Name or Alias].
18
+
19
+ ## Model Architecture
20
+
21
+ The model, `Interfuser`, is a Transformer-based architecture that:
22
+ - Uses a CNN backbone (ResNet-50 for RGB, ResNet-18 for LiDAR) to extract features.
23
+ - A Transformer Encoder fuses these multi-modal features.
24
+ - A Transformer Decoder predicts waypoints, junction status, and a Bird's-Eye-View (BEV) map of traffic.
25
+
26
+ ## How to Use
27
+
28
+ First, make sure you have `timm` installed: `pip install timm`.
29
+
30
+ You can then use the model with `AutoModel`. Remember to pass `trust_remote_code=True`.
31
+
32
+ ```python
33
+ import torch
34
+ from transformers import AutoModel, AutoConfig
35
+
36
+ # === 1. Load Model from Hugging Face Hub ===
37
+ model_id = "your-username/interfuser-driving-model" # Replace with your REPO_ID
38
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
39
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model.to(device)
43
+ model.eval()
44
+
45
+ # === 2. Prepare Dummy Input Data ===
46
+ batch_size = 1
47
+ dummy_inputs = {
48
+ 'rgb': torch.randn(batch_size, 3 * 4, 224, 224, device=device),
49
+ 'rgb_left': torch.randn(batch_size, 3, 224, 224, device=device),
50
+ 'rgb_right': torch.randn(batch_size, 3, 224, 224, device=device),
51
+ 'rgb_center': torch.randn(batch_size, 3, 224, 224, device=device),
52
+ 'lidar': torch.randn(batch_size, 3, 112, 112, device=device),
53
+ 'measurements': torch.randn(batch_size, 10, device=device),
54
+ 'target_point': torch.randn(batch_size, 2, device=device)
55
+ }
56
+
57
+ # === 3. Run Inference ===
58
+ with torch.no_grad():
59
+ outputs = model(**dummy_inputs)
60
+
61
+ # === 4. Interpret the Outputs ===
62
+ traffic, waypoints, is_junc, light, stop, _ = outputs
63
+ print("Inference successful!")
64
+ print(f"Waypoints shape: {waypoints.shape}")
65
+ print(f"Traffic BEV map shape: {traffic.shape}")
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "interfuser",
3
+ "architectures": [
4
+ "Interfuser"
5
+ ],
6
+ "embed_dim": 256,
7
+ "enc_depth": 6,
8
+ "dec_depth": 6,
9
+ "num_heads": 8,
10
+ "dim_feedforward": 2048,
11
+ "dropout": 0.1,
12
+ "rgb_backbone_name": "r50",
13
+ "lidar_backbone_name": "r18",
14
+ "use_different_backbone": true,
15
+ "waypoints_pred_head": "gru"
16
+ }
model.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ import logging
4
+ from collections import OrderedDict
5
+ from functools import partial
6
+ from typing import Optional, List
7
+
8
+ import torch
9
+ from torch import nn, Tensor
10
+ import torch.nn.functional as F
11
+ from torch.nn.parameter import Parameter
12
+
13
+ try:
14
+ from timm.models.layers import to_2tuple
15
+ from timm.models.resnet import resnet26d, resnet50d, resnet18d
16
+ except ImportError:
17
+ print("Please install timm: pip install timm")
18
+ raise
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+ # --- All the class definitions from the previous response go here ---
23
+ # HybridEmbed, PositionEmbeddingSine, Transformers, GRUWaypointsPredictor, etc.
24
+
25
+
26
+ class HybridEmbed(nn.Module):
27
+ def __init__(
28
+ self,
29
+ backbone,
30
+ img_size=224,
31
+ patch_size=1,
32
+ feature_size=None,
33
+ in_chans=3,
34
+ embed_dim=768,
35
+ ):
36
+ super().__init__()
37
+ assert isinstance(backbone, nn.Module)
38
+ img_size = to_2tuple(img_size)
39
+ patch_size = to_2tuple(patch_size)
40
+ self.img_size = img_size
41
+ self.patch_size = patch_size
42
+ self.backbone = backbone
43
+ if feature_size is None:
44
+ with torch.no_grad():
45
+ training = backbone.training
46
+ if training:
47
+ backbone.eval()
48
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
49
+ if isinstance(o, (list, tuple)):
50
+ o = o[-1] # last feature if backbone outputs list/tuple of features
51
+ feature_size = o.shape[-2:]
52
+ feature_dim = o.shape[1]
53
+ backbone.train(training)
54
+ else:
55
+ feature_size = to_2tuple(feature_size)
56
+ if hasattr(self.backbone, "feature_info"):
57
+ feature_dim = self.backbone.feature_info.channels()[-1]
58
+ else:
59
+ feature_dim = self.backbone.num_features
60
+
61
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
62
+
63
+ def forward(self, x):
64
+ x = self.backbone(x)
65
+ if isinstance(x, (list, tuple)):
66
+ x = x[-1] # last feature if backbone outputs list/tuple of features
67
+ x = self.proj(x)
68
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
69
+ return x, global_x
70
+
71
+
72
+ class PositionEmbeddingSine(nn.Module):
73
+ """
74
+ This is a more standard version of the position embedding, very similar to the one
75
+ used by the Attention is all you need paper, generalized to work on images.
76
+ """
77
+
78
+ def __init__(
79
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
80
+ ):
81
+ super().__init__()
82
+ self.num_pos_feats = num_pos_feats
83
+ self.temperature = temperature
84
+ self.normalize = normalize
85
+ if scale is not None and normalize is False:
86
+ raise ValueError("normalize should be True if scale is passed")
87
+ if scale is None:
88
+ scale = 2 * math.pi
89
+ self.scale = scale
90
+
91
+ def forward(self, tensor):
92
+ x = tensor
93
+ bs, _, h, w = x.shape
94
+ not_mask = torch.ones((bs, h, w), device=x.device)
95
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
96
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
97
+ if self.normalize:
98
+ eps = 1e-6
99
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
100
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
101
+
102
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
103
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
104
+
105
+ pos_x = x_embed[:, :, :, None] / dim_t
106
+ pos_y = y_embed[:, :, :, None] / dim_t
107
+ pos_x = torch.stack(
108
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
109
+ ).flatten(3)
110
+ pos_y = torch.stack(
111
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
112
+ ).flatten(3)
113
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
114
+ return pos
115
+
116
+
117
+ class TransformerEncoder(nn.Module):
118
+ def __init__(self, encoder_layer, num_layers, norm=None):
119
+ super().__init__()
120
+ self.layers = _get_clones(encoder_layer, num_layers)
121
+ self.num_layers = num_layers
122
+ self.norm = norm
123
+
124
+ def forward(
125
+ self,
126
+ src,
127
+ mask: Optional[Tensor] = None,
128
+ src_key_padding_mask: Optional[Tensor] = None,
129
+ pos: Optional[Tensor] = None,
130
+ ):
131
+ output = src
132
+
133
+ for layer in self.layers:
134
+ output = layer(
135
+ output,
136
+ src_mask=mask,
137
+ src_key_padding_mask=src_key_padding_mask,
138
+ pos=pos,
139
+ )
140
+
141
+ if self.norm is not None:
142
+ output = self.norm(output)
143
+
144
+ return output
145
+
146
+
147
+ class SpatialSoftmax(nn.Module):
148
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
149
+ super().__init__()
150
+
151
+ self.data_format = data_format
152
+ self.height = height
153
+ self.width = width
154
+ self.channel = channel
155
+
156
+ if temperature:
157
+ self.temperature = Parameter(torch.ones(1) * temperature)
158
+ else:
159
+ self.temperature = 1.0
160
+
161
+ pos_x, pos_y = np.meshgrid(
162
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
163
+ )
164
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
165
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
166
+ self.register_buffer("pos_x", pos_x)
167
+ self.register_buffer("pos_y", pos_y)
168
+
169
+ def forward(self, feature):
170
+ # Output:
171
+ # (N, C*2) x_0 y_0 ...
172
+
173
+ if self.data_format == "NHWC":
174
+ feature = (
175
+ feature.transpose(1, 3)
176
+ .tranpose(2, 3)
177
+ .view(-1, self.height * self.width)
178
+ )
179
+ else:
180
+ feature = feature.view(-1, self.height * self.width)
181
+
182
+ weight = F.softmax(feature / self.temperature, dim=-1)
183
+ expected_x = torch.sum(
184
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
185
+ )
186
+ expected_y = torch.sum(
187
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
188
+ )
189
+ expected_xy = torch.cat([expected_x, expected_y], 1)
190
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
191
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
192
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
193
+ return feature_keypoints
194
+
195
+
196
+ class MultiPath_Generator(nn.Module):
197
+ def __init__(self, in_channel, embed_dim, out_channel):
198
+ super().__init__()
199
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
200
+ self.tconv0 = nn.Sequential(
201
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
202
+ nn.BatchNorm2d(256),
203
+ nn.ReLU(True),
204
+ )
205
+ self.tconv1 = nn.Sequential(
206
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
207
+ nn.BatchNorm2d(256),
208
+ nn.ReLU(True),
209
+ )
210
+ self.tconv2 = nn.Sequential(
211
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
212
+ nn.BatchNorm2d(192),
213
+ nn.ReLU(True),
214
+ )
215
+ self.tconv3 = nn.Sequential(
216
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
217
+ nn.BatchNorm2d(64),
218
+ nn.ReLU(True),
219
+ )
220
+ self.tconv4_list = torch.nn.ModuleList(
221
+ [
222
+ nn.Sequential(
223
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
224
+ nn.Tanh(),
225
+ )
226
+ for _ in range(6)
227
+ ]
228
+ )
229
+
230
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
231
+
232
+ def forward(self, x, measurements):
233
+ mask = measurements[:, :6]
234
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
235
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
236
+ velocity = velocity.repeat(1, 32, 2, 2)
237
+
238
+ n, d, c = x.shape
239
+ x = x.transpose(1, 2)
240
+ x = x.view(n, -1, 2, 2)
241
+ x = torch.cat([x, velocity], dim=1)
242
+ x = self.tconv0(x)
243
+ x = self.tconv1(x)
244
+ x = self.tconv2(x)
245
+ x = self.tconv3(x)
246
+ x = self.upsample(x)
247
+ xs = []
248
+ for i in range(6):
249
+ xt = self.tconv4_list[i](x)
250
+ xs.append(xt)
251
+ xs = torch.stack(xs, dim=1)
252
+ x = torch.sum(xs * mask, dim=1)
253
+ x = self.spatial_softmax(x)
254
+ return x
255
+
256
+
257
+ class LinearWaypointsPredictor(nn.Module):
258
+ def __init__(self, input_dim, cumsum=True):
259
+ super().__init__()
260
+ self.cumsum = cumsum
261
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
262
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
263
+ self.head_relu = nn.ReLU(inplace=True)
264
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
265
+
266
+ def forward(self, x, measurements):
267
+ # input shape: n 10 embed_dim
268
+ bs, n, dim = x.shape
269
+ x = x + self.rank_embed
270
+ x = x.reshape(-1, dim)
271
+
272
+ mask = measurements[:, :6]
273
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
274
+
275
+ rs = []
276
+ for i in range(6):
277
+ res = self.head_fc1_list[i](x)
278
+ res = self.head_relu(res)
279
+ res = self.head_fc2_list[i](res)
280
+ rs.append(res)
281
+ rs = torch.stack(rs, 1)
282
+ x = torch.sum(rs * mask, dim=1)
283
+
284
+ x = x.view(bs, n, 2)
285
+ if self.cumsum:
286
+ x = torch.cumsum(x, 1)
287
+ return x
288
+
289
+
290
+ class GRUWaypointsPredictor(nn.Module):
291
+ def __init__(self, input_dim, waypoints=10):
292
+ super().__init__()
293
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
294
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
295
+ self.encoder = nn.Linear(2, 64)
296
+ self.decoder = nn.Linear(64, 2)
297
+ self.waypoints = waypoints
298
+
299
+ def forward(self, x, target_point):
300
+ bs = x.shape[0]
301
+ z = self.encoder(target_point).unsqueeze(0)
302
+ output, _ = self.gru(x, z)
303
+ output = output.reshape(bs * self.waypoints, -1)
304
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
305
+ output = torch.cumsum(output, 1)
306
+ return output
307
+
308
+ class GRUWaypointsPredictorWithCommand(nn.Module):
309
+ def __init__(self, input_dim, waypoints=10):
310
+ super().__init__()
311
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
312
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
313
+ self.encoder = nn.Linear(2, 64)
314
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
315
+ self.waypoints = waypoints
316
+
317
+ def forward(self, x, target_point, measurements):
318
+ bs, n, dim = x.shape
319
+ mask = measurements[:, :6, None, None]
320
+ mask = mask.repeat(1, 1, self.waypoints, 2)
321
+
322
+ z = self.encoder(target_point).unsqueeze(0)
323
+ outputs = []
324
+ for i in range(6):
325
+ output, _ = self.grus[i](x, z)
326
+ output = output.reshape(bs * self.waypoints, -1)
327
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
328
+ output = torch.cumsum(output, 1)
329
+ outputs.append(output)
330
+ outputs = torch.stack(outputs, 1)
331
+ output = torch.sum(outputs * mask, dim=1)
332
+ return output
333
+
334
+
335
+ class TransformerDecoder(nn.Module):
336
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
337
+ super().__init__()
338
+ self.layers = _get_clones(decoder_layer, num_layers)
339
+ self.num_layers = num_layers
340
+ self.norm = norm
341
+ self.return_intermediate = return_intermediate
342
+
343
+ def forward(
344
+ self,
345
+ tgt,
346
+ memory,
347
+ tgt_mask: Optional[Tensor] = None,
348
+ memory_mask: Optional[Tensor] = None,
349
+ tgt_key_padding_mask: Optional[Tensor] = None,
350
+ memory_key_padding_mask: Optional[Tensor] = None,
351
+ pos: Optional[Tensor] = None,
352
+ query_pos: Optional[Tensor] = None,
353
+ ):
354
+ output = tgt
355
+
356
+ intermediate = []
357
+
358
+ for layer in self.layers:
359
+ output = layer(
360
+ output,
361
+ memory,
362
+ tgt_mask=tgt_mask,
363
+ memory_mask=memory_mask,
364
+ tgt_key_padding_mask=tgt_key_padding_mask,
365
+ memory_key_padding_mask=memory_key_padding_mask,
366
+ pos=pos,
367
+ query_pos=query_pos,
368
+ )
369
+ if self.return_intermediate:
370
+ intermediate.append(self.norm(output))
371
+
372
+ if self.norm is not None:
373
+ output = self.norm(output)
374
+ if self.return_intermediate:
375
+ intermediate.pop()
376
+ intermediate.append(output)
377
+
378
+ if self.return_intermediate:
379
+ return torch.stack(intermediate)
380
+
381
+ return output.unsqueeze(0)
382
+
383
+
384
+ class TransformerEncoderLayer(nn.Module):
385
+ def __init__(
386
+ self,
387
+ d_model,
388
+ nhead,
389
+ dim_feedforward=2048,
390
+ dropout=0.1,
391
+ activation=nn.ReLU(),
392
+ normalize_before=False,
393
+ ):
394
+ super().__init__()
395
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
396
+ # Implementation of Feedforward model
397
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
398
+ self.dropout = nn.Dropout(dropout)
399
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
400
+
401
+ self.norm1 = nn.LayerNorm(d_model)
402
+ self.norm2 = nn.LayerNorm(d_model)
403
+ self.dropout1 = nn.Dropout(dropout)
404
+ self.dropout2 = nn.Dropout(dropout)
405
+
406
+ self.activation = activation()
407
+ self.normalize_before = normalize_before
408
+
409
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
410
+ return tensor if pos is None else tensor + pos
411
+
412
+ def forward_post(
413
+ self,
414
+ src,
415
+ src_mask: Optional[Tensor] = None,
416
+ src_key_padding_mask: Optional[Tensor] = None,
417
+ pos: Optional[Tensor] = None,
418
+ ):
419
+ q = k = self.with_pos_embed(src, pos)
420
+ src2 = self.self_attn(
421
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
422
+ )[0]
423
+ src = src + self.dropout1(src2)
424
+ src = self.norm1(src)
425
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
426
+ src = src + self.dropout2(src2)
427
+ src = self.norm2(src)
428
+ return src
429
+
430
+ def forward_pre(
431
+ self,
432
+ src,
433
+ src_mask: Optional[Tensor] = None,
434
+ src_key_padding_mask: Optional[Tensor] = None,
435
+ pos: Optional[Tensor] = None,
436
+ ):
437
+ src2 = self.norm1(src)
438
+ q = k = self.with_pos_embed(src2, pos)
439
+ src2 = self.self_attn(
440
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
441
+ )[0]
442
+ src = src + self.dropout1(src2)
443
+ src2 = self.norm2(src)
444
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
445
+ src = src + self.dropout2(src2)
446
+ return src
447
+
448
+ def forward(
449
+ self,
450
+ src,
451
+ src_mask: Optional[Tensor] = None,
452
+ src_key_padding_mask: Optional[Tensor] = None,
453
+ pos: Optional[Tensor] = None,
454
+ ):
455
+ if self.normalize_before:
456
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
457
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
458
+
459
+
460
+ class TransformerDecoderLayer(nn.Module):
461
+ def __init__(
462
+ self,
463
+ d_model,
464
+ nhead,
465
+ dim_feedforward=2048,
466
+ dropout=0.1,
467
+ activation=nn.ReLU(),
468
+ normalize_before=False,
469
+ ):
470
+ super().__init__()
471
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
472
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
473
+ # Implementation of Feedforward model
474
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
475
+ self.dropout = nn.Dropout(dropout)
476
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
477
+
478
+ self.norm1 = nn.LayerNorm(d_model)
479
+ self.norm2 = nn.LayerNorm(d_model)
480
+ self.norm3 = nn.LayerNorm(d_model)
481
+ self.dropout1 = nn.Dropout(dropout)
482
+ self.dropout2 = nn.Dropout(dropout)
483
+ self.dropout3 = nn.Dropout(dropout)
484
+
485
+ self.activation = activation()
486
+ self.normalize_before = normalize_before
487
+
488
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
489
+ return tensor if pos is None else tensor + pos
490
+
491
+ def forward_post(
492
+ self,
493
+ tgt,
494
+ memory,
495
+ tgt_mask: Optional[Tensor] = None,
496
+ memory_mask: Optional[Tensor] = None,
497
+ tgt_key_padding_mask: Optional[Tensor] = None,
498
+ memory_key_padding_mask: Optional[Tensor] = None,
499
+ pos: Optional[Tensor] = None,
500
+ query_pos: Optional[Tensor] = None,
501
+ ):
502
+ q = k = self.with_pos_embed(tgt, query_pos)
503
+ tgt2 = self.self_attn(
504
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
505
+ )[0]
506
+ tgt = tgt + self.dropout1(tgt2)
507
+ tgt = self.norm1(tgt)
508
+ tgt2 = self.multihead_attn(
509
+ query=self.with_pos_embed(tgt, query_pos),
510
+ key=self.with_pos_embed(memory, pos),
511
+ value=memory,
512
+ attn_mask=memory_mask,
513
+ key_padding_mask=memory_key_padding_mask,
514
+ )[0]
515
+ tgt = tgt + self.dropout2(tgt2)
516
+ tgt = self.norm2(tgt)
517
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
518
+ tgt = tgt + self.dropout3(tgt2)
519
+ tgt = self.norm3(tgt)
520
+ return tgt
521
+
522
+ def forward_pre(
523
+ self,
524
+ tgt,
525
+ memory,
526
+ tgt_mask: Optional[Tensor] = None,
527
+ memory_mask: Optional[Tensor] = None,
528
+ tgt_key_padding_mask: Optional[Tensor] = None,
529
+ memory_key_padding_mask: Optional[Tensor] = None,
530
+ pos: Optional[Tensor] = None,
531
+ query_pos: Optional[Tensor] = None,
532
+ ):
533
+ tgt2 = self.norm1(tgt)
534
+ q = k = self.with_pos_embed(tgt2, query_pos)
535
+ tgt2 = self.self_attn(
536
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
537
+ )[0]
538
+ tgt = tgt + self.dropout1(tgt2)
539
+ tgt2 = self.norm2(tgt)
540
+ tgt2 = self.multihead_attn(
541
+ query=self.with_pos_embed(tgt2, query_pos),
542
+ key=self.with_pos_embed(memory, pos),
543
+ value=memory,
544
+ attn_mask=memory_mask,
545
+ key_padding_mask=memory_key_padding_mask,
546
+ )[0]
547
+ tgt = tgt + self.dropout2(tgt2)
548
+ tgt2 = self.norm3(tgt)
549
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
550
+ tgt = tgt + self.dropout3(tgt2)
551
+ return tgt
552
+
553
+ def forward(
554
+ self,
555
+ tgt,
556
+ memory,
557
+ tgt_mask: Optional[Tensor] = None,
558
+ memory_mask: Optional[Tensor] = None,
559
+ tgt_key_padding_mask: Optional[Tensor] = None,
560
+ memory_key_padding_mask: Optional[Tensor] = None,
561
+ pos: Optional[Tensor] = None,
562
+ query_pos: Optional[Tensor] = None,
563
+ ):
564
+ if self.normalize_before:
565
+ return self.forward_pre(
566
+ tgt,
567
+ memory,
568
+ tgt_mask,
569
+ memory_mask,
570
+ tgt_key_padding_mask,
571
+ memory_key_padding_mask,
572
+ pos,
573
+ query_pos,
574
+ )
575
+ return self.forward_post(
576
+ tgt,
577
+ memory,
578
+ tgt_mask,
579
+ memory_mask,
580
+ tgt_key_padding_mask,
581
+ memory_key_padding_mask,
582
+ pos,
583
+ query_pos,
584
+ )
585
+
586
+
587
+ def _get_clones(module, N):
588
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
589
+
590
+
591
+ def _get_activation_fn(activation):
592
+ """Return an activation function given a string"""
593
+ if activation == "relu":
594
+ return F.relu
595
+ if activation == "gelu":
596
+ return F.gelu
597
+ if activation == "glu":
598
+ return F.glu
599
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
600
+
601
+
602
+ def build_attn_mask(mask_type):
603
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
604
+ if mask_type == "seperate_all":
605
+ mask[:50, :50] = False
606
+ mask[50:67, 50:67] = False
607
+ mask[67:84, 67:84] = False
608
+ mask[84:101, 84:101] = False
609
+ mask[101:151, 101:151] = False
610
+ elif mask_type == "seperate_view":
611
+ mask[:50, :50] = False
612
+ mask[50:67, 50:67] = False
613
+ mask[67:84, 67:84] = False
614
+ mask[84:101, 84:101] = False
615
+ mask[101:151, :] = False
616
+ mask[:, 101:151] = False
617
+ return mask
618
+
619
+ class Interfuser(nn.Module):
620
+ def __init__(
621
+ self,
622
+ img_size=224,
623
+ multi_view_img_size=112,
624
+ patch_size=8,
625
+ in_chans=3,
626
+ embed_dim=768,
627
+ enc_depth=6,
628
+ dec_depth=6,
629
+ dim_feedforward=2048,
630
+ normalize_before=False,
631
+ rgb_backbone_name="r26",
632
+ lidar_backbone_name="r26",
633
+ num_heads=8,
634
+ norm_layer=None,
635
+ dropout=0.1,
636
+ end2end=False,
637
+ direct_concat=True,
638
+ separate_view_attention=False,
639
+ separate_all_attention=False,
640
+ act_layer=None,
641
+ weight_init="",
642
+ freeze_num=-1,
643
+ with_lidar=False,
644
+ with_right_left_sensors=True,
645
+ with_center_sensor=False,
646
+ traffic_pred_head_type="det",
647
+ waypoints_pred_head="heatmap",
648
+ reverse_pos=True,
649
+ use_different_backbone=False,
650
+ use_view_embed=True,
651
+ use_mmad_pretrain=None,
652
+ ):
653
+ super().__init__()
654
+ self.traffic_pred_head_type = traffic_pred_head_type
655
+ self.num_features = (
656
+ self.embed_dim
657
+ ) = embed_dim # num_features for consistency with other models
658
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
659
+ act_layer = act_layer or nn.GELU
660
+
661
+ self.reverse_pos = reverse_pos
662
+ self.waypoints_pred_head = waypoints_pred_head
663
+ self.with_lidar = with_lidar
664
+ self.with_right_left_sensors = with_right_left_sensors
665
+ self.with_center_sensor = with_center_sensor
666
+
667
+ self.direct_concat = direct_concat
668
+ self.separate_view_attention = separate_view_attention
669
+ self.separate_all_attention = separate_all_attention
670
+ self.end2end = end2end
671
+ self.use_view_embed = use_view_embed
672
+
673
+ if self.direct_concat:
674
+ in_chans = in_chans * 4
675
+ self.with_center_sensor = False
676
+ self.with_right_left_sensors = False
677
+
678
+ if self.separate_view_attention:
679
+ self.attn_mask = build_attn_mask("seperate_view")
680
+ elif self.separate_all_attention:
681
+ self.attn_mask = build_attn_mask("seperate_all")
682
+ else:
683
+ self.attn_mask = None
684
+
685
+ if use_different_backbone:
686
+ if rgb_backbone_name == "r50":
687
+ self.rgb_backbone = resnet50d(
688
+ pretrained=True,
689
+ in_chans=in_chans,
690
+ features_only=True,
691
+ out_indices=[4],
692
+ )
693
+ elif rgb_backbone_name == "r26":
694
+ self.rgb_backbone = resnet26d(
695
+ pretrained=True,
696
+ in_chans=in_chans,
697
+ features_only=True,
698
+ out_indices=[4],
699
+ )
700
+ elif rgb_backbone_name == "r18":
701
+ self.rgb_backbone = resnet18d(
702
+ pretrained=True,
703
+ in_chans=in_chans,
704
+ features_only=True,
705
+ out_indices=[4],
706
+ )
707
+ if lidar_backbone_name == "r50":
708
+ self.lidar_backbone = resnet50d(
709
+ pretrained=False,
710
+ in_chans=in_chans,
711
+ features_only=True,
712
+ out_indices=[4],
713
+ )
714
+ elif lidar_backbone_name == "r26":
715
+ self.lidar_backbone = resnet26d(
716
+ pretrained=False,
717
+ in_chans=in_chans,
718
+ features_only=True,
719
+ out_indices=[4],
720
+ )
721
+ elif lidar_backbone_name == "r18":
722
+ self.lidar_backbone = resnet18d(
723
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
724
+ )
725
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
726
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
727
+
728
+ if use_mmad_pretrain:
729
+ params = torch.load(use_mmad_pretrain)["state_dict"]
730
+ updated_params = OrderedDict()
731
+ for key in params:
732
+ if "backbone" in key:
733
+ updated_params[key.replace("backbone.", "")] = params[key]
734
+ self.rgb_backbone.load_state_dict(updated_params)
735
+
736
+ self.rgb_patch_embed = rgb_embed_layer(
737
+ img_size=img_size,
738
+ patch_size=patch_size,
739
+ in_chans=in_chans,
740
+ embed_dim=embed_dim,
741
+ )
742
+ self.lidar_patch_embed = lidar_embed_layer(
743
+ img_size=img_size,
744
+ patch_size=patch_size,
745
+ in_chans=3,
746
+ embed_dim=embed_dim,
747
+ )
748
+ else:
749
+ if rgb_backbone_name == "r50":
750
+ self.rgb_backbone = resnet50d(
751
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
752
+ )
753
+ elif rgb_backbone_name == "r101":
754
+ self.rgb_backbone = resnet101d(
755
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
756
+ )
757
+ elif rgb_backbone_name == "r26":
758
+ self.rgb_backbone = resnet26d(
759
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
760
+ )
761
+ elif rgb_backbone_name == "r18":
762
+ self.rgb_backbone = resnet18d(
763
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
764
+ )
765
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
766
+
767
+ self.rgb_patch_embed = embed_layer(
768
+ img_size=img_size,
769
+ patch_size=patch_size,
770
+ in_chans=in_chans,
771
+ embed_dim=embed_dim,
772
+ )
773
+ self.lidar_patch_embed = embed_layer(
774
+ img_size=img_size,
775
+ patch_size=patch_size,
776
+ in_chans=in_chans,
777
+ embed_dim=embed_dim,
778
+ )
779
+
780
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
781
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
782
+
783
+ if self.end2end:
784
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
785
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
786
+ elif self.waypoints_pred_head == "heatmap":
787
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
788
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
789
+ else:
790
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
791
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
792
+
793
+ if self.end2end:
794
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
795
+ elif self.waypoints_pred_head == "heatmap":
796
+ self.waypoints_generator = MultiPath_Generator(
797
+ embed_dim + 32, embed_dim, 10
798
+ )
799
+ elif self.waypoints_pred_head == "gru":
800
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
801
+ elif self.waypoints_pred_head == "gru-command":
802
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
803
+ elif self.waypoints_pred_head == "linear":
804
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
805
+ elif self.waypoints_pred_head == "linear-sum":
806
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
807
+
808
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
809
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
810
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
811
+
812
+ if self.traffic_pred_head_type == "det":
813
+ self.traffic_pred_head = nn.Sequential(
814
+ *[
815
+ nn.Linear(embed_dim + 32, 64),
816
+ nn.ReLU(),
817
+ nn.Linear(64, 7),
818
+ nn.Sigmoid(),
819
+ ]
820
+ )
821
+ elif self.traffic_pred_head_type == "seg":
822
+ self.traffic_pred_head = nn.Sequential(
823
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
824
+ )
825
+
826
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
827
+
828
+ encoder_layer = TransformerEncoderLayer(
829
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
830
+ )
831
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
832
+
833
+ decoder_layer = TransformerDecoderLayer(
834
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
835
+ )
836
+ decoder_norm = nn.LayerNorm(embed_dim)
837
+ self.decoder = TransformerDecoder(
838
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
839
+ )
840
+ self.reset_parameters()
841
+
842
+ def reset_parameters(self):
843
+ nn.init.uniform_(self.global_embed)
844
+ nn.init.uniform_(self.view_embed)
845
+ nn.init.uniform_(self.query_embed)
846
+ nn.init.uniform_(self.query_pos_embed)
847
+
848
+ def forward_features(
849
+ self,
850
+ front_image,
851
+ left_image,
852
+ right_image,
853
+ front_center_image,
854
+ lidar,
855
+ measurements,
856
+ ):
857
+ features = []
858
+
859
+ # Front view processing
860
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
861
+ if self.use_view_embed:
862
+ front_image_token = (
863
+ front_image_token
864
+ + self.view_embed[:, :, 0:1, :]
865
+ + self.position_encoding(front_image_token)
866
+ )
867
+ else:
868
+ front_image_token = front_image_token + self.position_encoding(
869
+ front_image_token
870
+ )
871
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
872
+ front_image_token_global = (
873
+ front_image_token_global
874
+ + self.view_embed[:, :, 0, :]
875
+ + self.global_embed[:, :, 0:1]
876
+ )
877
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
878
+ features.extend([front_image_token, front_image_token_global])
879
+
880
+ if self.with_right_left_sensors:
881
+ # Left view processing
882
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
883
+ if self.use_view_embed:
884
+ left_image_token = (
885
+ left_image_token
886
+ + self.view_embed[:, :, 1:2, :]
887
+ + self.position_encoding(left_image_token)
888
+ )
889
+ else:
890
+ left_image_token = left_image_token + self.position_encoding(
891
+ left_image_token
892
+ )
893
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
894
+ left_image_token_global = (
895
+ left_image_token_global
896
+ + self.view_embed[:, :, 1, :]
897
+ + self.global_embed[:, :, 1:2]
898
+ )
899
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
900
+
901
+ # Right view processing
902
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
903
+ right_image
904
+ )
905
+ if self.use_view_embed:
906
+ right_image_token = (
907
+ right_image_token
908
+ + self.view_embed[:, :, 2:3, :]
909
+ + self.position_encoding(right_image_token)
910
+ )
911
+ else:
912
+ right_image_token = right_image_token + self.position_encoding(
913
+ right_image_token
914
+ )
915
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
916
+ right_image_token_global = (
917
+ right_image_token_global
918
+ + self.view_embed[:, :, 2, :]
919
+ + self.global_embed[:, :, 2:3]
920
+ )
921
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
922
+
923
+ features.extend(
924
+ [
925
+ left_image_token,
926
+ left_image_token_global,
927
+ right_image_token,
928
+ right_image_token_global,
929
+ ]
930
+ )
931
+
932
+ if self.with_center_sensor:
933
+ # Front center view processing
934
+ (
935
+ front_center_image_token,
936
+ front_center_image_token_global,
937
+ ) = self.rgb_patch_embed(front_center_image)
938
+ if self.use_view_embed:
939
+ front_center_image_token = (
940
+ front_center_image_token
941
+ + self.view_embed[:, :, 3:4, :]
942
+ + self.position_encoding(front_center_image_token)
943
+ )
944
+ else:
945
+ front_center_image_token = (
946
+ front_center_image_token
947
+ + self.position_encoding(front_center_image_token)
948
+ )
949
+
950
+ front_center_image_token = front_center_image_token.flatten(2).permute(
951
+ 2, 0, 1
952
+ )
953
+ front_center_image_token_global = (
954
+ front_center_image_token_global
955
+ + self.view_embed[:, :, 3, :]
956
+ + self.global_embed[:, :, 3:4]
957
+ )
958
+ front_center_image_token_global = front_center_image_token_global.permute(
959
+ 2, 0, 1
960
+ )
961
+ features.extend([front_center_image_token, front_center_image_token_global])
962
+
963
+ if self.with_lidar:
964
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
965
+ if self.use_view_embed:
966
+ lidar_token = (
967
+ lidar_token
968
+ + self.view_embed[:, :, 4:5, :]
969
+ + self.position_encoding(lidar_token)
970
+ )
971
+ else:
972
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
973
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
974
+ lidar_token_global = (
975
+ lidar_token_global
976
+ + self.view_embed[:, :, 4, :]
977
+ + self.global_embed[:, :, 4:5]
978
+ )
979
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
980
+ features.extend([lidar_token, lidar_token_global])
981
+
982
+ features = torch.cat(features, 0)
983
+ return features
984
+
985
+ def forward(self, x):
986
+ front_image = x["rgb"]
987
+ left_image = x["rgb_left"]
988
+ right_image = x["rgb_right"]
989
+ front_center_image = x["rgb_center"]
990
+ measurements = x["measurements"]
991
+ target_point = x["target_point"]
992
+ lidar = x["lidar"]
993
+
994
+ if self.direct_concat:
995
+ img_size = front_image.shape[-1]
996
+ left_image = torch.nn.functional.interpolate(
997
+ left_image, size=(img_size, img_size)
998
+ )
999
+ right_image = torch.nn.functional.interpolate(
1000
+ right_image, size=(img_size, img_size)
1001
+ )
1002
+ front_center_image = torch.nn.functional.interpolate(
1003
+ front_center_image, size=(img_size, img_size)
1004
+ )
1005
+ front_image = torch.cat(
1006
+ [front_image, left_image, right_image, front_center_image], dim=1
1007
+ )
1008
+ features = self.forward_features(
1009
+ front_image,
1010
+ left_image,
1011
+ right_image,
1012
+ front_center_image,
1013
+ lidar,
1014
+ measurements,
1015
+ )
1016
+
1017
+ bs = front_image.shape[0]
1018
+
1019
+ if self.end2end:
1020
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
1021
+ else:
1022
+ tgt = self.position_encoding(
1023
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1024
+ )
1025
+ tgt = tgt.flatten(2)
1026
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1027
+ tgt = tgt.permute(2, 0, 1)
1028
+
1029
+ memory = self.encoder(features, mask=self.attn_mask)
1030
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1031
+
1032
+ hs = hs.permute(1, 0, 2) # Batchsize , N, C
1033
+ if self.end2end:
1034
+ waypoints = self.waypoints_generator(hs, target_point)
1035
+ return waypoints
1036
+
1037
+ if self.waypoints_pred_head != "heatmap":
1038
+ traffic_feature = hs[:, :400]
1039
+ is_junction_feature = hs[:, 400]
1040
+ traffic_light_state_feature = hs[:, 400]
1041
+ stop_sign_feature = hs[:, 400]
1042
+ waypoints_feature = hs[:, 401:411]
1043
+ else:
1044
+ traffic_feature = hs[:, :400]
1045
+ is_junction_feature = hs[:, 400]
1046
+ traffic_light_state_feature = hs[:, 400]
1047
+ stop_sign_feature = hs[:, 400]
1048
+ waypoints_feature = hs[:, 401:405]
1049
+
1050
+ if self.waypoints_pred_head == "heatmap":
1051
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1052
+ elif self.waypoints_pred_head == "gru":
1053
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1054
+ elif self.waypoints_pred_head == "gru-command":
1055
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1056
+ elif self.waypoints_pred_head == "linear":
1057
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1058
+ elif self.waypoints_pred_head == "linear-sum":
1059
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1060
+
1061
+ is_junction = self.junction_pred_head(is_junction_feature)
1062
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1063
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1064
+
1065
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1066
+ velocity = velocity.repeat(1, 400, 32)
1067
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1068
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1069
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ef39ec4197f3e9ed8b659709b331a2cd63c30b6019c61eb155eb438135e3dd7
3
+ size 212334062