gsnn.simulate

Functions

simulate(G, n_train, n_test, input_nodes, ...)

Generate samples from a synthetic graph-structured data-generation process.

gsnn.simulate.simulate(G, n_train: int, n_test: int, input_nodes, output_nodes, *, noise_scale: float = 1.0, special_functions: Optional[Dict] = None, signed_edges: Optional[Dict[tuple, int]] = None)[source]

Generate samples from a synthetic graph-structured data-generation process.

The function takes a directed NetworkX graph that represents causal relationships between input, function, and output nodes. It converts the graph into a Pyro probabilistic program (via :pyfunc:`gsnn.simulate.utils.nx_to_pyro_model`) and then draws IID samples from that model.

Parameters:
  • G (networkx.DiGraph) – Directed graph encoding the Bayesian network structure.

  • n_train (int) – Number of training instances to simulate.

  • n_test (int) – Number of test instances to simulate.

  • input_nodes (list[str]) – Ordered list of node names that are treated as inputs (observed variables).

  • output_nodes (list[str]) – Ordered list of node names that are treated as outputs (targets).

  • noise_scale (float, optional) – Standard deviation of the additive Gaussian noise term used for every conditional distribution that has no special function attached. Default: 1.0.

  • special_functions (dict[str, callable], optional) –

    Mapping from node name to a Python callable that overrides the default linear relationship for that node. Each callable must have the signature

    f(parent_values: list) -> Tensor where parent_values is a list of the parent node values.

Shapes:
  • x_train\((n_{\text{train}}, |\text{inputs}|)\)

  • y_train\((n_{\text{train}}, |\text{outputs}|)\)

  • x_test\((n_{\text{test}}, |\text{inputs}|)\)

  • y_test\((n_{\text{test}}, |\text{outputs}|)\)

Returns:

(x_train, y_train, x_test, y_test) where each element is a dense NumPy array ordered according to input_nodes / output_nodes.

Return type:

Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Modules

gsnn.simulate.datasets

gsnn.simulate.graph_comparison

Graph comparison utilities for evaluating shared dependencies between graphs.

gsnn.simulate.nx2pyg

gsnn.simulate.simulate(G, n_train, n_test, ...)

Generate samples from a synthetic graph-structured data-generation process.

gsnn.simulate.utils