gsnn.models.NodeMLP
- class gsnn.models.NodeMLP(*args: Any, **kwargs: Any)[source]
Bases:
ModuleSmall MLP applied independently to each node’s channel vector inside a ResBlock.
- __init__(in_features: int, hidden_features: int, nonlin, dropout)[source]
- Parameters:
in_features – Channels per node (width of each node’s slice).
hidden_features – Hidden width of the two-layer MLP.
nonlin – Activation module class (e.g.
torch.nn.ELU).dropout – Dropout probability between layers.
Methods
__init__(in_features, hidden_features, ...)- param in_features:
Channels per node (width of each node's slice).
forward(x)xshape(batch, num_nodes, channels_per_node); returns same shape.- forward(x: torch.Tensor) torch.Tensor[source]
xshape(batch, num_nodes, channels_per_node); returns same shape.