This is the documentation page for GraphNeuralNetworks.jl, a graph neural network library written in Julia and based on the deep learning framework Flux.jl. GraphNeuralNetworks.jl is largely inspired by PyTorch Geometric, Deep Graph Library, and GeometricFlux.jl.
Among its features:
- Implements common graph convolutional layers.
- Supports computations on batched graphs.
- Easy to define custom layers.
- CUDA support.
- Integration with Graphs.jl.
- Examples of node, edge, and graph level machine learning tasks.
Let's give a brief overview of the package by solving a graph regression problem with synthetic data.
Usage examples on real datasets can be found in the examples folder.
We create a dataset consisting in multiple random graphs and associated data features.
using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics, MLUtils using Flux.Data: DataLoader all_graphs = GNNGraph for _ in 1:1000 g = GNNGraph(random_regular_graph(10, 4), ndata=(; x = randn(Float32, 16,10)), # input node features gdata=(; y = randn(Float32))) # regression target push!(all_graphs, g) end
We concisely define our model as a
GNNChain containing two graph convolutional layers. If CUDA is available, our model will live on the gpu.
device = CUDA.functional() ? Flux.gpu : Flux.cpu; model = GNNChain(GCNConv(16 => 64), BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension) x -> relu.(x), GCNConv(64 => 64, relu), GlobalPool(mean), # aggregate node-wise features into graph-wise features Dense(64, 1)) |> device ps = Flux.params(model) opt = Adam(1f-4)
Finally, we use a standard Flux training pipeline to fit our dataset. We use Flux's
DataLoader to iterate over mini-batches of graphs that are glued together into a single
GNNGraph using the
MLUtils.batch method. This is what happens under the hood when creating a
DataLoader with the
train_graphs, test_graphs = MLUtils.split(all_graphs, at=0.8) train_loader = DataLoader(train_graphs, batchsize=32, shuffle=true, collate=true) test_loader = DataLoader(test_graphs, batchsize=32, shuffle=false, collate=true) loss(g::GNNGraph) = mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2) loss(loader) = mean(loss(g |> device) for g in loader) for epoch in 1:100 for g in train_loader g = g |> device grad = gradient(() -> loss(g), ps) Flux.Optimise.update!(opt, ps, grad) end @info (; epoch, train_loss=loss(train_loader), test_loss=loss(test_loader)) end