FluxNet
fluxnet.FluxNet(node_in_dim, edge_in_dim, pe_dim, out_channels,
ffn_hidden_dim=None, modulator_hidden_dim=64,
dropout=0.0, norm_type='batch', add_self_loops=True,
aggr='mean', num_heads=4, use_attention=True)
Combines CKGConv
with a GATv2 attention mechanism to create a comprehensive continuous kernel graph convolution block.
Parameters:
- node_in_dim (
int
) – Input dimension of node features - edge_in_dim (
int
) – Input dimension of edge features - pe_dim (
int
) – Dimension of positional encodings - out_channels (
int
) – Output dimension of the convolution - ffn_hidden_dim (
int
, optional) – Hidden dimension for the feed-forward network. Default:4 * out_channels
- modulator_hidden_dim (
int
, optional) – Hidden dimension for the feature modulator. Default:64
- dropout (
float
, optional) – Dropout probability. Default:0.0
- norm_type (
str
, optional) – Normalization type, one of ['batch'
,'layer'
,'instance'
,'none'
]. Default:'batch'
- add_self_loops (
bool
, optional) – Whether to add self-loops to the graph. Default:True
- aggr (
str
, optional) – Aggregation method. Default:'mean'
- num_heads (
int
, optional) – Number of attention heads for GATv2. Default:4
- use_attention (
bool
, optional) – Whether to use the GATv2 attention mechanism. Default:True
Inputs:
- x (
Tensor
) – Node feature matrix of shape[num_nodes, node_in_dim]
- x_pe (
Tensor
) – Node positional encoding matrix of shape[num_nodes, pe_dim]
- edge_index (
LongTensor
) – Graph connectivity matrix of shape[2, num_edges]
- edge_attr (
Tensor
) – Edge feature matrix of shape[num_edges, edge_in_dim]
- edge_pe (
Tensor
) – Edge positional encoding matrix of shape[num_edges, pe_dim]
- batch (
LongTensor
, optional) – Batch vector of shape[num_nodes]
indicating node assignment to batch instances. Default:None
Returns:
- out (
Tensor
) – Updated node feature matrix of shape[num_nodes, out_channels]
Architecture:
FluxNet
combines several components:
- CKGConv Layer: Base graph convolution operation
- Normalization: Configurable normalization applied after each major component
- GATv2 Attention: Multi-head graph attention mechanism (optional)
- Feed-Forward Network: Two-layer MLP with GELU activation
- Residual Connections: Added after each major component
- Dropout: Applied to outputs of attention and feed-forward network
Processing Flow:
- Apply
CKGConv
to input features - Apply normalization
- Add residual connection if dimensions match
- If
use_attention=True
:- Apply GATv2 attention mechanism
- Add residual connection with dropout
- Apply normalization
- Apply feed-forward network
- Add residual connection with dropout
- Apply final normalization
Example:
import torch
from fluxnet import FluxNet
# Create a FluxNet layer
model = FluxNet(
node_in_dim=32,
edge_in_dim=16,
pe_dim=8,
out_channels=64,
dropout=0.1,
norm_type='layer'
)
# Input features
num_nodes = 100
num_edges = 500
x = torch.randn(num_nodes, 32)
x_pe = torch.randn(num_nodes, 8)
edge_index = torch.randint(0, num_nodes, (2, num_edges))
edge_attr = torch.randn(num_edges, 16)
edge_pe = torch.randn(num_edges, 8)
# Forward pass
output = model(x, x_pe, edge_index, edge_attr, edge_pe)
print(output.shape) # [100, 64]
Notes:
- When
ffn_hidden_dim
is not provided, it defaults to 4 times theout_channels
- The choice of normalization type significantly impacts performance on different graph datasets
- If
use_attention=False
, the model skips the GATv2 attention mechanism - For best performance with large graphs, consider using
norm_type='instance'