import torch
import numpy as np
from sklearn.metrics import r2_score
from scipy.stats import spearmanr
#########################################################################################################################
######################################### GSNN utils ################################################################
#########################################################################################################################
[docs]def hetero2homo(edge_index_dict, node_names_dict, edge_weight_dict=None):
r"""Convert a heterogeneous GSNN graph into a homogeneous graph representation.
The GSNN pipeline distinguishes three edge types:
1. ('input', 'to', 'function')
2. ('function', 'to', 'function')
3. ('function', 'to', 'output')
This function stacks these edge sets into one homogeneous graph and returns
boolean masks that let you recover the original node semantics.
Args:
edge_index_dict (Dict[Tuple[str, str, str], Tensor]): Edge-type mapping where each
value is a :obj:`LongTensor` with shape :obj:`[2, num_edges_of_type]`.
node_names_dict (Dict[str, List[str]]): Mapping of node types ('input', 'function', 'output')
to their respective node names.
edge_weight_dict (Dict[Tuple[str, str, str], Tensor]): Dictionary mapping edge types to edge weights.
Expected keys are ('input', 'to', 'function'), ('function', 'to', 'function'), and
('function', 'to', 'output'). Values should be tensors of shape :obj:`[num_edges]`.
(default: :obj:`None`)
Returns:
tuple: A tuple containing:
- edge_index (Tensor): Homogeneous edge indices of shape :obj:`[2, num_edges]`
- input_mask (Tensor): Boolean mask for input nodes of shape :obj:`[num_nodes]`
- output_mask (Tensor): Boolean mask for output nodes of shape :obj:`[num_nodes]`
- num_nodes (int): Total number of nodes in the homogeneous graph
- homo_names (List[str]): Node names in homogeneous ordering
- edge_weight (Optional[Tensor]): Homogeneous edge weights of shape :obj:`[num_edges]`,
or :obj:`None` if :obj:`edge_weight_dict` was :obj:`None`.
Example:
>>> edge_index_dict = {
... ('input', 'to', 'function'): torch.tensor([[0, 1], [0, 0]]),
... ('function', 'to', 'function'): torch.tensor([[0], [0]]),
... ('function', 'to', 'output'): torch.tensor([[0], [0]])
... }
>>> node_names_dict = {
... 'input': ['in1', 'in2'],
... 'function': ['func1'],
... 'output': ['out1']
... }
>>> edge_index, in_mask, out_mask, n_nodes, names = hetero2homo(
... edge_index_dict, node_names_dict
... )
>>> print(edge_index.shape) # [2, 4]
>>> print(in_mask.sum()) # 2 (number of input nodes)
>>> print(out_mask.sum()) # 1 (number of output nodes)
"""
# convert edge_index_dict to edge_index (homogenous)
input_edge_index = edge_index_dict['input', 'to', 'function'].clone()
function_edge_index = edge_index_dict['function', 'to', 'function'].clone()
output_edge_index = edge_index_dict['function', 'to', 'output'].clone()
N_input = len(node_names_dict['input'])
N_function = len(node_names_dict['function'])
N_output = len(node_names_dict['output'])
# add offsets to treat as unique nodes
input_edge_index[0, :] = input_edge_index[0,:] + N_function # increment input nodes only
output_edge_index[1, :] = output_edge_index[1, :] + N_function + N_input # increment output nodes only
edge_index = torch.cat((function_edge_index, input_edge_index, output_edge_index), dim=1)
if edge_weight_dict is not None:
edge_weight = torch.cat((edge_weight_dict['function', 'to', 'function'],
edge_weight_dict['input', 'to', 'function'],
edge_weight_dict['function', 'to', 'output']), dim=0)
else:
edge_weight = None
input_node_mask = torch.zeros((N_input + N_function + N_output), dtype=torch.bool)
input_nodes = torch.arange(N_input) + N_function
input_node_mask[input_nodes] = True
output_node_mask = torch.zeros((N_input + N_function + N_output), dtype=torch.bool)
output_nodes = torch.arange(N_output) + N_function + N_input
output_node_mask[output_nodes] = True
num_nodes = N_input + N_function + N_output
homo_names = node_names_dict['function'] + node_names_dict['input'] + node_names_dict['output']
return edge_index, input_node_mask, output_node_mask, num_nodes, homo_names, edge_weight
[docs]def get_Win_indices(edge_index, channels, function_nodes):
r"""Build sparse COO indices for the input weight matrix :math:`W_{in}`.
Args:
edge_index (Tensor): Homogeneous edge index of shape :obj:`[2, num_edges]`.
channels (int or Tensor): If int, every function node gets the same number of hidden channels.
If 1-D tensor/array, it must contain the per-node channel count of length :obj:`num_nodes`.
function_nodes (Tensor): Index list of nodes that represent functions.
Returns:
tuple: A tuple containing:
- indices (Tensor): COO indices of shape :obj:`[2, nnz]` for sparse tensor construction
- channel_count (numpy.ndarray): Per-node channel counts for later reuse
Example:
>>> edge_index = torch.tensor([[0, 1], [1, 0]]) # 2 edges
>>> channels = 3 # 3 channels per function node
>>> function_nodes = torch.tensor([0]) # Node 0 is a function node
>>> indices, counts = get_Win_indices(edge_index, channels, function_nodes)
>>> print(indices.shape) # [2, 6] (2 edges * 3 channels)
>>> print(counts) # [3, 0] (3 channels for node 0, 0 for node 1)
"""
# channels should be of size (Num_Nodes)
num_nodes = torch.unique(edge_index.view(-1)).size(0)
_channels = np.zeros(num_nodes, dtype=int)
# Convert function node indices to numpy for numpy array indexing
func_nodes_np = function_nodes.detach().cpu().numpy()
# Populate per-node channel counts
if isinstance(channels, (int, np.integer)):
_channels[func_nodes_np] = int(channels)
else:
ch_arr = np.asarray(channels, dtype=int)
if ch_arr.shape[0] != int(num_nodes):
raise ValueError(
f"channels must be an int or a length-{int(num_nodes)} array; got shape {ch_arr.shape}"
)
_channels = ch_arr.copy()
row = []
col = []
edge_np = edge_index.detach().cpu().numpy()
func_nodes_set = set(func_nodes_np.tolist())
for edge_id, (_, node_id) in enumerate(edge_np.T):
# skip edges whose destination is not a function node
if int(node_id) not in func_nodes_set:
continue
c = int(_channels[int(node_id)]) # number of func. node channels
node_id_idx0 = int(np.sum(_channels[: int(node_id) ])) # index of first hidden channel for this node
for k in range(c):
row.append(edge_id)
col.append(node_id_idx0 + k)
row = torch.tensor(row, dtype=torch.long)
col = torch.tensor(col, dtype=torch.long)
indices = torch.stack((row,col), dim=0)
return indices, _channels
[docs]def get_Wout_indices(edge_index, function_nodes, channels):
r"""Build sparse COO indices for the output weight matrix :math:`W_{out}`.
Args:
edge_index (Tensor): Homogeneous edge index of shape :obj:`[2, num_edges]`.
function_nodes (Tensor): Index list of nodes that represent functions.
channels (numpy.ndarray): Array indicating the number of channels for each node.
Returns:
Tensor: COO indices of shape :obj:`[2, nnz]` for sparse tensor construction.
Example:
>>> edge_index = torch.tensor([[0, 1], [1, 0]]) # 2 edges
>>> function_nodes = torch.tensor([0]) # Node 0 is a function node
>>> channels = np.array([3, 0]) # 3 channels for node 0, 0 for node 1
>>> indices = get_Wout_indices(edge_index, function_nodes, channels)
>>> print(indices.shape) # [2, 6] (3 channels * 2 edges)
"""
row = []
col = []
for node_id in function_nodes:
# get the edge ids of the function node
src,_ = edge_index
out_edges = (src == node_id).nonzero(as_tuple=True)[0]
c = channels[int(node_id)] # number of func. node channels
node_id_idx0 = np.sum(channels[:node_id.item()]) # node indexing: index of the first hidden channel for a given function node
for k in range(c):
for edge_id in out_edges:
row.append(node_id_idx0 + k)
col.append(edge_id.item())
row = torch.tensor(row, dtype=torch.long)
col = torch.tensor(col, dtype=torch.long)
indices = torch.stack((row,col), dim=0)
return indices
[docs]def node2edge(x, edge_index):
r"""Convert node-level features to edge-level features. Every out-going edge receives the feature of the source node.
Args:
x (Tensor): Node features of shape :obj:`[batch_size, num_nodes]`.
edge_index (Tensor): Edge indices of shape :obj:`[2, num_edges]`.
Returns:
Tensor: Edge features of shape :obj:`[batch_size, num_edges]`.
Example:
>>> x = torch.randn(32, 4) # [batch_size, num_nodes]
>>> edge_index = torch.tensor([[0, 1], [1, 2]]) # 2 edges
>>> edge_features = node2edge(x, edge_index)
>>> print(edge_features.shape) # [32, 2]
"""
src,dst = edge_index
return x[:, src]
[docs]def edge2node(x, edge_index, output_node_mask):
r"""Convert edge-level features back to node-level features, focusing on output nodes.
Typically, output nodes should be designed to have an in-degree of 1, however, in the case of multiple edges per output node,
the output features are summed and normalized by the square root of the in-degree.
Args:
x (Tensor): Edge features of shape :obj:`[batch_size, num_edges]`.
edge_index (Tensor): Edge indices of shape :obj:`[2, num_edges]`.
output_node_mask (Tensor): Boolean mask of shape :obj:`[num_nodes]` indicating output nodes.
Returns:
Tensor: Node features of shape :obj:`[batch_size, num_output_nodes]`.
Example:
>>> x = torch.randn(32, 3) # [batch_size, num_edges]
>>> edge_index = torch.tensor([[0, 1, 1], [2, 2, 3]]) # 3 edges
>>> output_mask = torch.tensor([0, 0, 1, 1]) # Nodes 2,3 are outputs
>>> node_features = edge2node(x, edge_index, output_mask)
>>> print(node_features.shape) # [32, 2]
"""
output_node_ixs = output_node_mask.nonzero(as_tuple=True)[0]
src, dst = edge_index
output_edge_mask = torch.isin(dst, output_node_ixs)
B = x.size(0)
out = torch.zeros(B, output_node_mask.size(0), dtype=torch.float32, device=x.device)
#out[:, dst[output_edge_mask].view(-1)] = x[:, output_edge_mask].view(B, -1)
idx = dst[output_edge_mask].view(1, -1).expand(B, -1)
src = x[:, output_edge_mask].view(B, -1)
out = out.scatter_add(1, idx, src)
# this is only applicable if there are many edges per output node
# user can define the graph structure to avoid this but jic...
deg_in = torch.bincount(dst, minlength=out.size(1)).clamp_min(1)
out = out / deg_in.sqrt()
return out
[docs]def get_conv_indices(edge_index, channels, function_nodes):
r"""Compute indexing structures for convolutional (sparse linear) layers.
Args:
edge_index (Tensor): Homogeneous edge indices of shape :obj:`[2, num_edges]`.
channels (int): Number of channels per function node.
function_nodes (Tensor): Indices of function nodes.
Returns:
tuple: A tuple containing:
- w_in_indices (Tensor): Indexing for :math:`W_{in}`
- w_out_indices (Tensor): Indexing for :math:`W_{out}`
- w_in_size (tuple): Size specification for :math:`W_{in}`
- w_out_size (tuple): Size specification for :math:`W_{out}`
- channel_groups (List[int]): List mapping each channel to its node
Example:
>>> edge_index = torch.tensor([[0, 1], [1, 0]]) # 2 edges
>>> channels = 3 # 3 channels per function node
>>> function_nodes = torch.tensor([0]) # Node 0 is a function node
>>> indices = get_conv_indices(edge_index, channels, function_nodes)
>>> print(len(indices)) # 5 (w_in_indices, w_out_indices, sizes, groups)
"""
E = edge_index.size(1)
w_in_indices, node_hidden_channels = get_Win_indices(edge_index, channels, function_nodes)
w_out_indices = get_Wout_indices(edge_index, function_nodes, node_hidden_channels)
w_in_size = (E, np.sum(node_hidden_channels))
w_out_size = (np.sum(node_hidden_channels), E)
channel_groups = []
for node_id, c in enumerate(node_hidden_channels):
for i in range(c):
channel_groups.append(node_id)
return (w_in_indices, w_out_indices, w_in_size, w_out_size, channel_groups)
#########################################################################################################################
######################################### ResBlock utils ################################################################
#########################################################################################################################
[docs]def apply_norm_and_nonlin(norm, nonlin, out, norm_first):
r"""Apply normalization and nonlinearity to the input tensor.
Args:
norm (callable): Normalization layer or operation.
nonlin (callable): Nonlinear activation function.
out (Tensor): Input tensor to be normalized and activated.
norm_first (bool): If :obj:`True`, apply normalization before nonlinearity.
Returns:
Tensor: The transformed tensor.
Example:
>>> norm = torch.nn.BatchNorm1d(32)
>>> nonlin = torch.nn.ReLU()
>>> x = torch.randn(16, 32) # [batch_size, num_features]
>>> # Apply normalization first
>>> out = apply_norm_and_nonlin(norm, nonlin, x, norm_first=True)
>>> print(out.shape) # [16, 32]
"""
if norm_first:
out = norm(out)
out = nonlin(out)
else:
out = nonlin(out)
out = norm(out)
return out
#########################################################################################################################
######################################### Prediction utils ################################################################
#########################################################################################################################
[docs]def predict_gsnn(loader, model, device, verbose=True):
"""Run ``model`` on ``loader``; return stacked numpy ``y``, ``yhat``, and sig ids from batches."""
model = model.eval()
ys = []
yhats = []
sig_ids = []
with torch.no_grad():
for i,(x, y, *sig_id) in enumerate(loader):
if verbose: print(f'progress: {i}/{len(loader)}', end='\r')
yhat = model(x.to(device))
y = y.to(device)
yhat = yhat.detach().cpu()
y = y.detach().cpu()
ys.append(y)
yhats.append(yhat)
sig_ids += sig_id
y = torch.cat(ys, dim=0).detach().cpu().numpy()
yhat = torch.cat(yhats, dim=0).detach().cpu().numpy()
return y, yhat, sig_ids
[docs]def corr_score(y, yhat, multioutput='uniform_weighted', method='pearson', eps=1e-6):
'''
calculate the average pearson correlation score
y (n_samples, n_outputs):
yhat (n_samples, n_outputs):
'''
if len(y.shape) == 1:
y = y.reshape(-1,1)
yhat = yhat.reshape(-1,1)
if method == 'pearson':
metric = lambda x,y: np.corrcoef(x, y)[0,1]
elif method == 'spearman':
metric = lambda x,y: spearmanr(x,y)[0]
elif method == 'r2':
#NOTE: hacky since r2 is not a corr.
metric = lambda x,y: r2_score(x,y)
else:
raise ValueError('unrecognized metric')
corrs = []
for i in range(y.shape[1]):
if (np.std(y[:, i]) < eps) | (np.std(yhat[:, i]) < eps):
p = 0
else:
p = metric(y[:, i], yhat[:, i])
corrs.append( p )
if multioutput == 'uniform_weighted':
return np.nanmean(corrs)
elif multioutput == 'uniform_median':
return np.nanmedian(corrs)
elif multioutput == 'raw_values':
return np.array(corrs)
else:
raise ValueError('unrecognized multioutput value, expected one of "uniform_weighted", "raw_values"')