Foil
The Foil
is a minimalistic version of the Trainer
that allows to make only minimal changes to your Flux code while still obtaining many of the benefits of Tsunami. This is similar to what Lighting Fabric
is to PyTorch Lightning
. Foil
also resembles HuggingFace's accelerate
library.
Tsunami.Foil
— TypeFoil(; kws...)
A type that takes care of the acceleration of the training process.
Constructor Arguments
accelerator: Supports passing different accelerator types:
:auto
(default): Automatically select a gpu if available, otherwise fallback on cpu.:gpu
: Like:auto
, but will throw an error if no gpu is available. In order for a gpu to be available, the corresponding package must be loaded (e.g. withusing CUDA
). The trigger packages areCUDA.jl
for Nvidia GPUs,AMDGPU.jl
for AMD GPUs, andMetal.jl
for Apple Silicon.:cpu
: Force using the cpu.
See also the
devices
option.devices: Pass an integer
n
to train onn
devices (only1
supported at the moment), or a list of devices ids to train on specific devices (e.g.[2]
to train on gpu with idx 2). Ids indexing starts from1
. Ifnothing
, will use the default device (seeMLDataDevices.gpu_device
documentation). Default:nothing
.precision: Supports passing different precision types
(:bf16, :f16, :f32, :f64)
, where:bf16
is BFloat16,:f16
is Float16,:f32
is Float32, and:f64
is Float64. Default::f32
.
Tsunami.setup
— Functionsetup(foil::Foil, model, optimisers)
Setup the model and optimisers for training sending them to the device and setting the precision. This function is called internally by Tsunami.fit!
.
See also Foil
.