dnadna.training
Implements model training on simulation data using neural nets.
Classes
|
Replaces |
|
|
|
Helper descriptor for attributes which are not set until |
- class dnadna.training.DummySummaryWriter(*args, **kwargs)[source]
Bases:
object
Replaces
torch.utils.tensorboard.writer.SummaryWriter
when error logging is disabled.
- class dnadna.training.ModelTrainer(config, validate=True, progress_bar=False)[source]
Bases:
dnadna.utils.config.ConfigMixIn
- property dataset
Returns a
DNATrainingDataset
instance.
- property device
The device for PyTorch to use (CPU or a CUDA device).
- ensure_run_dir(run_id=None, overwrite=False)[source]
Make sure the run directory for this training run exists.
Creates the directory if it does not already exists. If it does already exist and
overwrite=False
aFileExistsError
is raised.
- property error_log
torch.utils.tensorboard.writer.SummaryWriter
for streaming error statistics and other statistics.Name might change in the future as it can be used for logging more than just errors.
- classmethod from_config_file(filename, progress_bar=False, **kwargs)[source]
Instantiate from a config file.
- property full_net_params
Returns the arguments that the
ModelTrainer.net
instance was instantiated with as adict
.
- get_run_info(run_id=None)[source]
Return the run_id (if not specified) and the model_root-relative path to the run output directory.
- init_data_loaders()[source]
Returns a training data
DataLoader
and a validationDataLoader
for data from a single model dataset.
- property learned_params
Returns a
LearnedParams
instance from the config.
- property loss_funcs
Loss functions to use for each trained parameter.
- property loss_weights
Weights to apply to the loss functions of each trained parameter.
- property n_outputs
Number of outputs from the neural net model–this is related to the number of parameters.
Specifically, it is the number of regression parameters, plus the number of classes in each classification parameter.
- property net
Returns the
torch.nn.Module
instance for the neural net to be trained.
- property optimizer
Returns the
torch.optim.Optimizer
subclass (see alsodnadna.optim.Optimizer
) that provides the optimization step algorithm for training the model.
- property param_names
List of the names of trained parameters.
- prepare(dataset=None, preprocessed_scenario_params=None, run_id=None, overwrite=False, error_log=False, save_best=False, save_checkpoints=False)[source]
Perform initial preparation for model training based on the provided
Config
.- Keyword Arguments
dataset (
DNATrainingDataset
) – (optional) – ADNATrainingDataset
instance, or if omitted the dataset specified by the config.preprocessed_scenario_params (
pandas.DataFrame
) – (optional) – PandasDataFrame
containing the preprocessed scenario parameters returned byDataPreprocessor.preprocess_scenario_params
. If omitted, the parameters are read from the dataset.run_id (int or str) – (optional) – Unique ID to identify this training run. If not specified, it will be generated automatically from the sequence of existing runs in the
model_root
directory.overwrite (
bool
) – (optional) – If the output directory for this run already exists andoverwrite=True
any existing run artifacts will be overwritten, otherwise an error is raised.error_log (
bool
) – (optional) – Enable error logging for the training run. Disabled by default except when runningModelTrainer.run_training
.save_best (
bool
) – (optional) – Enable saving the version of the model with the best losses during the training run, overriding this copy of the model each time the loss improves over the previous best loss. Disabled by default except when runningModelTrainer.run_training
.save_checkpoints (
bool
) – (optional) – Enable saving checkpoints of the model during the training run. Disabled by default except when runningModelTrainer.run_training
.
- property run_dir
The directory where training run artifacts will be output.
- property run_id
The number or string identifying this training run.
- property run_name
The full name of the training run. Typically
run_{run_id}
.
- run_training(run_id=None, overwrite=False)[source]
High-level interface to prepare and execute a training run, and save the resulting net to a file.
This is the method run by the command-line interface.
- property save_checkpoints
Whether or not model checkpoints are being saved.
- save_net(filename, quiet=False, **kwargs)[source]
Save the current network state dict to a pickle file.
Additional
**kwargs
are included in the pickled dict alongside the network state dict.
- property training_loader
torch.utils.data.DataLoader
for the training data set.
- property validation_loader
torch.utils.data.DataLoader
for the validation data set.
- class dnadna.training.prepared_property(fget=None, fset=None, fdel=None, doc=None)[source]
Bases:
property
Helper descriptor for attributes which are not set until
ModelTrainer.prepare()
is called.