gsnn.models.NodeMLP

class gsnn.models.NodeMLP(*args: Any, **kwargs: Any)[source]

Bases: Module

Small MLP applied independently to each node’s channel vector inside a ResBlock.

__init__(in_features: int, hidden_features: int, nonlin, dropout)[source]
Parameters:
  • in_features – Channels per node (width of each node’s slice).

  • hidden_features – Hidden width of the two-layer MLP.

  • nonlin – Activation module class (e.g. torch.nn.ELU).

  • dropout – Dropout probability between layers.

Methods

__init__(in_features, hidden_features, ...)

param in_features:

Channels per node (width of each node's slice).

forward(x)

x shape (batch, num_nodes, channels_per_node); returns same shape.

forward(x: torch.Tensor) torch.Tensor[source]

x shape (batch, num_nodes, channels_per_node); returns same shape.