framework
zeus.utils.framework
Utilities for framework-specific code.
torch_is_available
cached
torch_is_available(
ensure_available=False, ensure_cuda=True
)
Check if PyTorch is available.
Source code in zeus/utils/framework.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
jax_is_available
cached
jax_is_available(ensure_available=False, ensure_cuda=True)
Check if JAX is available.
Source code in zeus/utils/framework.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
|
sync_execution
sync_execution(gpu_devices, sync_with='torch')
Block until all computations on the specified devices are finished.
PyTorch only runs GPU computations asynchronously, so synchronizing computations
for the given GPU devices is done by calling torch.cuda.synchronize
on each
device. On the other hand, JAX runs both CPU and GPU computations asynchronously,
but by default it only has a single CPU device (id=0). Therefore for JAX, all GPU
devices passed in and the CPU device (id=0) are synchronized.
Note
jax.device_put
with block_until_ready
is used to synchronize computations
on JAX devices. This is a workaround to the lack of a direct API for
synchronizing computations on JAX devices. Tracking issue:
https://github.com/google/jax/issues/4335
Note
Across the Zeus library, an integer device index corresponds to a single whole
physical device. This is usually what you want, except when using more advanced
device partitioning (e.g., using --xla_force_host_platform_device_count
in JAX
to partition CPUs into more pieces). In such cases, you probably want to opt out
from using this function and handle synchronization manually at the appropriate
granularity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gpu_devices |
list[int]
|
GPU device indices to synchronize. |
required |
sync_with |
Literal['torch', 'jax']
|
Deep learning framework to use to synchronize computations.
Defaults to |
'torch'
|
Source code in zeus/utils/framework.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
|