FluxModule

The FluxModule abstract type is the entry point for defining custom models in Tsunami. Subtypes of FluxModule are designed to be used with the Tsunami.fit! method, but can also be used independently.

Tsunami.FluxModuleType
abstract type FluxModule end

An abstract type for Flux models. A FluxModule helps orgainising you code and provides a standard interface for training.

A FluxModule comes with the functionality provided by Flux.@layer (pretty printing, etc...) and the ability to interact with Trainer and Optimisers.jl.

You can change the trainables by implementing Optimisers.trainables.

Types subtyping from FluxModule have to implement the following methods in order to interact with a Trainer.

Required methods

Optional Methods

Examples

using Flux, Tsunami, Optimisers

# Define a Multilayer Perceptron implementing the FluxModule interface

struct Model <: FluxModule
    net
end

function Model()
    net = Chain(Dense(4 => 32, relu), Dense(32 => 2))
    return Model(net)
end

(model::Model)(x) = model.net(x)

function Tsunami.train_step(model::Model, trainer, batch)
    x, y = batch
    y_hat = model(x)
    loss = Flux.Losses.mse(y_hat, y)
    return loss
end

function Tsunami.configure_optimisers(model::Model, trainer)
    return Optimisers.setup(Optimisers.Adam(1f-3), model)
end

# Prepare the dataset and the DataLoader
X, Y = rand(4, 100), rand(2, 100)
train_dataloader = Flux.DataLoader((X, Y), batchsize=10)

# Create and Train the model
model = Model()
trainer = Trainer(max_epochs=10)
Tsunami.fit!(model, trainer, train_dataloader)
source

Required methods

The following methods must be implemented for a subtype of FluxModule to be used with Tsunami.

Tsunami.configure_optimisersFunction
configure_optimisers(model, trainer)

Return an optimiser's state initialized for the model. It can also return a tuple of (optimiser, scheduler), where scheduler is any callable object that takes the current epoch as input and returns a scalar that will be set as the learning rate for the next epoch.

Examples

using Optimisers, ParameterSchedulers

function Tsunami.configure_optimisers(model::Model, trainer)
    return Optimisers.setup(AdamW(1f-3), model)
end

# Now with a scheduler dropping the learning rate by a factor 10 
# at epochs [50, 100, 200] starting from the initial value of 1e-2
function Tsunami.configure_optimisers(model::Model, trainer)

    function lr_scheduler(epoch)
        if epoch <= 50
            return 1e-2
        elseif epoch <= 100
            return 1e-3
        elseif epoch <= 200
            return 1e-4
        else
            return 1e-5
        end
    end
    
    opt_state = Optimisers.setup(AdamW(), model)
    return opt_state, lr_scheduler
end

# Same as above but using the ParameterSchedulers package.
function Tsunami.configure_optimisers(model::Model, trainer)
    lr_scheduler = ParameterSchedulers.Step(1f-2, 0.1f0, [50, 50, 100])
    opt_state = Optimisers.setup(AdamW(), model)
    return opt_state, lr_scheduler
end
source
Tsunami.train_stepFunction
train_step(model, trainer, batch, [batch_idx])

The method called at each training step during Tsunami.fit!. It should compute the forward pass of the model and return the loss (a scalar) corresponding to the minibatch batch. The optional argument batch_idx is the index of the batch in the current epoch.

Any Model <: FluxModule should implement either train_step(model::Model, trainer, batch) or train_step(model::Model, trainer, batch, batch_idx).

The training loop in Tsunami.fit! approximately looks like this:

for epoch in 1:epochs
    for (batch_idx, batch) in enumerate(train_dataloader)
        grads = gradient(model) do m
            loss = train_step(m, trainer, batch, batch_idx)
            return loss
        end
        Optimisers.update!(opt, model, grads[1])
    end
end

The output can be either a scalar or a named tuple:

  • If a scalar is returned, it is assumed to be the loss.
  • If a named tuple is returned, it has to contain the loss field.

The output can be accessed in hooks such as on_before_update or on_train_batch_end.

Examples

function Tsunami.train_step(model::Model, trainer, batch)
    x, y = batch
    ŷ = model(x)
    loss = Flux.Losses.logitcrossentropy(ŷ, y)
    Tsunami.log(trainer, "loss/train", loss)
    Tsunami.log(trainer, "accuracy/train", Tsunami.accuracy(ŷ, y))
    return loss
end
source

Optional methods

The following methods have default implementations that can be overridden if necessary. See also the Hooks section of the documentation for other methods that can be overridden.

Tsunami.val_stepFunction
val_step(model, trainer, batch, [batch_idx])

The method called at each validation step during Tsunami.fit!. Tipically used for computing metrics and statistics on the validation batch batch. The optional argument batch_idx is the index of the batch in the current validation epoch.

A Model <: FluxModule should implement either val_step(model::Model, trainer, batch) or val_step(model::Model, trainer, batch, batch_idx).

Optionally, the method can return a scalar or a named tuple, to be used in hooks such as on_val_batch_end.

See also train_step.

Examples

function Tsunami.val_step(model::Model, trainer, batch)
    x, y = batch
    ŷ = model(x)
    loss = Flux.Losses.logitcrossentropy(ŷ, y)
    accuracy = Tsunami.accuracy(ŷ, y)
    Tsunami.log(trainer, "loss/val", loss, on_step = false, on_epoch = true)
    Tsunami.log(trainer, "loss/accuracy", accuracy, on_step = false, on_epoch = true)
end
source