Basic Layers

Index

Docs

GraphNeuralNetworks.DotDecoderType
DotDecoder()

A graph neural network layer that for given input graph g and node features x, returns the dot product x_i ⋅ xj on each edge.

Examples

julia> g = rand_graph(5, 6)
GNNGraph:
    num_nodes = 5
    num_edges = 6

julia> dotdec = DotDecoder()
DotDecoder()

julia> dotdec(g, rand(2, 5))
1×6 Matrix{Float64}:
 0.345098  0.458305  0.106353  0.345098  0.458305  0.106353
source
GraphNeuralNetworks.GNNChainType
GNNChain(layers...)
GNNChain(name = layer, ...)

Collects multiple layers / functions to be called in sequence on given input graph and input node features.

It allows to compose layers in a sequential fashion as Flux.Chain does, propagating the output of each layer to the next one. In addition, GNNChain handles the input graph as well, providing it as a first argument only to layers subtyping the GNNLayer abstract type.

GNNChain supports indexing and slicing, m[2] or m[1:end-1], and if names are given, m[:name] == m[1] etc.

Examples

julia> using Flux, GraphNeuralNetworks

julia> m = GNNChain(GCNConv(2=>5), 
                    BatchNorm(5), 
                    x -> relu.(x), 
                    Dense(5, 4))
GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4))

julia> x = randn(Float32, 2, 3);

julia> g = rand_graph(3, 6)
GNNGraph:
    num_nodes = 3
    num_edges = 6

julia> m(g, x)
4×3 Matrix{Float32}:
    -0.795592  -0.795592  -0.795592
    -0.736409  -0.736409  -0.736409
    0.994925   0.994925   0.994925
    0.857549   0.857549   0.857549

julia> m2 = GNNChain(enc = m, 
                     dec = DotDecoder())
GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder())

julia> m2(g, x)
1×6 Matrix{Float32}:
 2.90053  2.90053  2.90053  2.90053  2.90053  2.90053

julia> m2[:enc](g, x) == m(g, x)
true
source
GraphNeuralNetworks.WithGraphType
WithGraph(model, g::GNNGraph; traingraph=false)

A type wrapping the model and tying it to the graph g. In the forward pass, can only take feature arrays as inputs, returning model(g, x...; kws...).

If traingraph=false, the graph's parameters, won't be collected when calling Flux.params on a WithGraph object.

Examples

g = GNNGraph([1,2,3], [2,3,1])
x = rand(Float32, 2, 3)
model = SAGEConv(2 => 3)
wg = WithGraph(model, g)
# No need to feed the graph to `wg`
@assert wg(x) == model(g, x)

g2 = GNNGraph([1,1,2,3], [2,4,1,1])
x2 = rand(Float32, 2, 4)
# WithGraph will ignore the internal graph if fed with a new one. 
@assert wg(g2, x2) == model(g2, x2)
source