Skip to content

framework

zeus.utils.framework

Utilities for framework-specific code.

torch_is_available cached

torch_is_available(ensure_available=False)

Check if PyTorch is available.

Source code in zeus/utils/framework.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@lru_cache(maxsize=1)
def torch_is_available(ensure_available: bool = False):
    """Check if PyTorch is available."""
    try:
        import torch

        assert (
            torch.cuda.is_available()
        ), "PyTorch is available but does not have CUDA support."
        MODULE_CACHE["torch"] = torch
        logger.info("PyTorch with CUDA support is available.")
        return True
    except ImportError as e:
        logger.info("PyTorch is not available.")
        if ensure_available:
            raise RuntimeError("Failed to import Pytorch") from e
        return False

jax_is_available cached

jax_is_available(ensure_available=False)

Check if JAX is available.

Source code in zeus/utils/framework.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@lru_cache(maxsize=1)
def jax_is_available(ensure_available: bool = False):
    """Check if JAX is available."""
    try:
        import jax  # type: ignore

        assert jax.devices("gpu"), "JAX is available but does not have CUDA support."
        MODULE_CACHE["jax"] = jax
        logger.info("JAX with CUDA support is available.")
        return True
    except ImportError as e:
        logger.info("JAX is not available")
        if ensure_available:
            raise RuntimeError("Failed to import JAX") from e
        return False

cuda_sync

cuda_sync(device=None, backend='torch')

Synchronize CPU with CUDA.

cupy.cuda.Device.synchronize may be a good choice to make

CUDA device synchronization more general. Haven't tested it yet.

Parameters:

Name Type Description Default
device int | None

The device to synchronize.

None
backend Literal['torch', 'jax']

Deep learning framework to use to synchronize GPU computations. Defaults to "torch", in which case torch.cuda.synchronize will be used.

'torch'
Source code in zeus/utils/framework.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def cuda_sync(
    device: int | None = None, backend: Literal["torch", "jax"] = "torch"
) -> None:
    """Synchronize CPU with CUDA.

    Note: `cupy.cuda.Device.synchronize` may be a good choice to make
          CUDA device synchronization more general. Haven't tested it yet.

    Args:
        device: The device to synchronize.
        backend: Deep learning framework to use to synchronize GPU computations.
            Defaults to `"torch"`, in which case `torch.cuda.synchronize` will be used.
    """
    if backend == "torch" and torch_is_available(ensure_available=True):
        torch = MODULE_CACHE["torch"]

        torch.cuda.synchronize(device)

    elif backend == "jax" and jax_is_available(ensure_available=True):
        jax = MODULE_CACHE["jax"]

        (
            jax.device_put(
                0.0, device=None if device is None else jax.devices("gpu")[device]
            )
            + 0
        ).block_until_ready()

    else:
        raise RuntimeError("No framework is available.")