mohammed-aljafry commited on
Commit
62ba2ba
·
verified ·
1 Parent(s): 58e1007

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +951 -210
model.py CHANGED
@@ -1,66 +1,121 @@
1
  import math
2
  import copy
3
  import logging
 
4
  from collections import OrderedDict
5
  from functools import partial
6
  from typing import Optional, List
7
- import numpy as np
8
 
 
9
  import torch
10
  from torch import nn, Tensor
11
  import torch.nn.functional as F
12
- from torch.nn.parameter import Parameter
 
 
 
13
 
14
- # Base classes from transformers
15
- from transformers import PretrainedConfig, PreTrainedModel
16
 
17
- # Dependencies from timm
 
 
 
 
 
 
 
18
  try:
19
- from timm.models.layers import to_2tuple
20
- # The original code uses resnet from a local file, but for portability, we use timm's resnets
21
- from timm.models.resnet import resnet26d, resnet50d, resnet18d
22
- except ImportError:
23
- print("This model requires the 'timm' library. Please install it using: pip install timm")
24
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # --- Helper Classes (Unchanged from original code) ---
27
 
28
- def _get_clones(module, N):
29
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
 
 
30
 
31
  class HybridEmbed(nn.Module):
32
- def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
 
 
 
 
 
 
 
 
33
  super().__init__()
34
  assert isinstance(backbone, nn.Module)
35
  img_size = to_2tuple(img_size)
 
36
  self.img_size = img_size
 
37
  self.backbone = backbone
38
  if feature_size is None:
39
  with torch.no_grad():
40
  training = backbone.training
41
- if training: backbone.eval()
 
42
  o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
43
- if isinstance(o, (list, tuple)): o = o[-1]
 
 
44
  feature_dim = o.shape[1]
45
  backbone.train(training)
46
  else:
47
- feature_dim = self.backbone.num_features
 
 
 
 
 
48
  self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
49
 
50
  def forward(self, x):
51
  x = self.backbone(x)
52
- if isinstance(x, (list, tuple)): x = x[-1]
 
53
  x = self.proj(x)
54
  global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
55
  return x, global_x
56
 
 
57
  class PositionEmbeddingSine(nn.Module):
58
- def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
 
 
 
 
 
 
 
59
  super().__init__()
60
  self.num_pos_feats = num_pos_feats
61
  self.temperature = temperature
62
  self.normalize = normalize
63
- if scale is None: scale = 2 * math.pi
 
 
 
64
  self.scale = scale
65
 
66
  def forward(self, tensor):
@@ -73,35 +128,328 @@ class PositionEmbeddingSine(nn.Module):
73
  eps = 1e-6
74
  y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
75
  x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
 
76
  dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
77
  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
 
78
  pos_x = x_embed[:, :, :, None] / dim_t
79
  pos_y = y_embed[:, :, :, None] / dim_t
80
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
81
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
 
 
 
 
82
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
83
  return pos
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  class TransformerEncoderLayer(nn.Module):
86
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU()):
 
 
 
 
 
 
 
 
87
  super().__init__()
88
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
 
89
  self.linear1 = nn.Linear(d_model, dim_feedforward)
90
  self.dropout = nn.Dropout(dropout)
91
  self.linear2 = nn.Linear(dim_feedforward, d_model)
 
92
  self.norm1 = nn.LayerNorm(d_model)
93
  self.norm2 = nn.LayerNorm(d_model)
94
  self.dropout1 = nn.Dropout(dropout)
95
  self.dropout2 = nn.Dropout(dropout)
96
- # THIS IS THE FIX: We receive an INSTANCE of the activation function, not the class.
97
- self.activation = activation
 
98
 
99
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
100
  return tensor if pos is None else tensor + pos
101
 
102
- def forward(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
 
 
 
 
 
 
103
  q = k = self.with_pos_embed(src, pos)
104
- src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
 
 
105
  src = src + self.dropout1(src2)
106
  src = self.norm1(src)
107
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
@@ -109,47 +457,91 @@ class TransformerEncoderLayer(nn.Module):
109
  src = self.norm2(src)
110
  return src
111
 
112
- class TransformerEncoder(nn.Module):
113
- def __init__(self, encoder_layer, num_layers, norm=None):
114
- super().__init__()
115
- self.layers = _get_clones(encoder_layer, num_layers)
116
- self.num_layers = num_layers
117
- self.norm = norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- def forward(self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
120
- output = src
121
- for layer in self.layers:
122
- output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
123
- if self.norm is not None:
124
- output = self.norm(output)
125
- return output
126
 
127
  class TransformerDecoderLayer(nn.Module):
128
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU()):
 
 
 
 
 
 
 
 
129
  super().__init__()
130
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
131
  self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
 
132
  self.linear1 = nn.Linear(d_model, dim_feedforward)
133
  self.dropout = nn.Dropout(dropout)
134
  self.linear2 = nn.Linear(dim_feedforward, d_model)
 
135
  self.norm1 = nn.LayerNorm(d_model)
136
  self.norm2 = nn.LayerNorm(d_model)
137
  self.norm3 = nn.LayerNorm(d_model)
138
  self.dropout1 = nn.Dropout(dropout)
139
  self.dropout2 = nn.Dropout(dropout)
140
  self.dropout3 = nn.Dropout(dropout)
141
- # THIS IS THE FIX: We receive an INSTANCE of the activation function, not the class.
142
- self.activation = activation
 
143
 
144
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
145
  return tensor if pos is None else tensor + pos
146
 
147
- def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
 
 
 
 
 
 
 
 
 
 
148
  q = k = self.with_pos_embed(tgt, query_pos)
149
- tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
 
 
150
  tgt = tgt + self.dropout1(tgt2)
151
  tgt = self.norm1(tgt)
152
- tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
 
 
 
 
 
 
153
  tgt = tgt + self.dropout2(tgt2)
154
  tgt = self.norm2(tgt)
155
  tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
@@ -157,202 +549,551 @@ class TransformerDecoderLayer(nn.Module):
157
  tgt = self.norm3(tgt)
158
  return tgt
159
 
160
- class TransformerDecoder(nn.Module):
161
- def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
162
- super().__init__()
163
- self.layers = _get_clones(decoder_layer, num_layers)
164
- self.num_layers = num_layers
165
- self.norm = norm
166
- self.return_intermediate = return_intermediate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
169
- output = tgt
170
- intermediate = []
171
- for layer in self.layers:
172
- output = layer(output, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
173
- if self.norm is not None:
174
- output = self.norm(output)
175
- if self.return_intermediate:
176
- return torch.stack(intermediate)
177
- return output.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- class GRUWaypointsPredictor(nn.Module):
180
- def __init__(self, input_dim, waypoints=10):
181
- super().__init__()
182
- self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
183
- self.encoder = nn.Linear(2, 64)
184
- self.decoder = nn.Linear(64, 2)
185
- self.waypoints = waypoints
186
 
187
- def forward(self, x, target_point):
188
- bs = x.shape[0]
189
- z = self.encoder(target_point).unsqueeze(0)
190
- output, _ = self.gru(x, z)
191
- output = output.reshape(bs * self.waypoints, -1)
192
- output = self.decoder(output).reshape(bs, self.waypoints, 2)
193
- output = torch.cumsum(output, 1)
194
- return output
195
 
196
- # --- Hugging Face Compatible Main Classes ---
 
 
 
 
 
 
 
 
197
 
198
- class InterfuserConfig(PretrainedConfig):
199
- model_type = "interfuser"
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  def __init__(
202
  self,
203
- embed_dim=256,
 
 
 
 
204
  enc_depth=6,
205
  dec_depth=6,
206
- num_heads=8,
207
  dim_feedforward=2048,
 
 
 
 
 
208
  dropout=0.1,
209
- rgb_backbone_name="r50",
210
- lidar_backbone_name="r18",
211
- use_different_backbone=True,
212
- waypoints_pred_head="gru",
213
  direct_concat=True,
214
- with_lidar=True,
 
 
 
 
 
215
  with_right_left_sensors=True,
 
 
 
 
 
216
  use_view_embed=True,
217
- **kwargs,
218
  ):
219
- self.embed_dim = embed_dim
220
- self.enc_depth = enc_depth
221
- self.dec_depth = dec_depth
222
- self.num_heads = num_heads
223
- self.dim_feedforward = dim_feedforward
224
- self.dropout = dropout
225
- self.rgb_backbone_name = rgb_backbone_name
226
- self.lidar_backbone_name = lidar_backbone_name
227
- self.use_different_backbone = use_different_backbone
228
  self.waypoints_pred_head = waypoints_pred_head
229
- self.direct_concat = direct_concat
230
  self.with_lidar = with_lidar
231
  self.with_right_left_sensors = with_right_left_sensors
 
 
 
 
 
 
232
  self.use_view_embed = use_view_embed
233
- super().__init__(**kwargs)
234
-
235
- class InterfuserModel(PreTrainedModel):
236
- config_class = InterfuserConfig
237
- # This base_model_prefix is needed to know that the core model is inside 'model'
238
- # when loading/saving weights. Since our model is monolithic, we set it to something simple.
239
- base_model_prefix = "interfuser"
240
-
241
- def __init__(self, config: InterfuserConfig):
242
- super().__init__(config)
243
- self.config = config
244
-
245
- # Extract params from config
246
- embed_dim=config.embed_dim; enc_depth=config.enc_depth; dec_depth=config.dec_depth
247
- num_heads=config.num_heads; dim_feedforward=config.dim_feedforward; dropout=config.dropout
248
- rgb_backbone_name=config.rgb_backbone_name; lidar_backbone_name=config.lidar_backbone_name
249
- use_different_backbone=config.use_different_backbone
250
-
251
- # Hardcoded params from original code
252
- in_chans=3; img_size=224;
253
-
254
- # Main model attributes
255
- self.direct_concat = config.direct_concat
256
- self.with_lidar = config.with_lidar
257
- self.with_right_left_sensors = config.with_right_left_sensors
258
- self.use_view_embed = config.use_view_embed
259
-
260
- in_chans_rgb = in_chans * 4 if self.direct_concat else in_chans
261
-
262
- # Initialize backbones
263
- backbone_kwargs = {'pretrained': False, 'features_only': True, 'out_indices': [4]}
264
- self.rgb_backbone = {'r50': resnet50d, 'r26': resnet26d, 'r18': resnet18d}[rgb_backbone_name](in_chans=in_chans_rgb, **backbone_kwargs)
265
- self.lidar_backbone = {'r50': resnet50d, 'r26': resnet26d, 'r18': resnet18d}[lidar_backbone_name](in_chans=in_chans, **backbone_kwargs)
266
-
267
- # Initialize embedding layers
268
- self.rgb_patch_embed = HybridEmbed(self.rgb_backbone, img_size=img_size, in_chans=in_chans_rgb, embed_dim=embed_dim)
269
- self.lidar_patch_embed = HybridEmbed(self.lidar_backbone, img_size=112, in_chans=in_chans, embed_dim=embed_dim)
270
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
272
  self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
273
- self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
274
- self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
275
-
276
- # Initialize prediction heads
277
- self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  self.junction_pred_head = nn.Linear(embed_dim, 2)
279
  self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
280
  self.stop_sign_head = nn.Linear(embed_dim, 2)
281
- self.traffic_pred_head = nn.Sequential(
282
- nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()
283
- )
284
-
 
 
 
 
 
 
 
 
 
 
 
285
  self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
286
-
287
- # --- THE FIX IS HERE ---
288
- # Create an INSTANCE of the activation function
289
- activation_instance = nn.GELU()
290
-
291
- # Pass the INSTANCE, not the class, to the layers
292
- encoder_layer = TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward, dropout, activation=activation_instance)
293
  self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
294
-
295
- decoder_layer = TransformerDecoderLayer(embed_dim, num_heads, dim_feedforward, dropout, activation=activation_instance)
296
- self.decoder = TransformerDecoder(decoder_layer, dec_depth, nn.LayerNorm(embed_dim), return_intermediate=False)
297
 
298
- def forward_features(self, front_image, left_image, right_image, lidar):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  features = []
300
-
301
- x, x_g = self.rgb_patch_embed(front_image)
302
- if self.use_view_embed: x = x + self.view_embed[:,:,0:1,:]
303
- x = x + self.position_encoding(x)
304
- x=x.flatten(2).permute(2,0,1)
305
- if self.use_view_embed: x_g = x_g + self.view_embed[:,:,0,:]
306
- x_g = x_g + self.global_embed[:,:,0:1]
307
- x_g=x_g.permute(2,0,1)
308
- features.extend([x,x_g])
309
-
310
- if self.with_right_left_sensors and not self.direct_concat:
311
- x, x_g = self.rgb_patch_embed(left_image);
312
- if self.use_view_embed: x = x + self.view_embed[:,:,1:2,:];
313
- x=x.flatten(2).permute(2,0,1);
314
- if self.use_view_embed: x_g=x_g+self.view_embed[:,:,1,:];
315
- x_g=x_g+self.global_embed[:,:,1:2]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
316
-
317
- x, x_g = self.rgb_patch_embed(right_image);
318
- if self.use_view_embed: x = x + self.view_embed[:,:,2:3,:];
319
- x=x.flatten(2).permute(2,0,1);
320
- if self.use_view_embed: x_g=x_g+self.view_embed[:,:,2,:];
321
- x_g=x_g+self.global_embed[:,:,2:3]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  if self.with_lidar:
324
- x, x_g = self.lidar_patch_embed(lidar);
325
- if self.use_view_embed: x = x + self.view_embed[:,:,4:5,:];
326
- x = x + self.position_encoding(x)
327
- x=x.flatten(2).permute(2,0,1);
328
- if self.use_view_embed: x_g=x_g+self.view_embed[:,:,4,:];
329
- x_g=x_g+self.global_embed[:,:,4:5]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
330
-
331
- return torch.cat(features, 0)
332
-
333
- def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
334
- front_image = rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  if self.direct_concat:
336
  img_size = front_image.shape[-1]
337
- front_image = torch.cat([front_image, F.interpolate(rgb_left,s:=(img_size,img_size)), F.interpolate(rgb_right,s), F.interpolate(rgb_center,s)], dim=1)
338
-
339
- features = self.forward_features(front_image, rgb_left, rgb_right, lidar)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  bs = front_image.shape[0]
341
-
342
- tgt = self.position_encoding(torch.ones((bs, 1, 20, 20), device=rgb.device)).flatten(2)
343
- tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2).permute(2, 0, 1)
344
-
345
- memory = self.encoder(features)
346
- hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0].permute(1, 0, 2)
347
-
348
- traffic_feature = hs[:, :400]
349
- waypoints_feature = hs[:, 401:411]
350
- is_junction_feature = hs[:, 400]
351
-
352
- waypoints = self.waypoints_generator(waypoints_feature, target_point)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  is_junction = self.junction_pred_head(is_junction_feature)
354
- traffic_light_state = self.traffic_light_pred_head(is_junction_feature)
355
- stop_sign = self.stop_sign_head(is_junction_feature)
356
- traffic = self.traffic_pred_head(traffic_feature)
357
-
 
 
 
358
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
 
1
  import math
2
  import copy
3
  import logging
4
+ import sys
5
  from collections import OrderedDict
6
  from functools import partial
7
  from typing import Optional, List
 
8
 
9
+ import numpy as np
10
  import torch
11
  from torch import nn, Tensor
12
  import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ from torch.utils.data import DataLoader, Dataset
15
+
16
+ # import wandb # Import wandb
17
 
18
+ # Add InterFuser to Python path
19
+ sys.path.append('/content/InterFuser')
20
 
21
+ # --- W&B Login ---
22
+ # You might need to provide your API key when running this in Colab
23
+ # try:
24
+ # wandb.login()
25
+ # except Exception as e:
26
+ # print(f"Wandb login failed. Please ensure you have provided your API key. Error: {e}")
27
+
28
+ # Import specific modules from the cloned repository (adjust paths if needed)
29
  try:
30
+ # Assuming the structure within the cloned repo is InterFuser/interfuser/...
31
+ from InterFuser.interfuser.timm.models.layers import StdConv2dSame, StdConv2d, to_2tuple
32
+ from InterFuser.interfuser.timm.models.registry import register_model
33
+ # Note: The original code seemed to have local imports like '.resnet',
34
+ # these need to be adjusted based on the actual file structure after cloning.
35
+ # Using the direct import path assuming it's available after appending '/content'
36
+ from InterFuser.interfuser.timm.models.resnet import resnet26d, resnet50d, resnet18d, resnet26, resnet50, resnet101d
37
+ except ImportError as e:
38
+ print(f"Error importing from InterFuser repository: {e}")
39
+ print("Please ensure the repository structure is correct and accessible.")
40
+ import torch
41
+ from torch.utils.data import Dataset, DataLoader
42
+ import numpy as np
43
+ import cv2
44
+ import json
45
+ from pathlib import Path
46
+ from torchvision import transforms
47
+ import os
48
+ import tqdm
49
 
 
50
 
51
+
52
+ _logger = logging.getLogger(__name__)
53
+ logging.basicConfig(level=logging.INFO) # Show logs, including warnings
54
+
55
 
56
  class HybridEmbed(nn.Module):
57
+ def __init__(
58
+ self,
59
+ backbone,
60
+ img_size=224,
61
+ patch_size=1,
62
+ feature_size=None,
63
+ in_chans=3,
64
+ embed_dim=768,
65
+ ):
66
  super().__init__()
67
  assert isinstance(backbone, nn.Module)
68
  img_size = to_2tuple(img_size)
69
+ patch_size = to_2tuple(patch_size)
70
  self.img_size = img_size
71
+ self.patch_size = patch_size
72
  self.backbone = backbone
73
  if feature_size is None:
74
  with torch.no_grad():
75
  training = backbone.training
76
+ if training:
77
+ backbone.eval()
78
  o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
79
+ if isinstance(o, (list, tuple)):
80
+ o = o[-1] # last feature if backbone outputs list/tuple of features
81
+ feature_size = o.shape[-2:]
82
  feature_dim = o.shape[1]
83
  backbone.train(training)
84
  else:
85
+ feature_size = to_2tuple(feature_size)
86
+ if hasattr(self.backbone, "feature_info"):
87
+ feature_dim = self.backbone.feature_info.channels()[-1]
88
+ else:
89
+ feature_dim = self.backbone.num_features
90
+
91
  self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
92
 
93
  def forward(self, x):
94
  x = self.backbone(x)
95
+ if isinstance(x, (list, tuple)):
96
+ x = x[-1] # last feature if backbone outputs list/tuple of features
97
  x = self.proj(x)
98
  global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
99
  return x, global_x
100
 
101
+
102
  class PositionEmbeddingSine(nn.Module):
103
+ """
104
+ This is a more standard version of the position embedding, very similar to the one
105
+ used by the Attention is all you need paper, generalized to work on images.
106
+ """
107
+
108
+ def __init__(
109
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
110
+ ):
111
  super().__init__()
112
  self.num_pos_feats = num_pos_feats
113
  self.temperature = temperature
114
  self.normalize = normalize
115
+ if scale is not None and normalize is False:
116
+ raise ValueError("normalize should be True if scale is passed")
117
+ if scale is None:
118
+ scale = 2 * math.pi
119
  self.scale = scale
120
 
121
  def forward(self, tensor):
 
128
  eps = 1e-6
129
  y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
130
  x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
131
+
132
  dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
133
  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
134
+
135
  pos_x = x_embed[:, :, :, None] / dim_t
136
  pos_y = y_embed[:, :, :, None] / dim_t
137
+ pos_x = torch.stack(
138
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
139
+ ).flatten(3)
140
+ pos_y = torch.stack(
141
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
142
+ ).flatten(3)
143
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
144
  return pos
145
 
146
+
147
+ class TransformerEncoder(nn.Module):
148
+ def __init__(self, encoder_layer, num_layers, norm=None):
149
+ super().__init__()
150
+ self.layers = _get_clones(encoder_layer, num_layers)
151
+ self.num_layers = num_layers
152
+ self.norm = norm
153
+
154
+ def forward(
155
+ self,
156
+ src,
157
+ mask: Optional[Tensor] = None,
158
+ src_key_padding_mask: Optional[Tensor] = None,
159
+ pos: Optional[Tensor] = None,
160
+ ):
161
+ output = src
162
+
163
+ for layer in self.layers:
164
+ output = layer(
165
+ output,
166
+ src_mask=mask,
167
+ src_key_padding_mask=src_key_padding_mask,
168
+ pos=pos,
169
+ )
170
+
171
+ if self.norm is not None:
172
+ output = self.norm(output)
173
+
174
+ return output
175
+
176
+
177
+ class SpatialSoftmax(nn.Module):
178
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
179
+ super().__init__()
180
+
181
+ self.data_format = data_format
182
+ self.height = height
183
+ self.width = width
184
+ self.channel = channel
185
+
186
+ if temperature:
187
+ self.temperature = Parameter(torch.ones(1) * temperature)
188
+ else:
189
+ self.temperature = 1.0
190
+
191
+ pos_x, pos_y = np.meshgrid(
192
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
193
+ )
194
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
195
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
196
+ self.register_buffer("pos_x", pos_x)
197
+ self.register_buffer("pos_y", pos_y)
198
+
199
+ def forward(self, feature):
200
+ # Output:
201
+ # (N, C*2) x_0 y_0 ...
202
+
203
+ if self.data_format == "NHWC":
204
+ feature = (
205
+ feature.transpose(1, 3)
206
+ .tranpose(2, 3)
207
+ .view(-1, self.height * self.width)
208
+ )
209
+ else:
210
+ feature = feature.view(-1, self.height * self.width)
211
+
212
+ weight = F.softmax(feature / self.temperature, dim=-1)
213
+ expected_x = torch.sum(
214
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
215
+ )
216
+ expected_y = torch.sum(
217
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
218
+ )
219
+ expected_xy = torch.cat([expected_x, expected_y], 1)
220
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
221
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
222
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
223
+ return feature_keypoints
224
+
225
+
226
+ class MultiPath_Generator(nn.Module):
227
+ def __init__(self, in_channel, embed_dim, out_channel):
228
+ super().__init__()
229
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
230
+ self.tconv0 = nn.Sequential(
231
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
232
+ nn.BatchNorm2d(256),
233
+ nn.ReLU(True),
234
+ )
235
+ self.tconv1 = nn.Sequential(
236
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
237
+ nn.BatchNorm2d(256),
238
+ nn.ReLU(True),
239
+ )
240
+ self.tconv2 = nn.Sequential(
241
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
242
+ nn.BatchNorm2d(192),
243
+ nn.ReLU(True),
244
+ )
245
+ self.tconv3 = nn.Sequential(
246
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
247
+ nn.BatchNorm2d(64),
248
+ nn.ReLU(True),
249
+ )
250
+ self.tconv4_list = torch.nn.ModuleList(
251
+ [
252
+ nn.Sequential(
253
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
254
+ nn.Tanh(),
255
+ )
256
+ for _ in range(6)
257
+ ]
258
+ )
259
+
260
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
261
+
262
+ def forward(self, x, measurements):
263
+ mask = measurements[:, :6]
264
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
265
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
266
+ velocity = velocity.repeat(1, 32, 2, 2)
267
+
268
+ n, d, c = x.shape
269
+ x = x.transpose(1, 2)
270
+ x = x.view(n, -1, 2, 2)
271
+ x = torch.cat([x, velocity], dim=1)
272
+ x = self.tconv0(x)
273
+ x = self.tconv1(x)
274
+ x = self.tconv2(x)
275
+ x = self.tconv3(x)
276
+ x = self.upsample(x)
277
+ xs = []
278
+ for i in range(6):
279
+ xt = self.tconv4_list[i](x)
280
+ xs.append(xt)
281
+ xs = torch.stack(xs, dim=1)
282
+ x = torch.sum(xs * mask, dim=1)
283
+ x = self.spatial_softmax(x)
284
+ return x
285
+
286
+
287
+ class LinearWaypointsPredictor(nn.Module):
288
+ def __init__(self, input_dim, cumsum=True):
289
+ super().__init__()
290
+ self.cumsum = cumsum
291
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
292
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
293
+ self.head_relu = nn.ReLU(inplace=True)
294
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
295
+
296
+ def forward(self, x, measurements):
297
+ # input shape: n 10 embed_dim
298
+ bs, n, dim = x.shape
299
+ x = x + self.rank_embed
300
+ x = x.reshape(-1, dim)
301
+
302
+ mask = measurements[:, :6]
303
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
304
+
305
+ rs = []
306
+ for i in range(6):
307
+ res = self.head_fc1_list[i](x)
308
+ res = self.head_relu(res)
309
+ res = self.head_fc2_list[i](res)
310
+ rs.append(res)
311
+ rs = torch.stack(rs, 1)
312
+ x = torch.sum(rs * mask, dim=1)
313
+
314
+ x = x.view(bs, n, 2)
315
+ if self.cumsum:
316
+ x = torch.cumsum(x, 1)
317
+ return x
318
+
319
+
320
+ class GRUWaypointsPredictor(nn.Module):
321
+ def __init__(self, input_dim, waypoints=10):
322
+ super().__init__()
323
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
324
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
325
+ self.encoder = nn.Linear(2, 64)
326
+ self.decoder = nn.Linear(64, 2)
327
+ self.waypoints = waypoints
328
+
329
+ def forward(self, x, target_point):
330
+ bs = x.shape[0]
331
+ z = self.encoder(target_point).unsqueeze(0)
332
+ output, _ = self.gru(x, z)
333
+ output = output.reshape(bs * self.waypoints, -1)
334
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
335
+ output = torch.cumsum(output, 1)
336
+ return output
337
+
338
+ class GRUWaypointsPredictorWithCommand(nn.Module):
339
+ def __init__(self, input_dim, waypoints=10):
340
+ super().__init__()
341
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
342
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
343
+ self.encoder = nn.Linear(2, 64)
344
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
345
+ self.waypoints = waypoints
346
+
347
+ def forward(self, x, target_point, measurements):
348
+ bs, n, dim = x.shape
349
+ mask = measurements[:, :6, None, None]
350
+ mask = mask.repeat(1, 1, self.waypoints, 2)
351
+
352
+ z = self.encoder(target_point).unsqueeze(0)
353
+ outputs = []
354
+ for i in range(6):
355
+ output, _ = self.grus[i](x, z)
356
+ output = output.reshape(bs * self.waypoints, -1)
357
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
358
+ output = torch.cumsum(output, 1)
359
+ outputs.append(output)
360
+ outputs = torch.stack(outputs, 1)
361
+ output = torch.sum(outputs * mask, dim=1)
362
+ return output
363
+
364
+
365
+ class TransformerDecoder(nn.Module):
366
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
367
+ super().__init__()
368
+ self.layers = _get_clones(decoder_layer, num_layers)
369
+ self.num_layers = num_layers
370
+ self.norm = norm
371
+ self.return_intermediate = return_intermediate
372
+
373
+ def forward(
374
+ self,
375
+ tgt,
376
+ memory,
377
+ tgt_mask: Optional[Tensor] = None,
378
+ memory_mask: Optional[Tensor] = None,
379
+ tgt_key_padding_mask: Optional[Tensor] = None,
380
+ memory_key_padding_mask: Optional[Tensor] = None,
381
+ pos: Optional[Tensor] = None,
382
+ query_pos: Optional[Tensor] = None,
383
+ ):
384
+ output = tgt
385
+
386
+ intermediate = []
387
+
388
+ for layer in self.layers:
389
+ output = layer(
390
+ output,
391
+ memory,
392
+ tgt_mask=tgt_mask,
393
+ memory_mask=memory_mask,
394
+ tgt_key_padding_mask=tgt_key_padding_mask,
395
+ memory_key_padding_mask=memory_key_padding_mask,
396
+ pos=pos,
397
+ query_pos=query_pos,
398
+ )
399
+ if self.return_intermediate:
400
+ intermediate.append(self.norm(output))
401
+
402
+ if self.norm is not None:
403
+ output = self.norm(output)
404
+ if self.return_intermediate:
405
+ intermediate.pop()
406
+ intermediate.append(output)
407
+
408
+ if self.return_intermediate:
409
+ return torch.stack(intermediate)
410
+
411
+ return output.unsqueeze(0)
412
+
413
+
414
  class TransformerEncoderLayer(nn.Module):
415
+ def __init__(
416
+ self,
417
+ d_model,
418
+ nhead,
419
+ dim_feedforward=2048,
420
+ dropout=0.1,
421
+ activation=nn.ReLU(),
422
+ normalize_before=False,
423
+ ):
424
  super().__init__()
425
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
426
+ # Implementation of Feedforward model
427
  self.linear1 = nn.Linear(d_model, dim_feedforward)
428
  self.dropout = nn.Dropout(dropout)
429
  self.linear2 = nn.Linear(dim_feedforward, d_model)
430
+
431
  self.norm1 = nn.LayerNorm(d_model)
432
  self.norm2 = nn.LayerNorm(d_model)
433
  self.dropout1 = nn.Dropout(dropout)
434
  self.dropout2 = nn.Dropout(dropout)
435
+
436
+ self.activation = activation()
437
+ self.normalize_before = normalize_before
438
 
439
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
440
  return tensor if pos is None else tensor + pos
441
 
442
+ def forward_post(
443
+ self,
444
+ src,
445
+ src_mask: Optional[Tensor] = None,
446
+ src_key_padding_mask: Optional[Tensor] = None,
447
+ pos: Optional[Tensor] = None,
448
+ ):
449
  q = k = self.with_pos_embed(src, pos)
450
+ src2 = self.self_attn(
451
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
452
+ )[0]
453
  src = src + self.dropout1(src2)
454
  src = self.norm1(src)
455
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
 
457
  src = self.norm2(src)
458
  return src
459
 
460
+ def forward_pre(
461
+ self,
462
+ src,
463
+ src_mask: Optional[Tensor] = None,
464
+ src_key_padding_mask: Optional[Tensor] = None,
465
+ pos: Optional[Tensor] = None,
466
+ ):
467
+ src2 = self.norm1(src)
468
+ q = k = self.with_pos_embed(src2, pos)
469
+ src2 = self.self_attn(
470
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
471
+ )[0]
472
+ src = src + self.dropout1(src2)
473
+ src2 = self.norm2(src)
474
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
475
+ src = src + self.dropout2(src2)
476
+ return src
477
+
478
+ def forward(
479
+ self,
480
+ src,
481
+ src_mask: Optional[Tensor] = None,
482
+ src_key_padding_mask: Optional[Tensor] = None,
483
+ pos: Optional[Tensor] = None,
484
+ ):
485
+ if self.normalize_before:
486
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
487
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
488
 
 
 
 
 
 
 
 
489
 
490
  class TransformerDecoderLayer(nn.Module):
491
+ def __init__(
492
+ self,
493
+ d_model,
494
+ nhead,
495
+ dim_feedforward=2048,
496
+ dropout=0.1,
497
+ activation=nn.ReLU(),
498
+ normalize_before=False,
499
+ ):
500
  super().__init__()
501
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
502
  self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
503
+ # Implementation of Feedforward model
504
  self.linear1 = nn.Linear(d_model, dim_feedforward)
505
  self.dropout = nn.Dropout(dropout)
506
  self.linear2 = nn.Linear(dim_feedforward, d_model)
507
+
508
  self.norm1 = nn.LayerNorm(d_model)
509
  self.norm2 = nn.LayerNorm(d_model)
510
  self.norm3 = nn.LayerNorm(d_model)
511
  self.dropout1 = nn.Dropout(dropout)
512
  self.dropout2 = nn.Dropout(dropout)
513
  self.dropout3 = nn.Dropout(dropout)
514
+
515
+ self.activation = activation()
516
+ self.normalize_before = normalize_before
517
 
518
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
519
  return tensor if pos is None else tensor + pos
520
 
521
+ def forward_post(
522
+ self,
523
+ tgt,
524
+ memory,
525
+ tgt_mask: Optional[Tensor] = None,
526
+ memory_mask: Optional[Tensor] = None,
527
+ tgt_key_padding_mask: Optional[Tensor] = None,
528
+ memory_key_padding_mask: Optional[Tensor] = None,
529
+ pos: Optional[Tensor] = None,
530
+ query_pos: Optional[Tensor] = None,
531
+ ):
532
  q = k = self.with_pos_embed(tgt, query_pos)
533
+ tgt2 = self.self_attn(
534
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
535
+ )[0]
536
  tgt = tgt + self.dropout1(tgt2)
537
  tgt = self.norm1(tgt)
538
+ tgt2 = self.multihead_attn(
539
+ query=self.with_pos_embed(tgt, query_pos),
540
+ key=self.with_pos_embed(memory, pos),
541
+ value=memory,
542
+ attn_mask=memory_mask,
543
+ key_padding_mask=memory_key_padding_mask,
544
+ )[0]
545
  tgt = tgt + self.dropout2(tgt2)
546
  tgt = self.norm2(tgt)
547
  tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
 
549
  tgt = self.norm3(tgt)
550
  return tgt
551
 
552
+ def forward_pre(
553
+ self,
554
+ tgt,
555
+ memory,
556
+ tgt_mask: Optional[Tensor] = None,
557
+ memory_mask: Optional[Tensor] = None,
558
+ tgt_key_padding_mask: Optional[Tensor] = None,
559
+ memory_key_padding_mask: Optional[Tensor] = None,
560
+ pos: Optional[Tensor] = None,
561
+ query_pos: Optional[Tensor] = None,
562
+ ):
563
+ tgt2 = self.norm1(tgt)
564
+ q = k = self.with_pos_embed(tgt2, query_pos)
565
+ tgt2 = self.self_attn(
566
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
567
+ )[0]
568
+ tgt = tgt + self.dropout1(tgt2)
569
+ tgt2 = self.norm2(tgt)
570
+ tgt2 = self.multihead_attn(
571
+ query=self.with_pos_embed(tgt2, query_pos),
572
+ key=self.with_pos_embed(memory, pos),
573
+ value=memory,
574
+ attn_mask=memory_mask,
575
+ key_padding_mask=memory_key_padding_mask,
576
+ )[0]
577
+ tgt = tgt + self.dropout2(tgt2)
578
+ tgt2 = self.norm3(tgt)
579
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
580
+ tgt = tgt + self.dropout3(tgt2)
581
+ return tgt
582
 
583
+ def forward(
584
+ self,
585
+ tgt,
586
+ memory,
587
+ tgt_mask: Optional[Tensor] = None,
588
+ memory_mask: Optional[Tensor] = None,
589
+ tgt_key_padding_mask: Optional[Tensor] = None,
590
+ memory_key_padding_mask: Optional[Tensor] = None,
591
+ pos: Optional[Tensor] = None,
592
+ query_pos: Optional[Tensor] = None,
593
+ ):
594
+ if self.normalize_before:
595
+ return self.forward_pre(
596
+ tgt,
597
+ memory,
598
+ tgt_mask,
599
+ memory_mask,
600
+ tgt_key_padding_mask,
601
+ memory_key_padding_mask,
602
+ pos,
603
+ query_pos,
604
+ )
605
+ return self.forward_post(
606
+ tgt,
607
+ memory,
608
+ tgt_mask,
609
+ memory_mask,
610
+ tgt_key_padding_mask,
611
+ memory_key_padding_mask,
612
+ pos,
613
+ query_pos,
614
+ )
615
 
 
 
 
 
 
 
 
616
 
617
+ def _get_clones(module, N):
618
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
619
+
 
 
 
 
 
620
 
621
+ def _get_activation_fn(activation):
622
+ """Return an activation function given a string"""
623
+ if activation == "relu":
624
+ return F.relu
625
+ if activation == "gelu":
626
+ return F.gelu
627
+ if activation == "glu":
628
+ return F.glu
629
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
630
 
 
 
631
 
632
+ def build_attn_mask(mask_type):
633
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
634
+ if mask_type == "seperate_all":
635
+ mask[:50, :50] = False
636
+ mask[50:67, 50:67] = False
637
+ mask[67:84, 67:84] = False
638
+ mask[84:101, 84:101] = False
639
+ mask[101:151, 101:151] = False
640
+ elif mask_type == "seperate_view":
641
+ mask[:50, :50] = False
642
+ mask[50:67, 50:67] = False
643
+ mask[67:84, 67:84] = False
644
+ mask[84:101, 84:101] = False
645
+ mask[101:151, :] = False
646
+ mask[:, 101:151] = False
647
+ return mask
648
+
649
+ class Interfuser(nn.Module):
650
  def __init__(
651
  self,
652
+ img_size=224,
653
+ multi_view_img_size=112,
654
+ patch_size=8,
655
+ in_chans=3,
656
+ embed_dim=768,
657
  enc_depth=6,
658
  dec_depth=6,
 
659
  dim_feedforward=2048,
660
+ normalize_before=False,
661
+ rgb_backbone_name="r26",
662
+ lidar_backbone_name="r26",
663
+ num_heads=8,
664
+ norm_layer=None,
665
  dropout=0.1,
666
+ end2end=False,
 
 
 
667
  direct_concat=True,
668
+ separate_view_attention=False,
669
+ separate_all_attention=False,
670
+ act_layer=None,
671
+ weight_init="",
672
+ freeze_num=-1,
673
+ with_lidar=False,
674
  with_right_left_sensors=True,
675
+ with_center_sensor=False,
676
+ traffic_pred_head_type="det",
677
+ waypoints_pred_head="heatmap",
678
+ reverse_pos=True,
679
+ use_different_backbone=False,
680
  use_view_embed=True,
681
+ use_mmad_pretrain=None,
682
  ):
683
+ super().__init__()
684
+ self.traffic_pred_head_type = traffic_pred_head_type
685
+ self.num_features = (
686
+ self.embed_dim
687
+ ) = embed_dim # num_features for consistency with other models
688
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
689
+ act_layer = act_layer or nn.GELU
690
+
691
+ self.reverse_pos = reverse_pos
692
  self.waypoints_pred_head = waypoints_pred_head
 
693
  self.with_lidar = with_lidar
694
  self.with_right_left_sensors = with_right_left_sensors
695
+ self.with_center_sensor = with_center_sensor
696
+
697
+ self.direct_concat = direct_concat
698
+ self.separate_view_attention = separate_view_attention
699
+ self.separate_all_attention = separate_all_attention
700
+ self.end2end = end2end
701
  self.use_view_embed = use_view_embed
702
+
703
+ if self.direct_concat:
704
+ in_chans = in_chans * 4
705
+ self.with_center_sensor = False
706
+ self.with_right_left_sensors = False
707
+
708
+ if self.separate_view_attention:
709
+ self.attn_mask = build_attn_mask("seperate_view")
710
+ elif self.separate_all_attention:
711
+ self.attn_mask = build_attn_mask("seperate_all")
712
+ else:
713
+ self.attn_mask = None
714
+
715
+ if use_different_backbone:
716
+ if rgb_backbone_name == "r50":
717
+ self.rgb_backbone = resnet50d(
718
+ pretrained=True,
719
+ in_chans=in_chans,
720
+ features_only=True,
721
+ out_indices=[4],
722
+ )
723
+ elif rgb_backbone_name == "r26":
724
+ self.rgb_backbone = resnet26d(
725
+ pretrained=True,
726
+ in_chans=in_chans,
727
+ features_only=True,
728
+ out_indices=[4],
729
+ )
730
+ elif rgb_backbone_name == "r18":
731
+ self.rgb_backbone = resnet18d(
732
+ pretrained=True,
733
+ in_chans=in_chans,
734
+ features_only=True,
735
+ out_indices=[4],
736
+ )
737
+ if lidar_backbone_name == "r50":
738
+ self.lidar_backbone = resnet50d(
739
+ pretrained=False,
740
+ in_chans=in_chans,
741
+ features_only=True,
742
+ out_indices=[4],
743
+ )
744
+ elif lidar_backbone_name == "r26":
745
+ self.lidar_backbone = resnet26d(
746
+ pretrained=False,
747
+ in_chans=in_chans,
748
+ features_only=True,
749
+ out_indices=[4],
750
+ )
751
+ elif lidar_backbone_name == "r18":
752
+ self.lidar_backbone = resnet18d(
753
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
754
+ )
755
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
756
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
757
+
758
+ if use_mmad_pretrain:
759
+ params = torch.load(use_mmad_pretrain)["state_dict"]
760
+ updated_params = OrderedDict()
761
+ for key in params:
762
+ if "backbone" in key:
763
+ updated_params[key.replace("backbone.", "")] = params[key]
764
+ self.rgb_backbone.load_state_dict(updated_params)
765
+
766
+ self.rgb_patch_embed = rgb_embed_layer(
767
+ img_size=img_size,
768
+ patch_size=patch_size,
769
+ in_chans=in_chans,
770
+ embed_dim=embed_dim,
771
+ )
772
+ self.lidar_patch_embed = lidar_embed_layer(
773
+ img_size=img_size,
774
+ patch_size=patch_size,
775
+ in_chans=3,
776
+ embed_dim=embed_dim,
777
+ )
778
+ else:
779
+ if rgb_backbone_name == "r50":
780
+ self.rgb_backbone = resnet50d(
781
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
782
+ )
783
+ elif rgb_backbone_name == "r101":
784
+ self.rgb_backbone = resnet101d(
785
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
786
+ )
787
+ elif rgb_backbone_name == "r26":
788
+ self.rgb_backbone = resnet26d(
789
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
790
+ )
791
+ elif rgb_backbone_name == "r18":
792
+ self.rgb_backbone = resnet18d(
793
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
794
+ )
795
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
796
+
797
+ self.rgb_patch_embed = embed_layer(
798
+ img_size=img_size,
799
+ patch_size=patch_size,
800
+ in_chans=in_chans,
801
+ embed_dim=embed_dim,
802
+ )
803
+ self.lidar_patch_embed = embed_layer(
804
+ img_size=img_size,
805
+ patch_size=patch_size,
806
+ in_chans=in_chans,
807
+ embed_dim=embed_dim,
808
+ )
809
+
810
  self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
811
  self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
812
+
813
+ if self.end2end:
814
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
815
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
816
+ elif self.waypoints_pred_head == "heatmap":
817
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
818
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
819
+ else:
820
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
821
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
822
+
823
+ if self.end2end:
824
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
825
+ elif self.waypoints_pred_head == "heatmap":
826
+ self.waypoints_generator = MultiPath_Generator(
827
+ embed_dim + 32, embed_dim, 10
828
+ )
829
+ elif self.waypoints_pred_head == "gru":
830
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
831
+ elif self.waypoints_pred_head == "gru-command":
832
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
833
+ elif self.waypoints_pred_head == "linear":
834
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
835
+ elif self.waypoints_pred_head == "linear-sum":
836
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
837
+
838
  self.junction_pred_head = nn.Linear(embed_dim, 2)
839
  self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
840
  self.stop_sign_head = nn.Linear(embed_dim, 2)
841
+
842
+ if self.traffic_pred_head_type == "det":
843
+ self.traffic_pred_head = nn.Sequential(
844
+ *[
845
+ nn.Linear(embed_dim + 32, 64),
846
+ nn.ReLU(),
847
+ nn.Linear(64, 7),
848
+ nn.Sigmoid(),
849
+ ]
850
+ )
851
+ elif self.traffic_pred_head_type == "seg":
852
+ self.traffic_pred_head = nn.Sequential(
853
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
854
+ )
855
+
856
  self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
857
+
858
+ encoder_layer = TransformerEncoderLayer(
859
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
860
+ )
 
 
 
861
  self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
 
 
 
862
 
863
+ decoder_layer = TransformerDecoderLayer(
864
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
865
+ )
866
+ decoder_norm = nn.LayerNorm(embed_dim)
867
+ self.decoder = TransformerDecoder(
868
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
869
+ )
870
+ self.reset_parameters()
871
+
872
+ def reset_parameters(self):
873
+ nn.init.uniform_(self.global_embed)
874
+ nn.init.uniform_(self.view_embed)
875
+ nn.init.uniform_(self.query_embed)
876
+ nn.init.uniform_(self.query_pos_embed)
877
+
878
+ def forward_features(
879
+ self,
880
+ front_image,
881
+ left_image,
882
+ right_image,
883
+ front_center_image,
884
+ lidar,
885
+ measurements,
886
+ ):
887
  features = []
888
+
889
+ # Front view processing
890
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
891
+ if self.use_view_embed:
892
+ front_image_token = (
893
+ front_image_token
894
+ + self.view_embed[:, :, 0:1, :]
895
+ + self.position_encoding(front_image_token)
896
+ )
897
+ else:
898
+ front_image_token = front_image_token + self.position_encoding(
899
+ front_image_token
900
+ )
901
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
902
+ front_image_token_global = (
903
+ front_image_token_global
904
+ + self.view_embed[:, :, 0, :]
905
+ + self.global_embed[:, :, 0:1]
906
+ )
907
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
908
+ features.extend([front_image_token, front_image_token_global])
909
+
910
+ if self.with_right_left_sensors:
911
+ # Left view processing
912
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
913
+ if self.use_view_embed:
914
+ left_image_token = (
915
+ left_image_token
916
+ + self.view_embed[:, :, 1:2, :]
917
+ + self.position_encoding(left_image_token)
918
+ )
919
+ else:
920
+ left_image_token = left_image_token + self.position_encoding(
921
+ left_image_token
922
+ )
923
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
924
+ left_image_token_global = (
925
+ left_image_token_global
926
+ + self.view_embed[:, :, 1, :]
927
+ + self.global_embed[:, :, 1:2]
928
+ )
929
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
930
+
931
+ # Right view processing
932
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
933
+ right_image
934
+ )
935
+ if self.use_view_embed:
936
+ right_image_token = (
937
+ right_image_token
938
+ + self.view_embed[:, :, 2:3, :]
939
+ + self.position_encoding(right_image_token)
940
+ )
941
+ else:
942
+ right_image_token = right_image_token + self.position_encoding(
943
+ right_image_token
944
+ )
945
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
946
+ right_image_token_global = (
947
+ right_image_token_global
948
+ + self.view_embed[:, :, 2, :]
949
+ + self.global_embed[:, :, 2:3]
950
+ )
951
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
952
+
953
+ features.extend(
954
+ [
955
+ left_image_token,
956
+ left_image_token_global,
957
+ right_image_token,
958
+ right_image_token_global,
959
+ ]
960
+ )
961
+
962
+ if self.with_center_sensor:
963
+ # Front center view processing
964
+ (
965
+ front_center_image_token,
966
+ front_center_image_token_global,
967
+ ) = self.rgb_patch_embed(front_center_image)
968
+ if self.use_view_embed:
969
+ front_center_image_token = (
970
+ front_center_image_token
971
+ + self.view_embed[:, :, 3:4, :]
972
+ + self.position_encoding(front_center_image_token)
973
+ )
974
+ else:
975
+ front_center_image_token = (
976
+ front_center_image_token
977
+ + self.position_encoding(front_center_image_token)
978
+ )
979
+
980
+ front_center_image_token = front_center_image_token.flatten(2).permute(
981
+ 2, 0, 1
982
+ )
983
+ front_center_image_token_global = (
984
+ front_center_image_token_global
985
+ + self.view_embed[:, :, 3, :]
986
+ + self.global_embed[:, :, 3:4]
987
+ )
988
+ front_center_image_token_global = front_center_image_token_global.permute(
989
+ 2, 0, 1
990
+ )
991
+ features.extend([front_center_image_token, front_center_image_token_global])
992
 
993
  if self.with_lidar:
994
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
995
+ if self.use_view_embed:
996
+ lidar_token = (
997
+ lidar_token
998
+ + self.view_embed[:, :, 4:5, :]
999
+ + self.position_encoding(lidar_token)
1000
+ )
1001
+ else:
1002
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
1003
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
1004
+ lidar_token_global = (
1005
+ lidar_token_global
1006
+ + self.view_embed[:, :, 4, :]
1007
+ + self.global_embed[:, :, 4:5]
1008
+ )
1009
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
1010
+ features.extend([lidar_token, lidar_token_global])
1011
+
1012
+ features = torch.cat(features, 0)
1013
+ return features
1014
+
1015
+ def forward(self, x):
1016
+ front_image = x["rgb"]
1017
+ left_image = x["rgb_left"]
1018
+ right_image = x["rgb_right"]
1019
+ front_center_image = x["rgb_center"]
1020
+ measurements = x["measurements"]
1021
+ target_point = x["target_point"]
1022
+ lidar = x["lidar"]
1023
+
1024
  if self.direct_concat:
1025
  img_size = front_image.shape[-1]
1026
+ left_image = torch.nn.functional.interpolate(
1027
+ left_image, size=(img_size, img_size)
1028
+ )
1029
+ right_image = torch.nn.functional.interpolate(
1030
+ right_image, size=(img_size, img_size)
1031
+ )
1032
+ front_center_image = torch.nn.functional.interpolate(
1033
+ front_center_image, size=(img_size, img_size)
1034
+ )
1035
+ front_image = torch.cat(
1036
+ [front_image, left_image, right_image, front_center_image], dim=1
1037
+ )
1038
+ features = self.forward_features(
1039
+ front_image,
1040
+ left_image,
1041
+ right_image,
1042
+ front_center_image,
1043
+ lidar,
1044
+ measurements,
1045
+ )
1046
+
1047
  bs = front_image.shape[0]
1048
+
1049
+ if self.end2end:
1050
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
1051
+ else:
1052
+ tgt = self.position_encoding(
1053
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1054
+ )
1055
+ tgt = tgt.flatten(2)
1056
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1057
+ tgt = tgt.permute(2, 0, 1)
1058
+
1059
+ memory = self.encoder(features, mask=self.attn_mask)
1060
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1061
+
1062
+ hs = hs.permute(1, 0, 2) # Batchsize , N, C
1063
+ if self.end2end:
1064
+ waypoints = self.waypoints_generator(hs, target_point)
1065
+ return waypoints
1066
+
1067
+ if self.waypoints_pred_head != "heatmap":
1068
+ traffic_feature = hs[:, :400]
1069
+ is_junction_feature = hs[:, 400]
1070
+ traffic_light_state_feature = hs[:, 400]
1071
+ stop_sign_feature = hs[:, 400]
1072
+ waypoints_feature = hs[:, 401:411]
1073
+ else:
1074
+ traffic_feature = hs[:, :400]
1075
+ is_junction_feature = hs[:, 400]
1076
+ traffic_light_state_feature = hs[:, 400]
1077
+ stop_sign_feature = hs[:, 400]
1078
+ waypoints_feature = hs[:, 401:405]
1079
+
1080
+ if self.waypoints_pred_head == "heatmap":
1081
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1082
+ elif self.waypoints_pred_head == "gru":
1083
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1084
+ elif self.waypoints_pred_head == "gru-command":
1085
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1086
+ elif self.waypoints_pred_head == "linear":
1087
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1088
+ elif self.waypoints_pred_head == "linear-sum":
1089
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1090
+
1091
  is_junction = self.junction_pred_head(is_junction_feature)
1092
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1093
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1094
+
1095
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1096
+ velocity = velocity.repeat(1, 400, 32)
1097
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1098
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1099
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature