Commit
·
2d7348d
1
Parent(s):
098730b
Update class names to MultiHeadLatentAttention
Browse files- src/__init__.py +2 -2
- src/mla.py +1 -1
- src/tests/test_mla.py +2 -2
src/__init__.py
CHANGED
|
@@ -5,7 +5,7 @@ Copyright (c) 2025
|
|
| 5 |
Implementation of the Multi-Latent Attention mechanism from the DeepSeek-V2 paper.
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
from .mla import
|
| 9 |
|
| 10 |
__version__ = "0.1.0"
|
| 11 |
-
__all__ = ["
|
|
|
|
| 5 |
Implementation of the Multi-Latent Attention mechanism from the DeepSeek-V2 paper.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
from .mla import MultiHeadLatentAttention, precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb
|
| 9 |
|
| 10 |
__version__ = "0.1.0"
|
| 11 |
+
__all__ = ["MultiHeadLatentAttention", "precompute_freqs_cis", "reshape_for_broadcast","apply_rotary_emb"]
|
src/mla.py
CHANGED
|
@@ -58,7 +58,7 @@ def apply_rotary_emb(
|
|
| 58 |
|
| 59 |
|
| 60 |
|
| 61 |
-
class
|
| 62 |
"""
|
| 63 |
Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
|
| 64 |
Key innovation from standard MHA:
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
|
| 61 |
+
class MultiHeadLatentAttention(nn.Module):
|
| 62 |
"""
|
| 63 |
Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
|
| 64 |
Key innovation from standard MHA:
|
src/tests/test_mla.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import unittest
|
| 2 |
import torch
|
| 3 |
-
from ..mla import
|
| 4 |
|
| 5 |
class TestMultiLatentAttention(unittest.TestCase):
|
| 6 |
def setUp(self):
|
|
@@ -15,7 +15,7 @@ class TestMultiLatentAttention(unittest.TestCase):
|
|
| 15 |
self.seq_len = 10
|
| 16 |
|
| 17 |
# Initialize MLA
|
| 18 |
-
self.mla =
|
| 19 |
d_model=self.d_model,
|
| 20 |
num_head=self.num_head,
|
| 21 |
d_embed=self.d_embed,
|
|
|
|
| 1 |
import unittest
|
| 2 |
import torch
|
| 3 |
+
from ..mla import MultiHeadLatentAttention # Using relative import
|
| 4 |
|
| 5 |
class TestMultiLatentAttention(unittest.TestCase):
|
| 6 |
def setUp(self):
|
|
|
|
| 15 |
self.seq_len = 10
|
| 16 |
|
| 17 |
# Initialize MLA
|
| 18 |
+
self.mla = MultiHeadLatentAttention(
|
| 19 |
d_model=self.d_model,
|
| 20 |
num_head=self.num_head,
|
| 21 |
d_embed=self.d_embed,
|