gsnn.models.NN

class gsnn.models.NN(*args: Any, **kwargs: Any)[source]

Bases: Module

Fully-connected baseline: Linear blocks with optional norm, activation, dropout.

__init__(in_channels, hidden_channels, out_channels, layers=2, dropout=0, nonlin=torch.nn.ELU, out=None, norm=torch.nn.LayerNorm)[source]

Build stack of linear layers; out is an optional module class after the last linear.

Methods

__init__(in_channels, hidden_channels, ...)

Build stack of linear layers; out is an optional module class after the last linear.

forward(x)

Flat features in, predictions out (shape depends on out_channels).

forward(x)[source]

Flat features in, predictions out (shape depends on out_channels).