Skip to content

repository

zeus.optimizer.batch_size.server.job.repository

Repository for manipulating Job table.

JobStateRepository

Bases: DatabaseRepository

Repository that provides basic interfaces to interact with Job table.

Source code in zeus/optimizer/batch_size/server/job/repository.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class JobStateRepository(DatabaseRepository):
    """Repository that provides basic interfaces to interact with Job table."""

    def __init__(self, session: AsyncSession):
        """Set db session and intialize job. We are working with only one job per session."""
        super().__init__(session)
        self.fetched_job: JobTable | None = None

    async def get_job(self, job_id: str) -> JobState | None:
        """Get job State, which includes jobSpec + batch_sizes(list[int]), without specific states of each batch_size.

        Args:
            job_id: Job id.

        Returns:
            set fetched_job and return `JobState` if we found a job, unless return None.
        """
        stmt = select(JobTable).where(JobTable.job_id == job_id)
        job = await self.session.scalar(stmt)

        if job is None:
            logger.info("get_job: NoResultFound")
            return None

        self.fetched_job = job
        return JobState.from_orm(job)

    def get_job_from_session(self, job_id: str) -> JobState | None:
        """Get a job that was fetched from this session.

        Args:
            job_id: Job id.

        Returns:
            Corresponding `JobState`. If none was found, return None.
        """
        if self.fetched_job is None or self.fetched_job.job_id != job_id:
            return None
        return JobState.from_orm(self.fetched_job)

    def update_exp_default_bs(self, updated_bs: UpdateExpDefaultBs) -> None:
        """Update exploration default batch size on fetched job.

        Args:
            updated_bs: Job Id and new batch size.
        """
        if self.fetched_job is None:
            raise ZeusBSOServiceBadOperationError("No job is fetched.")

        if updated_bs.job_id == self.fetched_job.job_id:
            self.fetched_job.exp_default_batch_size = updated_bs.exp_default_batch_size
        else:
            raise ZeusBSOValueError(
                f"Unknown job_id ({updated_bs.job_id}). Expecting {self.fetched_job.job_id}"
            )

    def update_stage(self, updated_stage: UpdateJobStage) -> None:
        """Update stage on fetched job.

        Args:
            updated_stage: Job Id and new stage.
        """
        if self.fetched_job is None:
            raise ZeusBSOServiceBadOperationError("No job is fetched.")

        if self.fetched_job.job_id == updated_stage.job_id:
            self.fetched_job.stage = updated_stage.stage
        else:
            raise ZeusBSOValueError(
                f"Unknown job_id ({updated_stage.job_id}). Expecting {self.fetched_job.job_id}"
            )

    def update_min(self, updated_min: UpdateJobMinCost) -> None:
        """Update exploration min training cost and corresponding batch size on fetched job.

        Args:
            updated_min: Job Id, new min cost and batch size.
        """
        if self.fetched_job is None:
            raise ZeusBSOServiceBadOperationError("No job is fetched.")

        if self.fetched_job.job_id == updated_min.job_id:
            self.fetched_job.min_cost = updated_min.min_cost
            self.fetched_job.min_cost_batch_size = updated_min.min_cost_batch_size
        else:
            raise ZeusBSOValueError(
                f"Unknown job_id ({updated_min.job_id}). Expecting {self.fetched_job.job_id}"
            )

    def update_generator_state(self, updated_state: UpdateGeneratorState) -> None:
        """Update generator state on fetched job.

        Args:
            updated_state: Job Id and new generator state.
        """
        if self.fetched_job is None:
            raise ZeusBSOServiceBadOperationError("No job is fetched.")

        if self.fetched_job.job_id == updated_state.job_id:
            self.fetched_job.mab_random_generator_state = updated_state.state
        else:
            raise ZeusBSOValueError(
                f"Unknown job_id ({updated_state.job_id}). Expecting {self.fetched_job.job_id}"
            )

    def create_job(self, new_job: CreateJob) -> None:
        """Create a new job by adding a new job to the session.

        Args:
            new_job: Job configuration for a new job.
        """
        self.session.add(new_job.to_orm())

    def check_job_fetched(self, job_id: str) -> bool:
        """Check if this job is already fetched before.

        Args:
            job_id: Job id.

        Returns:
            True if this job was fetched and in session. Otherwise, return false.
        """
        return not (self.fetched_job is None or self.fetched_job.job_id != job_id)

    async def delete_job(self, job_id: str) -> bool:
        """Delete the job of a given job_Id.

        Args:
            job_id: Job id.

        Returns:
            True if the job got deleted.
        """
        stmt = select(JobTable).where(JobTable.job_id == job_id)
        job = await self.session.scalar(stmt)

        if job is None:
            return False

        # We can't straight delete using a query, since some db such as sqlite
        # Foreign Key is default to OFF, so "on delete = cascade" will not be fired.
        await self.session.delete(job)
        return True

__init__

__init__(session)
Source code in zeus/optimizer/batch_size/server/job/repository.py
29
30
31
32
def __init__(self, session: AsyncSession):
    """Set db session and intialize job. We are working with only one job per session."""
    super().__init__(session)
    self.fetched_job: JobTable | None = None

get_job async

get_job(job_id)

Get job State, which includes jobSpec + batch_sizes(list[int]), without specific states of each batch_size.

Parameters:

Name Type Description Default
job_id str

Job id.

required

Returns:

Type Description
JobState | None

set fetched_job and return JobState if we found a job, unless return None.

Source code in zeus/optimizer/batch_size/server/job/repository.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
async def get_job(self, job_id: str) -> JobState | None:
    """Get job State, which includes jobSpec + batch_sizes(list[int]), without specific states of each batch_size.

    Args:
        job_id: Job id.

    Returns:
        set fetched_job and return `JobState` if we found a job, unless return None.
    """
    stmt = select(JobTable).where(JobTable.job_id == job_id)
    job = await self.session.scalar(stmt)

    if job is None:
        logger.info("get_job: NoResultFound")
        return None

    self.fetched_job = job
    return JobState.from_orm(job)

get_job_from_session

get_job_from_session(job_id)

Get a job that was fetched from this session.

Parameters:

Name Type Description Default
job_id str

Job id.

required

Returns:

Type Description
JobState | None

Corresponding JobState. If none was found, return None.

Source code in zeus/optimizer/batch_size/server/job/repository.py
53
54
55
56
57
58
59
60
61
62
63
64
def get_job_from_session(self, job_id: str) -> JobState | None:
    """Get a job that was fetched from this session.

    Args:
        job_id: Job id.

    Returns:
        Corresponding `JobState`. If none was found, return None.
    """
    if self.fetched_job is None or self.fetched_job.job_id != job_id:
        return None
    return JobState.from_orm(self.fetched_job)

update_exp_default_bs

update_exp_default_bs(updated_bs)

Update exploration default batch size on fetched job.

Parameters:

Name Type Description Default
updated_bs UpdateExpDefaultBs

Job Id and new batch size.

required
Source code in zeus/optimizer/batch_size/server/job/repository.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def update_exp_default_bs(self, updated_bs: UpdateExpDefaultBs) -> None:
    """Update exploration default batch size on fetched job.

    Args:
        updated_bs: Job Id and new batch size.
    """
    if self.fetched_job is None:
        raise ZeusBSOServiceBadOperationError("No job is fetched.")

    if updated_bs.job_id == self.fetched_job.job_id:
        self.fetched_job.exp_default_batch_size = updated_bs.exp_default_batch_size
    else:
        raise ZeusBSOValueError(
            f"Unknown job_id ({updated_bs.job_id}). Expecting {self.fetched_job.job_id}"
        )

update_stage

update_stage(updated_stage)

Update stage on fetched job.

Parameters:

Name Type Description Default
updated_stage UpdateJobStage

Job Id and new stage.

required
Source code in zeus/optimizer/batch_size/server/job/repository.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def update_stage(self, updated_stage: UpdateJobStage) -> None:
    """Update stage on fetched job.

    Args:
        updated_stage: Job Id and new stage.
    """
    if self.fetched_job is None:
        raise ZeusBSOServiceBadOperationError("No job is fetched.")

    if self.fetched_job.job_id == updated_stage.job_id:
        self.fetched_job.stage = updated_stage.stage
    else:
        raise ZeusBSOValueError(
            f"Unknown job_id ({updated_stage.job_id}). Expecting {self.fetched_job.job_id}"
        )

update_min

update_min(updated_min)

Update exploration min training cost and corresponding batch size on fetched job.

Parameters:

Name Type Description Default
updated_min UpdateJobMinCost

Job Id, new min cost and batch size.

required
Source code in zeus/optimizer/batch_size/server/job/repository.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def update_min(self, updated_min: UpdateJobMinCost) -> None:
    """Update exploration min training cost and corresponding batch size on fetched job.

    Args:
        updated_min: Job Id, new min cost and batch size.
    """
    if self.fetched_job is None:
        raise ZeusBSOServiceBadOperationError("No job is fetched.")

    if self.fetched_job.job_id == updated_min.job_id:
        self.fetched_job.min_cost = updated_min.min_cost
        self.fetched_job.min_cost_batch_size = updated_min.min_cost_batch_size
    else:
        raise ZeusBSOValueError(
            f"Unknown job_id ({updated_min.job_id}). Expecting {self.fetched_job.job_id}"
        )

update_generator_state

update_generator_state(updated_state)

Update generator state on fetched job.

Parameters:

Name Type Description Default
updated_state UpdateGeneratorState

Job Id and new generator state.

required
Source code in zeus/optimizer/batch_size/server/job/repository.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def update_generator_state(self, updated_state: UpdateGeneratorState) -> None:
    """Update generator state on fetched job.

    Args:
        updated_state: Job Id and new generator state.
    """
    if self.fetched_job is None:
        raise ZeusBSOServiceBadOperationError("No job is fetched.")

    if self.fetched_job.job_id == updated_state.job_id:
        self.fetched_job.mab_random_generator_state = updated_state.state
    else:
        raise ZeusBSOValueError(
            f"Unknown job_id ({updated_state.job_id}). Expecting {self.fetched_job.job_id}"
        )

create_job

create_job(new_job)

Create a new job by adding a new job to the session.

Parameters:

Name Type Description Default
new_job CreateJob

Job configuration for a new job.

required
Source code in zeus/optimizer/batch_size/server/job/repository.py
131
132
133
134
135
136
137
def create_job(self, new_job: CreateJob) -> None:
    """Create a new job by adding a new job to the session.

    Args:
        new_job: Job configuration for a new job.
    """
    self.session.add(new_job.to_orm())

check_job_fetched

check_job_fetched(job_id)

Check if this job is already fetched before.

Parameters:

Name Type Description Default
job_id str

Job id.

required

Returns:

Type Description
bool

True if this job was fetched and in session. Otherwise, return false.

Source code in zeus/optimizer/batch_size/server/job/repository.py
139
140
141
142
143
144
145
146
147
148
def check_job_fetched(self, job_id: str) -> bool:
    """Check if this job is already fetched before.

    Args:
        job_id: Job id.

    Returns:
        True if this job was fetched and in session. Otherwise, return false.
    """
    return not (self.fetched_job is None or self.fetched_job.job_id != job_id)

delete_job async

delete_job(job_id)

Delete the job of a given job_Id.

Parameters:

Name Type Description Default
job_id str

Job id.

required

Returns:

Type Description
bool

True if the job got deleted.

Source code in zeus/optimizer/batch_size/server/job/repository.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
async def delete_job(self, job_id: str) -> bool:
    """Delete the job of a given job_Id.

    Args:
        job_id: Job id.

    Returns:
        True if the job got deleted.
    """
    stmt = select(JobTable).where(JobTable.job_id == job_id)
    job = await self.session.scalar(stmt)

    if job is None:
        return False

    # We can't straight delete using a query, since some db such as sqlite
    # Foreign Key is default to OFF, so "on delete = cascade" will not be fired.
    await self.session.delete(job)
    return True