gsnn.models.SparseLinear

Batched sparse matrix multiplication that scales with GPU’s better.

Functions

batch_graphs(N, M, edge_index, B, device)

Create batched edge_index/edge_weight tensors for bipartite graphs.

kaiming_normal(size, fan_in, fan_out[, ...])

kaiming_uniform(size, fan_in, fan_out[, ...])

normal(size[, gain, dtype])

uniform(size[, gain, dtype])

xavier_normal(size, fan_in, fan_out[, gain, ...])

xavier_uniform(size, fan_in, fan_out[, ...])

Classes

Conv(*args, **kwargs)

SparseLinear(*args, **kwargs)

class gsnn.models.SparseLinear.Conv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

forward(x, edge_index, edge_weight, bias, size)[source]
message(x_j, edge_weight)[source]
class gsnn.models.SparseLinear.SparseLinear(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x, batched_indices=None)[source]

batch dimension is handled in torch_geometric fashion, e.g., concatenated batch graphs via incremented node idx

Parameters:
  • input (x) –

  • batched_indices

Returns:

Tensor (B, M, 1)

prune(idxs)[source]
gsnn.models.SparseLinear.batch_graphs(N, M, edge_index, B, device)[source]

Create batched edge_index/edge_weight tensors for bipartite graphs.

Parameters:
  • N (int) – Size of the first set of nodes in each bipartite graph.

  • M (int) – Size of the second set of nodes in each bipartite graph.

  • edge_index (tensor) – edge index tensor to batch.

  • B (int) –

  • device (str) –

Returns:

Batched edge index.

Return type:

torch.Tensor

gsnn.models.SparseLinear.kaiming_normal(size, fan_in, fan_out, fan_mode='fan_in', gain=1, dtype=torch.float32)[source]
gsnn.models.SparseLinear.kaiming_uniform(size, fan_in, fan_out, fan_mode='fan_in', gain=1, dtype=torch.float32)[source]
gsnn.models.SparseLinear.normal(size, gain=1, dtype=torch.float32)[source]
gsnn.models.SparseLinear.uniform(size, gain=1.0, dtype=torch.float32)[source]
gsnn.models.SparseLinear.xavier_normal(size, fan_in, fan_out, gain=1, dtype=torch.float32)[source]
gsnn.models.SparseLinear.xavier_uniform(size, fan_in, fan_out, gain=1, dtype=torch.float32)[source]