File size: 2,746 Bytes
			
			| 9c4ca75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | # Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import os
from typing import List, Optional
import pytest
# from composer.utils import reproducibility
# Allowed options for pytest.mark.world_size()
WORLD_SIZE_OPTIONS = (1, 2)
# Enforce deterministic mode before any tests start.
# reproducibility.configure_deterministic_mode()
# TODO: allow plugind when deps resolved
# Add the path of any pytest fixture files you want to make global
pytest_plugins = [
    # 'tests.fixtures.autouse',
    'tests.fixtures.fixtures',
]
def _get_world_size(item: pytest.Item):
    """Returns the world_size of a test, defaults to 1."""
    _default = pytest.mark.world_size(1).mark
    return item.get_closest_marker('world_size', default=_default).args[0]
def _get_option(
    config: pytest.Config,
    name: str,
    default: Optional[str] = None,
) -> str:  # type: ignore
    val = config.getoption(name)
    if val is not None:
        assert isinstance(val, str)
        return val
    val = config.getini(name)
    if val == []:
        val = None
    if val is None:
        if default is None:
            pytest.fail(f'Config option {name} is not specified but is required',)
        val = default
    assert isinstance(val, str)
    return val
def _add_option(
    parser: pytest.Parser,
    name: str,
    help: str,
    choices: Optional[list[str]] = None,
):
    parser.addoption(
        f'--{name}',
        default=None,
        type=str,
        choices=choices,
        help=help,
    )
    parser.addini(
        name=name,
        help=help,
        type='string',
        default=None,
    )
def pytest_collection_modifyitems(
    config: pytest.Config,
    items: List[pytest.Item],
) -> None:
    """Filter tests by world_size (for multi-GPU tests)"""
    world_size = int(os.environ.get('WORLD_SIZE', '1'))
    print(f'world_size={world_size}')
    conditions = [
        lambda item: _get_world_size(item) == world_size,
    ]
    # keep items that satisfy all conditions
    remaining = []
    deselected = []
    for item in items:
        if all(condition(item) for condition in conditions):
            remaining.append(item)
        else:
            deselected.append(item)
    if deselected:
        config.hook.pytest_deselected(items=deselected)
        items[:] = remaining
def pytest_addoption(parser: pytest.Parser) -> None:
    _add_option(
        parser,
        'seed',
        help="""\
        Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked
        before each test.""",
    )
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
    if exitstatus == 5:
        session.exitstatus = 0  # Ignore no-test-ran errors
 | 
