gsnn.models.SparseLinear
- class gsnn.models.SparseLinear(*args: Any, **kwargs: Any)[source]
Bases:
ModuleFixed sparsity pattern linear layer; forward is batched message passing on the COO indices.
- __init__(indices, size, dtype=torch.float32, bias=True, init='kaiming', init_gain=1.0, degree_norm_eps=1e-08)[source]
- Parameters:
indices – COO
(src, dst)for nonzeros of the weight matrix.size –
(N, M)shape of the dense view (first dim source / rows, second dest / cols).dtype – Parameter dtype.
bias – If True, learn a bias per destination node.
init – Weight init scheme name (see implementation for options).
init_gain – Scalar gain for init.
degree_norm_eps – Epsilon for
degree_normalizedinit.
Methods
__init__(indices, size[, dtype, bias, init, ...])- param indices:
COO
(src, dst)for nonzeros of the weight matrix.
forward(x[, batched_indices])batch dimension is handled in torch_geometric fashion, e.g., concatenated batch graphs via incremented node idx
prune(idxs)Keep only edges indexed by
idxs; drops corresponding weights and index columns.