| import torch | |
| import megablocks | |
| from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights | |
| def test_megablocks_moe_mlp_with_shared_expert_import(): | |
| mlp = MegaBlocksMoeMLPWithSharedExpert() | |
| assert hasattr(mlp, 'shared_up_proj_weight') | |
| assert hasattr(mlp, 'shared_down_proj_weight') | |
| assert hasattr(mlp, 'set_shared_expert_weights') | |
| def test_set_shared_expert_weights(): | |
| mlp = MegaBlocksMoeMLPWithSharedExpert() | |
| hidden_size = 128 | |
| shared_expert_hidden_size = 256 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| dtype = torch.float32 | |
| up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype) | |
| down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype) | |
| up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype) | |
| down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype) | |
| mlp.set_shared_expert_weights( | |
| up_proj_weight=up_proj_weight, | |
| down_proj_weight=down_proj_weight, | |
| up_proj_bias=up_proj_bias, | |
| down_proj_bias=down_proj_bias, | |
| weighted_sum=True, | |
| activation_fn=torch.nn.functional.gelu | |
| ) | |
| assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight) | |
| assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight) | |
| assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias) | |
| assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias) | |
| assert mlp.shared_expert_weighted_sum == True | |
| assert mlp.shared_activation_fn == torch.nn.functional.gelu | |
| def test_create_shared_expert_weights(): | |
| hidden_size = 128 | |
| shared_expert_hidden_size = 256 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| dtype = torch.float32 | |
| def init_method(tensor): | |
| torch.nn.init.xavier_uniform_(tensor) | |
| up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights( | |
| hidden_size=hidden_size, | |
| shared_expert_hidden_size=shared_expert_hidden_size, | |
| device=device, | |
| dtype=dtype, | |
| init_method=init_method | |
| ) | |
| assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size) | |
| assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size) | |
| assert up_proj_weight.device.type == device.type | |
| assert down_proj_weight.device.type == device.type | |
| assert up_proj_weight.dtype == dtype | |
| assert down_proj_weight.dtype == dtype | |
| assert up_proj_bias is None | |
| assert down_proj_bias is None | |
| def test_shared_expert_weights_none_by_default(): | |
| mlp = MegaBlocksMoeMLPWithSharedExpert() | |
| assert mlp.shared_up_proj_weight is None | |
| assert mlp.shared_down_proj_weight is None | |
| assert mlp.shared_up_proj_bias is None | |
| assert mlp.shared_down_proj_bias is None | |
| assert mlp.shared_expert_weighted_sum == False | |
| assert mlp.shared_activation_fn is None | |
| def test_inheritance_from_megablocks_moe_mlp(): | |
| mlp = MegaBlocksMoeMLPWithSharedExpert() | |
| from megablocks.layers import MegaBlocksMoeMLP | |
| assert isinstance(mlp, MegaBlocksMoeMLP) | |
| assert hasattr(mlp, 'forward') | |
| def test_shared_expert_weights_custom_init(): | |
| hidden_size = 64 | |
| shared_expert_hidden_size = 128 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| dtype = torch.float16 | |
| def custom_init(tensor): | |
| torch.nn.init.constant_(tensor, 0.5) | |
| def custom_output_init(tensor): | |
| torch.nn.init.constant_(tensor, 0.1) | |
| up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights( | |
| hidden_size=hidden_size, | |
| shared_expert_hidden_size=shared_expert_hidden_size, | |
| device=device, | |
| dtype=dtype, | |
| init_method=custom_init, | |
| output_layer_init_method=custom_output_init | |
| ) | |
| assert torch.all(up_proj_weight == 0.5) | |
| assert torch.all(down_proj_weight == 0.1) | |
| assert up_proj_weight.dtype == dtype | |
| assert down_proj_weight.dtype == dtype | |
| def test_shared_expert_weights_dimensions(): | |
| mlp = MegaBlocksMoeMLPWithSharedExpert() | |
| batch_size = 4 | |
| seq_len = 16 | |
| hidden_size = 128 | |
| shared_expert_hidden_size = 256 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device) | |
| down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device) | |
| mlp.set_shared_expert_weights( | |
| up_proj_weight=up_proj_weight, | |
| down_proj_weight=down_proj_weight | |
| ) | |
| x = torch.randn(seq_len, batch_size, hidden_size, device=device) | |
| expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size) | |
| expected_down_output_shape = (seq_len, batch_size, hidden_size) | |
| assert up_proj_weight.shape[1] == x.shape[-1] | |
| assert down_proj_weight.shape[0] == x.shape[-1] |