Overview
Introduction
What is GSNN?
Key Concepts
Why Use GSNN?
Core Features
How Are GSNNs Different from Graph Neural Networks?
Getting Started
Installation
Citation
Next Steps
Methods
Graph Structured Neural Network (GSNN)
Explainers
Edge Attribution Methods
Direct Perturbation Methods
Optimization-Based Methods
Robustness and Stability Methods
Choosing the Right Explainer
Tutorials
General Premise
Simulating structured data
Performance comparison on simulated data
Reinforcement learning for structure optimization
Gradient checkpointing and compiling
Uncertainty quantification with hypernetworks
GSNN Interpretation methods
DrugCell implementation example
Inferring output edges
API Reference
API Reference
Graph Structured Neural Networks
Index
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
K
|
L
|
M
|
N
|
O
|
P
|
R
|
S
|
T
|
U
|
V
|
X
_
__init__() (gsnn.gsnn.interpret.ContrastiveIGExplainer method)
(gsnn.gsnn.interpret.ContrastiveOcclusionExplainer method)
(gsnn.gsnn.interpret.CounterfactualExplainer method)
(gsnn.gsnn.interpret.GSNNExplainer method)
(gsnn.gsnn.interpret.IGExplainer method)
(gsnn.gsnn.interpret.NoiseTunnel method)
(gsnn.gsnn.interpret.OcclusionExplainer method)
(gsnn.gsnn.optim.TrainingDiagnostics method)
(gsnn.interpret.ContrastiveIGExplainer method)
(gsnn.interpret.ContrastiveOcclusionExplainer method)
(gsnn.interpret.CounterfactualExplainer method)
(gsnn.interpret.GSNNExplainer method)
(gsnn.interpret.IGExplainer method)
(gsnn.interpret.NoiseTunnel method)
(gsnn.interpret.OcclusionExplainer method)
(gsnn.optim.TrainingDiagnostics method)
A
add_hparam_results() (gsnn.gsnn.models.utils.TBLogger method)
(gsnn.models.utils.TBLogger method)
adjust_label_positions() (in module gsnn.gsnn.interpret.plot_explanation_graph)
(in module gsnn.interpret.plot_explanation_graph)
analyze() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
apply_norm_and_nonlin() (in module gsnn.gsnn.models.GSNN)
(in module gsnn.models.GSNN)
augment_edge_index() (gsnn.gsnn.optim.Environment.Environment method)
(gsnn.optim.Environment.Environment method)
B
batch_graphs() (in module gsnn.gsnn.models.SparseLinear)
(in module gsnn.models.SparseLinear)
bfs_distance() (in module gsnn.gsnn.proc.subset)
(in module gsnn.proc.subset)
bootstrap_r() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
bounding_box_overlap() (in module gsnn.gsnn.interpret.plot_explanation_graph)
(in module gsnn.interpret.plot_explanation_graph)
build() (gsnn.gsnn.proc.construct.GSNNNetworkConstructor method)
(gsnn.proc.construct.GSNNNetworkConstructor method)
build_nx() (in module gsnn.gsnn.proc.subset)
(in module gsnn.proc.subset)
C
ChannelEMANorm (class in gsnn.gsnn.models.ChannelEMANorm)
(class in gsnn.models.ChannelEMANorm)
compute_ECE() (in module gsnn.gsnn.optim.utils)
(in module gsnn.optim.utils)
compute_picp() (in module gsnn.gsnn.optim.utils)
(in module gsnn.optim.utils)
compute_sample_weights() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
compute_scalar_mmd() (in module gsnn.gsnn.external.mmd)
(in module gsnn.gsnn.ot.mmd)
ContrastiveGSNNExplainer (class in gsnn.gsnn.interpret.ContrastiveGSNNExplainer)
(class in gsnn.interpret.ContrastiveGSNNExplainer)
ContrastiveIGExplainer (class in gsnn.gsnn.interpret)
,
[1]
(class in gsnn.interpret)
,
[1]
ContrastiveOcclusionExplainer (class in gsnn.gsnn.interpret)
,
[1]
(class in gsnn.interpret)
,
[1]
Conv (class in gsnn.gsnn.models.SparseLinear)
(class in gsnn.models.SparseLinear)
corr_score() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
CounterfactualExplainer (class in gsnn.gsnn.interpret)
,
[1]
(class in gsnn.interpret)
,
[1]
D
dbscan_silhouette_score() (in module gsnn.gsnn.optim.utils)
(in module gsnn.optim.utils)
dense_func_node (class in gsnn.gsnn.interpret.extract_entity_function)
(class in gsnn.interpret.extract_entity_function)
diff_equivalence() (in module gsnn.gsnn.proc.coarsen)
(in module gsnn.proc.coarsen)
diff_io_equivalence() (in module gsnn.gsnn.proc.coarsen)
(in module gsnn.proc.coarsen)
E
early_stop() (gsnn.gsnn.optim.EarlyStopper.EarlyStopper method)
(gsnn.optim.EarlyStopper.EarlyStopper method)
EarlyStopper (class in gsnn.gsnn.optim.EarlyStopper)
(class in gsnn.optim.EarlyStopper)
edge2node() (in module gsnn.gsnn.models.GSNN)
(in module gsnn.models.GSNN)
edw() (in module gsnn.gsnn.optim.RewardScaler)
(in module gsnn.optim.RewardScaler)
ema() (in module gsnn.gsnn.optim.Environment)
(in module gsnn.optim.Environment)
Environment (class in gsnn.gsnn.optim.Environment)
(class in gsnn.optim.Environment)
evaluate() (gsnn.gsnn.optim.OutputEdgeInferer.OutputEdgeInferer method)
(gsnn.optim.OutputEdgeInferer.OutputEdgeInferer method)
explain() (gsnn.gsnn.interpret.ContrastiveGSNNExplainer.ContrastiveGSNNExplainer method)
(gsnn.gsnn.interpret.ContrastiveIGExplainer method)
,
[1]
(gsnn.gsnn.interpret.ContrastiveOcclusionExplainer method)
,
[1]
(gsnn.gsnn.interpret.CounterfactualExplainer method)
,
[1]
(gsnn.gsnn.interpret.GSNNExplainer method)
,
[1]
(gsnn.gsnn.interpret.IGExplainer method)
,
[1]
(gsnn.gsnn.interpret.NoiseTunnel method)
,
[1]
(gsnn.gsnn.interpret.OcclusionExplainer method)
,
[1]
(gsnn.interpret.ContrastiveGSNNExplainer.ContrastiveGSNNExplainer method)
(gsnn.interpret.ContrastiveIGExplainer method)
,
[1]
(gsnn.interpret.ContrastiveOcclusionExplainer method)
,
[1]
(gsnn.interpret.CounterfactualExplainer method)
,
[1]
(gsnn.interpret.GSNNExplainer method)
,
[1]
(gsnn.interpret.IGExplainer method)
,
[1]
(gsnn.interpret.NoiseTunnel method)
,
[1]
(gsnn.interpret.OcclusionExplainer method)
,
[1]
extract_entity_function() (in module gsnn.gsnn.interpret.extract_entity_function)
(in module gsnn.interpret.extract_entity_function)
F
fit() (gsnn.gsnn.optim.OutputEdgeInferer.OutputEdgeInferer method)
(gsnn.optim.OutputEdgeInferer.OutputEdgeInferer method)
forward() (gsnn.gsnn.interpret.extract_entity_function.dense_func_node method)
(gsnn.gsnn.models.ChannelEMANorm.ChannelEMANorm method)
(gsnn.gsnn.models.GroupBatchNorm.GroupBatchNorm method)
(gsnn.gsnn.models.GroupEMANorm.GroupEMANorm method)
(gsnn.gsnn.models.GroupLayerNorm.GroupLayerNorm method)
(gsnn.gsnn.models.GroupRMSNorm.GroupRMSNorm method)
(gsnn.gsnn.models.GSNN.GSNN method)
(gsnn.gsnn.models.GSNN.NodeAttention method)
(gsnn.gsnn.models.GSNN.NodeMLP method)
(gsnn.gsnn.models.GSNN.ResBlock method)
(gsnn.gsnn.models.GSNN.SignedMessagePassing method)
(gsnn.gsnn.models.NN.NN method)
(gsnn.gsnn.models.SoftmaxGroupNorm.SoftmaxGroupNorm method)
(gsnn.gsnn.models.SparseLinear.Conv method)
(gsnn.gsnn.models.SparseLinear.SparseLinear method)
(gsnn.gsnn.optim.OutputEdgeInferer.OutputEdgeInferer method)
(gsnn.interpret.extract_entity_function.dense_func_node method)
(gsnn.models.ChannelEMANorm.ChannelEMANorm method)
(gsnn.models.GroupBatchNorm.GroupBatchNorm method)
(gsnn.models.GroupEMANorm.GroupEMANorm method)
(gsnn.models.GroupLayerNorm.GroupLayerNorm method)
(gsnn.models.GroupRMSNorm.GroupRMSNorm method)
(gsnn.models.GSNN.GSNN method)
(gsnn.models.GSNN.NodeAttention method)
(gsnn.models.GSNN.NodeMLP method)
(gsnn.models.GSNN.ResBlock method)
(gsnn.models.GSNN.SignedMessagePassing method)
(gsnn.models.NN.NN method)
(gsnn.models.SoftmaxGroupNorm.SoftmaxGroupNorm method)
(gsnn.models.SparseLinear.Conv method)
(gsnn.models.SparseLinear.SparseLinear method)
(gsnn.optim.OutputEdgeInferer.OutputEdgeInferer method)
G
get_activation() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
get_all_possible_paths_set() (in module gsnn.gsnn.proc.subset)
(in module gsnn.proc.subset)
get_batch_params() (gsnn.gsnn.models.GSNN.GSNN method)
(gsnn.models.GSNN.GSNN method)
get_conv_indices() (in module gsnn.gsnn.models.GSNN)
(in module gsnn.models.GSNN)
get_crit() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
get_dependency_details() (gsnn.gsnn.simulate.graph_comparison.GraphComparison method)
(gsnn.simulate.graph_comparison.GraphComparison method)
get_edge_probs() (gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.REINFORCE.REINFORCE method)
get_gradient_histogram_data() (gsnn.gsnn.optim.TrainingDiagnostics method)
,
[1]
(gsnn.optim.TrainingDiagnostics method)
,
[1]
get_node_activations() (gsnn.gsnn.models.GSNN.GSNN method)
(gsnn.models.GSNN.GSNN method)
get_node_attention() (gsnn.gsnn.models.GSNN.GSNN method)
(gsnn.models.GSNN.GSNN method)
get_optim() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
get_params() (gsnn.gsnn.optim.RewardScaler.RewardScaler method)
(gsnn.optim.RewardScaler.RewardScaler method)
get_regressed_r() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
get_reward_params() (gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.REINFORCE.REINFORCE method)
get_scheduler() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
get_sigid_attrs() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
get_summary() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.gsnn.optim.TrainingDiagnostics method)
,
[1]
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.TrainingDiagnostics method)
,
[1]
get_Win_indices() (in module gsnn.gsnn.models.GSNN)
(in module gsnn.models.GSNN)
get_Wout_indices() (in module gsnn.gsnn.models.GSNN)
(in module gsnn.models.GSNN)
GradDiagnostics (class in gsnn.gsnn.optim.GradDiagnostics)
(class in gsnn.optim.GradDiagnostics)
GraphComparison (class in gsnn.gsnn.simulate.graph_comparison)
(class in gsnn.simulate.graph_comparison)
GroupBatchNorm (class in gsnn.gsnn.models.GroupBatchNorm)
(class in gsnn.models.GroupBatchNorm)
GroupEMANorm (class in gsnn.gsnn.models.GroupEMANorm)
(class in gsnn.models.GroupEMANorm)
GroupLayerNorm (class in gsnn.gsnn.models.GroupLayerNorm)
(class in gsnn.models.GroupLayerNorm)
GroupRMSNorm (class in gsnn.gsnn.models.GroupRMSNorm)
(class in gsnn.models.GroupRMSNorm)
GSNN (class in gsnn.gsnn.models.GSNN)
(class in gsnn.models.GSNN)
gsnn.gsnn
module
gsnn.gsnn.external
module
gsnn.gsnn.external.mmd
module
gsnn.gsnn.interpret
module
gsnn.gsnn.interpret.ContrastiveGSNNExplainer
module
gsnn.gsnn.interpret.extract_entity_function
module
gsnn.gsnn.interpret.plot_explanation_graph
module
gsnn.gsnn.interpret.utils
module
gsnn.gsnn.models
module
gsnn.gsnn.models.ChannelEMANorm
module
gsnn.gsnn.models.GroupBatchNorm
module
gsnn.gsnn.models.GroupEMANorm
module
gsnn.gsnn.models.GroupLayerNorm
module
gsnn.gsnn.models.GroupRMSNorm
module
gsnn.gsnn.models.GSNN
module
gsnn.gsnn.models.NN
module
gsnn.gsnn.models.SoftmaxGroupNorm
module
gsnn.gsnn.models.SparseLinear
module
gsnn.gsnn.models.utils
module
gsnn.gsnn.optim
module
gsnn.gsnn.optim.EarlyStopper
module
gsnn.gsnn.optim.Environment
module
gsnn.gsnn.optim.GradDiagnostics
module
gsnn.gsnn.optim.OutputEdgeInferer
module
gsnn.gsnn.optim.REINFORCE
module
gsnn.gsnn.optim.RewardScaler
module
gsnn.gsnn.optim.utils
module
gsnn.gsnn.ot
module
gsnn.gsnn.ot.mmd
module
gsnn.gsnn.proc
module
gsnn.gsnn.proc.coarsen
module
gsnn.gsnn.proc.construct
module
gsnn.gsnn.proc.subset
module
gsnn.gsnn.simulate
module
gsnn.gsnn.simulate.datasets
module
gsnn.gsnn.simulate.graph_comparison
module
gsnn.gsnn.simulate.nx2pyg
module
gsnn.gsnn.simulate.utils
module
gsnn.interpret
module
gsnn.interpret.ContrastiveGSNNExplainer
module
gsnn.interpret.extract_entity_function
module
gsnn.interpret.plot_explanation_graph
module
gsnn.interpret.utils
module
gsnn.models
module
gsnn.models.ChannelEMANorm
module
gsnn.models.GroupBatchNorm
module
gsnn.models.GroupEMANorm
module
gsnn.models.GroupLayerNorm
module
gsnn.models.GroupRMSNorm
module
gsnn.models.GSNN
module
gsnn.models.NN
module
gsnn.models.SoftmaxGroupNorm
module
gsnn.models.SparseLinear
module
gsnn.models.utils
module
gsnn.optim
module
gsnn.optim.EarlyStopper
module
gsnn.optim.Environment
module
gsnn.optim.GradDiagnostics
module
gsnn.optim.OutputEdgeInferer
module
gsnn.optim.REINFORCE
module
gsnn.optim.RewardScaler
module
gsnn.optim.utils
module
gsnn.proc
module
gsnn.proc.coarsen
module
gsnn.proc.construct
module
gsnn.proc.subset
module
gsnn.simulate
module
gsnn.simulate.datasets
module
gsnn.simulate.graph_comparison
module
gsnn.simulate.nx2pyg
module
gsnn.simulate.utils
module
GSNNExplainer (class in gsnn.gsnn.interpret)
,
[1]
(class in gsnn.interpret)
,
[1]
GSNNNetworkConstructor (class in gsnn.gsnn.proc.construct)
(class in gsnn.proc.construct)
H
hetero2homo() (in module gsnn.gsnn.models.GSNN)
(in module gsnn.models.GSNN)
I
IGExplainer (class in gsnn.gsnn.interpret)
,
[1]
(class in gsnn.interpret)
,
[1]
io_equivalence() (in module gsnn.gsnn.proc.coarsen)
(in module gsnn.proc.coarsen)
K
kaiming_normal() (in module gsnn.gsnn.models.SparseLinear)
(in module gsnn.models.SparseLinear)
kaiming_uniform() (in module gsnn.gsnn.models.SparseLinear)
(in module gsnn.models.SparseLinear)
L
log() (gsnn.gsnn.models.utils.TBLogger method)
(gsnn.models.utils.TBLogger method)
M
message() (gsnn.gsnn.models.GSNN.SignedMessagePassing method)
(gsnn.gsnn.models.SparseLinear.Conv method)
(gsnn.models.GSNN.SignedMessagePassing method)
(gsnn.models.SparseLinear.Conv method)
mmd_distance() (in module gsnn.gsnn.external.mmd)
(in module gsnn.gsnn.ot.mmd)
module
gsnn.gsnn
gsnn.gsnn.external
gsnn.gsnn.external.mmd
gsnn.gsnn.interpret
gsnn.gsnn.interpret.ContrastiveGSNNExplainer
gsnn.gsnn.interpret.extract_entity_function
gsnn.gsnn.interpret.plot_explanation_graph
gsnn.gsnn.interpret.utils
gsnn.gsnn.models
gsnn.gsnn.models.ChannelEMANorm
gsnn.gsnn.models.GroupBatchNorm
gsnn.gsnn.models.GroupEMANorm
gsnn.gsnn.models.GroupLayerNorm
gsnn.gsnn.models.GroupRMSNorm
gsnn.gsnn.models.GSNN
gsnn.gsnn.models.NN
gsnn.gsnn.models.SoftmaxGroupNorm
gsnn.gsnn.models.SparseLinear
gsnn.gsnn.models.utils
gsnn.gsnn.optim
gsnn.gsnn.optim.EarlyStopper
gsnn.gsnn.optim.Environment
gsnn.gsnn.optim.GradDiagnostics
gsnn.gsnn.optim.OutputEdgeInferer
gsnn.gsnn.optim.REINFORCE
gsnn.gsnn.optim.RewardScaler
gsnn.gsnn.optim.utils
gsnn.gsnn.ot
gsnn.gsnn.ot.mmd
gsnn.gsnn.proc
gsnn.gsnn.proc.coarsen
gsnn.gsnn.proc.construct
gsnn.gsnn.proc.subset
gsnn.gsnn.simulate
gsnn.gsnn.simulate.datasets
gsnn.gsnn.simulate.graph_comparison
gsnn.gsnn.simulate.nx2pyg
gsnn.gsnn.simulate.utils
gsnn.interpret
gsnn.interpret.ContrastiveGSNNExplainer
gsnn.interpret.extract_entity_function
gsnn.interpret.plot_explanation_graph
gsnn.interpret.utils
gsnn.models
gsnn.models.ChannelEMANorm
gsnn.models.GroupBatchNorm
gsnn.models.GroupEMANorm
gsnn.models.GroupLayerNorm
gsnn.models.GroupRMSNorm
gsnn.models.GSNN
gsnn.models.NN
gsnn.models.SoftmaxGroupNorm
gsnn.models.SparseLinear
gsnn.models.utils
gsnn.optim
gsnn.optim.EarlyStopper
gsnn.optim.Environment
gsnn.optim.GradDiagnostics
gsnn.optim.OutputEdgeInferer
gsnn.optim.REINFORCE
gsnn.optim.RewardScaler
gsnn.optim.utils
gsnn.proc
gsnn.proc.coarsen
gsnn.proc.construct
gsnn.proc.subset
gsnn.simulate
gsnn.simulate.datasets
gsnn.simulate.graph_comparison
gsnn.simulate.nx2pyg
gsnn.simulate.utils
N
neighborhood_preservation_score() (in module gsnn.gsnn.optim.utils)
(in module gsnn.optim.utils)
next_divisor() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
NN (class in gsnn.gsnn.models.NN)
(class in gsnn.models.NN)
node2edge() (in module gsnn.gsnn.models.GSNN)
(in module gsnn.models.GSNN)
NodeAttention (class in gsnn.gsnn.models.GSNN)
(class in gsnn.models.GSNN)
NodeMLP (class in gsnn.gsnn.models.GSNN)
(class in gsnn.models.GSNN)
NoiseTunnel (class in gsnn.gsnn.interpret)
,
[1]
(class in gsnn.interpret)
,
[1]
normal() (in module gsnn.gsnn.models.SparseLinear)
(in module gsnn.models.SparseLinear)
nx2pyg() (in module gsnn.gsnn.simulate.nx2pyg)
(in module gsnn.simulate.nx2pyg)
nx_to_pyro_model() (in module gsnn.gsnn.simulate.utils)
(in module gsnn.simulate.utils)
O
OcclusionExplainer (class in gsnn.gsnn.interpret)
,
[1]
(class in gsnn.interpret)
,
[1]
OutputEdgeInferer (class in gsnn.gsnn.optim.OutputEdgeInferer)
(class in gsnn.optim.OutputEdgeInferer)
P
plot_diagnostics() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.gsnn.optim.TrainingDiagnostics method)
,
[1]
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.TrainingDiagnostics method)
,
[1]
plot_edge_importance() (in module gsnn.gsnn.interpret.utils)
(in module gsnn.interpret.utils)
plot_explanation_graph() (in module gsnn.gsnn.interpret.plot_explanation_graph)
(in module gsnn.interpret.plot_explanation_graph)
plot_gradient_magnitude_by_layer() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
plot_gradient_ratio_heatmap() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
plot_hairball() (in module gsnn.gsnn.interpret.plot_explanation_graph)
(in module gsnn.interpret.plot_explanation_graph)
plot_node_importance() (in module gsnn.gsnn.interpret.utils)
(in module gsnn.interpret.utils)
plot_summary_statistics() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
plot_vanishing_over_time() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
predict_gnn() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
predict_gsnn() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
predict_nn() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
print_progress_() (gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.REINFORCE.REINFORCE method)
prob_of() (gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.REINFORCE.REINFORCE method)
prune() (gsnn.gsnn.models.GSNN.GSNN method)
(gsnn.gsnn.models.SparseLinear.SparseLinear method)
(gsnn.models.GSNN.GSNN method)
(gsnn.models.SparseLinear.SparseLinear method)
pyg2nx() (in module gsnn.gsnn.simulate.nx2pyg)
(in module gsnn.simulate.nx2pyg)
R
randomize() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
regress_out() (in module gsnn.gsnn.models.utils)
(in module gsnn.models.utils)
REINFORCE (class in gsnn.gsnn.optim.REINFORCE)
(class in gsnn.optim.REINFORCE)
ResBlock (class in gsnn.gsnn.models.GSNN)
(class in gsnn.models.GSNN)
reset() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.gsnn.optim.TrainingDiagnostics method)
,
[1]
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.TrainingDiagnostics method)
,
[1]
RewardScaler (class in gsnn.gsnn.optim.RewardScaler)
(class in gsnn.optim.RewardScaler)
root_mean_squared_picp_error() (in module gsnn.gsnn.optim.utils)
(in module gsnn.optim.utils)
run() (gsnn.gsnn.optim.Environment.Environment method)
(gsnn.optim.Environment.Environment method)
S
sample() (gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.REINFORCE.REINFORCE method)
scale() (gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.gsnn.optim.RewardScaler.RewardScaler method)
(gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.RewardScaler.RewardScaler method)
set_node_mask() (gsnn.gsnn.models.GSNN.ResBlock method)
(gsnn.models.GSNN.ResBlock method)
SignedMessagePassing (class in gsnn.gsnn.models.GSNN)
(class in gsnn.models.GSNN)
simulate() (in module gsnn.gsnn.simulate)
,
[1]
(in module gsnn.simulate)
,
[1]
simulate_10_in_25_func_10_out_cyclic() (in module gsnn.gsnn.simulate.datasets)
(in module gsnn.simulate.datasets)
simulate_3_in_3_out() (in module gsnn.gsnn.simulate.datasets)
(in module gsnn.simulate.datasets)
SoftmaxGroupNorm (class in gsnn.gsnn.models.SoftmaxGroupNorm)
(class in gsnn.models.SoftmaxGroupNorm)
SparseLinear (class in gsnn.gsnn.models.SparseLinear)
(class in gsnn.models.SparseLinear)
step() (gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.REINFORCE.REINFORCE method)
subset_graph() (in module gsnn.gsnn.proc.subset)
(in module gsnn.proc.subset)
T
TBLogger (class in gsnn.gsnn.models.utils)
(class in gsnn.models.utils)
train() (gsnn.gsnn.optim.Environment.Environment method)
(gsnn.optim.Environment.Environment method)
TrainingDiagnostics (class in gsnn.gsnn.optim)
,
[1]
(class in gsnn.optim)
,
[1]
tune() (gsnn.gsnn.interpret.ContrastiveGSNNExplainer.ContrastiveGSNNExplainer method)
(gsnn.gsnn.interpret.GSNNExplainer method)
,
[1]
(gsnn.interpret.ContrastiveGSNNExplainer.ContrastiveGSNNExplainer method)
(gsnn.interpret.GSNNExplainer method)
,
[1]
U
uniform() (in module gsnn.gsnn.models.SparseLinear)
(in module gsnn.models.SparseLinear)
update() (gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.gsnn.optim.RewardScaler.RewardScaler method)
(gsnn.gsnn.optim.TrainingDiagnostics method)
,
[1]
(gsnn.optim.GradDiagnostics.GradDiagnostics method)
(gsnn.optim.REINFORCE.REINFORCE method)
(gsnn.optim.RewardScaler.RewardScaler method)
(gsnn.optim.TrainingDiagnostics method)
,
[1]
V
validate() (gsnn.gsnn.optim.Environment.Environment method)
(gsnn.optim.Environment.Environment method)
X
xavier_normal() (in module gsnn.gsnn.models.SparseLinear)
(in module gsnn.models.SparseLinear)
xavier_uniform() (in module gsnn.gsnn.models.SparseLinear)
(in module gsnn.models.SparseLinear)