import torch
[docs]class NN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, layers=2, dropout=0,
nonlin=torch.nn.ELU, out=None, norm=torch.nn.LayerNorm):
'''
Args:
in_channels int number of input channels
hidden_channels int number of hidden channels per layer
out_channels int number of output channels
layers int number of hidden layers
dropout float dropout regularization probability
nonlin pytorch.module non-linear activation function
out pytorch.module output transformation to be applied (default: None)
norm pytorch.module normalization method to use
'''
super().__init__()
seq = [torch.nn.Linear(in_channels, hidden_channels)]
if norm is not None: seq.append(norm(hidden_channels))
seq += [nonlin(), torch.nn.Dropout(dropout)]
for _ in range(layers - 1):
seq += [torch.nn.Linear(hidden_channels, hidden_channels)]
if norm is not None: seq.append(norm(hidden_channels))
seq += [nonlin(), torch.nn.Dropout(dropout)]
seq += [torch.nn.Linear(hidden_channels, out_channels)]
if out is not None: seq += [out()]
self.nn = torch.nn.Sequential(*seq)
[docs] def forward(self, x):
return self.nn(x)