intelligence.zenith_tune.strategies.megatron_heuristic
Megatron 2-stage heuristic strategy for hyperparameter optimization.
MegatronHeuristicStrategy Objects
@StrategyRegistry.register("megatron-heuristic")
class MegatronHeuristicStrategy(TuningStrategy)
2-stage heuristic strategy for Megatron hyperparameter optimization.
Implements a deterministic 2-stage search based on domain knowledge of Megatron parallelization and recompute configurations.
All parameters must be CategoricalParameter in search_space:
Searched parameters (multiple choices):
- TP, EP: exhaustively searched in Stage 1; tried in descending/ascending order
- MBS: searched in Stage 2; choices must include 1
- RECOMPUTE_GRANULARITY: choices must include "selective" and "full"; values are determined by the strategy heuristic
Fixed parameters (choices must include the fixed value):
- PP, CP, ETP: fixed to 1
- RECOMPUTE_METHOD: fixed to "uniform"
- RECOMPUTE_NUM_LAYERS: fixed to 1
- VIT_GRADIENT_CHECKPOINTING: fixed to "true"
By requiring fixed parameters in search_space, other strategies can vary these values by providing multiple choices.
Stage 1 (parallelization search): Exhaustively searches TP and EP combinations. For each TP value, the best EP is recorded.
Stage 2 (computational efficiency search): For each TP that succeeded in Stage 1, fixes the best EP and searches MBS and recompute granularity:
- selective: tries MBS from 1 upward, stops on failure (OOM/ERROR)
- full: tries MBS from best_mbs (found in selective) upward, stops on failure. MBS=1 is skipped because it was already evaluated in Stage 1.
n_trials requirement: This strategy runs a fixed, deterministic number of trials. Set --n-trials large enough to avoid premature termination; excess budget is simply unused.
Example:
strategy = MegatronHeuristicStrategy(n_gpus=8) strategy.optimize(eval_fn, search_space, Direction.MAXIMIZE)
__init__
def __init__(*, n_gpus: int, **kwargs: Any) -> None
Initialize the Megatron heuristic strategy.
Arguments:
n_gpus- Total number of GPUs. Must be a power of 2. Used to constrain n_gpus % (tp * ep) == 0.**kwargs- Ignored (absorbs extra preset args).
optimize
def optimize(eval_fn: Callable[[dict[str, Any]], float | None],
search_space: dict[str, Parameter], direction: Direction) -> None
Run the 2-stage heuristic optimization loop.
Arguments:
eval_fn- Callable that executes one trial. Takes a parameter dict and returns the objective value, or None if the trial failed. Raises TrialExhausted when the trial budget is exceeded.search_space- Dictionary mapping parameter names to Parameter definitions. Searched keys: tp, ep, mbs (choices must include 1); recompute_granularity (choices must include "selective" and "full"; values are determined by the strategy). Fixed keys (choices must include the fixed value): pp, cp, etp=1; recompute_method="uniform"; recompute_num_layers=1; vit_gradient_checkpointing="true".direction- Optimization direction (MAXIMIZE for TFLOP/s).