yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
from typing import Set
import spconv
if float(spconv.__version__[2:]) >= 2.2:
spconv.constants.SPCONV_USE_DIRECT_TABLE = False
try:
import spconv.pytorch as spconv
except:
import spconv as spconv
import torch.nn as nn
def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
"""
Finds all spconv keys that need to have weight's transposed
"""
found_keys: Set[str] = set()
for name, child in model.named_children():
new_prefix = f"{prefix}.{name}" if prefix != "" else name
if isinstance(child, spconv.conv.SparseConvolution):
new_prefix = f"{new_prefix}.weight"
found_keys.add(new_prefix)
found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))
return found_keys
def replace_feature(out, new_features):
if "replace_feature" in out.__dir__():
# spconv 2.x behaviour
return out.replace_feature(new_features)
else:
out.features = new_features
return out