|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ClsToken(nn.Module): | 
					
						
						|  | def __init__(self, ndim: int, | 
					
						
						|  | num_tokens: int = 1, | 
					
						
						|  | enabled: bool = True, | 
					
						
						|  | register_multiple: int = 0, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.ndim = ndim | 
					
						
						|  | self.enabled = enabled | 
					
						
						|  | self.num_registers = 0 | 
					
						
						|  | self.num_tokens = num_tokens | 
					
						
						|  | if enabled: | 
					
						
						|  | if register_multiple > 0: | 
					
						
						|  | self.num_registers = register_multiple - (num_tokens % register_multiple) | 
					
						
						|  |  | 
					
						
						|  | scale = ndim ** -0.5 | 
					
						
						|  | self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale) | 
					
						
						|  | else: | 
					
						
						|  | self.token = None | 
					
						
						|  |  | 
					
						
						|  | self.num_patches = self.num_tokens + self.num_registers | 
					
						
						|  |  | 
					
						
						|  | def disable(self): | 
					
						
						|  | self.token = None | 
					
						
						|  | self.enabled = False | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor): | 
					
						
						|  | if self.token is None: | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) | 
					
						
						|  | x = torch.cat([ | 
					
						
						|  | token, | 
					
						
						|  | x, | 
					
						
						|  | ], dim=1) | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def no_weight_decay(self): | 
					
						
						|  | return [ | 
					
						
						|  | 'token', | 
					
						
						|  | ] | 
					
						
						|  |  |