Skip to content

Hyperparameter Tuning

Hyperparameter optimization with Ray Tune integration.

LearnerTrainable

Bases: Trainable

Ray Tune Trainable wrapper for Learners.

Parameters:

Name Type Description Default
config

Ray Tune config dict containing 'create_lrn' and 'dls' references.

required

HPOptimizer

HPOptimizer(create_lrn: Callable, dls)

High-level interface for hyperparameter optimization with Ray Tune.

Parameters:

Name Type Description Default
create_lrn Callable

factory function that creates a Learner from (dls, config).

required
dls

DataLoaders to use for training.

required
Source code in tsfast/tune.py
def __init__(self, create_lrn: Callable, dls):
    self.create_lrn = create_lrn
    self.dls = dls
    self.analysis = None

start_ray

start_ray(**kwargs)

Initialize Ray runtime.

Source code in tsfast/tune.py
def start_ray(self, **kwargs):
    """Initialize Ray runtime."""
    ray.shutdown()
    ray.init(**kwargs)

stop_ray

stop_ray()

Shut down Ray runtime.

Source code in tsfast/tune.py
def stop_ray(self):
    """Shut down Ray runtime."""
    ray.shutdown()

optimize

optimize(config: dict, optimize_func: Callable = learner_optimize, resources_per_trial: dict = {'gpu': 1.0}, verbose: int = 1, **kwargs)

Run hyperparameter optimization using the function-based API.

Parameters:

Name Type Description Default
config dict

Ray Tune search space configuration dict.

required
optimize_func Callable

training function to optimize.

learner_optimize
resources_per_trial dict

resource dict per trial (e.g. GPU/CPU counts).

{'gpu': 1.0}
verbose int

Ray Tune verbosity level.

1
Source code in tsfast/tune.py
def optimize(
    self,
    config: dict,
    optimize_func: Callable = learner_optimize,
    resources_per_trial: dict = {"gpu": 1.0},
    verbose: int = 1,
    **kwargs,
):
    """Run hyperparameter optimization using the function-based API.

    Args:
        config: Ray Tune search space configuration dict.
        optimize_func: training function to optimize.
        resources_per_trial: resource dict per trial (e.g. GPU/CPU counts).
        verbose: Ray Tune verbosity level.
    """
    self._ensure_ray()
    config["create_lrn"] = ray.put(self.create_lrn)
    # dls are large objects, letting ray handle the copying process makes it much faster
    config["dls"] = ray.put(self.dls)

    self.analysis = tune.run(
        optimize_func, config=config, resources_per_trial=resources_per_trial, verbose=verbose, **kwargs
    )
    return self.analysis

optimize_pbt

optimize_pbt(opt_name: str, num_samples: int, config: dict, mut_conf: dict, perturbation_interval: int = 2, stop: dict = {'training_iteration': 40}, resources_per_trial: dict = {'gpu': 1}, resample_probability: float = 0.25, quantile_fraction: float = 0.25, **kwargs)

Run Population Based Training optimization.

Parameters:

Name Type Description Default
opt_name str

experiment name for Ray Tune.

required
num_samples int

number of parallel trials.

required
config dict

initial hyperparameter configuration dict.

required
mut_conf dict

mutable hyperparameter space for PBT mutations.

required
perturbation_interval int

epochs between PBT perturbations.

2
stop dict

stopping criteria dict.

{'training_iteration': 40}
resources_per_trial dict

resource dict per trial.

{'gpu': 1}
resample_probability float

probability of resampling vs. perturbing.

0.25
quantile_fraction float

fraction of trials to exploit/explore.

0.25
Source code in tsfast/tune.py
def optimize_pbt(
    self,
    opt_name: str,
    num_samples: int,
    config: dict,
    mut_conf: dict,
    perturbation_interval: int = 2,
    stop: dict = {"training_iteration": 40},
    resources_per_trial: dict = {"gpu": 1},
    resample_probability: float = 0.25,
    quantile_fraction: float = 0.25,
    **kwargs,
):
    """Run Population Based Training optimization.

    Args:
        opt_name: experiment name for Ray Tune.
        num_samples: number of parallel trials.
        config: initial hyperparameter configuration dict.
        mut_conf: mutable hyperparameter space for PBT mutations.
        perturbation_interval: epochs between PBT perturbations.
        stop: stopping criteria dict.
        resources_per_trial: resource dict per trial.
        resample_probability: probability of resampling vs. perturbing.
        quantile_fraction: fraction of trials to exploit/explore.
    """
    self.mut_conf = mut_conf

    self._ensure_ray()
    config["create_lrn"] = ray.put(self.create_lrn)
    # dls are large objects, letting ray handle the copying process makes it much faster
    config["dls"] = ray.put(self.dls)

    scheduler = PopulationBasedTraining(
        time_attr="training_iteration",
        metric="mean_loss",
        mode="min",
        perturbation_interval=perturbation_interval,
        resample_probability=resample_probability,
        quantile_fraction=quantile_fraction,
        hyperparam_mutations=mut_conf,
    )

    self.analysis = tune.run(
        LearnerTrainable,
        name=opt_name,
        scheduler=scheduler,
        reuse_actors=True,
        verbose=1,
        stop=stop,
        checkpoint_score_attr="mean_loss",
        num_samples=num_samples,
        resources_per_trial=resources_per_trial,
        config=config,
        **kwargs,
    )
    return self.analysis

best_model

best_model() -> nn.Module

Load and return the best model from the optimization run.

Source code in tsfast/tune.py
def best_model(self) -> nn.Module:
    """Load and return the best model from the optimization run."""
    if self.analysis is None:
        raise Exception
    model = self.create_lrn(self.dls, sample_config(self.mut_conf)).model
    f_path = ray.get(self.analysis.get_best_trial("mean_loss", mode="min").checkpoint.value)
    model.load_state_dict(torch.load(f_path))
    return model

log_uniform

log_uniform(min_bound: float, max_bound: float, base: float = 10) -> Callable

Sample uniformly in an exponential (log) range.

Parameters:

Name Type Description Default
min_bound float

lower bound of the sampling range.

required
max_bound float

upper bound of the sampling range.

required
base float

logarithm base for the exponential scale.

10
Source code in tsfast/tune.py
def log_uniform(min_bound: float, max_bound: float, base: float = 10) -> Callable:
    """Sample uniformly in an exponential (log) range.

    Args:
        min_bound: lower bound of the sampling range.
        max_bound: upper bound of the sampling range.
        base: logarithm base for the exponential scale.
    """
    logmin = np.log(min_bound) / np.log(base)
    logmax = np.log(max_bound) / np.log(base)

    def _sample():
        return base ** (np.random.uniform(logmin, logmax))

    return _sample

stop_shared_memory_managers

stop_shared_memory_managers(obj: object)

Find and stop all SharedMemoryManager instances within an object.

Parameters:

Name Type Description Default
obj object

root object to traverse for SharedMemoryManager instances.

required
Source code in tsfast/tune.py
def stop_shared_memory_managers(obj: object):
    """Find and stop all SharedMemoryManager instances within an object.

    Args:
        obj: root object to traverse for SharedMemoryManager instances.
    """
    visited = set()  # Track visited objects to avoid infinite loops
    stack = [obj]  # Use a stack to manage objects to inspect

    while stack:
        current_obj = stack.pop()
        obj_id = id(current_obj)

        if obj_id in visited:
            continue  # Skip already visited objects
        visited.add(obj_id)

        # Check if the current object is a SharedMemoryManager and stop it
        if isinstance(current_obj, SharedMemoryManager):
            current_obj.shutdown()
            continue

        # If it's a collection, add its items to the stack. Otherwise, add its attributes.
        if isinstance(current_obj, dict):
            stack.extend(current_obj.keys())
            stack.extend(current_obj.values())
        elif isinstance(current_obj, (list, set, tuple)):
            stack.extend(current_obj)
        elif hasattr(current_obj, "__dict__"):  # Check for custom objects with attributes
            stack.extend(vars(current_obj).values())

learner_optimize

learner_optimize(config: dict)

Training function for Ray Tune function-based API.

Parameters:

Name Type Description Default
config dict

Ray Tune config dict containing 'create_lrn', 'dls', 'fit_method', and hyperparameters.

required
Source code in tsfast/tune.py
def learner_optimize(config: dict):
    """Training function for Ray Tune function-based API.

    Args:
        config: Ray Tune config dict containing 'create_lrn', 'dls',
            'fit_method', and hyperparameters.
    """
    try:
        create_lrn = ray.get(config["create_lrn"])
        dls = ray.get(config["dls"])
        dls = dls.to(_worker_device()) if hasattr(dls, "to") else dls

        # Scheduling Parameters for training the Model
        lrn_kwargs = {"n_epoch": 100, "pct_start": 0.5}
        for attr in ["n_epoch", "pct_start"]:
            if attr in config:
                lrn_kwargs[attr] = config[attr]

        lrn = create_lrn(dls, config)

        # load checkpoint data if provided
        checkpoint: tune.Checkpoint = tune.get_checkpoint()
        if checkpoint:
            with checkpoint.as_directory() as checkpoint_dir:
                lrn.model.load_state_dict(torch.load(checkpoint_dir + "model.pth"))

        lr = config["lr"] if "lr" in config else 3e-3
        lrn.lr = lr() if callable(lr) else lr
        _attach_ray_reporter(lrn)
        fit_method = ray.get(config["fit_method"]) if "fit_method" in config else None
        with lrn.no_bar():
            if fit_method is not None:
                fit_method(lrn, **lrn_kwargs)
            else:
                lrn.fit_flat_cos(**lrn_kwargs)
    finally:
        # cleanup shared memory even when earlystopping occurs
        if "lrn" in locals():
            stop_shared_memory_managers(lrn)
            del lrn
            gc.collect()

sample_config

sample_config(config: dict) -> dict

Sample concrete values from a config of callables.

Parameters:

Name Type Description Default
config dict

dict mapping keys to callable samplers.

required
Source code in tsfast/tune.py
def sample_config(config: dict) -> dict:
    """Sample concrete values from a config of callables.

    Args:
        config: dict mapping keys to callable samplers.
    """
    ret_conf = config.copy()
    for k in ret_conf:
        ret_conf[k] = ret_conf[k]()
    return ret_conf