feat: Add CPU support
#18
by
gabegoodhart
- opened
- modeling_nemotron_h.py +38 -12
modeling_nemotron_h.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
"""PyTorch NemotronH model."""
|
17 |
|
18 |
import math
|
|
|
19 |
from dataclasses import dataclass
|
20 |
from typing import Any, Dict, Optional, Tuple, Union
|
21 |
|
@@ -61,8 +62,9 @@ else:
|
|
61 |
try:
|
62 |
#from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
63 |
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
|
|
|
64 |
except ImportError:
|
65 |
-
|
66 |
|
67 |
if is_causal_conv1d_available():
|
68 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
@@ -269,14 +271,30 @@ class MambaRMSNormGated(torch.nn.Module):
|
|
269 |
|
270 |
# jan28b version
|
271 |
def forward(self, hidden_states, gate=None):
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
class NemotronHMamba2Mixer(nn.Module):
|
282 |
"""
|
@@ -623,8 +641,8 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
623 |
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
|
624 |
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
625 |
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
626 |
-
B = B.
|
627 |
-
C = C.
|
628 |
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
|
629 |
|
630 |
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
@@ -757,6 +775,14 @@ class NemotronHBlock(nn.Module):
|
|
757 |
else:
|
758 |
raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")
|
759 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
def forward(
|
761 |
self,
|
762 |
hidden_states,
|
@@ -764,7 +790,7 @@ class NemotronHBlock(nn.Module):
|
|
764 |
cache_position: Optional[torch.LongTensor] = None,
|
765 |
attention_mask: Optional[torch.Tensor] = None,
|
766 |
):
|
767 |
-
with
|
768 |
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
|
769 |
residual = hidden_states
|
770 |
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
|
|
16 |
"""PyTorch NemotronH model."""
|
17 |
|
18 |
import math
|
19 |
+
from contextlib import contextmanager
|
20 |
from dataclasses import dataclass
|
21 |
from typing import Any, Dict, Optional, Tuple, Union
|
22 |
|
|
|
62 |
try:
|
63 |
#from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
64 |
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
|
65 |
+
FAST_RMSNORM = True
|
66 |
except ImportError:
|
67 |
+
FAST_RMSNORM = False
|
68 |
|
69 |
if is_causal_conv1d_available():
|
70 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
|
271 |
|
272 |
# jan28b version
|
273 |
def forward(self, hidden_states, gate=None):
|
274 |
+
if FAST_RMSNORM:
|
275 |
+
return rmsnorm_fn(x=hidden_states,
|
276 |
+
weight=self.weight,
|
277 |
+
bias=None, # No bias
|
278 |
+
z=gate,
|
279 |
+
eps=self.variance_epsilon,
|
280 |
+
group_size=self.group_size,
|
281 |
+
norm_before_gate=False
|
282 |
+
)
|
283 |
+
|
284 |
+
# standard version
|
285 |
+
input_dtype = hidden_states.dtype
|
286 |
+
batch_size, seq_len, hidden_size = hidden_states.shape
|
287 |
+
num_groups = self.weight.shape[0] // self.group_size
|
288 |
+
hidden_states = hidden_states.to(torch.float32)
|
289 |
+
|
290 |
+
if gate is not None:
|
291 |
+
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
|
292 |
+
hidden_states = hidden_states.view(batch_size, seq_len, num_groups, hidden_size // num_groups)
|
293 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
294 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
295 |
+
|
296 |
+
hidden_states = hidden_states.view(batch_size, seq_len, hidden_size)
|
297 |
+
return self.weight * hidden_states.to(input_dtype)
|
298 |
|
299 |
class NemotronHMamba2Mixer(nn.Module):
|
300 |
"""
|
|
|
641 |
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
|
642 |
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
643 |
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
644 |
+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
|
645 |
+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
|
646 |
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
|
647 |
|
648 |
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
|
|
775 |
else:
|
776 |
raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")
|
777 |
|
778 |
+
@contextmanager
|
779 |
+
def _maybe_cuda_stream(self, device):
|
780 |
+
if torch.cuda.is_available():
|
781 |
+
with torch.cuda.stream(torch.cuda.default_stream(device)):
|
782 |
+
yield
|
783 |
+
else:
|
784 |
+
yield
|
785 |
+
|
786 |
def forward(
|
787 |
self,
|
788 |
hidden_states,
|
|
|
790 |
cache_position: Optional[torch.LongTensor] = None,
|
791 |
attention_mask: Optional[torch.Tensor] = None,
|
792 |
):
|
793 |
+
with self._maybe_cuda_stream(hidden_states.device):
|
794 |
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
|
795 |
residual = hidden_states
|
796 |
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|