Hooks

Hooks are a way to extend the functionality of Tsunami. They are a way to inject custom code into the FluxModule or into a Callback at various points in the training, testing, and validation loops.

At a high level, and omitting function imputs and outputs, a simplified version of the Tsunami.fit! method looks like this:

function fit!()
    configure_optimizers()
    
    for epoch in epochs
        train_loop()
    end
end

function train_loop()
    on_train_epoch_start()
    set_learning_rate(lr_scheduler, epoch)

    for (batch, batch_idx) in enumerate(train_dataloader)
        batch = transfer_batch_to_device(batch)
        on_train_batch_start(batch, batch_idx)
        out, grad = out_and_gradient(train_step, model, trainer, batch, batch_idx)
        on_before_update(out, grad)
        update!(opt_state, model, grad)
        on_train_batch_end(out, batch, batch_idx)
        if should_check_val
            val_loop()
        end
    end
    on_train_epoch_end()
end

function val_loop()
    on_val_epoch_start()
    for (batch, batch_idx) in val_dataloader
        batch = transfer_batch_to_device(batch)
        on_val_batch_start(batch, batch_idx)
        out = val_step(model, trainer, batch, batch_idx)
        on_val_batch_end(out, batch, batch_idx)
    end
    on_val_epoch_end()
end

Each on_something hook takes as input the model and the trainer.

Hooks API

Tsunami.on_before_updateFunction
on_before_update([callback,] model, trainer, out, grad)

Called before the call to Optimisers.update! that applies the gradient grad to update the model's parameters. out is the output of the last call to train_step.

source
Tsunami.on_train_epoch_endFunction
on_train_epoch_end([callback,] model, trainer)

Called at the end of each training epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the model and access them in this hook:

See also on_train_epoch_start.

Examples

struct Callback
    training_step_outputs::Vector{Float32}
    # other fields...
end

function Tsunami.train_step(model::MyModel, trainer, batch)
    ...
    return (loss = loss, accuracy = accuracy)
end

function Tsunami.on_train_epoch_start(cb::Callback, model, trainer)
    empty!(cb.training_step_outputs)
end

function Tsunami.on_train_batch_end(cb::Callback, model, trainer, out, batch, batch_idx)
    push!(cb.training_step_outputs, out.accuracy)
end

function Tsunami.on_train_epoch_end(cb::Callback, model, trainer)
    println("Mean accuracy: ", mean(cb.training_step_outputs))
end
source
Tsunami.on_val_batch_endFunction
on_val_batch_end([callback,] model, trainer, out, batch, batch_idx)

Called at the end of each iteration in the validation loop. out is the output of val_step.

source