Hooks
Hooks are a way to extend the functionality of Tsunami. They are a way to inject custom code into the FluxModule or into a Callback at various points in the training, testing, and validation loops.
At a high level, and omitting function imputs and outputs, a simplified version of the Tsunami.fit!
method looks like this:
function fit!()
configure_optimizers()
for epoch in epochs
train_loop()
end
end
function train_loop()
on_train_epoch_start()
set_learning_rate(lr_scheduler, epoch)
for (batch, batch_idx) in enumerate(train_dataloader)
batch = transfer_batch_to_device(batch)
on_train_batch_start(batch, batch_idx)
out, grad = out_and_gradient(train_step, model, trainer, batch, batch_idx)
on_before_update(out, grad)
update!(opt_state, model, grad)
on_train_batch_end(out, batch, batch_idx)
if should_check_val
val_loop()
end
end
on_train_epoch_end()
end
function val_loop()
on_val_epoch_start()
for (batch, batch_idx) in val_dataloader
batch = transfer_batch_to_device(batch)
on_val_batch_start(batch, batch_idx)
out = val_step(model, trainer, batch, batch_idx)
on_val_batch_end(out, batch, batch_idx)
end
on_val_epoch_end()
end
Each on_something
hook takes as input the model and the trainer.
Hooks API
Tsunami.on_before_update
— Functionon_before_update([callback,] model, trainer, out, grad)
Called before the call to Optimisers.update!
that applies the gradient grad
to update the model
's parameters. out
is the output of the last call to train_step
.
Tsunami.on_train_batch_start
— Functionon_train_batch_start([callback,] model, trainer, batch, batch_idx)
Called at the beginning of each training batch.
Tsunami.on_train_batch_end
— Functionon_train_batch_end([callback,] model, trainer, out, batch, batch_idx)
Called at the end of each iteration in the training loop. out
is the output of train_step
.
Tsunami.on_train_epoch_start
— Functionon_train_epoch_start([callback,] model, trainer)
Called at the beginning of each training epoch.
Tsunami.on_train_epoch_end
— Functionon_train_epoch_end([callback,] model, trainer)
Called at the end of each training epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the model and access them in this hook:
See also on_train_epoch_start
.
Examples
struct Callback
training_step_outputs::Vector{Float32}
# other fields...
end
function Tsunami.train_step(model::MyModel, trainer, batch)
...
return (loss = loss, accuracy = accuracy)
end
function Tsunami.on_train_epoch_start(cb::Callback, model, trainer)
empty!(cb.training_step_outputs)
end
function Tsunami.on_train_batch_end(cb::Callback, model, trainer, out, batch, batch_idx)
push!(cb.training_step_outputs, out.accuracy)
end
function Tsunami.on_train_epoch_end(cb::Callback, model, trainer)
println("Mean accuracy: ", mean(cb.training_step_outputs))
end
Tsunami.on_test_batch_start
— Functionon_test_batch_start([callback,] model, trainer, batch, batch_idx)
Called at the beginning of each test batch.
Tsunami.on_test_batch_end
— Functionon_test_batch_end([callback,] model, trainer, out, batch, batch_idx)
Called at the end of each iteration in the test loop. out
is the output of test_step
.
Tsunami.on_test_epoch_start
— Functionon_test_epoch_start([callback,] model, trainer)
Called at the beginning of each test epoch.
Tsunami.on_test_epoch_end
— Functionon_test_epoch_end([callback,] model, trainer)
Called at the end of each test epoch.
Tsunami.on_val_batch_start
— Functionon_val_batch_start([callback,] model, trainer, batch, batch_idx)
Called at the beginning of each validation batch.
Tsunami.on_val_batch_end
— Functionon_val_batch_end([callback,] model, trainer, out, batch, batch_idx)
Called at the end of each iteration in the validation loop. out
is the output of val_step
.
Tsunami.on_val_epoch_start
— Functionon_val_epoch_start([callback,] model, trainer)
Called at the beginning of each validation epoch.
Tsunami.on_val_epoch_end
— Functionon_val_epoch_end([callback,] model, trainer)
Called at the end of each validation epoch.