Skip to content

power

zeus.monitor.power

Monitor the power usage of GPUs.

PowerMonitor

Monitor power usage from GPUs.

This class acts as a lower level wrapper around a Python process that polls the power consumption of GPUs. This is primarily used by ZeusMonitor for older architecture GPUs that do not support the nvmlDeviceGetTotalEnergyConsumption API.

Attributes:

Name Type Description
gpu_indices list[int]

Indices of the GPUs to monitor.

update_period int

Update period of the power monitor in seconds. Holds inferred update period if update_period was given as None.

Source code in zeus/monitor/power.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class PowerMonitor:
    """Monitor power usage from GPUs.

    This class acts as a lower level wrapper around a Python process that polls
    the power consumption of GPUs. This is primarily used by
    [`ZeusMonitor`][zeus.monitor.ZeusMonitor] for older architecture GPUs that
    do not support the nvmlDeviceGetTotalEnergyConsumption API.

    Attributes:
        gpu_indices (list[int]): Indices of the GPUs to monitor.
        update_period (int): Update period of the power monitor in seconds.
            Holds inferred update period if `update_period` was given as `None`.
    """

    def __init__(
        self,
        gpu_indices: list[int] | None = None,
        update_period: float | None = None,
    ) -> None:
        """Initialize the power monitor.

        Initialization should not be done in global scope due to python's protection.
        Refer to the "Safe importing of main module" section in https://docs.python.org/3/library/multiprocessing.html for more detail.

        Args:
            gpu_indices: Indices of the GPUs to monitor. If None, monitor all GPUs.
            update_period: Update period of the power monitor in seconds. If None,
                infer the update period by max speed polling the power counter for
                each GPU model.
        """
        if gpu_indices is not None and not gpu_indices:
            raise ValueError("`gpu_indices` must be either `None` or non-empty")

        # Get GPUs
        gpus = get_gpus()

        # Set up logging.
        self.logger = get_logger(type(self).__name__)

        # Get GPUs
        self.gpu_indices = (
            gpu_indices if gpu_indices is not None else list(range(len(gpus)))
        )
        self.logger.info("Monitoring power usage of GPUs %s", self.gpu_indices)

        # Infer the update period if necessary.
        if update_period is None:
            update_period = infer_counter_update_period(self.gpu_indices)
        self.update_period = update_period

        # Create the CSV file for power measurements.
        power_csv = tempfile.mkstemp(suffix=".csv", text=True)[1]
        open(power_csv, "w").close()
        self.power_f = open(power_csv)
        self.power_df_columns = ["time"] + [f"power{i}" for i in self.gpu_indices]
        self.power_df = pd.DataFrame(columns=self.power_df_columns)

        # Spawn the power polling process.
        atexit.register(self._stop)
        self.process = mp.get_context("spawn").Process(
            target=_polling_process, args=(self.gpu_indices, power_csv, update_period)
        )
        self.process.start()

    def _stop(self) -> None:
        """Stop monitoring power usage."""
        if self.process is not None:
            self.process.terminate()
            self.process.join(timeout=1.0)
            self.process.kill()
            self.process = None

    def _update_df(self) -> None:
        """Add rows to the power dataframe from the CSV file."""
        try:
            additional_df = typing.cast(
                pd.DataFrame,
                pd.read_csv(self.power_f, header=None, names=self.power_df_columns),  # type: ignore
            )
        except pd.errors.EmptyDataError:
            return
        self.power_df = pd.concat([self.power_df, additional_df], axis=0)

    def get_energy(self, start_time: float, end_time: float) -> dict[int, float] | None:
        """Get the energy used by the GPUs between two times.

        Args:
            start_time: Start time of the interval, from time.time().
            end_time: End time of the interval, from time.time().

        Returns:
            A dictionary mapping GPU indices to the energy used by the GPU between the
            two times. GPU indices are from the DL framework's perspective after
            applying `CUDA_VISIBLE_DEVICES`.
            If there are no power readings, return None.
        """
        self._update_df()

        if self.power_df.empty:
            return None

        df = typing.cast(
            pd.DataFrame, self.power_df.query(f"{start_time} <= time <= {end_time}")
        )

        try:
            return {
                i: float(auc(df["time"], df[f"power{i}"])) for i in self.gpu_indices
            }
        except ValueError:
            return None

    def get_power(self, time: float | None = None) -> dict[int, float] | None:
        """Get the power usage of the GPUs at a specific time point.

        Args:
            time: Time point to get the power usage at. If None, get the power usage
                at the last recorded time point.

        Returns:
            A dictionary mapping GPU indices to the power usage of the GPU at the
            specified time point. GPU indices are from the DL framework's perspective
            after applying `CUDA_VISIBLE_DEVICES`.
            If there are no power readings, return None.
        """
        self._update_df()

        if self.power_df.empty:
            return None

        if time is None:
            row = self.power_df.iloc[-1]
        else:
            ind = self.power_df.time.searchsorted(time)
            try:
                row = self.power_df.iloc[ind]
            except IndexError:
                # This means that the time is after the last recorded power reading.
                row = self.power_df.iloc[-1]

        return {i: float(row[f"power{i}"]) for i in self.gpu_indices}

__init__

1
__init__(gpu_indices=None, update_period=None)

Initialization should not be done in global scope due to python's protection. Refer to the "Safe importing of main module" section in https://docs.python.org/3/library/multiprocessing.html for more detail.

Parameters:

Name Type Description Default
gpu_indices list[int] | None

Indices of the GPUs to monitor. If None, monitor all GPUs.

None
update_period float | None

Update period of the power monitor in seconds. If None, infer the update period by max speed polling the power counter for each GPU model.

None
Source code in zeus/monitor/power.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(
    self,
    gpu_indices: list[int] | None = None,
    update_period: float | None = None,
) -> None:
    """Initialize the power monitor.

    Initialization should not be done in global scope due to python's protection.
    Refer to the "Safe importing of main module" section in https://docs.python.org/3/library/multiprocessing.html for more detail.

    Args:
        gpu_indices: Indices of the GPUs to monitor. If None, monitor all GPUs.
        update_period: Update period of the power monitor in seconds. If None,
            infer the update period by max speed polling the power counter for
            each GPU model.
    """
    if gpu_indices is not None and not gpu_indices:
        raise ValueError("`gpu_indices` must be either `None` or non-empty")

    # Get GPUs
    gpus = get_gpus()

    # Set up logging.
    self.logger = get_logger(type(self).__name__)

    # Get GPUs
    self.gpu_indices = (
        gpu_indices if gpu_indices is not None else list(range(len(gpus)))
    )
    self.logger.info("Monitoring power usage of GPUs %s", self.gpu_indices)

    # Infer the update period if necessary.
    if update_period is None:
        update_period = infer_counter_update_period(self.gpu_indices)
    self.update_period = update_period

    # Create the CSV file for power measurements.
    power_csv = tempfile.mkstemp(suffix=".csv", text=True)[1]
    open(power_csv, "w").close()
    self.power_f = open(power_csv)
    self.power_df_columns = ["time"] + [f"power{i}" for i in self.gpu_indices]
    self.power_df = pd.DataFrame(columns=self.power_df_columns)

    # Spawn the power polling process.
    atexit.register(self._stop)
    self.process = mp.get_context("spawn").Process(
        target=_polling_process, args=(self.gpu_indices, power_csv, update_period)
    )
    self.process.start()

_stop

1
_stop()

Stop monitoring power usage.

Source code in zeus/monitor/power.py
165
166
167
168
169
170
171
def _stop(self) -> None:
    """Stop monitoring power usage."""
    if self.process is not None:
        self.process.terminate()
        self.process.join(timeout=1.0)
        self.process.kill()
        self.process = None

_update_df

1
_update_df()

Add rows to the power dataframe from the CSV file.

Source code in zeus/monitor/power.py
173
174
175
176
177
178
179
180
181
182
def _update_df(self) -> None:
    """Add rows to the power dataframe from the CSV file."""
    try:
        additional_df = typing.cast(
            pd.DataFrame,
            pd.read_csv(self.power_f, header=None, names=self.power_df_columns),  # type: ignore
        )
    except pd.errors.EmptyDataError:
        return
    self.power_df = pd.concat([self.power_df, additional_df], axis=0)

get_energy

1
get_energy(start_time, end_time)

Get the energy used by the GPUs between two times.

Parameters:

Name Type Description Default
start_time float

Start time of the interval, from time.time().

required
end_time float

End time of the interval, from time.time().

required

Returns:

Type Description
dict[int, float] | None

A dictionary mapping GPU indices to the energy used by the GPU between the

dict[int, float] | None

two times. GPU indices are from the DL framework's perspective after

dict[int, float] | None

applying CUDA_VISIBLE_DEVICES.

dict[int, float] | None

If there are no power readings, return None.

Source code in zeus/monitor/power.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def get_energy(self, start_time: float, end_time: float) -> dict[int, float] | None:
    """Get the energy used by the GPUs between two times.

    Args:
        start_time: Start time of the interval, from time.time().
        end_time: End time of the interval, from time.time().

    Returns:
        A dictionary mapping GPU indices to the energy used by the GPU between the
        two times. GPU indices are from the DL framework's perspective after
        applying `CUDA_VISIBLE_DEVICES`.
        If there are no power readings, return None.
    """
    self._update_df()

    if self.power_df.empty:
        return None

    df = typing.cast(
        pd.DataFrame, self.power_df.query(f"{start_time} <= time <= {end_time}")
    )

    try:
        return {
            i: float(auc(df["time"], df[f"power{i}"])) for i in self.gpu_indices
        }
    except ValueError:
        return None

get_power

1
get_power(time=None)

Get the power usage of the GPUs at a specific time point.

Parameters:

Name Type Description Default
time float | None

Time point to get the power usage at. If None, get the power usage at the last recorded time point.

None

Returns:

Type Description
dict[int, float] | None

A dictionary mapping GPU indices to the power usage of the GPU at the

dict[int, float] | None

specified time point. GPU indices are from the DL framework's perspective

dict[int, float] | None

after applying CUDA_VISIBLE_DEVICES.

dict[int, float] | None

If there are no power readings, return None.

Source code in zeus/monitor/power.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def get_power(self, time: float | None = None) -> dict[int, float] | None:
    """Get the power usage of the GPUs at a specific time point.

    Args:
        time: Time point to get the power usage at. If None, get the power usage
            at the last recorded time point.

    Returns:
        A dictionary mapping GPU indices to the power usage of the GPU at the
        specified time point. GPU indices are from the DL framework's perspective
        after applying `CUDA_VISIBLE_DEVICES`.
        If there are no power readings, return None.
    """
    self._update_df()

    if self.power_df.empty:
        return None

    if time is None:
        row = self.power_df.iloc[-1]
    else:
        ind = self.power_df.time.searchsorted(time)
        try:
            row = self.power_df.iloc[ind]
        except IndexError:
            # This means that the time is after the last recorded power reading.
            row = self.power_df.iloc[-1]

    return {i: float(row[f"power{i}"]) for i in self.gpu_indices}

infer_counter_update_period

1
infer_counter_update_period(gpu_indicies)

Infer the update period of the NVML power counter.

NVML counters can update as slow as 10 Hz depending on the GPU model, so there's no need to poll them too faster than that. This function infers the update period for each unique GPU model and selects the fastest-updating period detected. Then, it returns half the period to ensure that the counter is polled at least twice per update period.

Source code in zeus/monitor/power.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def infer_counter_update_period(gpu_indicies: list[int]) -> float:
    """Infer the update period of the NVML power counter.

    NVML counters can update as slow as 10 Hz depending on the GPU model, so
    there's no need to poll them too faster than that. This function infers the
    update period for each unique GPU model and selects the fastest-updating
    period detected. Then, it returns half the period to ensure that the
    counter is polled at least twice per update period.
    """
    logger = get_logger(__name__)

    # get gpus
    gpus = get_gpus()

    # For each unique GPU model, infer the update period.
    update_period = 0.0
    gpu_models_covered = set()
    for index in gpu_indicies:
        if (model := gpus.getName(index)) not in gpu_models_covered:
            logger.info(
                "Detected %s, inferring NVML power counter update period.", model
            )
            gpu_models_covered.add(model)
            detected_period = _infer_counter_update_period_single(index)
            logger.info(
                "Counter update period for %s is %.2f s",
                model,
                detected_period,
            )
            if update_period > detected_period:
                update_period = detected_period

    # Target half the update period to ensure that the counter is enough.
    update_period /= 2.0

    # Anything less than ten times a second is probably too slow.
    if update_period > 0.1:
        logger.warning(
            "Inferred update period (%.2f s) is too long. Using 0.1 s instead.",
            update_period,
        )
        update_period = 0.1
    return update_period

_infer_counter_update_period_single

1
_infer_counter_update_period_single(gpu_index)

Infer the update period of the NVML power counter for a single GPU.

Source code in zeus/monitor/power.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def _infer_counter_update_period_single(gpu_index: int) -> float:
    """Infer the update period of the NVML power counter for a single GPU."""
    # get gpus
    gpus = get_gpus()
    # Collect 1000 samples of the power counter with timestamps.
    time_power_samples: list[tuple[float, int]] = [(0.0, 0) for _ in range(1000)]
    for i in range(len(time_power_samples)):
        time_power_samples[i] = (
            time(),
            gpus.getPowerUsage(gpu_index),
        )

    # Find the timestamps when the power readings changed.
    changed_times = []
    prev_power = time_power_samples[0][1]
    for t, p in time_power_samples:
        if p != prev_power:
            changed_times.append(t)
            prev_power = p

    # Compute the minimum time difference between power change timestamps.
    return min(time2 - time1 for time1, time2 in zip(changed_times, changed_times[1:]))

_polling_process

1
_polling_process(gpu_indices, power_csv, update_period)

Run the power monitor.

Source code in zeus/monitor/power.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def _polling_process(
    gpu_indices: list[int],
    power_csv: str,
    update_period: float,
) -> None:
    """Run the power monitor."""
    try:
        # Get GPUs
        gpus = get_gpus()

        # Use line buffering.
        with open(power_csv, "w", buffering=1) as power_f:
            while True:
                power: list[float] = []
                now = time()
                for index in gpu_indices:
                    power.append(gpus.getPowerUsage(index))
                power_str = ",".join(map(lambda p: str(p / 1000), power))
                power_f.write(f"{now},{power_str}\n")
                if (sleep_time := update_period - (time() - now)) > 0:
                    sleep(sleep_time)
    except KeyboardInterrupt:
        return