Skip to content

framework

zeus.utils.framework

Utilities for framework-specific code.

torch_is_available cached

torch_is_available(
    ensure_available=False, ensure_cuda=True
)

Check if PyTorch is available.

Source code in zeus/utils/framework.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@lru_cache(maxsize=1)
def torch_is_available(ensure_available: bool = False, ensure_cuda: bool = True):
    """Check if PyTorch is available."""
    try:
        import torch

        cuda_available = torch.cuda.is_available()
        if ensure_cuda and not cuda_available:
            raise RuntimeError("PyTorch is available but does not have CUDA support.")
        MODULE_CACHE["torch"] = torch
        logger.info(
            "PyTorch %s CUDA support is available.",
            "with" if cuda_available else "without",
        )
        return True
    except ImportError as e:
        logger.info("PyTorch is not available.")
        if ensure_available:
            raise RuntimeError("Failed to import Pytorch") from e
        return False

jax_is_available cached

jax_is_available(ensure_available=False, ensure_cuda=True)

Check if JAX is available.

Source code in zeus/utils/framework.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@lru_cache(maxsize=1)
def jax_is_available(ensure_available: bool = False, ensure_cuda: bool = True):
    """Check if JAX is available."""
    try:
        import jax  # type: ignore

        cuda_available = jax.devices("gpu")
        if ensure_cuda and not cuda_available:
            raise RuntimeError("JAX is available but does not have CUDA support.")
        MODULE_CACHE["jax"] = jax
        logger.info(
            "JAX %s CUDA support is available.", "with" if cuda_available else "without"
        )
        return True
    except ImportError as e:
        logger.info("JAX is not available")
        if ensure_available:
            raise RuntimeError("Failed to import JAX") from e
        return False

sync_execution

sync_execution(gpu_devices, sync_with='torch')

Block until all computations on the specified devices are finished.

PyTorch only runs GPU computations asynchronously, so synchronizing computations for the given GPU devices is done by calling torch.cuda.synchronize on each device. On the other hand, JAX runs both CPU and GPU computations asynchronously, but by default it only has a single CPU device (id=0). Therefore for JAX, all GPU devices passed in and the CPU device (id=0) are synchronized.

Note

jax.device_put with block_until_ready is used to synchronize computations on JAX devices. This is a workaround to the lack of a direct API for synchronizing computations on JAX devices. Tracking issue: https://github.com/google/jax/issues/4335

Note

Across the Zeus library, an integer device index corresponds to a single whole physical device. This is usually what you want, except when using more advanced device partitioning (e.g., using --xla_force_host_platform_device_count in JAX to partition CPUs into more pieces). In such cases, you probably want to opt out from using this function and handle synchronization manually at the appropriate granularity.

Parameters:

Name Type Description Default
gpu_devices list[int]

GPU device indices to synchronize.

required
sync_with Literal['torch', 'jax']

Deep learning framework to use to synchronize computations. Defaults to "torch", in which case torch.cuda.synchronize will be used.

'torch'
Source code in zeus/utils/framework.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def sync_execution(
    gpu_devices: list[int], sync_with: Literal["torch", "jax"] = "torch"
) -> None:
    """Block until all computations on the specified devices are finished.

    PyTorch only runs GPU computations asynchronously, so synchronizing computations
    for the given GPU devices is done by calling `torch.cuda.synchronize` on each
    device. On the other hand, JAX runs both CPU and GPU computations asynchronously,
    but by default it only has a single CPU device (id=0). Therefore for JAX, all GPU
    devices passed in and the CPU device (id=0) are synchronized.

    !!! Note
        `jax.device_put` with `block_until_ready` is used to synchronize computations
        on JAX devices. This is a workaround to the lack of a direct API for
        synchronizing computations on JAX devices. Tracking issue:
        https://github.com/google/jax/issues/4335

    !!! Note
        Across the Zeus library, an integer device index corresponds to a single whole
        physical device. This is usually what you want, except when using more advanced
        device partitioning (e.g., using `--xla_force_host_platform_device_count` in JAX
        to partition CPUs into more pieces). In such cases, you probably want to opt out
        from using this function and handle synchronization manually at the appropriate
        granularity.

    Args:
        gpu_devices: GPU device indices to synchronize.
        sync_with: Deep learning framework to use to synchronize computations.
            Defaults to `"torch"`, in which case `torch.cuda.synchronize` will be used.
    """
    if sync_with == "torch" and torch_is_available(ensure_available=True):
        torch = MODULE_CACHE["torch"]
        for device in gpu_devices:
            torch.cuda.synchronize(device)
        return

    if sync_with == "jax" and jax_is_available(ensure_available=True):
        jax = MODULE_CACHE["jax"]
        futures = [
            jax.device_put(0.0, device=jax.devices("gpu")[device]) + 0
            for device in gpu_devices
        ]
        futures.append(jax.device_put(0.0, device=jax.devices("cpu")[0]) + 0)
        jax.block_until_ready(futures)
        return

    raise RuntimeError("No framework is available.")