Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational

feat: Add CPU support

#18
Files changed (1) hide show
  1. modeling_nemotron_h.py +38 -12
modeling_nemotron_h.py CHANGED
@@ -16,6 +16,7 @@
16
  """PyTorch NemotronH model."""
17
 
18
  import math
 
19
  from dataclasses import dataclass
20
  from typing import Any, Dict, Optional, Tuple, Union
21
 
@@ -61,8 +62,9 @@ else:
61
  try:
62
  #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
63
  from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
 
64
  except ImportError:
65
- raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")
66
 
67
  if is_causal_conv1d_available():
68
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
@@ -269,14 +271,30 @@ class MambaRMSNormGated(torch.nn.Module):
269
 
270
  # jan28b version
271
  def forward(self, hidden_states, gate=None):
272
- return rmsnorm_fn(x=hidden_states,
273
- weight=self.weight,
274
- bias=None, # No bias
275
- z=gate,
276
- eps=self.variance_epsilon,
277
- group_size=self.group_size,
278
- norm_before_gate=False
279
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  class NemotronHMamba2Mixer(nn.Module):
282
  """
@@ -623,8 +641,8 @@ class NemotronHMamba2Mixer(nn.Module):
623
  hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
624
  B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
625
  C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
626
- B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
627
- C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
628
  pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
629
 
630
  D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
@@ -757,6 +775,14 @@ class NemotronHBlock(nn.Module):
757
  else:
758
  raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")
759
 
 
 
 
 
 
 
 
 
760
  def forward(
761
  self,
762
  hidden_states,
@@ -764,7 +790,7 @@ class NemotronHBlock(nn.Module):
764
  cache_position: Optional[torch.LongTensor] = None,
765
  attention_mask: Optional[torch.Tensor] = None,
766
  ):
767
- with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
768
  # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
769
  residual = hidden_states
770
  hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
 
16
  """PyTorch NemotronH model."""
17
 
18
  import math
19
+ from contextlib import contextmanager
20
  from dataclasses import dataclass
21
  from typing import Any, Dict, Optional, Tuple, Union
22
 
 
62
  try:
63
  #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
64
  from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
65
+ FAST_RMSNORM = True
66
  except ImportError:
67
+ FAST_RMSNORM = False
68
 
69
  if is_causal_conv1d_available():
70
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
 
271
 
272
  # jan28b version
273
  def forward(self, hidden_states, gate=None):
274
+ if FAST_RMSNORM:
275
+ return rmsnorm_fn(x=hidden_states,
276
+ weight=self.weight,
277
+ bias=None, # No bias
278
+ z=gate,
279
+ eps=self.variance_epsilon,
280
+ group_size=self.group_size,
281
+ norm_before_gate=False
282
+ )
283
+
284
+ # standard version
285
+ input_dtype = hidden_states.dtype
286
+ batch_size, seq_len, hidden_size = hidden_states.shape
287
+ num_groups = self.weight.shape[0] // self.group_size
288
+ hidden_states = hidden_states.to(torch.float32)
289
+
290
+ if gate is not None:
291
+ hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
292
+ hidden_states = hidden_states.view(batch_size, seq_len, num_groups, hidden_size // num_groups)
293
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
294
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
295
+
296
+ hidden_states = hidden_states.view(batch_size, seq_len, hidden_size)
297
+ return self.weight * hidden_states.to(input_dtype)
298
 
299
  class NemotronHMamba2Mixer(nn.Module):
300
  """
 
641
  hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
642
  B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
643
  C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
644
+ B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
645
+ C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
646
  pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
647
 
648
  D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
 
775
  else:
776
  raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")
777
 
778
+ @contextmanager
779
+ def _maybe_cuda_stream(self, device):
780
+ if torch.cuda.is_available():
781
+ with torch.cuda.stream(torch.cuda.default_stream(device)):
782
+ yield
783
+ else:
784
+ yield
785
+
786
  def forward(
787
  self,
788
  hidden_states,
 
790
  cache_position: Optional[torch.LongTensor] = None,
791
  attention_mask: Optional[torch.Tensor] = None,
792
  ):
793
+ with self._maybe_cuda_stream(hidden_states.device):
794
  # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
795
  residual = hidden_states
796
  hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))