Skip to content

repository

zeus.optimizer.batch_size.server.batch_size_state.repository

Repository for batch size states(Trial, Gaussian Ts arm state).

BatchSizeStateRepository

Bases: DatabaseRepository

Repository for handling batch size related operations.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class BatchSizeStateRepository(DatabaseRepository):
    """Repository for handling batch size related operations."""

    def __init__(self, session: AsyncSession):
        """Set db session and intialize fetched trial. We are only updating one trial per session."""
        super().__init__(session)
        self.fetched_trial: TrialTable | None = None
        self.fetched_arm: GaussianTsArmStateTable | None = None

    async def get_next_trial_number(self, job_id: str) -> int:
        """Get next trial number of a given job. Trial number starts from 1 and increase by 1 at a time."""
        stmt = select(func.max(TrialTable.trial_number)).where(
            and_(
                TrialTable.job_id == job_id,
            )
        )
        res = await self.session.scalar(stmt)
        if res is None:
            return 1
        return res + 1

    async def get_trial_results_of_bs(
        self, batch_size: BatchSizeBase, window_size: int
    ) -> TrialResultsPerBs:
        """Load window size amount of results for a given batch size. If window size <= 0, load all of them.

        From all trials, we filter succeeded one since failed/dispatched ones doesn't have a valid result.

        Args:
            batch_size (BatchSizeBase): The batch size object.
            window_size (int): The size of the measurement window.

        Returns:
            TrialResultsPerBs: trial results for the given batch size.
        """
        stmt = (
            select(TrialTable)
            .where(
                and_(
                    TrialTable.job_id == batch_size.job_id,
                    TrialTable.batch_size == batch_size.batch_size,
                    TrialTable.status == TrialStatus.Succeeded,
                )
            )
            .order_by(TrialTable.trial_number.desc())
        )
        if window_size > 0:
            stmt = stmt.limit(window_size)

        res = (await self.session.scalars(stmt)).all()
        return TrialResultsPerBs(
            job_id=batch_size.job_id,
            batch_size=batch_size.batch_size,
            results=[TrialResult.from_orm(t) for t in res],
        )

    async def get_arms(self, job_id: str) -> list[GaussianTsArmState]:
        """Retrieve Gaussian Thompson Sampling arms for a given job.

        Args:
            job_id (str): The ID of the job.

        Returns:
            List[GaussianTsArmStateModel]: List of Gaussian Thompson Sampling arms. These arms are all "good" arms (converged during pruning stage).
            Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
        """
        stmt = select(GaussianTsArmStateTable).where(
            GaussianTsArmStateTable.job_id == job_id
        )
        res = (await self.session.scalars(stmt)).all()
        return [GaussianTsArmState.from_orm(arm) for arm in res]

    async def get_arm(self, bs: BatchSizeBase) -> GaussianTsArmState | None:
        """Retrieve Gaussian Thompson Sampling arm for a given job id and batch size.

        Args:
            bs (BatchSizeBase): The batch size object.

        Returns:
            Optional[GaussianTsArmStateModel]: Gaussian Thompson Sampling arm if found, else None.
            Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
        """
        stmt = select(GaussianTsArmStateTable).where(
            and_(
                GaussianTsArmStateTable.job_id == bs.job_id,
                GaussianTsArmStateTable.batch_size == bs.batch_size,
            )
        )
        arm = await self.session.scalar(stmt)
        if arm is None:
            return None
        self.fetched_arm = arm
        return GaussianTsArmState.from_orm(arm)

    async def get_trial(self, trial: ReadTrial) -> Trial | None:
        """Get a corresponding trial.

        Args:
            trial: job_id, batch_size, trial_number.

        Returns:
            Found Trial. If none found, return None.
        """
        stmt = select(TrialTable).where(
            TrialTable.job_id == trial.job_id,
            TrialTable.batch_size == trial.batch_size,
            TrialTable.trial_number == trial.trial_number,
        )
        fetched_trial = await self.session.scalar(stmt)

        if fetched_trial is None:
            logger.info("get_trial: NoResultFound")
            return None

        self.fetched_trial = fetched_trial
        return Trial.from_orm(fetched_trial)

    def get_trial_from_session(self, trial: ReadTrial) -> Trial | None:
        """Fetch a trial from the session."""
        if (
            self.fetched_trial is None
            or self.fetched_trial.job_id != trial.job_id
            or self.fetched_trial.batch_size != trial.batch_size
            or self.fetched_trial.trial_number != trial.trial_number
        ):
            return None
        return Trial.from_orm(self.fetched_trial)

    def create_trial(self, trial: CreateTrial) -> None:
        """Create a trial in db.

        Refer to `CreateTrial`[zeus.optimizer.batch_size.server.batch_size_state.models.CreateTrial] for attributes.

        Args:
            trial (CreateTrial): The trial to add.
        """
        self.session.add(trial.to_orm())

    def updated_current_trial(self, updated_trial: UpdateTrial) -> None:
        """Update trial in the database (report the result of trial).

        Args:
            updated_trial (UpdateTrial): The updated trial. Refer to `UpdateTrial`[zeus.optimizer.batch_size.server.batch_size_state.models.UpdateTrial] for attributes.
        """
        if self.fetched_trial is None:
            raise ZeusBSOValueError("No trial is fetched.")

        if (
            self.fetched_trial.job_id != updated_trial.job_id
            or self.fetched_trial.batch_size != updated_trial.batch_size
            or self.fetched_trial.trial_number != updated_trial.trial_number
        ):
            raise ZeusBSOValueError("Trying to update invalid trial.")

        self.fetched_trial.end_timestamp = updated_trial.end_timestamp
        self.fetched_trial.status = updated_trial.status
        self.fetched_trial.time = updated_trial.time
        self.fetched_trial.energy = updated_trial.energy
        self.fetched_trial.converged = updated_trial.converged

    def create_arms(self, new_arms: list[GaussianTsArmState]) -> None:
        """Create Gaussian Thompson Sampling arms in the database.

        Args:
            new_arms (List[GaussianTsArmStateModel]): List of new arms to create.
                Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
        """
        self.session.add_all([arm.to_orm() for arm in new_arms])

    def update_arm_state(self, updated_mab_state: GaussianTsArmState) -> None:
        """Update Gaussian Thompson Sampling arm state in db.

        Args:
            updated_mab_state (GaussianTsArmStateModel): The updated arm state.
                Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
        """
        if self.fetched_arm is None:
            raise ZeusBSOValueError("No arm is fetched.")

        if (
            self.fetched_arm.job_id != updated_mab_state.job_id
            or self.fetched_arm.batch_size != updated_mab_state.batch_size
        ):
            raise ZeusBSOValueError(
                "Fetch arm does not correspond with the arm trying to update."
            )

        self.fetched_arm.param_mean = updated_mab_state.param_mean
        self.fetched_arm.param_precision = updated_mab_state.param_precision
        self.fetched_arm.reward_precision = updated_mab_state.reward_precision
        self.fetched_arm.num_observations = updated_mab_state.num_observations

    async def get_explorations_of_job(self, job_id: str) -> ExplorationsPerJob:
        """Retrieve succeeded or ongoing explorations for a given job.

        Args:
            job_id: ID of the job

        Returns:
            ExplorationsPerJob: Explorations for the given batch size.
            Refer to `ExplorationsPerJob`[zeus.optimizer.batch_size.server.batch_size_state.models.ExplorationsPerJob] for attributes.
        """
        stmt = (
            select(TrialTable)
            .where(
                and_(
                    TrialTable.job_id == job_id,
                    TrialTable.type == TrialType.Exploration,
                    TrialTable.status != TrialStatus.Failed,
                )
            )
            .order_by(TrialTable.trial_number.asc())
        )

        explorations = (await self.session.scalars(stmt)).all()
        exps_per_bs: defaultdict[int, list[Trial]] = defaultdict(list)
        for exp in explorations:
            exps_per_bs[exp.batch_size].append(Trial.from_orm(exp))

        return ExplorationsPerJob(job_id=job_id, explorations_per_bs=exps_per_bs)

__init__

__init__(session)
Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
38
39
40
41
42
def __init__(self, session: AsyncSession):
    """Set db session and intialize fetched trial. We are only updating one trial per session."""
    super().__init__(session)
    self.fetched_trial: TrialTable | None = None
    self.fetched_arm: GaussianTsArmStateTable | None = None

get_next_trial_number async

get_next_trial_number(job_id)

Get next trial number of a given job. Trial number starts from 1 and increase by 1 at a time.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
44
45
46
47
48
49
50
51
52
53
54
async def get_next_trial_number(self, job_id: str) -> int:
    """Get next trial number of a given job. Trial number starts from 1 and increase by 1 at a time."""
    stmt = select(func.max(TrialTable.trial_number)).where(
        and_(
            TrialTable.job_id == job_id,
        )
    )
    res = await self.session.scalar(stmt)
    if res is None:
        return 1
    return res + 1

get_trial_results_of_bs async

get_trial_results_of_bs(batch_size, window_size)

Load window size amount of results for a given batch size. If window size <= 0, load all of them.

From all trials, we filter succeeded one since failed/dispatched ones doesn't have a valid result.

Parameters:

Name Type Description Default
batch_size BatchSizeBase

The batch size object.

required
window_size int

The size of the measurement window.

required

Returns:

Name Type Description
TrialResultsPerBs TrialResultsPerBs

trial results for the given batch size.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
async def get_trial_results_of_bs(
    self, batch_size: BatchSizeBase, window_size: int
) -> TrialResultsPerBs:
    """Load window size amount of results for a given batch size. If window size <= 0, load all of them.

    From all trials, we filter succeeded one since failed/dispatched ones doesn't have a valid result.

    Args:
        batch_size (BatchSizeBase): The batch size object.
        window_size (int): The size of the measurement window.

    Returns:
        TrialResultsPerBs: trial results for the given batch size.
    """
    stmt = (
        select(TrialTable)
        .where(
            and_(
                TrialTable.job_id == batch_size.job_id,
                TrialTable.batch_size == batch_size.batch_size,
                TrialTable.status == TrialStatus.Succeeded,
            )
        )
        .order_by(TrialTable.trial_number.desc())
    )
    if window_size > 0:
        stmt = stmt.limit(window_size)

    res = (await self.session.scalars(stmt)).all()
    return TrialResultsPerBs(
        job_id=batch_size.job_id,
        batch_size=batch_size.batch_size,
        results=[TrialResult.from_orm(t) for t in res],
    )

get_arms async

get_arms(job_id)

Retrieve Gaussian Thompson Sampling arms for a given job.

Parameters:

Name Type Description Default
job_id str

The ID of the job.

required

Returns:

Type Description
list[GaussianTsArmState]

List[GaussianTsArmStateModel]: List of Gaussian Thompson Sampling arms. These arms are all "good" arms (converged during pruning stage).

list[GaussianTsArmState]

Refer to GaussianTsArmStateModel[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
async def get_arms(self, job_id: str) -> list[GaussianTsArmState]:
    """Retrieve Gaussian Thompson Sampling arms for a given job.

    Args:
        job_id (str): The ID of the job.

    Returns:
        List[GaussianTsArmStateModel]: List of Gaussian Thompson Sampling arms. These arms are all "good" arms (converged during pruning stage).
        Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
    """
    stmt = select(GaussianTsArmStateTable).where(
        GaussianTsArmStateTable.job_id == job_id
    )
    res = (await self.session.scalars(stmt)).all()
    return [GaussianTsArmState.from_orm(arm) for arm in res]

get_arm async

get_arm(bs)

Retrieve Gaussian Thompson Sampling arm for a given job id and batch size.

Parameters:

Name Type Description Default
bs BatchSizeBase

The batch size object.

required

Returns:

Type Description
GaussianTsArmState | None

Optional[GaussianTsArmStateModel]: Gaussian Thompson Sampling arm if found, else None.

GaussianTsArmState | None

Refer to GaussianTsArmStateModel[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
async def get_arm(self, bs: BatchSizeBase) -> GaussianTsArmState | None:
    """Retrieve Gaussian Thompson Sampling arm for a given job id and batch size.

    Args:
        bs (BatchSizeBase): The batch size object.

    Returns:
        Optional[GaussianTsArmStateModel]: Gaussian Thompson Sampling arm if found, else None.
        Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
    """
    stmt = select(GaussianTsArmStateTable).where(
        and_(
            GaussianTsArmStateTable.job_id == bs.job_id,
            GaussianTsArmStateTable.batch_size == bs.batch_size,
        )
    )
    arm = await self.session.scalar(stmt)
    if arm is None:
        return None
    self.fetched_arm = arm
    return GaussianTsArmState.from_orm(arm)

get_trial async

get_trial(trial)

Get a corresponding trial.

Parameters:

Name Type Description Default
trial ReadTrial

job_id, batch_size, trial_number.

required

Returns:

Type Description
Trial | None

Found Trial. If none found, return None.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
async def get_trial(self, trial: ReadTrial) -> Trial | None:
    """Get a corresponding trial.

    Args:
        trial: job_id, batch_size, trial_number.

    Returns:
        Found Trial. If none found, return None.
    """
    stmt = select(TrialTable).where(
        TrialTable.job_id == trial.job_id,
        TrialTable.batch_size == trial.batch_size,
        TrialTable.trial_number == trial.trial_number,
    )
    fetched_trial = await self.session.scalar(stmt)

    if fetched_trial is None:
        logger.info("get_trial: NoResultFound")
        return None

    self.fetched_trial = fetched_trial
    return Trial.from_orm(fetched_trial)

get_trial_from_session

get_trial_from_session(trial)

Fetch a trial from the session.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
152
153
154
155
156
157
158
159
160
161
def get_trial_from_session(self, trial: ReadTrial) -> Trial | None:
    """Fetch a trial from the session."""
    if (
        self.fetched_trial is None
        or self.fetched_trial.job_id != trial.job_id
        or self.fetched_trial.batch_size != trial.batch_size
        or self.fetched_trial.trial_number != trial.trial_number
    ):
        return None
    return Trial.from_orm(self.fetched_trial)

create_trial

create_trial(trial)

Create a trial in db.

Refer to CreateTrial[zeus.optimizer.batch_size.server.batch_size_state.models.CreateTrial] for attributes.

Parameters:

Name Type Description Default
trial CreateTrial

The trial to add.

required
Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
163
164
165
166
167
168
169
170
171
def create_trial(self, trial: CreateTrial) -> None:
    """Create a trial in db.

    Refer to `CreateTrial`[zeus.optimizer.batch_size.server.batch_size_state.models.CreateTrial] for attributes.

    Args:
        trial (CreateTrial): The trial to add.
    """
    self.session.add(trial.to_orm())

updated_current_trial

updated_current_trial(updated_trial)

Update trial in the database (report the result of trial).

Parameters:

Name Type Description Default
updated_trial UpdateTrial

The updated trial. Refer to UpdateTrial[zeus.optimizer.batch_size.server.batch_size_state.models.UpdateTrial] for attributes.

required
Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def updated_current_trial(self, updated_trial: UpdateTrial) -> None:
    """Update trial in the database (report the result of trial).

    Args:
        updated_trial (UpdateTrial): The updated trial. Refer to `UpdateTrial`[zeus.optimizer.batch_size.server.batch_size_state.models.UpdateTrial] for attributes.
    """
    if self.fetched_trial is None:
        raise ZeusBSOValueError("No trial is fetched.")

    if (
        self.fetched_trial.job_id != updated_trial.job_id
        or self.fetched_trial.batch_size != updated_trial.batch_size
        or self.fetched_trial.trial_number != updated_trial.trial_number
    ):
        raise ZeusBSOValueError("Trying to update invalid trial.")

    self.fetched_trial.end_timestamp = updated_trial.end_timestamp
    self.fetched_trial.status = updated_trial.status
    self.fetched_trial.time = updated_trial.time
    self.fetched_trial.energy = updated_trial.energy
    self.fetched_trial.converged = updated_trial.converged

create_arms

create_arms(new_arms)

Create Gaussian Thompson Sampling arms in the database.

Parameters:

Name Type Description Default
new_arms List[GaussianTsArmStateModel]

List of new arms to create. Refer to GaussianTsArmStateModel[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.

required
Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
195
196
197
198
199
200
201
202
def create_arms(self, new_arms: list[GaussianTsArmState]) -> None:
    """Create Gaussian Thompson Sampling arms in the database.

    Args:
        new_arms (List[GaussianTsArmStateModel]): List of new arms to create.
            Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
    """
    self.session.add_all([arm.to_orm() for arm in new_arms])

update_arm_state

update_arm_state(updated_mab_state)

Update Gaussian Thompson Sampling arm state in db.

Parameters:

Name Type Description Default
updated_mab_state GaussianTsArmStateModel

The updated arm state. Refer to GaussianTsArmStateModel[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.

required
Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def update_arm_state(self, updated_mab_state: GaussianTsArmState) -> None:
    """Update Gaussian Thompson Sampling arm state in db.

    Args:
        updated_mab_state (GaussianTsArmStateModel): The updated arm state.
            Refer to `GaussianTsArmStateModel`[zeus.optimizer.batch_size.server.batch_size_state.models.GaussianTsArmStateModel] for attributes.
    """
    if self.fetched_arm is None:
        raise ZeusBSOValueError("No arm is fetched.")

    if (
        self.fetched_arm.job_id != updated_mab_state.job_id
        or self.fetched_arm.batch_size != updated_mab_state.batch_size
    ):
        raise ZeusBSOValueError(
            "Fetch arm does not correspond with the arm trying to update."
        )

    self.fetched_arm.param_mean = updated_mab_state.param_mean
    self.fetched_arm.param_precision = updated_mab_state.param_precision
    self.fetched_arm.reward_precision = updated_mab_state.reward_precision
    self.fetched_arm.num_observations = updated_mab_state.num_observations

get_explorations_of_job async

get_explorations_of_job(job_id)

Retrieve succeeded or ongoing explorations for a given job.

Parameters:

Name Type Description Default
job_id str

ID of the job

required

Returns:

Name Type Description
ExplorationsPerJob ExplorationsPerJob

Explorations for the given batch size.

ExplorationsPerJob

Refer to ExplorationsPerJob[zeus.optimizer.batch_size.server.batch_size_state.models.ExplorationsPerJob] for attributes.

Source code in zeus/optimizer/batch_size/server/batch_size_state/repository.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
async def get_explorations_of_job(self, job_id: str) -> ExplorationsPerJob:
    """Retrieve succeeded or ongoing explorations for a given job.

    Args:
        job_id: ID of the job

    Returns:
        ExplorationsPerJob: Explorations for the given batch size.
        Refer to `ExplorationsPerJob`[zeus.optimizer.batch_size.server.batch_size_state.models.ExplorationsPerJob] for attributes.
    """
    stmt = (
        select(TrialTable)
        .where(
            and_(
                TrialTable.job_id == job_id,
                TrialTable.type == TrialType.Exploration,
                TrialTable.status != TrialStatus.Failed,
            )
        )
        .order_by(TrialTable.trial_number.asc())
    )

    explorations = (await self.session.scalars(stmt)).all()
    exps_per_bs: defaultdict[int, list[Trial]] = defaultdict(list)
    for exp in explorations:
        exps_per_bs[exp.batch_size].append(Trial.from_orm(exp))

    return ExplorationsPerJob(job_id=job_id, explorations_per_bs=exps_per_bs)