from typing import Set | |
import spconv | |
if float(spconv.__version__[2:]) >= 2.2: | |
spconv.constants.SPCONV_USE_DIRECT_TABLE = False | |
try: | |
import spconv.pytorch as spconv | |
except: | |
import spconv as spconv | |
import torch.nn as nn | |
def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]: | |
""" | |
Finds all spconv keys that need to have weight's transposed | |
""" | |
found_keys: Set[str] = set() | |
for name, child in model.named_children(): | |
new_prefix = f"{prefix}.{name}" if prefix != "" else name | |
if isinstance(child, spconv.conv.SparseConvolution): | |
new_prefix = f"{new_prefix}.weight" | |
found_keys.add(new_prefix) | |
found_keys.update(find_all_spconv_keys(child, prefix=new_prefix)) | |
return found_keys | |
def replace_feature(out, new_features): | |
if "replace_feature" in out.__dir__(): | |
# spconv 2.x behaviour | |
return out.replace_feature(new_features) | |
else: | |
out.features = new_features | |
return out | |