Trainer

The Trainer struct is the main entry point for training a model. It is responsible for managing the training loop, logging, and checkpointing. It is also responsible for managing the FitState struct, which contains the state of the training loop.

Pass a model (a FluxModule) and a trainer to the function Tsunami.fit! to train the model. After training, you can use the Tsunami.test function to test the model on a test dataset.

Tsunami.TrainerType
Trainer(; kws...)

A type storing the training options to be passed to fit!.

A Trainer object also contains a field fit_state of type FitState mantaining updated information about the fit state during the execution of fit!.

Constructor Arguments

  • autodiff: The automatic differentiation engine to use. Possible values are :zygote and :enzyme . Default: :zygote.

  • callbacks: Pass a single or a list of callbacks. Default nothing.

  • checkpointer: If true, enable checkpointing. Default: true.

  • default_root_dir : Default path for logs and weights. Default: pwd().

  • fast_dev_run: If set to true runs a single batch for train and validation to find any bugs. Default: false.

  • log_every_n_steps: How often to log within steps. See also logger. Default: 50.

  • logger: If true use tensorboard for logging. Every output of the train_step will be logged every 50 steps by default. Set log_every_n_steps to change this. Default: true.

  • max_epochs: Stop training once this number of epochs is reached. Disabled by default (nothing). If both max_epochs and max_steps are not specified, defaults to max_epochs = 1000. To enable infinite training, set max_epochs = -1. Default: nothing.

  • max_steps: Stop training after this number of steps. Disabled by default (-1). If max_steps = -1 and max_epochs = nothing, will default to max_epochs = 1000. To enable infinite training, set max_epochs to -1. Default: -1.

  • progress_bar: It true, shows a progress bar during training. Default: true.

  • val_every_n_epochs: Perform a validation loop every after every N training epochs. The validation loop is in any case performed at the end of the last training epoch. Set to 0 or negative to disable validation. Default: 1.

The constructor also take any of the Foil's constructor arguments:

  • accelerator: Supports passing different accelerator types:

    • :auto (default): Automatically select a gpu if available, otherwise fallback on cpu.
    • :gpu: Like :auto, but will throw an error if no gpu is available. In order for a gpu to be available, the corresponding package must be loaded (e.g. with using CUDA). The trigger packages are CUDA.jl for Nvidia GPUs, AMDGPU.jl for AMD GPUs, and Metal.jl for Apple Silicon.
    • :cpu: Force using the cpu.

    See also the devices option.

  • devices: Pass an integer n to train on n devices (only 1 supported at the moment), or a list of devices ids to train on specific devices (e.g. [2] to train on gpu with idx 2). Ids indexing starts from 1. If nothing, will use the default device (see MLDataDevices.gpu_device documentation). Default: nothing.

  • precision: Supports passing different precision types (:bf16, :f16, :f32, :f64), where :bf16 is BFloat16, :f16 is Float16, :f32 is Float32, and :f64 is Float64. Default: :f32.

Fields

Besides most of the constructor arguments, a Trainer object also contains the following fields:

  • fit_state: A FitState object storing the state of execution during a call to fit!.
  • foil: A Foil object.
  • loggers: A list of loggers.
  • lr_schedulers: The learning rate schedulers used for training.
  • optimisers: The optimisers used for training.

Examples

trainer = Trainer(max_epochs = 10, 
                  accelerator = :cpu,
                  checkpointer = true,
                  logger = true)

Tsunami.fit!(model, trainer, train_dataloader, val_dataloader)
source
Tsunami.fit!Function
fit!(model, trainer, train_dataloader, [val_dataloader]; [ckpt_path])

Train model using the configuration given by trainer. If ckpt_path is given, training is resumed from the checkpoint.

After the fit, trainer.fit_state will contain the final state of the training.

See also Trainer and FitState.

Arguments

  • model: A Flux model subtyping FluxModule.
  • trainer: A Trainer object storing the configuration options for fit!.
  • train_dataloader: An iterator over the training dataset, typically a Flux.DataLoader.
  • val_dataloader: An iterator over the validation dataset, typically a Flux.DataLoader. Default: nothing.
  • ckpt_path: Path of the checkpoint from which training is resumed. Default: nothing.

Examples

model = ...
trainer = Trainer(max_epochs = 10)
Tsunami.fit!(model, trainer, train_dataloader, val_dataloader)
run_dir = trainer.fit_state.run_dir

# Resume training from checkpoint
trainer = Trainer(max_epochs = 20) # train for 10 more epochs
ckpt_path = joinpath(run_dir, "checkpoints", "ckpt_last.jld2")
fit_state′ = Tsunami.fit!(model, trainer, train_dataloader, val_dataloader; ckpt_path)
source
Tsunami.FitStateType
FitState

A type storing the state of execution during a call to fit!.

A FitState object is part of a Trainer object.

Fields

  • epoch: the current epoch number.
  • run_dir: the directory where the logs and checkpoints are saved.
  • stage: the current stage of execution. One of :training, :train_epoch_end, :validation, :val_epoch_end.
  • step: the current step number.
  • batchsize: number of samples in the current batch.
  • should_stop: set to true to stop the training loop.
source
Tsunami.testFunction
test(model::FluxModule, trainer, dataloader)

Run the test loop, calling the test_step method on the model for each batch returned by the dataloader. Returns the aggregated results from the values logged in the test_step as a dictionary.

Examples

julia> struct Model <: FluxModule end 

julia> function Tsunami.test_step(::Model, trainer, batch)
    Tsunami.log(trainer, "test/loss", rand())
end

julia> model, trainer = Model(), Trainer();

julia> test_results = Tsunami.test(model, trainer, [rand(2) for i=1:3]);
Testing: 100%|████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 (6.04 μs/it)
  test/loss:  0.675

julia> test_results
Dict{String, Float64} with 1 entry:
  "test/loss" => 0.674665
source
Tsunami.validateFunction
validate(model::FluxModule, trainer, dataloader)

Run the validation loop, calling the val_step method on the model for each batch returned by the dataloader. Returns the aggregated results from the values logged in the val_step as a dictionary.

See also Tsunami.test and Tsunami.fit!.

source