| from typing import Any, Dict, List | |
| import torch | |
| from einops import rearrange | |
| from enformer_pytorch import Enformer | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import ( | |
| TorchUNetHead, | |
| ) | |
| FEATURES = [ | |
| "protein_coding_gene", | |
| "lncRNA", | |
| "exon", | |
| "intron", | |
| "splice_donor", | |
| "splice_acceptor", | |
| "5UTR", | |
| "3UTR", | |
| "CTCF-bound", | |
| "polyA_signal", | |
| "enhancer_Tissue_specific", | |
| "enhancer_Tissue_invariant", | |
| "promoter_Tissue_specific", | |
| "promoter_Tissue_invariant", | |
| ] | |
| class SegmentEnformerConfig(PretrainedConfig): | |
| model_type = "segment_enformer" | |
| def __init__( | |
| self, | |
| features: List[str] = FEATURES, | |
| embed_dim: int = 1536, | |
| dim_divisible_by: int = 128, | |
| **kwargs: Dict[str, Any] | |
| ) -> None: | |
| self.features = features | |
| self.embed_dim = embed_dim | |
| self.dim_divisible_by = dim_divisible_by | |
| super().__init__(**kwargs) | |
| class SegmentEnformer(PreTrainedModel): | |
| config_class = SegmentEnformerConfig | |
| def __init__(self, config: SegmentEnformerConfig) -> None: | |
| super().__init__(config=config) | |
| enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough") | |
| self.stem = enformer.stem | |
| self.conv_tower = enformer.conv_tower | |
| self.transformer = enformer.transformer | |
| self.unet_head = TorchUNetHead( | |
| features=config.features, | |
| embed_dimension=config.embed_dim, | |
| nucl_per_token=config.dim_divisible_by, | |
| remove_cls_token=False, | |
| ) | |
| def __call__(self, x: torch.Tensor) -> torch.Tensor: | |
| x = rearrange(x, "b n d -> b d n") | |
| x = self.stem(x) | |
| x = self.conv_tower(x) | |
| x = rearrange(x, "b d n -> b n d") | |
| x = self.transformer(x) | |
| x = rearrange(x, "b n d -> b d n") | |
| x = self.unet_head(x) | |
| return x | |