267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450 | class PruningGTSBatchSizeOptimizer(BatchSizeOptimizer):
"""One Gaussian Thompson Sampling MAB for each job with double pruning exploration."""
def __init__(
self,
prior_mean: float = 0.0,
prior_precision: float = 0.0,
window_size: int = 0,
concurrency: bool = False,
seed: int = 123456,
verbose: bool = True,
) -> None:
"""Initialze the optimizer.
Refer to the constructor of [`GaussianTS`][zeus.policy.mab.GaussianTS]
for descriptions of other arguments.
Args:
window_size: Size of the window for the MAB (for drift handling).
concurrency: Whether to support concurrent job submissions.
"""
self.prior_mean = prior_mean
self.prior_precision = prior_precision
self.window_size = window_size
self.concurrency = concurrency
self.seed = seed
self.verbose = verbose
# One MAB for each job.
self.mabs: dict[Job, GaussianTS] = {}
# One PruningExplorationManager for each job.
self.exp_manager: dict[Job, PruningExploreManager] = {}
# Observation history (batch size, reward) for each job.
self.history: dict[Job, list[tuple[int, float]]] = {}
@property
def name(self) -> str:
"""Name of the batch size optimizer."""
return "Pruning GaussianTS BSO"
def register_job(self, job: Job, batch_sizes: list[int]) -> None:
"""Register the job."""
# Sanity checks.
if job.default_bs is None:
raise ValueError(f"Default BS not specified for {job}.")
if not batch_sizes:
raise ValueError(f"Batch size list for {job} is empty.")
# Set internal states.
self.exp_manager[job] = PruningExploreManager(
sorted(batch_sizes), job.default_bs
)
self.history[job] = []
if self.verbose:
self._log(f"Registered {job}")
def predict(self, job: Job) -> int:
"""Return the batch size to use for the job."""
# Try to see if the exploration manager has something.
try:
batch_size = self.exp_manager[job].next_batch_size()
if self.verbose:
self._log(f"{job} in pruning stage -> \033[31mBS = {batch_size}\033[0m")
except StopIteration as exp:
# Pruning stage is over.
if job not in self.mabs:
self._construct_mab(job, exp.value)
batch_size = self.mabs[job].predict()
if self.verbose:
self._log(
f"{job} in Thompson Sampling stage -> \033[31mBS = {batch_size}\033[0m"
)
return batch_size
def observe(
self, job: Job, batch_size: int, cost: float, converged: bool | None = None
) -> None:
"""Learn from the cost of using the given batch size for the job."""
# Add observation to history.
self.history[job].append((batch_size, -cost))
# We're in Thompson Sampling stage.
if job in self.mabs:
# Since we're learning the reward precision, we need to
# 1. re-compute the precision of this arm based on the reward history,
# 2. update the arm's reward precision
# 3. and `fit` the new MAB instance on all the reward history.
# Note that `arm_rewards` always has more than one entry (and hence a
# non-zero variance) because we've been through pruning exploration.
arm_rewards = np.array(self._get_history_for_bs(job, batch_size))
precision = np.reciprocal(np.var(arm_rewards))
mab = self.mabs[job]
mab.arm_reward_prec[batch_size] = precision
mab.fit_arm(batch_size, arm_rewards, reset=True)
if self.verbose:
arm_rewards_repr = ", ".join([f"{r:.2f}" for r in arm_rewards])
self._log(
f"{job} @ {batch_size}: "
f"arm_rewards = [{arm_rewards_repr}], reward_prec = {precision}"
)
# We're in pruning stage.
else:
assert converged is not None
# Log before we potentially error out.
if self.verbose:
self._log(
f"{job} in pruning stage, expecting BS {self.exp_manager[job].expecting}."
f" Current BS {batch_size} that did {'not ' * converged}converge."
)
# If we don't support concurrency, we can just pass the results to the
# exploration manager, and the manager will err if the order of batch sizes
# is screwed up.
if not self.concurrency:
self.exp_manager[job].report_batch_size_result(
batch_size, cost, converged
)
return
# If we are supporting concurrency, there's a subtle issue.
# Pruning exploration demands a specific order of trying out a batch size
# and receiving the results (cost and whether reached). This breaks in the
# following situation, for example:
# 1. Job with BS 32 that is part of pruning exploration starts.
# 2. Concurrent job comes in, and we launch it with the best known BS 64.
# 3. Job with BS 64 finishes first, and calls bso.observe with BS 64.
# This breaks the observation order assumption of PruningExplorationManager.
# Thus we check whether the current batch size is the one expected by
# PruningExplorationManager, and then only if so, call bso.observe.
# Otherwise, we silently insert the cost observation into the bso's history
# (first line of this method) and don't touch the PruningExplorationManager.
if self.exp_manager[job].expecting == batch_size:
self.exp_manager[job].report_batch_size_result(
batch_size, cost, converged
)
def _get_history_for_bs(self, job: Job, batch_size: int) -> list[float]:
"""Return the windowed history for the given job's batch size."""
history = self.history[job]
rewards = []
# Collect rewards starting from the most recent ones and backwards.
for bs, reward in reversed(history):
if bs == batch_size:
rewards.append(reward)
if len(rewards) == self.window_size:
break
# There's no need to return this in time order, but just in case.
return list(reversed(rewards))
def _construct_mab(self, job: Job, batch_sizes: list[int]) -> None:
"""When exploration is over, this method is called to construct and learn GTS."""
# Sanity check.
if not batch_sizes:
raise ValueError(
"Empty batch size set when constructing MAB. "
"Probably all batch sizes have been pruned."
)
if self.verbose:
self._log(f"Construct MAB for {job} with arms {batch_sizes}")
mab = GaussianTS(
arms=batch_sizes, # The MAB only has "good" arms.
reward_precision=0.0,
prior_mean=self.prior_mean,
prior_precision=self.prior_precision,
num_exploration=2,
seed=self.seed,
verbose=self.verbose,
)
# Fit the arm for each good batch size.
for batch_size in self.exp_manager[job].batch_sizes:
arm_rewards = np.array(self._get_history_for_bs(job, batch_size))
assert (
len(arm_rewards) >= 2
), f"Number of observations for {batch_size} is {len(arm_rewards)}."
mab.arm_reward_prec[batch_size] = np.reciprocal(np.var(arm_rewards))
mab.fit_arm(batch_size, arm_rewards, reset=True)
# Save the MAB.
self.mabs[job] = mab
|