gsnn.models.SparseLinear

class gsnn.models.SparseLinear(*args: Any, **kwargs: Any)[source]

Bases: Module

Fixed sparsity pattern linear layer; forward is batched message passing on the COO indices.

__init__(indices, size, dtype=torch.float32, bias=True, init='kaiming', init_gain=1.0, degree_norm_eps=1e-08)[source]
Parameters:
  • indices – COO (src, dst) for nonzeros of the weight matrix.

  • size(N, M) shape of the dense view (first dim source / rows, second dest / cols).

  • dtype – Parameter dtype.

  • bias – If True, learn a bias per destination node.

  • init – Weight init scheme name (see implementation for options).

  • init_gain – Scalar gain for init.

  • degree_norm_eps – Epsilon for degree_normalized init.

Methods

__init__(indices, size[, dtype, bias, init, ...])

param indices:

COO (src, dst) for nonzeros of the weight matrix.

forward(x[, batched_indices])

batch dimension is handled in torch_geometric fashion, e.g., concatenated batch graphs via incremented node idx

prune(idxs)

Keep only edges indexed by idxs; drops corresponding weights and index columns.

forward(x, batched_indices=None)[source]

batch dimension is handled in torch_geometric fashion, e.g., concatenated batch graphs via incremented node idx

Parameters:
  • input (x) –

  • batched_indices

Returns:

Tensor (B, M, 1)

prune(idxs)[source]

Keep only edges indexed by idxs; drops corresponding weights and index columns.