Skip to content

Commit

Permalink
refactor testing (#140)
Browse files Browse the repository at this point in the history
* refactor testing

* rename to architectures
  • Loading branch information
eitanturok authored Aug 12, 2024
1 parent 27d3d2c commit bce5d7b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@
from megablocks.layers.arguments import Arguments


def allclose(x, y, pct=0.5):
mask = torch.isclose(x, y, rtol=5e-2)
pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
if pct_diff > pct:
print('{:.2f}% of values not close.'.format(pct_diff))
return False
return True


class FFN(torch.nn.Module):

def __init__(self, args: Arguments):
Expand Down
16 changes: 9 additions & 7 deletions tests/layers/dmoe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch

from megablocks import grouped_gemm_util as gg
from megablocks.layers import dmoe, moe, testing
from megablocks.layers.arguments import Arguments
from megablocks.layers.dmoe import dMoE
from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
from tests.layers.architectures import FFN

# min size: (1, 2, 128, 2, 1)
_FORWARD_TESTS_DEFAULT = (
Expand Down Expand Up @@ -64,9 +66,9 @@ def construct_moes(
bf16=True,
)

mlp = testing.FFN(args)
moe_mlp = moe.MoE(args)
dmoe_mlp = dmoe.dMoE(args)
mlp = FFN(args)
moe_mlp = MoE(args)
dmoe_mlp = dMoE(args)

mlp.cuda(torch.cuda.current_device()).to(torch.bfloat16)
moe_mlp.cuda(torch.cuda.current_device()).to(torch.bfloat16)
Expand Down Expand Up @@ -106,7 +108,7 @@ def test_dmoe_forward(

out, _ = layer(x)
assert out.shape == x.shape
moe.clear_load_balancing_loss()
clear_load_balancing_loss()


@pytest.mark.gpu
Expand All @@ -132,12 +134,12 @@ def test_dmoe_forward_backward(

out, _ = layer(x)
assert out.shape == x.shape
loss = out.sum() + moe.batched_load_balancing_loss(args)
loss = out.sum() + batched_load_balancing_loss(args)
loss.backward()
assert x.grad is not None
layer.zero_grad(set_to_none=True)
x.grad = None
moe.clear_load_balancing_loss()
clear_load_balancing_loss()


@pytest.mark.gpu
Expand Down
5 changes: 3 additions & 2 deletions tests/layers/glu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import stk
import torch

from megablocks.layers import dmlp_registry, testing
from megablocks.layers import dmlp_registry
from megablocks.layers.arguments import Arguments
from tests.layers.architectures import GLU

_DENSE_TESTS = (
(16, 1024, 512),
Expand Down Expand Up @@ -36,7 +37,7 @@ def construct_dmoe_glu(
bf16=True,
)

glu = testing.GLU(args)
glu = GLU(args)
dmoe_glu = dmlp_registry.get(args)

dmoe_glu.cuda(torch.cuda.current_device()).to(torch.bfloat16)
Expand Down
19 changes: 10 additions & 9 deletions tests/layers/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import pytest
import torch

from megablocks.layers import moe, testing
from megablocks.layers.arguments import Arguments
from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
from tests.layers.architectures import FFN

_FORWARD_TESTS = (
(16, 1024, 512, 1, 1),
Expand Down Expand Up @@ -48,8 +49,8 @@ def construct_moe(
init_method=init_method,
)

mlp = testing.FFN(args)
moe_mlp = moe.MoE(args)
mlp = FFN(args)
moe_mlp = MoE(args)

mlp.cuda(torch.cuda.current_device()).half()
moe_mlp.cuda(torch.cuda.current_device()).half()
Expand All @@ -76,7 +77,7 @@ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int):

out, _ = layer(x)
assert out.shape == x.shape
moe.clear_load_balancing_loss()
clear_load_balancing_loss()


@pytest.mark.gpu
Expand All @@ -101,11 +102,11 @@ def test_moe_forward_backward(
out, _ = layer(x)
assert out.shape == x.shape

loss = out.sum() + moe.batched_load_balancing_loss(args)
loss = out.sum() + batched_load_balancing_loss(args)
loss.backward()
layer.zero_grad(set_to_none=True)
x.grad = None
moe.clear_load_balancing_loss()
clear_load_balancing_loss()


@pytest.mark.gpu
Expand All @@ -119,7 +120,7 @@ def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):
out, _ = moe_mlp(x)
assert out.shape == x.shape == expected_out.shape
assert torch.allclose(out, expected_out)
moe.clear_load_balancing_loss()
clear_load_balancing_loss()


@pytest.mark.gpu
Expand All @@ -137,7 +138,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int):
w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze()
moe_mlp.zero_grad(set_to_none=True)
x.grad = None
moe.clear_load_balancing_loss()
clear_load_balancing_loss()

expected_out = mlp(x)
expected_loss = expected_out.sum()
Expand All @@ -152,4 +153,4 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int):
assert w2_grad.shape == expected_w2_grad.shape
assert torch.allclose(w1_grad, expected_w1_grad)
assert torch.allclose(w2_grad, expected_w2_grad)
moe.clear_load_balancing_loss()
clear_load_balancing_loss()

0 comments on commit bce5d7b

Please sign in to comment.