Implements model training on simulation data using neural nets.


DummySummaryWriter(*args, **kwargs)

Replaces torch.utils.tensorboard.writer.SummaryWriter when error logging is disabled.

ModelTrainer(config[, validate, progress_bar])

prepared_property([fget, fset, fdel, doc])

Helper descriptor for attributes which are not set until ModelTrainer.prepare() is called.

class dnadna.training.DummySummaryWriter(*args, **kwargs)[source]

Bases: object

Replaces torch.utils.tensorboard.writer.SummaryWriter when error logging is disabled.

class dnadna.training.ModelTrainer(config, validate=True, progress_bar=False)[source]

Bases: dnadna.utils.config.ConfigMixIn

property dataset

Returns a DNATrainingDataset instance.

property device

The device for PyTorch to use (CPU or a CUDA device).

ensure_run_dir(run_id=None, overwrite=False)[source]

Make sure the run directory for this training run exists.

Creates the directory if it does not already exists. If it does already exist and overwrite=False a FileExistsError is raised.

property error_log

torch.utils.tensorboard.writer.SummaryWriter for streaming error statistics and other statistics.

Name might change in the future as it can be used for logging more than just errors.

classmethod from_config_file(filename, progress_bar=False, **kwargs)[source]

Instantiate from a config file.

property full_net_params

Returns the arguments that the ModelTrainer.net instance was instantiated with as a dict.


Return the run_id (if not specified) and the model_root-relative path to the run output directory.


Returns a training data DataLoader and a validation DataLoader for data from a single model dataset.

property learned_params

Returns a LearnedParams instance from the config.

property loss_funcs

Loss functions to use for each trained parameter.

property loss_weights

Weights to apply to the loss functions of each trained parameter.

property n_outputs

Number of outputs from the neural net model–this is related to the number of parameters.

Specifically, it is the number of regression parameters, plus the number of classes in each classification parameter.

property net

Returns the torch.nn.Module instance for the neural net to be trained.

property optimizer

Returns the torch.optim.Optimizer subclass (see also dnadna.optim.Optimizer) that provides the optimization step algorithm for training the model.

property param_names

List of the names of trained parameters.

prepare(dataset=None, preprocessed_scenario_params=None, run_id=None, overwrite=False, error_log=False, save_best=False, save_checkpoints=False)[source]

Perform initial preparation for model training based on the provided Config.

Keyword Arguments
  • dataset (DNATrainingDataset) – (optional) – A DNATrainingDataset instance, or if omitted the dataset specified by the config.

  • preprocessed_scenario_params (pandas.DataFrame) – (optional) – Pandas DataFrame containing the preprocessed scenario parameters returned by DataPreprocessor.preprocess_scenario_params. If omitted, the parameters are read from the dataset.

  • run_id (int or str) – (optional) – Unique ID to identify this training run. If not specified, it will be generated automatically from the sequence of existing runs in the model_root directory.

  • overwrite (bool) – (optional) – If the output directory for this run already exists and overwrite=True any existing run artifacts will be overwritten, otherwise an error is raised.

  • error_log (bool) – (optional) – Enable error logging for the training run. Disabled by default except when running ModelTrainer.run_training.

  • save_best (bool) – (optional) – Enable saving the version of the model with the best losses during the training run, overriding this copy of the model each time the loss improves over the previous best loss. Disabled by default except when running ModelTrainer.run_training.

  • save_checkpoints (bool) – (optional) – Enable saving checkpoints of the model during the training run. Disabled by default except when running ModelTrainer.run_training.

property run_dir

The directory where training run artifacts will be output.

property run_id

The number or string identifying this training run.

property run_name

The full name of the training run. Typically run_{run_id}.

run_training(run_id=None, overwrite=False)[source]

High-level interface to prepare and execute a training run, and save the resulting net to a file.

This is the method run by the command-line interface.

property save_checkpoints

Whether or not model checkpoints are being saved.

save_net(filename, quiet=False, **kwargs)[source]

Save the current network state dict to a pickle file.

Additional **kwargs are included in the pickled dict alongside the network state dict.

property training_loader

torch.utils.data.DataLoader for the training data set.

property validation_loader

torch.utils.data.DataLoader for the validation data set.

class dnadna.training.prepared_property(fget=None, fset=None, fdel=None, doc=None)[source]

Bases: property

Helper descriptor for attributes which are not set until ModelTrainer.prepare() is called.