Callbacks

Callbacks are functions that are called at certain points in the training process. They are useful for logging, early stopping, and other tasks.

Callbacks are passed to the Trainer constructor:

callback1 = Checkpointer(...)
trainer = Trainer(..., callbacks = [callback1, ...])

Callback implement their functionalities thanks to the hooks described in the Hooks section of the documentation.

Available Callbacks

A few callbacks are provided by Tsunami.

Checkpoints

Callbacks for saving and loading the model and optimizer state.

Tsunami.CheckpointerType
Checkpointer(folder = nothing) <: AbstractCallback

An helper class for saving a FluxModule and the fit state. The checkpoint is saved as a JLD@ file with the name ckpt_epoch=X_step=Y.jld2. A symbolic link to the last checkpoint is also created as ckpt_last.jld2.

A Checkpointer is automatically created when checkpointer = true is passed to fit!.

If folder is not specified, the checkpoints are saved in a folder named checkpoints in the run directory.

See also: load_checkpoint.

Examples

checkpointer = Checkpointer()
Tsunami.fit!(..., callbacks = [checkpointer])
source
Tsunami.load_checkpointFunction
load_checkpoint(path)

Loads a checkpoint that was saved to path. Returns a namedtuple containing the model state, the fit state, the lr schedulers and the optimisers.

See also: Checkpointer.

Examples

ckpt = load_checkpoint("checkpoints/ckpt_last.jld2")
model = MyModel(...)
Flux.loadmodel!(model, ckpt.model_state)
source

Writing Custom Callbacks

Users can write their own callbacks by defining customs types and implementing the hooks they need. For example

struct MyCallback end

function Tsunami.on_train_epoch_end(cb::MyCallback, model, trainer)
    fit_state = trainer.fit_state # contains info about the training status
    # do something
end

trainer = Trainer(..., callbacks = [MyCallback()])

See the implementation of Checkpointer and the Hooks section of the documentation for more information on how to write custom callbacks. Also, the examples folder contains some examples of custom callbacks.