mohammed-aljafry commited on
Commit
ce78280
·
verified ·
1 Parent(s): d7bf34a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +72 -433
model.py CHANGED
@@ -619,456 +619,95 @@ def build_attn_mask(mask_type):
619
  mask[:, 101:151] = False
620
  return mask
621
 
622
- class Interfuser(nn.Module):
 
 
623
  def __init__(
624
  self,
625
- img_size=224,
626
- multi_view_img_size=112,
627
- patch_size=8,
628
- in_chans=3,
629
- embed_dim=768,
630
  enc_depth=6,
631
  dec_depth=6,
632
- dim_feedforward=2048,
633
- normalize_before=False,
634
- rgb_backbone_name="r26",
635
- lidar_backbone_name="r26",
636
  num_heads=8,
637
- norm_layer=None,
638
  dropout=0.1,
639
- end2end=False,
640
- direct_concat=True,
641
- separate_view_attention=False,
642
- separate_all_attention=False,
643
- act_layer=None,
644
- weight_init="",
645
- freeze_num=-1,
646
- with_lidar=False,
647
- with_right_left_sensors=True,
648
- with_center_sensor=False,
649
- traffic_pred_head_type="det",
650
- waypoints_pred_head="heatmap",
651
- reverse_pos=True,
652
- use_different_backbone=False,
653
- use_view_embed=True,
654
- use_mmad_pretrain=None,
655
  ):
656
- super().__init__()
657
- self.traffic_pred_head_type = traffic_pred_head_type
658
- self.num_features = (
659
- self.embed_dim
660
- ) = embed_dim # num_features for consistency with other models
661
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
662
- act_layer = act_layer or nn.GELU
663
-
664
- self.reverse_pos = reverse_pos
665
  self.waypoints_pred_head = waypoints_pred_head
666
- self.with_lidar = with_lidar
667
- self.with_right_left_sensors = with_right_left_sensors
668
- self.with_center_sensor = with_center_sensor
669
-
670
- self.direct_concat = direct_concat
671
- self.separate_view_attention = separate_view_attention
672
- self.separate_all_attention = separate_all_attention
673
- self.end2end = end2end
674
- self.use_view_embed = use_view_embed
675
-
676
- if self.direct_concat:
677
- in_chans = in_chans * 4
678
- self.with_center_sensor = False
679
- self.with_right_left_sensors = False
680
-
681
- if self.separate_view_attention:
682
- self.attn_mask = build_attn_mask("seperate_view")
683
- elif self.separate_all_attention:
684
- self.attn_mask = build_attn_mask("seperate_all")
685
- else:
686
- self.attn_mask = None
687
-
688
  if use_different_backbone:
689
- if rgb_backbone_name == "r50":
690
- self.rgb_backbone = resnet50d(
691
- pretrained=True,
692
- in_chans=in_chans,
693
- features_only=True,
694
- out_indices=[4],
695
- )
696
- elif rgb_backbone_name == "r26":
697
- self.rgb_backbone = resnet26d(
698
- pretrained=True,
699
- in_chans=in_chans,
700
- features_only=True,
701
- out_indices=[4],
702
- )
703
- elif rgb_backbone_name == "r18":
704
- self.rgb_backbone = resnet18d(
705
- pretrained=True,
706
- in_chans=in_chans,
707
- features_only=True,
708
- out_indices=[4],
709
- )
710
- if lidar_backbone_name == "r50":
711
- self.lidar_backbone = resnet50d(
712
- pretrained=False,
713
- in_chans=in_chans,
714
- features_only=True,
715
- out_indices=[4],
716
- )
717
- elif lidar_backbone_name == "r26":
718
- self.lidar_backbone = resnet26d(
719
- pretrained=False,
720
- in_chans=in_chans,
721
- features_only=True,
722
- out_indices=[4],
723
- )
724
- elif lidar_backbone_name == "r18":
725
- self.lidar_backbone = resnet18d(
726
- pretrained=False, in_chans=3, features_only=True, out_indices=[4]
727
- )
728
- rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
729
- lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
730
-
731
- if use_mmad_pretrain:
732
- params = torch.load(use_mmad_pretrain)["state_dict"]
733
- updated_params = OrderedDict()
734
- for key in params:
735
- if "backbone" in key:
736
- updated_params[key.replace("backbone.", "")] = params[key]
737
- self.rgb_backbone.load_state_dict(updated_params)
738
-
739
- self.rgb_patch_embed = rgb_embed_layer(
740
- img_size=img_size,
741
- patch_size=patch_size,
742
- in_chans=in_chans,
743
- embed_dim=embed_dim,
744
- )
745
- self.lidar_patch_embed = lidar_embed_layer(
746
- img_size=img_size,
747
- patch_size=patch_size,
748
- in_chans=3,
749
- embed_dim=embed_dim,
750
- )
751
- else:
752
- if rgb_backbone_name == "r50":
753
- self.rgb_backbone = resnet50d(
754
- pretrained=True, in_chans=3, features_only=True, out_indices=[4]
755
- )
756
- elif rgb_backbone_name == "r101":
757
- self.rgb_backbone = resnet101d(
758
- pretrained=True, in_chans=3, features_only=True, out_indices=[4]
759
- )
760
- elif rgb_backbone_name == "r26":
761
- self.rgb_backbone = resnet26d(
762
- pretrained=True, in_chans=3, features_only=True, out_indices=[4]
763
- )
764
- elif rgb_backbone_name == "r18":
765
- self.rgb_backbone = resnet18d(
766
- pretrained=True, in_chans=3, features_only=True, out_indices=[4]
767
- )
768
- embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
769
-
770
- self.rgb_patch_embed = embed_layer(
771
- img_size=img_size,
772
- patch_size=patch_size,
773
- in_chans=in_chans,
774
- embed_dim=embed_dim,
775
- )
776
- self.lidar_patch_embed = embed_layer(
777
- img_size=img_size,
778
- patch_size=patch_size,
779
- in_chans=in_chans,
780
- embed_dim=embed_dim,
781
- )
782
-
783
- self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
784
- self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
785
-
786
- if self.end2end:
787
- self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
788
- self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
789
- elif self.waypoints_pred_head == "heatmap":
790
- self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
791
- self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
792
- else:
793
- self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
794
- self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
795
-
796
- if self.end2end:
797
- self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
798
- elif self.waypoints_pred_head == "heatmap":
799
- self.waypoints_generator = MultiPath_Generator(
800
- embed_dim + 32, embed_dim, 10
801
- )
802
- elif self.waypoints_pred_head == "gru":
803
- self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
804
- elif self.waypoints_pred_head == "gru-command":
805
- self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
806
- elif self.waypoints_pred_head == "linear":
807
- self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
808
- elif self.waypoints_pred_head == "linear-sum":
809
- self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
810
-
811
- self.junction_pred_head = nn.Linear(embed_dim, 2)
812
- self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
813
- self.stop_sign_head = nn.Linear(embed_dim, 2)
814
-
815
- if self.traffic_pred_head_type == "det":
816
- self.traffic_pred_head = nn.Sequential(
817
- *[
818
- nn.Linear(embed_dim + 32, 64),
819
- nn.ReLU(),
820
- nn.Linear(64, 7),
821
- nn.Sigmoid(),
822
- ]
823
- )
824
- elif self.traffic_pred_head_type == "seg":
825
- self.traffic_pred_head = nn.Sequential(
826
- *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
827
- )
828
-
829
  self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
830
-
831
- encoder_layer = TransformerEncoderLayer(
832
- embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
833
- )
834
  self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
 
 
835
 
836
- decoder_layer = TransformerDecoderLayer(
837
- embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
838
- )
839
- decoder_norm = nn.LayerNorm(embed_dim)
840
- self.decoder = TransformerDecoder(
841
- decoder_layer, dec_depth, decoder_norm, return_intermediate=False
842
- )
843
- self.reset_parameters()
844
-
845
- def reset_parameters(self):
846
- nn.init.uniform_(self.global_embed)
847
- nn.init.uniform_(self.view_embed)
848
- nn.init.uniform_(self.query_embed)
849
- nn.init.uniform_(self.query_pos_embed)
850
-
851
- def forward_features(
852
- self,
853
- front_image,
854
- left_image,
855
- right_image,
856
- front_center_image,
857
- lidar,
858
- measurements,
859
- ):
860
- features = []
861
-
862
- # Front view processing
863
- front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
864
- if self.use_view_embed:
865
- front_image_token = (
866
- front_image_token
867
- + self.view_embed[:, :, 0:1, :]
868
- + self.position_encoding(front_image_token)
869
- )
870
- else:
871
- front_image_token = front_image_token + self.position_encoding(
872
- front_image_token
873
- )
874
- front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
875
- front_image_token_global = (
876
- front_image_token_global
877
- + self.view_embed[:, :, 0, :]
878
- + self.global_embed[:, :, 0:1]
879
- )
880
- front_image_token_global = front_image_token_global.permute(2, 0, 1)
881
- features.extend([front_image_token, front_image_token_global])
882
-
883
  if self.with_right_left_sensors:
884
- # Left view processing
885
- left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
886
- if self.use_view_embed:
887
- left_image_token = (
888
- left_image_token
889
- + self.view_embed[:, :, 1:2, :]
890
- + self.position_encoding(left_image_token)
891
- )
892
- else:
893
- left_image_token = left_image_token + self.position_encoding(
894
- left_image_token
895
- )
896
- left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
897
- left_image_token_global = (
898
- left_image_token_global
899
- + self.view_embed[:, :, 1, :]
900
- + self.global_embed[:, :, 1:2]
901
- )
902
- left_image_token_global = left_image_token_global.permute(2, 0, 1)
903
-
904
- # Right view processing
905
- right_image_token, right_image_token_global = self.rgb_patch_embed(
906
- right_image
907
- )
908
- if self.use_view_embed:
909
- right_image_token = (
910
- right_image_token
911
- + self.view_embed[:, :, 2:3, :]
912
- + self.position_encoding(right_image_token)
913
- )
914
- else:
915
- right_image_token = right_image_token + self.position_encoding(
916
- right_image_token
917
- )
918
- right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
919
- right_image_token_global = (
920
- right_image_token_global
921
- + self.view_embed[:, :, 2, :]
922
- + self.global_embed[:, :, 2:3]
923
- )
924
- right_image_token_global = right_image_token_global.permute(2, 0, 1)
925
-
926
- features.extend(
927
- [
928
- left_image_token,
929
- left_image_token_global,
930
- right_image_token,
931
- right_image_token_global,
932
- ]
933
- )
934
-
935
- if self.with_center_sensor:
936
- # Front center view processing
937
- (
938
- front_center_image_token,
939
- front_center_image_token_global,
940
- ) = self.rgb_patch_embed(front_center_image)
941
- if self.use_view_embed:
942
- front_center_image_token = (
943
- front_center_image_token
944
- + self.view_embed[:, :, 3:4, :]
945
- + self.position_encoding(front_center_image_token)
946
- )
947
- else:
948
- front_center_image_token = (
949
- front_center_image_token
950
- + self.position_encoding(front_center_image_token)
951
- )
952
-
953
- front_center_image_token = front_center_image_token.flatten(2).permute(
954
- 2, 0, 1
955
- )
956
- front_center_image_token_global = (
957
- front_center_image_token_global
958
- + self.view_embed[:, :, 3, :]
959
- + self.global_embed[:, :, 3:4]
960
- )
961
- front_center_image_token_global = front_center_image_token_global.permute(
962
- 2, 0, 1
963
- )
964
- features.extend([front_center_image_token, front_center_image_token_global])
965
-
966
  if self.with_lidar:
967
- lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
968
- if self.use_view_embed:
969
- lidar_token = (
970
- lidar_token
971
- + self.view_embed[:, :, 4:5, :]
972
- + self.position_encoding(lidar_token)
973
- )
974
- else:
975
- lidar_token = lidar_token + self.position_encoding(lidar_token)
976
- lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
977
- lidar_token_global = (
978
- lidar_token_global
979
- + self.view_embed[:, :, 4, :]
980
- + self.global_embed[:, :, 4:5]
981
- )
982
- lidar_token_global = lidar_token_global.permute(2, 0, 1)
983
- features.extend([lidar_token, lidar_token_global])
984
-
985
- features = torch.cat(features, 0)
986
- return features
987
-
988
- def forward(self, x):
989
- front_image = x["rgb"]
990
- left_image = x["rgb_left"]
991
- right_image = x["rgb_right"]
992
- front_center_image = x["rgb_center"]
993
- measurements = x["measurements"]
994
- target_point = x["target_point"]
995
- lidar = x["lidar"]
996
-
997
  if self.direct_concat:
998
- img_size = front_image.shape[-1]
999
- left_image = torch.nn.functional.interpolate(
1000
- left_image, size=(img_size, img_size)
1001
- )
1002
- right_image = torch.nn.functional.interpolate(
1003
- right_image, size=(img_size, img_size)
1004
- )
1005
- front_center_image = torch.nn.functional.interpolate(
1006
- front_center_image, size=(img_size, img_size)
1007
- )
1008
- front_image = torch.cat(
1009
- [front_image, left_image, right_image, front_center_image], dim=1
1010
- )
1011
- features = self.forward_features(
1012
- front_image,
1013
- left_image,
1014
- right_image,
1015
- front_center_image,
1016
- lidar,
1017
- measurements,
1018
- )
1019
-
1020
  bs = front_image.shape[0]
1021
-
1022
- if self.end2end:
1023
- tgt = self.query_pos_embed.repeat(bs, 1, 1)
1024
- else:
1025
- tgt = self.position_encoding(
1026
- torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1027
- )
1028
- tgt = tgt.flatten(2)
1029
- tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1030
- tgt = tgt.permute(2, 0, 1)
1031
-
1032
- memory = self.encoder(features, mask=self.attn_mask)
1033
- hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1034
-
1035
- hs = hs.permute(1, 0, 2) # Batchsize , N, C
1036
- if self.end2end:
1037
- waypoints = self.waypoints_generator(hs, target_point)
1038
- return waypoints
1039
-
1040
- if self.waypoints_pred_head != "heatmap":
1041
- traffic_feature = hs[:, :400]
1042
- is_junction_feature = hs[:, 400]
1043
- traffic_light_state_feature = hs[:, 400]
1044
- stop_sign_feature = hs[:, 400]
1045
- waypoints_feature = hs[:, 401:411]
1046
- else:
1047
- traffic_feature = hs[:, :400]
1048
- is_junction_feature = hs[:, 400]
1049
- traffic_light_state_feature = hs[:, 400]
1050
- stop_sign_feature = hs[:, 400]
1051
- waypoints_feature = hs[:, 401:405]
1052
-
1053
- if self.waypoints_pred_head == "heatmap":
1054
- waypoints = self.waypoints_generator(waypoints_feature, measurements)
1055
- elif self.waypoints_pred_head == "gru":
1056
- waypoints = self.waypoints_generator(waypoints_feature, target_point)
1057
- elif self.waypoints_pred_head == "gru-command":
1058
- waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1059
- elif self.waypoints_pred_head == "linear":
1060
- waypoints = self.waypoints_generator(waypoints_feature, measurements)
1061
- elif self.waypoints_pred_head == "linear-sum":
1062
- waypoints = self.waypoints_generator(waypoints_feature, measurements)
1063
-
1064
  is_junction = self.junction_pred_head(is_junction_feature)
1065
- traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1066
- stop_sign = self.stop_sign_head(stop_sign_feature)
1067
-
1068
- velocity = measurements[:, 6:7].unsqueeze(-1)
1069
- velocity = velocity.repeat(1, 400, 32)
1070
- traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1071
- traffic = self.traffic_pred_head(traffic_feature_with_vel)
1072
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1073
 
1074
 
 
619
  mask[:, 101:151] = False
620
  return mask
621
 
622
+ class InterfuserConfig(PretrainedConfig):
623
+ model_type = "interfuser"
624
+
625
  def __init__(
626
  self,
627
+ embed_dim=256,
 
 
 
 
628
  enc_depth=6,
629
  dec_depth=6,
 
 
 
 
630
  num_heads=8,
631
+ dim_feedforward=2048,
632
  dropout=0.1,
633
+ rgb_backbone_name="r50",
634
+ lidar_backbone_name="r18",
635
+ use_different_backbone=True,
636
+ waypoints_pred_head="gru",
637
+ **kwargs,
 
 
 
 
 
 
 
 
 
 
 
638
  ):
639
+ self.embed_dim = embed_dim
640
+ self.enc_depth = enc_depth
641
+ self.dec_depth = dec_depth
642
+ self.num_heads = num_heads
643
+ self.dim_feedforward = dim_feedforward
644
+ self.dropout = dropout
645
+ self.rgb_backbone_name = rgb_backbone_name
646
+ self.lidar_backbone_name = lidar_backbone_name
647
+ self.use_different_backbone = use_different_backbone
648
  self.waypoints_pred_head = waypoints_pred_head
649
+ super().__init__(**kwargs)
650
+
651
+ class InterfuserModel(PreTrainedModel):
652
+ config_class = InterfuserConfig
653
+
654
+ def __init__(self, config: InterfuserConfig):
655
+ super().__init__(config)
656
+ self.config = config
657
+
658
+ embed_dim=config.embed_dim; enc_depth=config.enc_depth; dec_depth=config.dec_depth; num_heads=config.num_heads; dim_feedforward=config.dim_feedforward; dropout=config.dropout
659
+ rgb_backbone_name=config.rgb_backbone_name; lidar_backbone_name=config.lidar_backbone_name; use_different_backbone=config.use_different_backbone
660
+ in_chans=3; img_size=224; direct_concat=True; with_lidar=True; with_right_left_sensors=True
661
+
662
+ self.embed_dim = embed_dim; self.direct_concat = direct_concat; self.with_lidar = with_lidar; self.with_right_left_sensors = with_right_left_sensors
663
+ in_chans_rgb = in_chans * 4 if self.direct_concat else in_chans
664
+
 
 
 
 
 
 
665
  if use_different_backbone:
666
+ self.rgb_backbone = {'r50': resnet50d, 'r26': resnet26d, 'r18': resnet18d}[rgb_backbone_name](pretrained=False, in_chans=in_chans_rgb, features_only=True, out_indices=[4])
667
+ self.lidar_backbone = {'r50': resnet50d, 'r26': resnet26d, 'r18': resnet18d}[lidar_backbone_name](pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4])
668
+ self.rgb_patch_embed = HybridEmbed(self.rgb_backbone, img_size=img_size, in_chans=in_chans_rgb, embed_dim=embed_dim)
669
+ self.lidar_patch_embed = HybridEmbed(self.lidar_backbone, img_size=112, in_chans=in_chans, embed_dim=embed_dim)
670
+
671
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)); self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
672
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)); self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
673
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
674
+ self.junction_pred_head = nn.Linear(embed_dim, 2); self.traffic_light_pred_head = nn.Linear(embed_dim, 2); self.stop_sign_head = nn.Linear(embed_dim, 2)
675
+ self.traffic_pred_head = nn.Sequential(*[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
677
+ act_layer = nn.GELU()
678
+ encoder_layer = TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer)
 
 
679
  self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
680
+ decoder_layer = TransformerDecoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer)
681
+ self.decoder = TransformerDecoder(decoder_layer, dec_depth, nn.LayerNorm(embed_dim), return_intermediate=False)
682
 
683
+ def forward_features(self, front_image, left_image, right_image, lidar):
684
+ features = [];
685
+ x, x_g = self.rgb_patch_embed(front_image); x = x + self.view_embed[:,:,0:1,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,0,:]+self.global_embed[:,:,0:1]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
  if self.with_right_left_sensors:
687
+ x, x_g = self.rgb_patch_embed(left_image); x = x + self.view_embed[:,:,1:2,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,1,:]+self.global_embed[:,:,1:2]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
688
+ x, x_g = self.rgb_patch_embed(right_image); x = x + self.view_embed[:,:,2:3,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,2,:]+self.global_embed[:,:,2:3]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  if self.with_lidar:
690
+ x, x_g = self.lidar_patch_embed(lidar); x = x + self.view_embed[:,:,4:5,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,4,:]+self.global_embed[:,:,4:5]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
691
+ return torch.cat(features, 0)
692
+
693
+ def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
694
+ front_image=rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  if self.direct_concat:
696
+ img_size=front_image.shape[-1]; front_image=torch.cat([front_image,F.interpolate(rgb_left,s:=(img_size,img_size)),F.interpolate(rgb_right,s),F.interpolate(rgb_center,s)],dim=1)
697
+
698
+ features = self.forward_features(front_image, rgb_left, rgb_right, lidar)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  bs = front_image.shape[0]
700
+ tgt = self.position_encoding(torch.ones((bs, 1, 20, 20), device=rgb.device)).flatten(2)
701
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2).permute(2, 0, 1)
702
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), self.encoder(features), query_pos=tgt)[0].permute(1, 0, 2)
703
+
704
+ traffic_feature = hs[:, :400]; waypoints_feature = hs[:, 401:411]; is_junction_feature = hs[:, 400]
705
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  is_junction = self.junction_pred_head(is_junction_feature)
707
+ traffic_light_state = self.traffic_light_pred_head(is_junction_feature)
708
+ stop_sign = self.stop_sign_head(is_junction_feature)
709
+ traffic = self.traffic_pred_head(traffic_feature)
710
+
 
 
 
711
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
712
 
713