Source code for gsnn.models.SparseLinear

'''
Batched sparse matrix multiplication that scales with GPU's better. 
'''

import torch
import torch_geometric as pyg 


class Conv(pyg.nn.MessagePassing):
    """Message-passing layer for one bipartite sparse matmul step (neighbor sum + edge weights)."""

    def __init__(self):
        super().__init__(aggr='add')

    def forward(self, x, edge_index, edge_weight, bias, size):
        """Aggregate weighted neighbor features into destination nodes."""
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size).view(-1, 1)
        
        if bias is not None: out = out + bias.view(-1, 1)
        
        return out 

    def message(self, x_j, edge_weight):
        """Per-edge contribution: source feature times edge weight."""
        return edge_weight.view(-1, 1) * x_j.view(-1, 1)


def batch_graphs(N, M, edge_index, B, device):
    '''
    Create batched edge_index/edge_weight tensors for bipartite graphs.

    Args:
        N (int): Size of the first set of nodes in each bipartite graph.
        M (int): Size of the second set of nodes in each bipartite graph.
        edge_index (tensor): edge index tensor to batch.
        B (int) batch size
        device (str)

    Returns:
        torch.Tensor: Batched edge index.
    '''
    E = edge_index.size(1)
    batched_edge_indices = edge_index.repeat(1, B).contiguous()
    batch_idx = torch.repeat_interleave(torch.arange(B, dtype=torch.long, device=device), E).contiguous()

    src_incr = batch_idx*N
    dst_incr = batch_idx*M
    incr = torch.stack((src_incr, dst_incr), dim=0)
    batched_edge_indices += incr

    return batched_edge_indices

def xavier_uniform(size, fan_in, fan_out, gain=1, dtype=torch.float32):
    """Xavier/Glorot uniform init scaled by fan-in and fan-out (uses PyTorch uniform then scales)."""
    a = gain * torch.sqrt((6/(fan_in + fan_out)))
    out = torch.empty(size, dtype=dtype)
    out = torch.nn.init.uniform_(out, a=-1, b=1) 
    out = out * a 
    return out 

def xavier_normal(size, fan_in, fan_out, gain=1, dtype=torch.float32):
    """Xavier/Glorot normal init scaled by fan-in and fan-out."""
    a = gain * torch.sqrt((2/(fan_in + fan_out)))
    out = torch.empty(size, dtype=dtype)
    out = torch.nn.init.normal_(out, mean=0, std=1) 
    out = out * a 
    return out 

def uniform(size, gain=1., dtype=torch.float32):
    """Uniform init on [-1, 1] then scaled by ``gain``."""
    a = gain 
    out = torch.empty(size, dtype=dtype)
    out = torch.nn.init.uniform_(out, a=-1, b=1) 
    out = out * a 
    return out 

def normal(size, gain=1, dtype=torch.float32):
    """Standard normal init then scaled by ``gain``."""
    a = gain
    out = torch.empty(size, dtype=dtype)
    out = torch.nn.init.normal_(out, mean=0, std=1) 
    out = out * a 
    return out 

def kaiming_uniform(size, fan_in, fan_out, fan_mode='fan_in', gain=1, dtype=torch.float32):
    """Kaiming/He uniform init (fan-in or fan-out)."""
    fan_val = fan_in if fan_mode == 'fan_in' else fan_out
    a = gain * torch.sqrt( 3 / fan_val )
    out = torch.empty(size, dtype=dtype)
    out = torch.nn.init.uniform_(out, a=-1, b=1) 
    out = out * a 
    return out  

def kaiming_normal(size, fan_in, fan_out, fan_mode='fan_in', gain=1, dtype=torch.float32):
    """Kaiming/He normal init (fan-in or fan-out)."""
    fan_val = fan_in if fan_mode == 'fan_in' else fan_out
    a = gain / torch.sqrt( fan_val )
    out = torch.empty(size, dtype=dtype)
    out = torch.nn.init.normal_(out, mean=0, std=1) 
    out = out * a 
    return out  


[docs]class SparseLinear(torch.nn.Module): """Fixed sparsity pattern linear layer; forward is batched message passing on the COO indices."""
[docs] def __init__(self, indices, size, dtype=torch.float32, bias=True, init='kaiming', init_gain=1., degree_norm_eps=1e-8): """ Args: indices: COO ``(src, dst)`` for nonzeros of the weight matrix. size: ``(N, M)`` shape of the dense view (first dim source / rows, second dest / cols). dtype: Parameter dtype. bias: If True, learn a bias per destination node. init: Weight init scheme name (see implementation for options). init_gain: Scalar gain for init. degree_norm_eps: Epsilon for ``degree_normalized`` init. """ super().__init__() self.N, self.M = size self.size = size self.conv = Conv() src, dst = indices.type(torch.long) # weight initialization fan_in = pyg.utils.degree(dst, num_nodes=self.M) fan_out = pyg.utils.degree(src, num_nodes=self.N) n_in = fan_in[dst] # number of input channels n_out = fan_out[src] # number of output channels if init == 'xavier_uniform': values = xavier_uniform(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype) elif init == 'xavier_normal': values = xavier_normal(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype) elif init == 'kaiming_uniform': values = kaiming_uniform(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype) elif init == 'kaiming_normal': values = kaiming_normal(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype) elif init == 'uniform': values = uniform(indices.size(1), gain=init_gain, dtype=dtype) elif init == 'normal': values = normal(indices.size(1), gain=init_gain, dtype=dtype) elif init == 'degree_normalized': # Initialize with uniform distribution, then apply degree-based normalization # This implements D^(-0.5)AD^(-0.5) style normalization from GCNs values = uniform(indices.size(1), gain=init_gain, dtype=dtype) # Compute degree normalization factor: 1/sqrt(deg_i * deg_j) for edge (i,j) degree_norm = 1.0 / torch.sqrt((n_in + degree_norm_eps) * (n_out + degree_norm_eps)) values = values * degree_norm elif init == 'zeros': values = torch.zeros((indices.size(1), 1), dtype=dtype) else: raise ValueError('unrecognized weight initialization method, options: xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, uniform, normal, degree_normalized, zeros') self.values = torch.nn.Parameter(values) # torch optimizer require dense parameters self.register_buffer('indices', indices.type(torch.long)) if bias: self.bias = torch.nn.Parameter(torch.zeros((self.M, 1), dtype=dtype)) # caching self._B = 0 self._edge_index = None
[docs] def forward(self, x, batched_indices=None): ''' batch dimension is handled in `torch_geometric` fashion, e.g., concatenated batch graphs via incremented node idx Args: x input (B, N, 1) batched_indices Returns: Tensor (B, M, 1) ''' device = x.device B = x.size(0) edge_weight = self.values.expand(B, *self.values.shape).reshape(-1) if batched_indices is None: batched_indices = batch_graphs(N=self.N, M=self.M, edge_index=self.indices, B=B, device=device) if hasattr(self, 'bias'): bias_idx = torch.arange(self.M, device=device).repeat(B) bias = self.bias[bias_idx] else: bias = None x = x.view(-1,1) x = self.conv(x, batched_indices, edge_weight, bias, size=(self.N*B, self.M*B)) x = x.view(B, -1, 1) return x
[docs] def prune(self, idxs): """Keep only edges indexed by ``idxs``; drops corresponding weights and index columns.""" self.values = torch.nn.Parameter(self.values[idxs]) self.register_buffer('indices', self.indices[:, idxs])