Trainer
The Trainer struct is the main entry point for training a model. It is responsible for managing the training loop, logging, and checkpointing. It is also responsible for managing the FitState struct, which contains the state of the training loop.
Pass a model (a FluxModule) and a trainer to the function Tsunami.fit! to train the model. After training, you can use the Tsunami.test function to test the model on a test dataset.
Tsunami.Trainer — TypeTrainer(; kws...)A type storing the training options to be passed to fit!.
A Trainer object also contains a field fit_state of type FitState mantaining updated information about the fit state during the execution of fit!.
Constructor Arguments
autodiff: The automatic differentiation engine to use. Possible values are
:zygoteand:enzyme. Default::zygote.callbacks: Pass a single or a list of callbacks. Default
nothing.checkpointer: If
true, enable checkpointing. Default:true.default_root_dir : Default path for logs and weights. Default:
pwd().fast_dev_run: If set to
trueruns a single batch for train and validation to find any bugs. Default:false.log_every_n_steps: How often to log within steps. See also
logger. Default:50.logger: If
trueuse tensorboard for logging. Every output of thetrain_stepwill be logged every 50 steps by default. Setlog_every_n_stepsto change this. Default:true.max_epochs: Stop training once this number of epochs is reached. Disabled by default (
nothing). If bothmax_epochsandmax_stepsare not specified, defaults tomax_epochs = 1000. To enable infinite training, setmax_epochs= -1. Default:nothing.max_steps: Stop training after this number of steps. Disabled by default (
-1). Ifmax_steps = -1andmax_epochs = nothing, will default tomax_epochs = 1000. To enable infinite training, setmax_epochsto-1. Default:-1.progress_bar: It
true, shows a progress bar during training. Default:true.val_every_n_epochs: Perform a validation loop every after every N training epochs. The validation loop is in any case performed at the end of the last training epoch. Set to 0 or negative to disable validation. Default:
1.
The constructor also take any of the Foil's constructor arguments:
accelerator: Supports passing different accelerator types:
:auto(default): Automatically select a gpu if available, otherwise fallback on cpu.:gpu: Like:auto, but will throw an error if no gpu is available. In order for a gpu to be available, the corresponding package must be loaded (e.g. withusing CUDA). The trigger packages areCUDA.jlfor Nvidia GPUs,AMDGPU.jlfor AMD GPUs, andMetal.jlfor Apple Silicon.:cpu: Force using the cpu.
See also the
devicesoption.devices: Pass an integer
nto train onndevices (only1supported at the moment), or a list of devices ids to train on specific devices (e.g.[2]to train on gpu with idx 2). Ids indexing starts from1. Ifnothing, will use the default device (seeMLDataDevices.gpu_devicedocumentation). Default:nothing.precision: Supports passing different precision types
(:bf16, :f16, :f32, :f64), where:bf16is BFloat16,:f16is Float16,:f32is Float32, and:f64is Float64. Default::f32.
Fields
Besides most of the constructor arguments, a Trainer object also contains the following fields:
- fit_state: A
FitStateobject storing the state of execution during a call tofit!. - foil: A
Foilobject. - loggers: A list of loggers.
- lr_schedulers: The learning rate schedulers used for training.
- optimisers: The optimisers used for training.
Examples
trainer = Trainer(max_epochs = 10,
accelerator = :cpu,
checkpointer = true,
logger = true)
Tsunami.fit!(model, trainer, train_dataloader, val_dataloader)Tsunami.fit! — Functionfit!(model, trainer, train_dataloader, [val_dataloader]; [ckpt_path])Train model using the configuration given by trainer. If ckpt_path is given, training is resumed from the checkpoint.
After the fit, trainer.fit_state will contain the final state of the training.
See also Trainer and FitState.
Arguments
- model: A Flux model subtyping
FluxModule. - trainer: A
Trainerobject storing the configuration options forfit!. - train_dataloader: An iterator over the training dataset, typically a
Flux.DataLoader. - val_dataloader: An iterator over the validation dataset, typically a
Flux.DataLoader. Default:nothing. - ckpt_path: Path of the checkpoint from which training is resumed. Default:
nothing.
Examples
model = ...
trainer = Trainer(max_epochs = 10)
Tsunami.fit!(model, trainer, train_dataloader, val_dataloader)
run_dir = trainer.fit_state.run_dir
# Resume training from checkpoint
trainer = Trainer(max_epochs = 20) # train for 10 more epochs
ckpt_path = joinpath(run_dir, "checkpoints", "ckpt_last.jld2")
fit_state′ = Tsunami.fit!(model, trainer, train_dataloader, val_dataloader; ckpt_path)Tsunami.FitState — TypeFitStateA type storing the state of execution during a call to fit!.
A FitState object is part of a Trainer object.
Fields
epoch: the current epoch number.run_dir: the directory where the logs and checkpoints are saved.stage: the current stage of execution. One of:training,:train_epoch_end,:validation,:val_epoch_end.step: the current step number.batchsize: number of samples in the current batch.should_stop: set totrueto stop the training loop.
Tsunami.test — Functiontest(model::FluxModule, trainer, dataloader)Run the test loop, calling the test_step method on the model for each batch returned by the dataloader. Returns the aggregated results from the values logged in the test_step as a dictionary.
Examples
julia> struct Model <: FluxModule end
julia> function Tsunami.test_step(::Model, trainer, batch)
Tsunami.log(trainer, "test/loss", rand())
end
julia> model, trainer = Model(), Trainer();
julia> test_results = Tsunami.test(model, trainer, [rand(2) for i=1:3]);
Testing: 100%|████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 (6.04 μs/it)
test/loss: 0.675
julia> test_results
Dict{String, Float64} with 1 entry:
"test/loss" => 0.674665Tsunami.validate — Functionvalidate(model::FluxModule, trainer, dataloader)Run the validation loop, calling the val_step method on the model for each batch returned by the dataloader. Returns the aggregated results from the values logged in the val_step as a dictionary.
See also Tsunami.test and Tsunami.fit!.