Skip to content

router

zeus.optimizer.batch_size.server.router

Zeus batch size optimizer server FAST API router.

get_job_locks

get_job_locks()

Get global job locks.

Source code in zeus/optimizer/batch_size/server/router.py
39
40
41
def get_job_locks() -> defaultdict[str, asyncio.Lock]:
    """Get global job locks."""
    return JOB_LOCKS

get_prefix_locks

get_prefix_locks()

Get global job Id prefix locks.

Source code in zeus/optimizer/batch_size/server/router.py
44
45
46
def get_prefix_locks() -> defaultdict[str, asyncio.Lock]:
    """Get global job Id prefix locks."""
    return PREFIX_LOCKS

register_job async

register_job(
    job,
    response,
    db_session=Depends(get_db_session),
    prefix_locks=Depends(get_prefix_locks),
)

Endpoint for users to register a job or check if the job is registered and configuration is identical.

Source code in zeus/optimizer/batch_size/server/router.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@app.post(
    REGISTER_JOB_URL,
    responses={
        200: {"description": "Job is already registered"},
        201: {"description": "Job is successfully registered"},
    },
    response_model=JobSpecFromClient,
)
async def register_job(
    job: JobSpecFromClient,
    response: Response,
    db_session: AsyncSession = Depends(get_db_session),
    prefix_locks: defaultdict[str, asyncio.Lock] = Depends(get_prefix_locks),
):
    """Endpoint for users to register a job or check if the job is registered and configuration is identical."""
    async with prefix_locks[job.job_id_prefix]:
        # One lock for registering a job. To prevent getting a same lock
        optimizer = ZeusBatchSizeOptimizer(ZeusService(db_session))
        try:
            created = await optimizer.register_job(job)
            await db_session.commit()
            if created:
                # new job is created
                response.status_code = status.HTTP_201_CREATED
            else:
                # job already exists
                response.status_code = status.HTTP_200_OK
            return job
        except ZeusBSOServerBaseError as err:
            await db_session.rollback()
            return JSONResponse(
                status_code=err.status_code,
                content={"message": err.message},
            )
        except Exception as err:
            await db_session.rollback()
            logger.error("Commit Failed: %s", str(err))
            return JSONResponse(
                status_code=500,
                content={"message": str(err)},
            )

delete_job async

delete_job(
    job_id,
    db_session=Depends(get_db_session),
    job_locks=Depends(get_job_locks),
)

Endpoint for users to delete a job.

Source code in zeus/optimizer/batch_size/server/router.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@app.delete(DELETE_JOB_URL)
async def delete_job(
    job_id: str,
    db_session: AsyncSession = Depends(get_db_session),
    job_locks: defaultdict[str, asyncio.Lock] = Depends(get_job_locks),
):
    """Endpoint for users to delete a job."""
    async with job_locks[job_id]:
        try:
            optimizer = ZeusBatchSizeOptimizer(ZeusService(db_session))
            await optimizer.delete_job(job_id)
            await db_session.commit()
        except ZeusBSOServerBaseError as err:
            await db_session.rollback()
            return JSONResponse(
                status_code=err.status_code,
                content={"message": err.message},
            )
        except Exception as err:
            await db_session.rollback()
            logger.error("Commit Failed: %s", str(err))
            return JSONResponse(
                status_code=500,
                content={"message": str(err)},
            )
        finally:
            job_locks.pop(job_id)

end_trial async

end_trial(
    trial,
    db_session=Depends(get_db_session),
    job_locks=Depends(get_job_locks),
)

Endpoint for users to end the trial.

Source code in zeus/optimizer/batch_size/server/router.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@app.patch(REPORT_END_URL)
async def end_trial(
    trial: TrialId,
    db_session: AsyncSession = Depends(get_db_session),
    job_locks: defaultdict[str, asyncio.Lock] = Depends(get_job_locks),
):
    """Endpoint for users to end the trial."""
    async with job_locks[trial.job_id]:
        optimizer = ZeusBatchSizeOptimizer(ZeusService(db_session))
        try:
            await optimizer.end_trial(trial)
            await db_session.commit()
        except ZeusBSOServerBaseError as err:
            await db_session.rollback()
            return JSONResponse(
                status_code=err.status_code,
                content={"message": err.message},
            )
        except Exception as err:
            await db_session.rollback()
            logger.error("Commit Failed: %s", str(err))
            return JSONResponse(
                status_code=500,
                content={"message": str(err)},
            )

predict async

predict(
    job_id,
    db_session=Depends(get_db_session),
    job_locks=Depends(get_job_locks),
)

Endpoint for users to receive a batch size.

Source code in zeus/optimizer/batch_size/server/router.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@app.get(GET_NEXT_BATCH_SIZE_URL, response_model=TrialId)
async def predict(
    job_id: str,
    db_session: AsyncSession = Depends(get_db_session),
    job_locks: defaultdict[str, asyncio.Lock] = Depends(get_job_locks),
):
    """Endpoint for users to receive a batch size."""
    async with job_locks[job_id]:
        optimizer = ZeusBatchSizeOptimizer(ZeusService(db_session))
        try:
            res = await optimizer.predict(job_id)
            await db_session.commit()
            return res
        except ZeusBSOServerBaseError as err:
            await db_session.rollback()
            return JSONResponse(
                status_code=err.status_code,
                content={"message": err.message},
            )
        except Exception as err:
            await db_session.rollback()
            logger.error("Commit Failed: %s", str(err))
            return JSONResponse(
                status_code=500,
                content={"message": str(err)},
            )

report async

report(
    result,
    db_session=Depends(get_db_session),
    job_locks=Depends(get_job_locks),
)

Endpoint for users to report the training result.

Source code in zeus/optimizer/batch_size/server/router.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
@app.post(REPORT_RESULT_URL, response_model=ReportResponse)
async def report(
    result: TrainingResult,
    db_session: AsyncSession = Depends(get_db_session),
    job_locks: defaultdict[str, asyncio.Lock] = Depends(get_job_locks),
):
    """Endpoint for users to report the training result."""
    async with job_locks[result.job_id]:
        optimizer = ZeusBatchSizeOptimizer(ZeusService(db_session))
        try:
            logger.info("Report with result %s", str(result))
            res = await optimizer.report(result)
            await db_session.commit()
            return res
        except ZeusBSOServerBaseError as err:
            await db_session.rollback()
            return JSONResponse(
                status_code=err.status_code,
                content={"message": err.message},
            )
        except Exception as err:
            await db_session.rollback()
            logger.error("Commit Failed: %s", str(err))
            return JSONResponse(
                status_code=500,
                content={"message": str(err)},
            )