Trainer
The Trainer
struct is the main entry point for training a model. It is responsible for managing the training loop, logging, and checkpointing. It is also responsible for managing the FitState
struct, which contains the state of the training loop.
Pass a model (a FluxModule
) and a trainer to the function Tsunami.fit!
to train the model. After training, you can use the Tsunami.test
function to test the model on a test dataset.
Tsunami.Trainer
— TypeTrainer(; kws...)
A type storing the training options to be passed to fit!
.
A Trainer
object also contains a field fit_state
of type FitState
mantaining updated information about the fit state during the execution of fit!
.
Constructor Arguments
autodiff: The automatic differentiation engine to use. Possible values are
:zygote
and:enzyme
. Default::zygote
.callbacks: Pass a single or a list of callbacks. Default
nothing
.checkpointer: If
true
, enable checkpointing. Default:true
.default_root_dir : Default path for logs and weights. Default:
pwd()
.fast_dev_run: If set to
true
runs a single batch for train and validation to find any bugs. Default:false
.log_every_n_steps: How often to log within steps. See also
logger
. Default:50
.logger: If
true
use tensorboard for logging. Every output of thetrain_step
will be logged every 50 steps by default. Setlog_every_n_steps
to change this. Default:true
.max_epochs: Stop training once this number of epochs is reached. Disabled by default (
nothing
). If bothmax_epochs
andmax_steps
are not specified, defaults tomax_epochs = 1000
. To enable infinite training, setmax_epochs
= -1. Default:nothing
.max_steps: Stop training after this number of steps. Disabled by default (
-1
). Ifmax_steps = -1
andmax_epochs = nothing
, will default tomax_epochs = 1000
. To enable infinite training, setmax_epochs
to-1
. Default:-1
.progress_bar: It
true
, shows a progress bar during training. Default:true
.val_every_n_epochs: Perform a validation loop every after every N training epochs. The validation loop is in any case performed at the end of the last training epoch. Set to 0 or negative to disable validation. Default:
1
.
The constructor also take any of the Foil
's constructor arguments:
accelerator: Supports passing different accelerator types:
:auto
(default): Automatically select a gpu if available, otherwise fallback on cpu.:gpu
: Like:auto
, but will throw an error if no gpu is available. In order for a gpu to be available, the corresponding package must be loaded (e.g. withusing CUDA
). The trigger packages areCUDA.jl
for Nvidia GPUs,AMDGPU.jl
for AMD GPUs, andMetal.jl
for Apple Silicon.:cpu
: Force using the cpu.
See also the
devices
option.devices: Pass an integer
n
to train onn
devices (only1
supported at the moment), or a list of devices ids to train on specific devices (e.g.[2]
to train on gpu with idx 2). Ids indexing starts from1
. Ifnothing
, will use the default device (seeMLDataDevices.gpu_device
documentation). Default:nothing
.precision: Supports passing different precision types
(:bf16, :f16, :f32, :f64)
, where:bf16
is BFloat16,:f16
is Float16,:f32
is Float32, and:f64
is Float64. Default::f32
.
Fields
Besides most of the constructor arguments, a Trainer
object also contains the following fields:
- fit_state: A
FitState
object storing the state of execution during a call tofit!
. - foil: A
Foil
object. - loggers: A list of loggers.
- lr_schedulers: The learning rate schedulers used for training.
- optimisers: The optimisers used for training.
Examples
trainer = Trainer(max_epochs = 10,
accelerator = :cpu,
checkpointer = true,
logger = true)
Tsunami.fit!(model, trainer, train_dataloader, val_dataloader)
Tsunami.fit!
— Functionfit!(model, trainer, train_dataloader, [val_dataloader]; [ckpt_path])
Train model
using the configuration given by trainer
. If ckpt_path
is given, training is resumed from the checkpoint.
After the fit, trainer.fit_state
will contain the final state of the training.
See also Trainer
and FitState
.
Arguments
- model: A Flux model subtyping
FluxModule
. - trainer: A
Trainer
object storing the configuration options forfit!
. - train_dataloader: An iterator over the training dataset, typically a
Flux.DataLoader
. - val_dataloader: An iterator over the validation dataset, typically a
Flux.DataLoader
. Default:nothing
. - ckpt_path: Path of the checkpoint from which training is resumed. Default:
nothing
.
Examples
model = ...
trainer = Trainer(max_epochs = 10)
Tsunami.fit!(model, trainer, train_dataloader, val_dataloader)
run_dir = trainer.fit_state.run_dir
# Resume training from checkpoint
trainer = Trainer(max_epochs = 20) # train for 10 more epochs
ckpt_path = joinpath(run_dir, "checkpoints", "ckpt_last.jld2")
fit_state′ = Tsunami.fit!(model, trainer, train_dataloader, val_dataloader; ckpt_path)
Tsunami.FitState
— TypeFitState
A type storing the state of execution during a call to fit!
.
A FitState
object is part of a Trainer
object.
Fields
epoch
: the current epoch number.run_dir
: the directory where the logs and checkpoints are saved.stage
: the current stage of execution. One of:training
,:train_epoch_end
,:validation
,:val_epoch_end
.step
: the current step number.batchsize
: number of samples in the current batch.should_stop
: set totrue
to stop the training loop.
Tsunami.test
— Functiontest(model::FluxModule, trainer, dataloader)
Run the test loop, calling the test_step
method on the model for each batch returned by the dataloader
. Returns the aggregated results from the values logged in the test_step
as a dictionary.
Examples
julia> struct Model <: FluxModule end
julia> function Tsunami.test_step(::Model, trainer, batch)
Tsunami.log(trainer, "test/loss", rand())
end
julia> model, trainer = Model(), Trainer();
julia> test_results = Tsunami.test(model, trainer, [rand(2) for i=1:3]);
Testing: 100%|████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 (6.04 μs/it)
test/loss: 0.675
julia> test_results
Dict{String, Float64} with 1 entry:
"test/loss" => 0.674665
Tsunami.validate
— Functionvalidate(model::FluxModule, trainer, dataloader)
Run the validation loop, calling the val_step
method on the model for each batch returned by the dataloader
. Returns the aggregated results from the values logged in the val_step
as a dictionary.
See also Tsunami.test
and Tsunami.fit!
.