Callbacks
Callbacks are functions that are called at certain points in the training process. They are useful for logging, early stopping, and other tasks.
Callbacks are passed to the Trainer
constructor:
callback1 = Checkpointer(...)
trainer = Trainer(..., callbacks = [callback1, ...])
Callback implement their functionalities thanks to the hooks described in the Hooks section of the documentation.
Available Callbacks
A few callbacks are provided by Tsunami.
Checkpoints
Callbacks for saving and loading the model and optimizer state.
Tsunami.Checkpointer
— TypeCheckpointer(folder = nothing) <: AbstractCallback
An helper class for saving a FluxModule
and the fit state. The checkpoint is saved as a JLD@ file with the name ckpt_epoch=X_step=Y.jld2
. A symbolic link to the last checkpoint is also created as ckpt_last.jld2
.
A Checkpointer
is automatically created when checkpointer = true
is passed to fit!
.
If folder
is not specified, the checkpoints are saved in a folder named checkpoints
in the run directory.
See also: load_checkpoint
.
Examples
checkpointer = Checkpointer()
Tsunami.fit!(..., callbacks = [checkpointer])
Tsunami.load_checkpoint
— Functionload_checkpoint(path)
Loads a checkpoint that was saved to path
. Returns a namedtuple containing the model state, the fit state, the lr schedulers and the optimisers.
See also: Checkpointer
.
Examples
ckpt = load_checkpoint("checkpoints/ckpt_last.jld2")
model = MyModel(...)
Flux.loadmodel!(model, ckpt.model_state)
Writing Custom Callbacks
Users can write their own callbacks by defining customs types and implementing the hooks they need. For example
struct MyCallback end
function Tsunami.on_train_epoch_end(cb::MyCallback, model, trainer)
fit_state = trainer.fit_state # contains info about the training status
# do something
end
trainer = Trainer(..., callbacks = [MyCallback()])
See the implementation of Checkpointer
and the Hooks section of the documentation for more information on how to write custom callbacks. Also, the examples folder contains some examples of custom callbacks.