Skip to content

optimizer

zeus.optimizer.perseus.optimizer

Perseus optimizer implementation.

The PerseusOptimizer is to be integrated into the user-side framework. It is responsible for communicating with the Perseus server and managing the FrequencyController instance, which is responsible for controlling the frequency of the CPU of the current process.

PerseusOptimizer

Bases: Callback

Perseus optimizer.

Source code in zeus/optimizer/perseus/optimizer.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class PerseusOptimizer(Callback):
    """Perseus optimizer."""

    def __init__(
        self,
        rank: int,
        dp_rank: int,
        pp_rank: int,
        tp_rank: int,
        device_id: int,
        dp_degree: int,
        pp_degree: int,
        tp_degree: int,
        world_size: int,
        server_url: str,
        job_metadata: str | None = None,
    ) -> None:
        """Initialize the Perseus optimizer.

        Assumptions:
            - `torch.distributed` has been initialized.
            - `torch.cuda.set_device` has been called with `device_id`.
                This is needed to broadcast the job ID to all ranks.

        The master process (rank 0) will register the job with the Peresus
        server and retrieve the job ID of this job. Then, each rank will
        report itself to the Perseus server with the job ID.

        Args:
            rank: Global rank of the current process.
            dp_rank: Rank in the data parallel group.
            pp_rank: Rank in the pipeline parallel group.
            tp_rank: Rank in the tensor parallel group.
            device_id: CUDA device ID that the current process manages.
            dp_degree: Size of the data parallel group.
            pp_degree: Size of the pipeline parallel group.
            tp_degree: Size of the tensor parallel group.
            world_size: Total number of ranks that participate in training.
            server_url: URL of the Perseus server.
            job_metadata: An optional arbitrary string that describes the job. This will
                be appended to the job ID if given. Typically for logging purposes.
        """
        if not dist.is_initialized():
            raise RuntimeError(
                "Instantiate `PerseusOptimizer` after `init_process_group`."
            )

        self.server_url = server_url
        self.rank = rank
        self.dp_rank = dp_rank
        self.pp_rank = pp_rank
        self.tp_rank = tp_rank

        gpus = get_gpus()
        torch.cuda.set_device(device_id)

        # Rank 0 registers the job with the Perseus server and retrieves the job ID.
        job_id = None
        if rank == 0:
            job_info = JobInfo(
                pp_degree=pp_degree,
                dp_degree=dp_degree,
                tp_degree=tp_degree,
                world_size=world_size,
                job_metadata=job_metadata,
            )
            response = httpx.post(
                self.server_url + REGISTER_JOB_URL, json=job_info.dict()
            )
            if (code := response.status_code) != 200:
                raise RuntimeError(
                    f"Perseus server returned status code {code}: {response.text}"
                )
            job_id = response.json()
            if not isinstance(job_id, str):
                raise RuntimeError(
                    f"Perseus server returned a strange job ID: {job_id=}"
                )

        # Rank 0 broadcasts the job ID across all ranks.
        objects = [job_id]
        dist.broadcast_object_list(objects, src=0)
        self.job_id = objects[0]
        if self.job_id is None:
            raise RuntimeError("Failed to broadcast job ID to all ranks")

        # Query the list of available frequencies of the GPU.
        max_mem_freq = max(gpus.getSupportedMemoryClocks(device_id))
        freqs = sorted(
            gpus.getSupportedGraphicsClocks(device_id, max_mem_freq),
            reverse=True,
        )

        # Each rank reports itself to the Perseus server with the job ID.
        rank_info = RankInfo(
            rank=self.rank,
            dp_rank=self.dp_rank,
            pp_rank=self.pp_rank,
            tp_rank=self.tp_rank,
            available_frequencies=freqs,
        )
        response = httpx.post(
            self.server_url + REGISTER_RANK_URL.format(job_id=self.job_id),
            json=rank_info.dict(),
        )
        if (code := response.status_code) != 200:
            raise RuntimeError(
                f"Perseus server returned status code {code}: {response.text}"
            )

        # The frequency controller is responsible for controlling the frequency
        # of the GPU (device_id) asynchronously.
        self.frequency_controller = FrequencyController(device_id=device_id)

        # Fetch the frequency schedule from the Perseus server.
        self.freq_schedule = self._get_frequency_schedule()
        self.freq_schedule_iter = iter(self.freq_schedule)

    def _get_frequency_schedule(self) -> list[tuple[str, int]]:
        """Get the frequency schedule from the Perseus server."""
        response = httpx.get(
            self.server_url + GET_FREQUENCY_SCHEDULE_URL.format(job_id=self.job_id),
            params={"rank": self.rank},
            timeout=None,
        )
        if (code := response.status_code) != 200:
            raise RuntimeError(
                f"Perseus server returned status code {code}: {response.text}"
            )
        schedule = FrequencySchedule.parse_raw(response.text)
        if schedule.rank != self.rank:
            raise RuntimeError(
                f"Perseus server returned a schedule for rank {schedule.rank} to rank {self.rank}"
            )
        return schedule.frequencies

    def on_step_begin(self) -> None:
        """Mark the beginning of a step.

        TODO(jaywonchung): InstructionProfiler iteration start mark.
        """
        pass

    def on_step_end(self) -> None:
        """Mark the end of a step.

        TODO(jaywonchung): InstructionProfiler iteration end mark.
        Also report the profiling result to the Perseus server after N iterations.
        """
        # Frequency schedule holds one iteration-worth of frequencies, so at
        # the end of each iteration, the iterator should be exhausted.
        item = next(self.freq_schedule_iter, None)
        if item is not None:
            raise RuntimeError(
                "Perseus server returned more frequencies than expected. "
                f"Next expected instruction and frequency is {item}"
            )
        self.freq_schedule_iter = iter(self.freq_schedule)

    def on_instruction_begin(self, name: str) -> None:
        """Mark the beginning of an instruction, like forward and backward.

        Retrieve the next frequency from the schedule, check whether the next
        expected instruction matches the name of the instruction, and set the
        frequency accordingly.
        """
        cuda_sync(self.device_id)

        # Retrieve the next frequency from the schedule.
        item = next(self.freq_schedule_iter, None)
        if item is None:
            raise RuntimeError(
                "Perseus server returned fewer frequencies than expected"
            )

        # Check whether the next expected instruction matches the name of the instruction.
        instruction, frequency = item
        if instruction != name:
            raise RuntimeError(
                f"The next expected instruction is not forward: {instruction}"
            )

        self.frequency_controller.set_frequency(frequency)

    def on_instruction_end(self, _: str) -> None:
        """Mark the end of an instruction, like forward and backward."""

__init__

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
__init__(
    rank,
    dp_rank,
    pp_rank,
    tp_rank,
    device_id,
    dp_degree,
    pp_degree,
    tp_degree,
    world_size,
    server_url,
    job_metadata=None,
)
Assumptions
  • torch.distributed has been initialized.
  • torch.cuda.set_device has been called with device_id. This is needed to broadcast the job ID to all ranks.

The master process (rank 0) will register the job with the Peresus server and retrieve the job ID of this job. Then, each rank will report itself to the Perseus server with the job ID.

Parameters:

Name Type Description Default
rank int

Global rank of the current process.

required
dp_rank int

Rank in the data parallel group.

required
pp_rank int

Rank in the pipeline parallel group.

required
tp_rank int

Rank in the tensor parallel group.

required
device_id int

CUDA device ID that the current process manages.

required
dp_degree int

Size of the data parallel group.

required
pp_degree int

Size of the pipeline parallel group.

required
tp_degree int

Size of the tensor parallel group.

required
world_size int

Total number of ranks that participate in training.

required
server_url str

URL of the Perseus server.

required
job_metadata str | None

An optional arbitrary string that describes the job. This will be appended to the job ID if given. Typically for logging purposes.

None
Source code in zeus/optimizer/perseus/optimizer.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def __init__(
    self,
    rank: int,
    dp_rank: int,
    pp_rank: int,
    tp_rank: int,
    device_id: int,
    dp_degree: int,
    pp_degree: int,
    tp_degree: int,
    world_size: int,
    server_url: str,
    job_metadata: str | None = None,
) -> None:
    """Initialize the Perseus optimizer.

    Assumptions:
        - `torch.distributed` has been initialized.
        - `torch.cuda.set_device` has been called with `device_id`.
            This is needed to broadcast the job ID to all ranks.

    The master process (rank 0) will register the job with the Peresus
    server and retrieve the job ID of this job. Then, each rank will
    report itself to the Perseus server with the job ID.

    Args:
        rank: Global rank of the current process.
        dp_rank: Rank in the data parallel group.
        pp_rank: Rank in the pipeline parallel group.
        tp_rank: Rank in the tensor parallel group.
        device_id: CUDA device ID that the current process manages.
        dp_degree: Size of the data parallel group.
        pp_degree: Size of the pipeline parallel group.
        tp_degree: Size of the tensor parallel group.
        world_size: Total number of ranks that participate in training.
        server_url: URL of the Perseus server.
        job_metadata: An optional arbitrary string that describes the job. This will
            be appended to the job ID if given. Typically for logging purposes.
    """
    if not dist.is_initialized():
        raise RuntimeError(
            "Instantiate `PerseusOptimizer` after `init_process_group`."
        )

    self.server_url = server_url
    self.rank = rank
    self.dp_rank = dp_rank
    self.pp_rank = pp_rank
    self.tp_rank = tp_rank

    gpus = get_gpus()
    torch.cuda.set_device(device_id)

    # Rank 0 registers the job with the Perseus server and retrieves the job ID.
    job_id = None
    if rank == 0:
        job_info = JobInfo(
            pp_degree=pp_degree,
            dp_degree=dp_degree,
            tp_degree=tp_degree,
            world_size=world_size,
            job_metadata=job_metadata,
        )
        response = httpx.post(
            self.server_url + REGISTER_JOB_URL, json=job_info.dict()
        )
        if (code := response.status_code) != 200:
            raise RuntimeError(
                f"Perseus server returned status code {code}: {response.text}"
            )
        job_id = response.json()
        if not isinstance(job_id, str):
            raise RuntimeError(
                f"Perseus server returned a strange job ID: {job_id=}"
            )

    # Rank 0 broadcasts the job ID across all ranks.
    objects = [job_id]
    dist.broadcast_object_list(objects, src=0)
    self.job_id = objects[0]
    if self.job_id is None:
        raise RuntimeError("Failed to broadcast job ID to all ranks")

    # Query the list of available frequencies of the GPU.
    max_mem_freq = max(gpus.getSupportedMemoryClocks(device_id))
    freqs = sorted(
        gpus.getSupportedGraphicsClocks(device_id, max_mem_freq),
        reverse=True,
    )

    # Each rank reports itself to the Perseus server with the job ID.
    rank_info = RankInfo(
        rank=self.rank,
        dp_rank=self.dp_rank,
        pp_rank=self.pp_rank,
        tp_rank=self.tp_rank,
        available_frequencies=freqs,
    )
    response = httpx.post(
        self.server_url + REGISTER_RANK_URL.format(job_id=self.job_id),
        json=rank_info.dict(),
    )
    if (code := response.status_code) != 200:
        raise RuntimeError(
            f"Perseus server returned status code {code}: {response.text}"
        )

    # The frequency controller is responsible for controlling the frequency
    # of the GPU (device_id) asynchronously.
    self.frequency_controller = FrequencyController(device_id=device_id)

    # Fetch the frequency schedule from the Perseus server.
    self.freq_schedule = self._get_frequency_schedule()
    self.freq_schedule_iter = iter(self.freq_schedule)

_get_frequency_schedule

1
_get_frequency_schedule()

Get the frequency schedule from the Perseus server.

Source code in zeus/optimizer/perseus/optimizer.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def _get_frequency_schedule(self) -> list[tuple[str, int]]:
    """Get the frequency schedule from the Perseus server."""
    response = httpx.get(
        self.server_url + GET_FREQUENCY_SCHEDULE_URL.format(job_id=self.job_id),
        params={"rank": self.rank},
        timeout=None,
    )
    if (code := response.status_code) != 200:
        raise RuntimeError(
            f"Perseus server returned status code {code}: {response.text}"
        )
    schedule = FrequencySchedule.parse_raw(response.text)
    if schedule.rank != self.rank:
        raise RuntimeError(
            f"Perseus server returned a schedule for rank {schedule.rank} to rank {self.rank}"
        )
    return schedule.frequencies

on_step_begin

1
on_step_begin()

Mark the beginning of a step.

TODO(jaywonchung): InstructionProfiler iteration start mark.

Source code in zeus/optimizer/perseus/optimizer.py
179
180
181
182
183
184
def on_step_begin(self) -> None:
    """Mark the beginning of a step.

    TODO(jaywonchung): InstructionProfiler iteration start mark.
    """
    pass

on_step_end

1
on_step_end()

Mark the end of a step.

TODO(jaywonchung): InstructionProfiler iteration end mark. Also report the profiling result to the Perseus server after N iterations.

Source code in zeus/optimizer/perseus/optimizer.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def on_step_end(self) -> None:
    """Mark the end of a step.

    TODO(jaywonchung): InstructionProfiler iteration end mark.
    Also report the profiling result to the Perseus server after N iterations.
    """
    # Frequency schedule holds one iteration-worth of frequencies, so at
    # the end of each iteration, the iterator should be exhausted.
    item = next(self.freq_schedule_iter, None)
    if item is not None:
        raise RuntimeError(
            "Perseus server returned more frequencies than expected. "
            f"Next expected instruction and frequency is {item}"
        )
    self.freq_schedule_iter = iter(self.freq_schedule)

on_instruction_begin

1
on_instruction_begin(name)

Mark the beginning of an instruction, like forward and backward.

Retrieve the next frequency from the schedule, check whether the next expected instruction matches the name of the instruction, and set the frequency accordingly.

Source code in zeus/optimizer/perseus/optimizer.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def on_instruction_begin(self, name: str) -> None:
    """Mark the beginning of an instruction, like forward and backward.

    Retrieve the next frequency from the schedule, check whether the next
    expected instruction matches the name of the instruction, and set the
    frequency accordingly.
    """
    cuda_sync(self.device_id)

    # Retrieve the next frequency from the schedule.
    item = next(self.freq_schedule_iter, None)
    if item is None:
        raise RuntimeError(
            "Perseus server returned fewer frequencies than expected"
        )

    # Check whether the next expected instruction matches the name of the instruction.
    instruction, frequency = item
    if instruction != name:
        raise RuntimeError(
            f"The next expected instruction is not forward: {instruction}"
        )

    self.frequency_controller.set_frequency(frequency)

on_instruction_end

1
on_instruction_end(_)

Mark the end of an instruction, like forward and backward.

Source code in zeus/optimizer/perseus/optimizer.py
227
228
def on_instruction_end(self, _: str) -> None:
    """Mark the end of an instruction, like forward and backward."""