Skip to content

job_manager

zeus.optimizer.perseus.server.job_manager

The JobManager singleton class manages all job states.

JobManager

A singleton class that manages all states.

Source code in zeus/optimizer/perseus/server/job_manager.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class JobManager:
    """A singleton class that manages all states."""

    def __init__(self, perseus_settings: PerseusSettings) -> None:
        """Initialize the job manager."""
        self.perseus_settings = perseus_settings

        self._job_infos: dict[str, JobInfo] = {}
        self._job_rank_infos: dict[str, list[RankInfo]] = {}
        self._job_tasks: dict[str, asyncio.Task] = {}
        self._job_result_channels: dict[str, asyncio.Queue[ProfilingResult]] = {}
        self._job_sched_request_channels: dict[str, asyncio.Queue] = {}
        self._job_sched_response_channels: dict[str, list[asyncio.Queue]] = {}
        self._job_last_active_time: dict[str, float] = {}

        # Spawn cleanup task that evicts the state of jobs that have not been active
        # for a long time.
        create_task(
            self._cleanup_task(
                cleanup_period=60,
                max_idle_time=perseus_settings.max_job_idle_time,
            ),
            logger=logger,
        )

    def register_job(self, job_info: JobInfo) -> None:
        """Prepare internal state for a new job.

        This method will be invoked exactly once by the global rank 0 (master) process.
        """
        job_id = job_info.job_id
        world_size = job_info.world_size
        self._job_infos[job_id] = job_info
        self._job_rank_infos[job_id] = []
        self._job_result_channels[job_id] = asyncio.Queue(maxsize=world_size)
        self._job_sched_request_channels[job_id] = asyncio.Queue(maxsize=world_size)
        self._job_sched_response_channels[job_id] = [
            asyncio.Queue(maxsize=1) for _ in range(world_size)
        ]
        self._job_tasks[job_id] = create_task(
            self._job_task(job_id, self.perseus_settings.dump_data),
            logger=logger,
        )
        self._job_last_active_time[job_id] = time.monotonic()

    def register_rank(self, job_id: str, rank_info: RankInfo) -> None:
        """Register rank-specific information for an already registered job.

        This method will be invoked `world_size` number of times (once per rank).
        """
        self._job_rank_infos[job_id].append(rank_info)
        self._job_last_active_time[job_id] = time.monotonic()

    async def get_frequency_schedule(self, job_id: str, rank: int) -> FrequencySchedule:
        """Get the next frequency schedule for a rank.

        This method will be called `world_size` number of times (once per rank).
        All ranks will block on this method untill everyone reports their
        profiling results and calls this method.

        When an internal scheduler error happened at any point of servicing the
        job, clients will be notified through this API with a 500 Internal Error.
        """
        await self._job_sched_request_channels[job_id].put(rank)
        res = await self._job_sched_response_channels[job_id][rank].get()
        if isinstance(res, Exception):
            code = 400 if isinstance(res, ValueError) else 500
            raise HTTPException(
                status_code=code,
                detail="".join(
                    traceback.format_exception(type(res), res, res.__traceback__)
                ),
            )
        self._job_last_active_time[job_id] = time.monotonic()
        return res

    def report_profiling_result(self, job_id: str, result: ProfilingResult) -> None:
        """Send the profiling result to the job task and immediately return.

        This method will be called `world_size` number of times - one for each rank.
        """
        self._job_result_channels[job_id].put_nowait(result)
        self._job_last_active_time[job_id] = time.monotonic()

    async def _cleanup_task(
        self,
        cleanup_period: int,
        max_idle_time: int,
    ) -> None:
        """Periodically evict job states.

        Args:
            cleanup_period: How often to run the cleanup task, in seconds.
            max_idle_time: Maximum amount of time a job can be idle for, in seconds.
        """
        while True:
            await asyncio.sleep(cleanup_period)
            for job_id in list(self._job_last_active_time.keys()):
                if (
                    time.monotonic() - self._job_last_active_time[job_id]
                    > max_idle_time
                ):
                    self._job_tasks[job_id].cancel()
                    del self._job_infos[job_id]
                    del self._job_rank_infos[job_id]
                    del self._job_result_channels[job_id]
                    del self._job_sched_request_channels[job_id]
                    del self._job_sched_response_channels[job_id]
                    del self._job_tasks[job_id]
                    del self._job_last_active_time[job_id]

    async def _job_task(self, job_id: str, dump_data: bool) -> None:
        """Coalese requests and responses of each rank and interface with the scheduler."""
        result_chan = self._job_result_channels[job_id]
        sched_req_chan = self._job_sched_request_channels[job_id]
        sched_resp_chan = self._job_sched_response_channels[job_id]

        job_info = self._job_infos[job_id]

        try:
            # Wait until all ranks have reported their `RankInfo`s.
            rank_infos = self._job_rank_infos[job_id]
            while True:
                await asyncio.sleep(0.1)
                # Indexing the first element is always safe because this task is
                # created after putting the `RankInfo` of the first-connected rank
                # in `self.job_rank_infos[job_id]`.
                if len(rank_infos) == job_info.world_size:
                    break

            # Sort `RankInfo`s in rank order.
            rank_infos.sort(key=lambda r: r.rank)

            # Create directory to dump Perseus states.
            dump_dir = f"{self.perseus_settings.dump_dir}/{job_id}"
            if dump_data:
                await save_ranks(rank_infos, dump_dir)

            # Instantiate the frequency scheduler.
            scheduler = self.perseus_settings.scheduler(
                job_info,
                rank_infos,
                self.perseus_settings,
                **self.perseus_settings.scheduler_args,
            )

            # Provide next schedules, observe profiling results, and repeat.
            schedule_num = 0
            while True:
                # Compute the next `FrequencySchedule`s.
                schedules = scheduler.next_schedule()

                # Wait until all the ranks ask for the next schedule.
                await asyncio.gather(*[sched_req_chan.get() for _ in rank_infos])

                # Send out `FrequencySchedule`s.
                await asyncio.gather(
                    *[sched_resp_chan[s.rank].put(s) for s in schedules]
                )

                # Gather profiling results from all ranks.
                results = await asyncio.gather(*[result_chan.get() for _ in rank_infos])
                results.sort(key=lambda r: r.rank)

                # Dump profiling results and schedules.
                if dump_data:
                    schedules.sort(key=lambda s: s.rank)
                    await save_prof(results, dump_dir, schedule_num)
                    await save_sched(schedules, dump_dir, schedule_num)

                # Send `ProfilingResult`s to the scheduler.
                scheduler.observe(results)

                # Increment schedule number.
                schedule_num += 1

        except asyncio.CancelledError:
            # This task gets cancelled when it's idle for too long and evicted.
            pass

        except Exception as exc:
            # In case the scheduler errored, send out the exception to the clients.
            # The clients will receive the error when they ask for the next schedule.
            for chan in sched_resp_chan:
                chan.put_nowait(exc)
            raise

__init__

1
__init__(perseus_settings)
Source code in zeus/optimizer/perseus/server/job_manager.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self, perseus_settings: PerseusSettings) -> None:
    """Initialize the job manager."""
    self.perseus_settings = perseus_settings

    self._job_infos: dict[str, JobInfo] = {}
    self._job_rank_infos: dict[str, list[RankInfo]] = {}
    self._job_tasks: dict[str, asyncio.Task] = {}
    self._job_result_channels: dict[str, asyncio.Queue[ProfilingResult]] = {}
    self._job_sched_request_channels: dict[str, asyncio.Queue] = {}
    self._job_sched_response_channels: dict[str, list[asyncio.Queue]] = {}
    self._job_last_active_time: dict[str, float] = {}

    # Spawn cleanup task that evicts the state of jobs that have not been active
    # for a long time.
    create_task(
        self._cleanup_task(
            cleanup_period=60,
            max_idle_time=perseus_settings.max_job_idle_time,
        ),
        logger=logger,
    )

register_job

1
register_job(job_info)

Prepare internal state for a new job.

This method will be invoked exactly once by the global rank 0 (master) process.

Source code in zeus/optimizer/perseus/server/job_manager.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def register_job(self, job_info: JobInfo) -> None:
    """Prepare internal state for a new job.

    This method will be invoked exactly once by the global rank 0 (master) process.
    """
    job_id = job_info.job_id
    world_size = job_info.world_size
    self._job_infos[job_id] = job_info
    self._job_rank_infos[job_id] = []
    self._job_result_channels[job_id] = asyncio.Queue(maxsize=world_size)
    self._job_sched_request_channels[job_id] = asyncio.Queue(maxsize=world_size)
    self._job_sched_response_channels[job_id] = [
        asyncio.Queue(maxsize=1) for _ in range(world_size)
    ]
    self._job_tasks[job_id] = create_task(
        self._job_task(job_id, self.perseus_settings.dump_data),
        logger=logger,
    )
    self._job_last_active_time[job_id] = time.monotonic()

register_rank

1
register_rank(job_id, rank_info)

Register rank-specific information for an already registered job.

This method will be invoked world_size number of times (once per rank).

Source code in zeus/optimizer/perseus/server/job_manager.py
88
89
90
91
92
93
94
def register_rank(self, job_id: str, rank_info: RankInfo) -> None:
    """Register rank-specific information for an already registered job.

    This method will be invoked `world_size` number of times (once per rank).
    """
    self._job_rank_infos[job_id].append(rank_info)
    self._job_last_active_time[job_id] = time.monotonic()

get_frequency_schedule async

1
get_frequency_schedule(job_id, rank)

Get the next frequency schedule for a rank.

This method will be called world_size number of times (once per rank). All ranks will block on this method untill everyone reports their profiling results and calls this method.

When an internal scheduler error happened at any point of servicing the job, clients will be notified through this API with a 500 Internal Error.

Source code in zeus/optimizer/perseus/server/job_manager.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
async def get_frequency_schedule(self, job_id: str, rank: int) -> FrequencySchedule:
    """Get the next frequency schedule for a rank.

    This method will be called `world_size` number of times (once per rank).
    All ranks will block on this method untill everyone reports their
    profiling results and calls this method.

    When an internal scheduler error happened at any point of servicing the
    job, clients will be notified through this API with a 500 Internal Error.
    """
    await self._job_sched_request_channels[job_id].put(rank)
    res = await self._job_sched_response_channels[job_id][rank].get()
    if isinstance(res, Exception):
        code = 400 if isinstance(res, ValueError) else 500
        raise HTTPException(
            status_code=code,
            detail="".join(
                traceback.format_exception(type(res), res, res.__traceback__)
            ),
        )
    self._job_last_active_time[job_id] = time.monotonic()
    return res

report_profiling_result

1
report_profiling_result(job_id, result)

Send the profiling result to the job task and immediately return.

This method will be called world_size number of times - one for each rank.

Source code in zeus/optimizer/perseus/server/job_manager.py
119
120
121
122
123
124
125
def report_profiling_result(self, job_id: str, result: ProfilingResult) -> None:
    """Send the profiling result to the job task and immediately return.

    This method will be called `world_size` number of times - one for each rank.
    """
    self._job_result_channels[job_id].put_nowait(result)
    self._job_last_active_time[job_id] = time.monotonic()

_cleanup_task async

1
_cleanup_task(cleanup_period, max_idle_time)

Periodically evict job states.

Parameters:

Name Type Description Default
cleanup_period int

How often to run the cleanup task, in seconds.

required
max_idle_time int

Maximum amount of time a job can be idle for, in seconds.

required
Source code in zeus/optimizer/perseus/server/job_manager.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
async def _cleanup_task(
    self,
    cleanup_period: int,
    max_idle_time: int,
) -> None:
    """Periodically evict job states.

    Args:
        cleanup_period: How often to run the cleanup task, in seconds.
        max_idle_time: Maximum amount of time a job can be idle for, in seconds.
    """
    while True:
        await asyncio.sleep(cleanup_period)
        for job_id in list(self._job_last_active_time.keys()):
            if (
                time.monotonic() - self._job_last_active_time[job_id]
                > max_idle_time
            ):
                self._job_tasks[job_id].cancel()
                del self._job_infos[job_id]
                del self._job_rank_infos[job_id]
                del self._job_result_channels[job_id]
                del self._job_sched_request_channels[job_id]
                del self._job_sched_response_channels[job_id]
                del self._job_tasks[job_id]
                del self._job_last_active_time[job_id]

_job_task async

1
_job_task(job_id, dump_data)

Coalese requests and responses of each rank and interface with the scheduler.

Source code in zeus/optimizer/perseus/server/job_manager.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
async def _job_task(self, job_id: str, dump_data: bool) -> None:
    """Coalese requests and responses of each rank and interface with the scheduler."""
    result_chan = self._job_result_channels[job_id]
    sched_req_chan = self._job_sched_request_channels[job_id]
    sched_resp_chan = self._job_sched_response_channels[job_id]

    job_info = self._job_infos[job_id]

    try:
        # Wait until all ranks have reported their `RankInfo`s.
        rank_infos = self._job_rank_infos[job_id]
        while True:
            await asyncio.sleep(0.1)
            # Indexing the first element is always safe because this task is
            # created after putting the `RankInfo` of the first-connected rank
            # in `self.job_rank_infos[job_id]`.
            if len(rank_infos) == job_info.world_size:
                break

        # Sort `RankInfo`s in rank order.
        rank_infos.sort(key=lambda r: r.rank)

        # Create directory to dump Perseus states.
        dump_dir = f"{self.perseus_settings.dump_dir}/{job_id}"
        if dump_data:
            await save_ranks(rank_infos, dump_dir)

        # Instantiate the frequency scheduler.
        scheduler = self.perseus_settings.scheduler(
            job_info,
            rank_infos,
            self.perseus_settings,
            **self.perseus_settings.scheduler_args,
        )

        # Provide next schedules, observe profiling results, and repeat.
        schedule_num = 0
        while True:
            # Compute the next `FrequencySchedule`s.
            schedules = scheduler.next_schedule()

            # Wait until all the ranks ask for the next schedule.
            await asyncio.gather(*[sched_req_chan.get() for _ in rank_infos])

            # Send out `FrequencySchedule`s.
            await asyncio.gather(
                *[sched_resp_chan[s.rank].put(s) for s in schedules]
            )

            # Gather profiling results from all ranks.
            results = await asyncio.gather(*[result_chan.get() for _ in rank_infos])
            results.sort(key=lambda r: r.rank)

            # Dump profiling results and schedules.
            if dump_data:
                schedules.sort(key=lambda s: s.rank)
                await save_prof(results, dump_dir, schedule_num)
                await save_sched(schedules, dump_dir, schedule_num)

            # Send `ProfilingResult`s to the scheduler.
            scheduler.observe(results)

            # Increment schedule number.
            schedule_num += 1

    except asyncio.CancelledError:
        # This task gets cancelled when it's idle for too long and evicted.
        pass

    except Exception as exc:
        # In case the scheduler errored, send out the exception to the clients.
        # The clients will receive the error when they ask for the next schedule.
        for chan in sched_resp_chan:
            chan.put_nowait(exc)
        raise

init_global_job_manager

1
init_global_job_manager(perseus_settings)

Instantiate the global singleton JobManager.

Source code in zeus/optimizer/perseus/server/job_manager.py
231
232
233
234
def init_global_job_manager(perseus_settings: PerseusSettings) -> None:
    """Instantiate the global singleton `JobManager`."""
    global GLOBAL_JOB_MANAGER
    GLOBAL_JOB_MANAGER = JobManager(perseus_settings=perseus_settings)

get_global_job_manager

1
get_global_job_manager()

Fetch the global singleton JobManager.

Source code in zeus/optimizer/perseus/server/job_manager.py
237
238
239
240
def get_global_job_manager() -> JobManager:
    """Fetch the global singleton `JobManager`."""
    assert GLOBAL_JOB_MANAGER is not None, "`init_global_job_manager` was not called."
    return GLOBAL_JOB_MANAGER