# Models

GraphNeuralNetworks.jl provides common graph convolutional layers by which you can assemble arbitrarily deep or complex models. GNN layers are compatible with Flux.jl ones, therefore expert Flux users are promptly able to define and train their models.

In what follows, we discuss two different styles for model creation: the *explicit modeling* style, more verbose but more flexible, and the *implicit modeling* style based on `GNNChain`

, more concise but less flexible.

## Explicit modeling

In the explicit modeling style, the model is created according to the following steps:

- Define a new type for your model (
`GNN`

in the example below). Layers and submodels are fields. - Apply
`Flux.@layer`

to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...) - Optionally define a convenience constructor for your model.
- Define the forward pass by implementing the call method for your type.
- Instantiate the model.

Here is an example of this construction:

```
using Flux, Graphs, GraphNeuralNetworks
struct GNN # step 1
conv1
bn
conv2
dropout
dense
end
Flux.@layer GNN # step 2
function GNN(din::Int, d::Int, dout::Int) # step 3
GNN(GCNConv(din => d),
BatchNorm(d),
GraphConv(d => d, relu),
Dropout(0.5),
Dense(d, dout))
end
function (model::GNN)(g::GNNGraph, x) # step 4
x = model.conv1(g, x)
x = relu.(model.bn(x))
x = model.conv2(g, x)
x = model.dropout(x)
x = model.dense(x)
return x
end
din, d, dout = 3, 4, 2
model = GNN(din, d, dout) # step 5
g = rand_graph(10, 30)
X = randn(Float32, din, 10)
y = model(g, X) # output size: (dout, g.num_nodes)
grad = gradient(model -> sum(model(g, X)), model)
```

## Implicit modeling with GNNChains

While very flexible, the way in which we defined `GNN`

model definition in last section is a bit verbose. In order to simplify things, we provide the `GNNChain`

type. It is very similar to Flux's well known `Chain`

. It allows to compose layers in a sequential fashion as Chain does, propagating the output of each layer to the next one. In addition, `GNNChain`

handles propagates the input graph as well, providing it as a first argument to layers subtyping the `GNNLayer`

abstract type.

Using `GNNChain`

, the previous example becomes

```
using Flux, Graphs, GraphNeuralNetworks
din, d, dout = 3, 4, 2
g = rand_graph(10, 30)
X = randn(Float32, din, 10)
model = GNNChain(GCNConv(din => d),
BatchNorm(d),
x -> relu.(x),
GCNConv(d => d, relu),
Dropout(0.5),
Dense(d, dout))
```

The `GNNChain`

only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.

A `GNNChain`

opportunely propagates the graph into the branches created by the `Flux.Parallel`

layer:

```
AddResidual(l) = Parallel(+, identity, l) # implementing a skip/residual connection
model = GNNChain( ResGatedGraphConv(din => d, relu),
AddResidual(ResGatedGraphConv(d => d, relu)),
AddResidual(ResGatedGraphConv(d => d, relu)),
AddResidual(ResGatedGraphConv(d => d, relu)),
GlobalPooling(mean),
Dense(d, dout))
y = model(g, X) # output size: (dout, g.num_graphs)
```

## Embedding a graph in the model

Sometimes it is useful to consider a specific graph as a part of a model instead of its input. GraphNeuralNetworks.jl provides the `WithGraph`

type to deal with this scenario.

```
chain = GNNChain(GCNConv(din => d, relu),
GCNConv(d => d))
g = rand_graph(10, 30)
model = WithGraph(chain, g)
X = randn(Float32, din, 10)
# Pass only X as input, the model already contains the graph.
y = model(X)
```

An example of `WithGraph`

usage is given in the graph neural ODE script in the examples folder.