Skip to content

controller

zeus.controller

Controllers influence the flow or progress of training.

EarlyStopController

Bases: Callback

Controller for early stopping.

Source code in zeus/controller.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class EarlyStopController(Callback):
    """Controller for early stopping."""

    def __init__(
        self,
        monitor: ZeusMonitor | None = None,
        eta_knob: float = 0.5,
        cost_threshold: float | None = None,
        max_epochs: int | None = None,
        target_metric: float | None = None,
        higher_is_better: bool | None = None,
    ) -> None:
        r"""Initialize the controller.

        Check whether training should terminate through the `should_training_stop` attribute.
        - If you gave `max_epochs`, check after `on_epoch_end()`.
        - If you gave `cost_threshold`, check after `on_epoch_end()`.
        - If you gave `target_metric`, check after `on_evaluate()`.

        Args:
            monitor: The monitor instance to use for measuring time and energy.
                Required if `cost_threshold` is given.
            eta_knob: The $0 \le \eta \le 1$ knob for the Zeus time-energy cost.
                (Default: 0.5)
            cost_threshold: When running the next epoch will exceed this cost.
                Only training cost is considered, not validation or testing cost.
            max_epochs: Maximum number of epochs to run.
            target_metric: Stop training when the metric reaches this value.
            higher_is_better: If `True`, `target_metric` is assumed reached when the
                reported metric is larger than or equal to the `target_metric`.
                Required if `target_metric` is given.
        """
        # Sanity check the arguments.
        if max_epochs is not None and max_epochs <= 0:
            raise ValueError("max_epochs must be positive")
        if cost_threshold is not None and cost_threshold <= 0:
            raise ValueError("cost_threshold must be positive")
        if (cost_threshold is None) ^ (monitor is None):
            raise ValueError("cost_threshold and monitor must be given together")
        if (target_metric is None) ^ (higher_is_better is None):
            raise ValueError(
                "target_metric and higher_is_better must be given together"
            )

        # Save arguments.
        self.monitor = monitor
        self.eta_knob = eta_knob
        self.cost_threshold = cost_threshold
        self.max_epochs = max_epochs
        self.target_metric = target_metric
        self.higher_is_better = higher_is_better

        # Setup logging.
        self.logger = get_logger(type(self).__name__)

        # Cache NVML device handles if they're needed.
        self.max_power = {}
        if self.cost_threshold is not None:
            assert self.monitor is not None
            gpus = get_gpus()
            for gpu_index in self.monitor.gpu_indices:
                self.max_power[gpu_index] = (
                    gpus.getPowerManagementLimitConstraints(gpu_index)[1] // 1000
                )

        # States.
        self.epochs_trained = 0
        self.epoch_costs = []

        # Once switched to `True`, there is no switching back to `False`.
        self.should_training_stop = False

    def on_epoch_begin(self) -> None:
        """Start measuring the cost of the next epoch."""
        if self.cost_threshold is not None:
            assert self.monitor is not None
            self.monitor.begin_window("__EarlyStopController_epoch")

    def on_epoch_end(self) -> None:
        """Check if the training cost of the next epoch will exceed the threshold."""
        if self.max_epochs is not None:
            self.epochs_trained += 1
            if self.epochs_trained >= self.max_epochs:
                self.logger.info(
                    "[Stop training!] Epochs trained %d >= Max epochs %d",
                    self.epochs_trained,
                    self.max_epochs,
                )
                self.should_training_stop = True
                return

        if self.cost_threshold is not None:
            assert self.monitor is not None
            measurement = self.monitor.end_window("__EarlyStopController_epoch")
            cost = sum(
                zeus_cost(
                    energy=measurement.energy[gpu_index],
                    time=measurement.time,
                    eta_knob=self.eta_knob,
                    max_power=self.max_power[gpu_index],
                )
                for gpu_index in self.monitor.gpu_indices
            )
            self.epoch_costs.append(cost)
            if (nec := self._expected_next_epoch_cost()) >= self.cost_threshold:
                self.logger.info(
                    "[Stop training!] Expected next epoch cost %f >= Cost threshold %f",
                    nec,
                    self.cost_threshold,
                )
                self.should_training_stop = True
                return

    def on_evaluate(self, metric: float) -> None:
        """Check if the target metric was reached."""
        if self.target_metric is not None:
            assert self.higher_is_better is not None
            # ruff: noqa: SIM108
            if self.higher_is_better:
                reached = metric >= self.target_metric
            else:
                reached = metric <= self.target_metric
            if reached:
                self.logger.info(
                    "[Stop training!] Evaluation metric %f reached target metric %f",
                    metric,
                    self.target_metric,
                )
                self.should_training_stop = True

    def _expected_next_epoch_cost(self) -> float:
        """Predict the total cost if the next training epoch is to be run."""
        cost_until_now = sum(self.epoch_costs)
        average_epoch_cost = cost_until_now / len(self.epoch_costs)
        return cost_until_now + average_epoch_cost

__init__

1
2
3
4
5
6
7
8
__init__(
    monitor=None,
    eta_knob=0.5,
    cost_threshold=None,
    max_epochs=None,
    target_metric=None,
    higher_is_better=None,
)

Check whether training should terminate through the should_training_stop attribute. - If you gave max_epochs, check after on_epoch_end(). - If you gave cost_threshold, check after on_epoch_end(). - If you gave target_metric, check after on_evaluate().

Parameters:

Name Type Description Default
monitor ZeusMonitor | None

The monitor instance to use for measuring time and energy. Required if cost_threshold is given.

None
eta_knob float

The \(0 \le \eta \le 1\) knob for the Zeus time-energy cost. (Default: 0.5)

0.5
cost_threshold float | None

When running the next epoch will exceed this cost. Only training cost is considered, not validation or testing cost.

None
max_epochs int | None

Maximum number of epochs to run.

None
target_metric float | None

Stop training when the metric reaches this value.

None
higher_is_better bool | None

If True, target_metric is assumed reached when the reported metric is larger than or equal to the target_metric. Required if target_metric is given.

None
Source code in zeus/controller.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    monitor: ZeusMonitor | None = None,
    eta_knob: float = 0.5,
    cost_threshold: float | None = None,
    max_epochs: int | None = None,
    target_metric: float | None = None,
    higher_is_better: bool | None = None,
) -> None:
    r"""Initialize the controller.

    Check whether training should terminate through the `should_training_stop` attribute.
    - If you gave `max_epochs`, check after `on_epoch_end()`.
    - If you gave `cost_threshold`, check after `on_epoch_end()`.
    - If you gave `target_metric`, check after `on_evaluate()`.

    Args:
        monitor: The monitor instance to use for measuring time and energy.
            Required if `cost_threshold` is given.
        eta_knob: The $0 \le \eta \le 1$ knob for the Zeus time-energy cost.
            (Default: 0.5)
        cost_threshold: When running the next epoch will exceed this cost.
            Only training cost is considered, not validation or testing cost.
        max_epochs: Maximum number of epochs to run.
        target_metric: Stop training when the metric reaches this value.
        higher_is_better: If `True`, `target_metric` is assumed reached when the
            reported metric is larger than or equal to the `target_metric`.
            Required if `target_metric` is given.
    """
    # Sanity check the arguments.
    if max_epochs is not None and max_epochs <= 0:
        raise ValueError("max_epochs must be positive")
    if cost_threshold is not None and cost_threshold <= 0:
        raise ValueError("cost_threshold must be positive")
    if (cost_threshold is None) ^ (monitor is None):
        raise ValueError("cost_threshold and monitor must be given together")
    if (target_metric is None) ^ (higher_is_better is None):
        raise ValueError(
            "target_metric and higher_is_better must be given together"
        )

    # Save arguments.
    self.monitor = monitor
    self.eta_knob = eta_knob
    self.cost_threshold = cost_threshold
    self.max_epochs = max_epochs
    self.target_metric = target_metric
    self.higher_is_better = higher_is_better

    # Setup logging.
    self.logger = get_logger(type(self).__name__)

    # Cache NVML device handles if they're needed.
    self.max_power = {}
    if self.cost_threshold is not None:
        assert self.monitor is not None
        gpus = get_gpus()
        for gpu_index in self.monitor.gpu_indices:
            self.max_power[gpu_index] = (
                gpus.getPowerManagementLimitConstraints(gpu_index)[1] // 1000
            )

    # States.
    self.epochs_trained = 0
    self.epoch_costs = []

    # Once switched to `True`, there is no switching back to `False`.
    self.should_training_stop = False

on_epoch_begin

1
on_epoch_begin()

Start measuring the cost of the next epoch.

Source code in zeus/controller.py
 98
 99
100
101
102
def on_epoch_begin(self) -> None:
    """Start measuring the cost of the next epoch."""
    if self.cost_threshold is not None:
        assert self.monitor is not None
        self.monitor.begin_window("__EarlyStopController_epoch")

on_epoch_end

1
on_epoch_end()

Check if the training cost of the next epoch will exceed the threshold.

Source code in zeus/controller.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def on_epoch_end(self) -> None:
    """Check if the training cost of the next epoch will exceed the threshold."""
    if self.max_epochs is not None:
        self.epochs_trained += 1
        if self.epochs_trained >= self.max_epochs:
            self.logger.info(
                "[Stop training!] Epochs trained %d >= Max epochs %d",
                self.epochs_trained,
                self.max_epochs,
            )
            self.should_training_stop = True
            return

    if self.cost_threshold is not None:
        assert self.monitor is not None
        measurement = self.monitor.end_window("__EarlyStopController_epoch")
        cost = sum(
            zeus_cost(
                energy=measurement.energy[gpu_index],
                time=measurement.time,
                eta_knob=self.eta_knob,
                max_power=self.max_power[gpu_index],
            )
            for gpu_index in self.monitor.gpu_indices
        )
        self.epoch_costs.append(cost)
        if (nec := self._expected_next_epoch_cost()) >= self.cost_threshold:
            self.logger.info(
                "[Stop training!] Expected next epoch cost %f >= Cost threshold %f",
                nec,
                self.cost_threshold,
            )
            self.should_training_stop = True
            return

on_evaluate

1
on_evaluate(metric)

Check if the target metric was reached.

Source code in zeus/controller.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def on_evaluate(self, metric: float) -> None:
    """Check if the target metric was reached."""
    if self.target_metric is not None:
        assert self.higher_is_better is not None
        # ruff: noqa: SIM108
        if self.higher_is_better:
            reached = metric >= self.target_metric
        else:
            reached = metric <= self.target_metric
        if reached:
            self.logger.info(
                "[Stop training!] Evaluation metric %f reached target metric %f",
                metric,
                self.target_metric,
            )
            self.should_training_stop = True

_expected_next_epoch_cost

1
_expected_next_epoch_cost()

Predict the total cost if the next training epoch is to be run.

Source code in zeus/controller.py
156
157
158
159
160
def _expected_next_epoch_cost(self) -> float:
    """Predict the total cost if the next training epoch is to be run."""
    cost_until_now = sum(self.epoch_costs)
    average_epoch_cost = cost_until_now / len(self.epoch_costs)
    return cost_until_now + average_epoch_cost