Upload folder using huggingface_hub
Browse files
model.py
CHANGED
@@ -619,456 +619,95 @@ def build_attn_mask(mask_type):
|
|
619 |
mask[:, 101:151] = False
|
620 |
return mask
|
621 |
|
622 |
-
class
|
|
|
|
|
623 |
def __init__(
|
624 |
self,
|
625 |
-
|
626 |
-
multi_view_img_size=112,
|
627 |
-
patch_size=8,
|
628 |
-
in_chans=3,
|
629 |
-
embed_dim=768,
|
630 |
enc_depth=6,
|
631 |
dec_depth=6,
|
632 |
-
dim_feedforward=2048,
|
633 |
-
normalize_before=False,
|
634 |
-
rgb_backbone_name="r26",
|
635 |
-
lidar_backbone_name="r26",
|
636 |
num_heads=8,
|
637 |
-
|
638 |
dropout=0.1,
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
weight_init="",
|
645 |
-
freeze_num=-1,
|
646 |
-
with_lidar=False,
|
647 |
-
with_right_left_sensors=True,
|
648 |
-
with_center_sensor=False,
|
649 |
-
traffic_pred_head_type="det",
|
650 |
-
waypoints_pred_head="heatmap",
|
651 |
-
reverse_pos=True,
|
652 |
-
use_different_backbone=False,
|
653 |
-
use_view_embed=True,
|
654 |
-
use_mmad_pretrain=None,
|
655 |
):
|
656 |
-
|
657 |
-
self.
|
658 |
-
self.
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
self.
|
665 |
self.waypoints_pred_head = waypoints_pred_head
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
self.
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
self.attn_mask = build_attn_mask("seperate_view")
|
683 |
-
elif self.separate_all_attention:
|
684 |
-
self.attn_mask = build_attn_mask("seperate_all")
|
685 |
-
else:
|
686 |
-
self.attn_mask = None
|
687 |
-
|
688 |
if use_different_backbone:
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
in_chans=in_chans,
|
700 |
-
features_only=True,
|
701 |
-
out_indices=[4],
|
702 |
-
)
|
703 |
-
elif rgb_backbone_name == "r18":
|
704 |
-
self.rgb_backbone = resnet18d(
|
705 |
-
pretrained=True,
|
706 |
-
in_chans=in_chans,
|
707 |
-
features_only=True,
|
708 |
-
out_indices=[4],
|
709 |
-
)
|
710 |
-
if lidar_backbone_name == "r50":
|
711 |
-
self.lidar_backbone = resnet50d(
|
712 |
-
pretrained=False,
|
713 |
-
in_chans=in_chans,
|
714 |
-
features_only=True,
|
715 |
-
out_indices=[4],
|
716 |
-
)
|
717 |
-
elif lidar_backbone_name == "r26":
|
718 |
-
self.lidar_backbone = resnet26d(
|
719 |
-
pretrained=False,
|
720 |
-
in_chans=in_chans,
|
721 |
-
features_only=True,
|
722 |
-
out_indices=[4],
|
723 |
-
)
|
724 |
-
elif lidar_backbone_name == "r18":
|
725 |
-
self.lidar_backbone = resnet18d(
|
726 |
-
pretrained=False, in_chans=3, features_only=True, out_indices=[4]
|
727 |
-
)
|
728 |
-
rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
|
729 |
-
lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
|
730 |
-
|
731 |
-
if use_mmad_pretrain:
|
732 |
-
params = torch.load(use_mmad_pretrain)["state_dict"]
|
733 |
-
updated_params = OrderedDict()
|
734 |
-
for key in params:
|
735 |
-
if "backbone" in key:
|
736 |
-
updated_params[key.replace("backbone.", "")] = params[key]
|
737 |
-
self.rgb_backbone.load_state_dict(updated_params)
|
738 |
-
|
739 |
-
self.rgb_patch_embed = rgb_embed_layer(
|
740 |
-
img_size=img_size,
|
741 |
-
patch_size=patch_size,
|
742 |
-
in_chans=in_chans,
|
743 |
-
embed_dim=embed_dim,
|
744 |
-
)
|
745 |
-
self.lidar_patch_embed = lidar_embed_layer(
|
746 |
-
img_size=img_size,
|
747 |
-
patch_size=patch_size,
|
748 |
-
in_chans=3,
|
749 |
-
embed_dim=embed_dim,
|
750 |
-
)
|
751 |
-
else:
|
752 |
-
if rgb_backbone_name == "r50":
|
753 |
-
self.rgb_backbone = resnet50d(
|
754 |
-
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
|
755 |
-
)
|
756 |
-
elif rgb_backbone_name == "r101":
|
757 |
-
self.rgb_backbone = resnet101d(
|
758 |
-
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
|
759 |
-
)
|
760 |
-
elif rgb_backbone_name == "r26":
|
761 |
-
self.rgb_backbone = resnet26d(
|
762 |
-
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
|
763 |
-
)
|
764 |
-
elif rgb_backbone_name == "r18":
|
765 |
-
self.rgb_backbone = resnet18d(
|
766 |
-
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
|
767 |
-
)
|
768 |
-
embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
|
769 |
-
|
770 |
-
self.rgb_patch_embed = embed_layer(
|
771 |
-
img_size=img_size,
|
772 |
-
patch_size=patch_size,
|
773 |
-
in_chans=in_chans,
|
774 |
-
embed_dim=embed_dim,
|
775 |
-
)
|
776 |
-
self.lidar_patch_embed = embed_layer(
|
777 |
-
img_size=img_size,
|
778 |
-
patch_size=patch_size,
|
779 |
-
in_chans=in_chans,
|
780 |
-
embed_dim=embed_dim,
|
781 |
-
)
|
782 |
-
|
783 |
-
self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
|
784 |
-
self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
|
785 |
-
|
786 |
-
if self.end2end:
|
787 |
-
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
|
788 |
-
self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
|
789 |
-
elif self.waypoints_pred_head == "heatmap":
|
790 |
-
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
|
791 |
-
self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
|
792 |
-
else:
|
793 |
-
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
|
794 |
-
self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
|
795 |
-
|
796 |
-
if self.end2end:
|
797 |
-
self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
|
798 |
-
elif self.waypoints_pred_head == "heatmap":
|
799 |
-
self.waypoints_generator = MultiPath_Generator(
|
800 |
-
embed_dim + 32, embed_dim, 10
|
801 |
-
)
|
802 |
-
elif self.waypoints_pred_head == "gru":
|
803 |
-
self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
|
804 |
-
elif self.waypoints_pred_head == "gru-command":
|
805 |
-
self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
|
806 |
-
elif self.waypoints_pred_head == "linear":
|
807 |
-
self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
|
808 |
-
elif self.waypoints_pred_head == "linear-sum":
|
809 |
-
self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
|
810 |
-
|
811 |
-
self.junction_pred_head = nn.Linear(embed_dim, 2)
|
812 |
-
self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
|
813 |
-
self.stop_sign_head = nn.Linear(embed_dim, 2)
|
814 |
-
|
815 |
-
if self.traffic_pred_head_type == "det":
|
816 |
-
self.traffic_pred_head = nn.Sequential(
|
817 |
-
*[
|
818 |
-
nn.Linear(embed_dim + 32, 64),
|
819 |
-
nn.ReLU(),
|
820 |
-
nn.Linear(64, 7),
|
821 |
-
nn.Sigmoid(),
|
822 |
-
]
|
823 |
-
)
|
824 |
-
elif self.traffic_pred_head_type == "seg":
|
825 |
-
self.traffic_pred_head = nn.Sequential(
|
826 |
-
*[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
|
827 |
-
)
|
828 |
-
|
829 |
self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
|
830 |
-
|
831 |
-
encoder_layer = TransformerEncoderLayer(
|
832 |
-
embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
|
833 |
-
)
|
834 |
self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
|
|
|
|
|
835 |
|
836 |
-
|
837 |
-
|
838 |
-
)
|
839 |
-
decoder_norm = nn.LayerNorm(embed_dim)
|
840 |
-
self.decoder = TransformerDecoder(
|
841 |
-
decoder_layer, dec_depth, decoder_norm, return_intermediate=False
|
842 |
-
)
|
843 |
-
self.reset_parameters()
|
844 |
-
|
845 |
-
def reset_parameters(self):
|
846 |
-
nn.init.uniform_(self.global_embed)
|
847 |
-
nn.init.uniform_(self.view_embed)
|
848 |
-
nn.init.uniform_(self.query_embed)
|
849 |
-
nn.init.uniform_(self.query_pos_embed)
|
850 |
-
|
851 |
-
def forward_features(
|
852 |
-
self,
|
853 |
-
front_image,
|
854 |
-
left_image,
|
855 |
-
right_image,
|
856 |
-
front_center_image,
|
857 |
-
lidar,
|
858 |
-
measurements,
|
859 |
-
):
|
860 |
-
features = []
|
861 |
-
|
862 |
-
# Front view processing
|
863 |
-
front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
|
864 |
-
if self.use_view_embed:
|
865 |
-
front_image_token = (
|
866 |
-
front_image_token
|
867 |
-
+ self.view_embed[:, :, 0:1, :]
|
868 |
-
+ self.position_encoding(front_image_token)
|
869 |
-
)
|
870 |
-
else:
|
871 |
-
front_image_token = front_image_token + self.position_encoding(
|
872 |
-
front_image_token
|
873 |
-
)
|
874 |
-
front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
|
875 |
-
front_image_token_global = (
|
876 |
-
front_image_token_global
|
877 |
-
+ self.view_embed[:, :, 0, :]
|
878 |
-
+ self.global_embed[:, :, 0:1]
|
879 |
-
)
|
880 |
-
front_image_token_global = front_image_token_global.permute(2, 0, 1)
|
881 |
-
features.extend([front_image_token, front_image_token_global])
|
882 |
-
|
883 |
if self.with_right_left_sensors:
|
884 |
-
|
885 |
-
|
886 |
-
if self.use_view_embed:
|
887 |
-
left_image_token = (
|
888 |
-
left_image_token
|
889 |
-
+ self.view_embed[:, :, 1:2, :]
|
890 |
-
+ self.position_encoding(left_image_token)
|
891 |
-
)
|
892 |
-
else:
|
893 |
-
left_image_token = left_image_token + self.position_encoding(
|
894 |
-
left_image_token
|
895 |
-
)
|
896 |
-
left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
|
897 |
-
left_image_token_global = (
|
898 |
-
left_image_token_global
|
899 |
-
+ self.view_embed[:, :, 1, :]
|
900 |
-
+ self.global_embed[:, :, 1:2]
|
901 |
-
)
|
902 |
-
left_image_token_global = left_image_token_global.permute(2, 0, 1)
|
903 |
-
|
904 |
-
# Right view processing
|
905 |
-
right_image_token, right_image_token_global = self.rgb_patch_embed(
|
906 |
-
right_image
|
907 |
-
)
|
908 |
-
if self.use_view_embed:
|
909 |
-
right_image_token = (
|
910 |
-
right_image_token
|
911 |
-
+ self.view_embed[:, :, 2:3, :]
|
912 |
-
+ self.position_encoding(right_image_token)
|
913 |
-
)
|
914 |
-
else:
|
915 |
-
right_image_token = right_image_token + self.position_encoding(
|
916 |
-
right_image_token
|
917 |
-
)
|
918 |
-
right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
|
919 |
-
right_image_token_global = (
|
920 |
-
right_image_token_global
|
921 |
-
+ self.view_embed[:, :, 2, :]
|
922 |
-
+ self.global_embed[:, :, 2:3]
|
923 |
-
)
|
924 |
-
right_image_token_global = right_image_token_global.permute(2, 0, 1)
|
925 |
-
|
926 |
-
features.extend(
|
927 |
-
[
|
928 |
-
left_image_token,
|
929 |
-
left_image_token_global,
|
930 |
-
right_image_token,
|
931 |
-
right_image_token_global,
|
932 |
-
]
|
933 |
-
)
|
934 |
-
|
935 |
-
if self.with_center_sensor:
|
936 |
-
# Front center view processing
|
937 |
-
(
|
938 |
-
front_center_image_token,
|
939 |
-
front_center_image_token_global,
|
940 |
-
) = self.rgb_patch_embed(front_center_image)
|
941 |
-
if self.use_view_embed:
|
942 |
-
front_center_image_token = (
|
943 |
-
front_center_image_token
|
944 |
-
+ self.view_embed[:, :, 3:4, :]
|
945 |
-
+ self.position_encoding(front_center_image_token)
|
946 |
-
)
|
947 |
-
else:
|
948 |
-
front_center_image_token = (
|
949 |
-
front_center_image_token
|
950 |
-
+ self.position_encoding(front_center_image_token)
|
951 |
-
)
|
952 |
-
|
953 |
-
front_center_image_token = front_center_image_token.flatten(2).permute(
|
954 |
-
2, 0, 1
|
955 |
-
)
|
956 |
-
front_center_image_token_global = (
|
957 |
-
front_center_image_token_global
|
958 |
-
+ self.view_embed[:, :, 3, :]
|
959 |
-
+ self.global_embed[:, :, 3:4]
|
960 |
-
)
|
961 |
-
front_center_image_token_global = front_center_image_token_global.permute(
|
962 |
-
2, 0, 1
|
963 |
-
)
|
964 |
-
features.extend([front_center_image_token, front_center_image_token_global])
|
965 |
-
|
966 |
if self.with_lidar:
|
967 |
-
|
968 |
-
|
969 |
-
|
970 |
-
|
971 |
-
|
972 |
-
+ self.position_encoding(lidar_token)
|
973 |
-
)
|
974 |
-
else:
|
975 |
-
lidar_token = lidar_token + self.position_encoding(lidar_token)
|
976 |
-
lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
|
977 |
-
lidar_token_global = (
|
978 |
-
lidar_token_global
|
979 |
-
+ self.view_embed[:, :, 4, :]
|
980 |
-
+ self.global_embed[:, :, 4:5]
|
981 |
-
)
|
982 |
-
lidar_token_global = lidar_token_global.permute(2, 0, 1)
|
983 |
-
features.extend([lidar_token, lidar_token_global])
|
984 |
-
|
985 |
-
features = torch.cat(features, 0)
|
986 |
-
return features
|
987 |
-
|
988 |
-
def forward(self, x):
|
989 |
-
front_image = x["rgb"]
|
990 |
-
left_image = x["rgb_left"]
|
991 |
-
right_image = x["rgb_right"]
|
992 |
-
front_center_image = x["rgb_center"]
|
993 |
-
measurements = x["measurements"]
|
994 |
-
target_point = x["target_point"]
|
995 |
-
lidar = x["lidar"]
|
996 |
-
|
997 |
if self.direct_concat:
|
998 |
-
img_size
|
999 |
-
|
1000 |
-
|
1001 |
-
)
|
1002 |
-
right_image = torch.nn.functional.interpolate(
|
1003 |
-
right_image, size=(img_size, img_size)
|
1004 |
-
)
|
1005 |
-
front_center_image = torch.nn.functional.interpolate(
|
1006 |
-
front_center_image, size=(img_size, img_size)
|
1007 |
-
)
|
1008 |
-
front_image = torch.cat(
|
1009 |
-
[front_image, left_image, right_image, front_center_image], dim=1
|
1010 |
-
)
|
1011 |
-
features = self.forward_features(
|
1012 |
-
front_image,
|
1013 |
-
left_image,
|
1014 |
-
right_image,
|
1015 |
-
front_center_image,
|
1016 |
-
lidar,
|
1017 |
-
measurements,
|
1018 |
-
)
|
1019 |
-
|
1020 |
bs = front_image.shape[0]
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
-
)
|
1028 |
-
tgt = tgt.flatten(2)
|
1029 |
-
tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
|
1030 |
-
tgt = tgt.permute(2, 0, 1)
|
1031 |
-
|
1032 |
-
memory = self.encoder(features, mask=self.attn_mask)
|
1033 |
-
hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
|
1034 |
-
|
1035 |
-
hs = hs.permute(1, 0, 2) # Batchsize , N, C
|
1036 |
-
if self.end2end:
|
1037 |
-
waypoints = self.waypoints_generator(hs, target_point)
|
1038 |
-
return waypoints
|
1039 |
-
|
1040 |
-
if self.waypoints_pred_head != "heatmap":
|
1041 |
-
traffic_feature = hs[:, :400]
|
1042 |
-
is_junction_feature = hs[:, 400]
|
1043 |
-
traffic_light_state_feature = hs[:, 400]
|
1044 |
-
stop_sign_feature = hs[:, 400]
|
1045 |
-
waypoints_feature = hs[:, 401:411]
|
1046 |
-
else:
|
1047 |
-
traffic_feature = hs[:, :400]
|
1048 |
-
is_junction_feature = hs[:, 400]
|
1049 |
-
traffic_light_state_feature = hs[:, 400]
|
1050 |
-
stop_sign_feature = hs[:, 400]
|
1051 |
-
waypoints_feature = hs[:, 401:405]
|
1052 |
-
|
1053 |
-
if self.waypoints_pred_head == "heatmap":
|
1054 |
-
waypoints = self.waypoints_generator(waypoints_feature, measurements)
|
1055 |
-
elif self.waypoints_pred_head == "gru":
|
1056 |
-
waypoints = self.waypoints_generator(waypoints_feature, target_point)
|
1057 |
-
elif self.waypoints_pred_head == "gru-command":
|
1058 |
-
waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
|
1059 |
-
elif self.waypoints_pred_head == "linear":
|
1060 |
-
waypoints = self.waypoints_generator(waypoints_feature, measurements)
|
1061 |
-
elif self.waypoints_pred_head == "linear-sum":
|
1062 |
-
waypoints = self.waypoints_generator(waypoints_feature, measurements)
|
1063 |
-
|
1064 |
is_junction = self.junction_pred_head(is_junction_feature)
|
1065 |
-
traffic_light_state = self.traffic_light_pred_head(
|
1066 |
-
stop_sign = self.stop_sign_head(
|
1067 |
-
|
1068 |
-
|
1069 |
-
velocity = velocity.repeat(1, 400, 32)
|
1070 |
-
traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
|
1071 |
-
traffic = self.traffic_pred_head(traffic_feature_with_vel)
|
1072 |
return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
|
1073 |
|
1074 |
|
|
|
619 |
mask[:, 101:151] = False
|
620 |
return mask
|
621 |
|
622 |
+
class InterfuserConfig(PretrainedConfig):
|
623 |
+
model_type = "interfuser"
|
624 |
+
|
625 |
def __init__(
|
626 |
self,
|
627 |
+
embed_dim=256,
|
|
|
|
|
|
|
|
|
628 |
enc_depth=6,
|
629 |
dec_depth=6,
|
|
|
|
|
|
|
|
|
630 |
num_heads=8,
|
631 |
+
dim_feedforward=2048,
|
632 |
dropout=0.1,
|
633 |
+
rgb_backbone_name="r50",
|
634 |
+
lidar_backbone_name="r18",
|
635 |
+
use_different_backbone=True,
|
636 |
+
waypoints_pred_head="gru",
|
637 |
+
**kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
638 |
):
|
639 |
+
self.embed_dim = embed_dim
|
640 |
+
self.enc_depth = enc_depth
|
641 |
+
self.dec_depth = dec_depth
|
642 |
+
self.num_heads = num_heads
|
643 |
+
self.dim_feedforward = dim_feedforward
|
644 |
+
self.dropout = dropout
|
645 |
+
self.rgb_backbone_name = rgb_backbone_name
|
646 |
+
self.lidar_backbone_name = lidar_backbone_name
|
647 |
+
self.use_different_backbone = use_different_backbone
|
648 |
self.waypoints_pred_head = waypoints_pred_head
|
649 |
+
super().__init__(**kwargs)
|
650 |
+
|
651 |
+
class InterfuserModel(PreTrainedModel):
|
652 |
+
config_class = InterfuserConfig
|
653 |
+
|
654 |
+
def __init__(self, config: InterfuserConfig):
|
655 |
+
super().__init__(config)
|
656 |
+
self.config = config
|
657 |
+
|
658 |
+
embed_dim=config.embed_dim; enc_depth=config.enc_depth; dec_depth=config.dec_depth; num_heads=config.num_heads; dim_feedforward=config.dim_feedforward; dropout=config.dropout
|
659 |
+
rgb_backbone_name=config.rgb_backbone_name; lidar_backbone_name=config.lidar_backbone_name; use_different_backbone=config.use_different_backbone
|
660 |
+
in_chans=3; img_size=224; direct_concat=True; with_lidar=True; with_right_left_sensors=True
|
661 |
+
|
662 |
+
self.embed_dim = embed_dim; self.direct_concat = direct_concat; self.with_lidar = with_lidar; self.with_right_left_sensors = with_right_left_sensors
|
663 |
+
in_chans_rgb = in_chans * 4 if self.direct_concat else in_chans
|
664 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
if use_different_backbone:
|
666 |
+
self.rgb_backbone = {'r50': resnet50d, 'r26': resnet26d, 'r18': resnet18d}[rgb_backbone_name](pretrained=False, in_chans=in_chans_rgb, features_only=True, out_indices=[4])
|
667 |
+
self.lidar_backbone = {'r50': resnet50d, 'r26': resnet26d, 'r18': resnet18d}[lidar_backbone_name](pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4])
|
668 |
+
self.rgb_patch_embed = HybridEmbed(self.rgb_backbone, img_size=img_size, in_chans=in_chans_rgb, embed_dim=embed_dim)
|
669 |
+
self.lidar_patch_embed = HybridEmbed(self.lidar_backbone, img_size=112, in_chans=in_chans, embed_dim=embed_dim)
|
670 |
+
|
671 |
+
self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)); self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
|
672 |
+
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)); self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
|
673 |
+
self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
|
674 |
+
self.junction_pred_head = nn.Linear(embed_dim, 2); self.traffic_light_pred_head = nn.Linear(embed_dim, 2); self.stop_sign_head = nn.Linear(embed_dim, 2)
|
675 |
+
self.traffic_pred_head = nn.Sequential(*[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
|
677 |
+
act_layer = nn.GELU()
|
678 |
+
encoder_layer = TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer)
|
|
|
|
|
679 |
self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
|
680 |
+
decoder_layer = TransformerDecoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer)
|
681 |
+
self.decoder = TransformerDecoder(decoder_layer, dec_depth, nn.LayerNorm(embed_dim), return_intermediate=False)
|
682 |
|
683 |
+
def forward_features(self, front_image, left_image, right_image, lidar):
|
684 |
+
features = [];
|
685 |
+
x, x_g = self.rgb_patch_embed(front_image); x = x + self.view_embed[:,:,0:1,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,0,:]+self.global_embed[:,:,0:1]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
if self.with_right_left_sensors:
|
687 |
+
x, x_g = self.rgb_patch_embed(left_image); x = x + self.view_embed[:,:,1:2,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,1,:]+self.global_embed[:,:,1:2]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
|
688 |
+
x, x_g = self.rgb_patch_embed(right_image); x = x + self.view_embed[:,:,2:3,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,2,:]+self.global_embed[:,:,2:3]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
if self.with_lidar:
|
690 |
+
x, x_g = self.lidar_patch_embed(lidar); x = x + self.view_embed[:,:,4:5,:] + self.position_encoding(x); x=x.flatten(2).permute(2,0,1); x_g=x_g+self.view_embed[:,:,4,:]+self.global_embed[:,:,4:5]; x_g=x_g.permute(2,0,1); features.extend([x,x_g])
|
691 |
+
return torch.cat(features, 0)
|
692 |
+
|
693 |
+
def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
|
694 |
+
front_image=rgb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
if self.direct_concat:
|
696 |
+
img_size=front_image.shape[-1]; front_image=torch.cat([front_image,F.interpolate(rgb_left,s:=(img_size,img_size)),F.interpolate(rgb_right,s),F.interpolate(rgb_center,s)],dim=1)
|
697 |
+
|
698 |
+
features = self.forward_features(front_image, rgb_left, rgb_right, lidar)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
bs = front_image.shape[0]
|
700 |
+
tgt = self.position_encoding(torch.ones((bs, 1, 20, 20), device=rgb.device)).flatten(2)
|
701 |
+
tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2).permute(2, 0, 1)
|
702 |
+
hs = self.decoder(self.query_embed.repeat(1, bs, 1), self.encoder(features), query_pos=tgt)[0].permute(1, 0, 2)
|
703 |
+
|
704 |
+
traffic_feature = hs[:, :400]; waypoints_feature = hs[:, 401:411]; is_junction_feature = hs[:, 400]
|
705 |
+
waypoints = self.waypoints_generator(waypoints_feature, target_point)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
706 |
is_junction = self.junction_pred_head(is_junction_feature)
|
707 |
+
traffic_light_state = self.traffic_light_pred_head(is_junction_feature)
|
708 |
+
stop_sign = self.stop_sign_head(is_junction_feature)
|
709 |
+
traffic = self.traffic_pred_head(traffic_feature)
|
710 |
+
|
|
|
|
|
|
|
711 |
return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
|
712 |
|
713 |
|