Usage Examples

This page provides examples of how to use the FluxNet in practice.

Basic Example

import torch
from torch_geometric.data import Data

# Define dimensions
node_in_dim = 32
edge_in_dim = 16
pe_dim = 8
out_channels = 64

# Create a simple graph
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.randn(3, node_in_dim)  # 3 nodes with 32 features each
edge_attr = torch.randn(4, edge_in_dim)  # 4 edges with 16 features each
x_pe = torch.randn(3, pe_dim)  # Positional encoding for nodes
edge_pe = torch.randn(4, pe_dim)  # Positional encoding for edges

# Initialize model
model = FluxNet(
    node_in_dim=node_in_dim,
    edge_in_dim=edge_in_dim,
    pe_dim=pe_dim,
    out_channels=out_channels,
    dropout=0.1,
    norm_type='layer'
)

# Forward pass
output = model(x, x_pe, edge_index, edge_attr, edge_pe)
print(output.shape)  # Should be [3, 64]

Creating a Complete GNN Model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool

class FluxNetModel(nn.Module):
    def __init__(self, node_in_dim, edge_in_dim, pe_dim, hidden_dim, output_dim, num_layers=3):
        super(FluxNetModel, self).__init__()
        
        self.node_embedding = nn.Linear(node_in_dim, hidden_dim)
        self.edge_embedding = nn.Linear(edge_in_dim, hidden_dim)
        
        self.conv_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.conv_layers.append(
                FluxNet(
                    node_in_dim=hidden_dim,
                    edge_in_dim=hidden_dim,
                    pe_dim=pe_dim,
                    out_channels=hidden_dim,
                    dropout=0.1
                )
            )
            
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, x, x_pe, edge_index, edge_attr, edge_pe, batch):
        # Initial embeddings
        x = self.node_embedding(x)
        edge_attr = self.edge_embedding(edge_attr)
        
        # Apply conv layers
        for conv in self.conv_layers:
            x = conv(x, x_pe, edge_index, edge_attr, edge_pe, batch)
            
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Output layer
        x = self.output_layer(x)
        
        return x

# Example usage
model = FluxNetModel(
    node_in_dim=32,
    edge_in_dim=16,
    pe_dim=8,
    hidden_dim=64,
    output_dim=10
)

Training Loop Example

import torch
import torch.optim as optim
from torch_geometric.loader import DataLoader

# Assuming you have a dataset of PyG Data objects
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize model
model = FluxNetModel(...)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train(epoch):
    model.train()
    total_loss = 0
    
    for data in loader:
        optimizer.zero_grad()
        
        # Get data attributes
        x, edge_index = data.x, data.edge_index
        edge_attr, batch = data.edge_attr, data.batch
        x_pe, edge_pe = data.x_pe, data.edge_pe  # Assuming these are included in your dataset
        y = data.y
        
        # Forward pass
        out = model(x, x_pe, edge_index, edge_attr, edge_pe, batch)
        
        # Calculate loss
        loss = F.cross_entropy(out, y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * data.num_graphs
    
    return total_loss / len(loader.dataset)

# Run training for multiple epochs
for epoch in range(1, 101):
    loss = train(epoch)
    print(f'Epoch: {epoch}, Loss: {loss:.4f}')