'''
Batched sparse matrix multiplication that scales with GPU's better.
'''
import torch
import torch_geometric as pyg
[docs]class Conv(pyg.nn.MessagePassing):
def __init__(self):
super().__init__(aggr='add')
[docs] def forward(self, x, edge_index, edge_weight, bias, size):
out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size).view(-1, 1)
if bias is not None: out = out + bias.view(-1, 1)
return out
[docs] def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j.view(-1, 1)
[docs]def batch_graphs(N, M, edge_index, B, device):
'''
Create batched edge_index/edge_weight tensors for bipartite graphs.
Args:
N (int): Size of the first set of nodes in each bipartite graph.
M (int): Size of the second set of nodes in each bipartite graph.
edge_index (tensor): edge index tensor to batch.
B (int) batch size
device (str)
Returns:
torch.Tensor: Batched edge index.
'''
E = edge_index.size(1)
batched_edge_indices = edge_index.repeat(1, B).contiguous()
batch_idx = torch.repeat_interleave(torch.arange(B, dtype=torch.long, device=device), E).contiguous()
src_incr = batch_idx*N
dst_incr = batch_idx*M
incr = torch.stack((src_incr, dst_incr), dim=0)
batched_edge_indices += incr
return batched_edge_indices
[docs]def xavier_normal(size, fan_in, fan_out, gain=1, dtype=torch.float32):
a = gain * torch.sqrt((2/(fan_in + fan_out)))
out = torch.empty(size, dtype=dtype)
out = torch.nn.init.normal_(out, mean=0, std=1)
out = out * a
return out
[docs]def normal(size, gain=1, dtype=torch.float32):
a = gain
out = torch.empty(size, dtype=dtype)
out = torch.nn.init.normal_(out, mean=0, std=1)
out = out * a
return out
[docs]def kaiming_normal(size, fan_in, fan_out, fan_mode='fan_in', gain=1, dtype=torch.float32):
fan_val = fan_in if fan_mode == 'fan_in' else fan_out
a = gain / torch.sqrt( fan_val )
out = torch.empty(size, dtype=dtype)
out = torch.nn.init.normal_(out, mean=0, std=1)
out = out * a
return out
[docs]class SparseLinear(torch.nn.Module):
def __init__(self, indices, size, dtype=torch.float32, bias=True, init='kaiming', init_gain=1., degree_norm_eps=1e-8):
'''
Sparse Linear layer, equivalent to sparse matrix multiplication as provided by indices.
Args:
indices COO coordinates for the sparse matrix multiplication
size size of weight matrix
dtype weight matrix type
bias whether to include a bias term; Wx + B
init weight initialization strategy
init_gain gain factor for initialization
degree_norm_eps small epsilon to prevent division by zero in degree normalization
'''
super().__init__()
self.N, self.M = size
self.size = size
self.conv = Conv()
src, dst = indices.type(torch.long)
# weight initialization
fan_in = pyg.utils.degree(dst, num_nodes=self.M)
fan_out = pyg.utils.degree(src, num_nodes=self.N)
n_in = fan_in[dst] # number of input channels
n_out = fan_out[src] # number of output channels
if init == 'xavier_uniform':
values = xavier_uniform(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype)
elif init == 'xavier_normal':
values = xavier_normal(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype)
elif init == 'kaiming_uniform':
values = kaiming_uniform(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype)
elif init == 'kaiming_normal':
values = kaiming_normal(indices.size(1), n_in, n_out, gain=init_gain, dtype=dtype)
elif init == 'uniform':
values = uniform(indices.size(1), gain=init_gain, dtype=dtype)
elif init == 'normal':
values = normal(indices.size(1), gain=init_gain, dtype=dtype)
elif init == 'degree_normalized':
# Initialize with uniform distribution, then apply degree-based normalization
# This implements D^(-0.5)AD^(-0.5) style normalization from GCNs
values = uniform(indices.size(1), gain=init_gain, dtype=dtype)
# Compute degree normalization factor: 1/sqrt(deg_i * deg_j) for edge (i,j)
degree_norm = 1.0 / torch.sqrt((n_in + degree_norm_eps) * (n_out + degree_norm_eps))
values = values * degree_norm
elif init == 'zeros':
values = torch.zeros((indices.size(1), 1), dtype=dtype)
else:
raise ValueError('unrecognized weight initialization method, options: xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, uniform, normal, degree_normalized, zeros')
self.values = torch.nn.Parameter(values) # torch optimizer require dense parameters
self.register_buffer('indices', indices.type(torch.long))
if bias: self.bias = torch.nn.Parameter(torch.zeros((self.M, 1), dtype=dtype))
# caching
self._B = 0
self._edge_index = None
[docs] def forward(self, x, batched_indices=None):
'''
batch dimension is handled in `torch_geometric` fashion, e.g., concatenated batch graphs via incremented node idx
Args:
x input (B, N, 1)
batched_indices
Returns:
Tensor (B, M, 1)
'''
device = x.device
B = x.size(0)
edge_weight = self.values.expand(B, *self.values.shape).reshape(-1)
if batched_indices is None:
batched_indices = batch_graphs(N=self.N, M=self.M, edge_index=self.indices, B=B, device=device)
if hasattr(self, 'bias'):
bias_idx = torch.arange(self.M, device=device).repeat(B)
bias = self.bias[bias_idx]
else:
bias = None
x = x.view(-1,1)
x = self.conv(x, batched_indices, edge_weight, bias, size=(self.N*B, self.M*B))
x = x.view(B, -1, 1)
return x
[docs] def prune(self, idxs):
"""
"""
self.values = torch.nn.Parameter(self.values[idxs])
self.register_buffer('indices', self.indices[:, idxs])