[1]:
from matplotlib import pyplot as plt
import seaborn as sbn
import numpy as np
import torch
import pandas as pd

from gsnn.models.GSNN import GSNN
from gsnn.simulate.nx2pyg import nx2pyg
from gsnn.simulate.datasets import simulate_3_in_3_out
import time

# for reproducibility
torch.manual_seed(0)
np.random.seed(0)

%load_ext autoreload
%autoreload 2

import torch
/home/teddy/miniconda3/envs/gsnn-lib/lib/python3.12/site-packages/torch_geometric/typing.py:124: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: /home/teddy/miniconda3/envs/gsnn-lib/lib/python3.12/site-packages/torch_sparse/_version_cuda.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSs
  warnings.warn(f"An issue occurred while importing 'torch-sparse'. "

Gradient checkpointing and compiling

[2]:
G, pos, x_train, x_test, y_train, y_test, \
    input_nodes, function_nodes, output_nodes = simulate_3_in_3_out(n_train=10,
                                                                    n_test=1,
                                                                    noise_scale=0.1)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
data = nx2pyg(G, input_nodes, function_nodes, output_nodes)

kwargs = {'channels': 10,
          'bias': True,
          'add_function_self_edges': True,
          'norm': 'none',
          'residual': True,
          'dropout': 0.}
[3]:
def memory_usage(model, x_train):
    torch.cuda.reset_peak_memory_stats()
    model(x_train.to(device))
    return torch.cuda.max_memory_allocated() / 1e6

def time_usage(model, x_train, n=50):
    times = []
    for i in range(n):
        start = time.time()
        model(x_train.to(device))
        times.append(time.time() - start)
    return np.mean(times)

checkpointing with share_layers=False

[4]:

res = {'layers':[], 'mem_no_ckpt':[], 'time_no_ckpt':[], 'mem_ckpt':[], 'time_ckpt':[]} for layers in np.linspace(5,100,10): print(f'progress: {layers:.2f}%', end='\r') layers = int(layers) model_no_ckpt = GSNN(data.edge_index_dict, data.node_names_dict, share_layers=False, checkpoint=False, layers=layers, **kwargs).to(device) model_ckpt = GSNN(data.edge_index_dict, data.node_names_dict, share_layers=False, checkpoint=True, layers=layers, **kwargs).to(device) res['layers'].append(layers) res['mem_no_ckpt'].append(memory_usage(model_no_ckpt, x_train)) res['time_no_ckpt'].append(time_usage(model_no_ckpt, x_train)) res['mem_ckpt'].append(memory_usage(model_ckpt, x_train)) res['time_ckpt'].append(time_usage(model_ckpt, x_train)) res = pd.DataFrame(res) f,axes = plt.subplots(1,2, figsize=(6,3)) sbn.scatterplot(data=res, x='layers', y='mem_no_ckpt', color='red', label='no checkpointing', ax=axes[0]) sbn.scatterplot(data=res, x='layers', y='mem_ckpt', color='blue', label='checkpointing', ax=axes[0]) sbn.scatterplot(data=res, x='layers', y='time_no_ckpt', color='red', label='no checkpointing', ax=axes[1]) sbn.scatterplot(data=res, x='layers', y='time_ckpt', color='blue', label='checkpointing', ax=axes[1]) axes[0].set_ylabel('memory usage (MB)') axes[1].set_ylabel('time (s)') axes[0].set_xlabel('layers') axes[1].set_xlabel('layers') plt.tight_layout() plt.show() mem_percent_change = ((res.mem_ckpt - res.mem_no_ckpt)/res.mem_no_ckpt*100).mean() time_percent_change = ((res.time_ckpt - res.time_no_ckpt)/res.time_no_ckpt*100).mean() print('with `share_layers=False`:') print(f'\tusing checkpointing on average has a {mem_percent_change:.2f}% decrease in memory usage') print(f'\tusing checkpointing on average has a {time_percent_change:.2f}% increase in runtime')
progress: 100.00%
../_images/tutorials_06_checkpointing_and_compiling_5_1.png
with `share_layers=False`:
        using checkpointing on average has a -54.80% decrease in memory usage
        using checkpointing on average has a 36.12% increase in runtime

Checkpointing with share_layers=True

[5]:

res = {'layers':[], 'mem_no_ckpt':[], 'time_no_ckpt':[], 'mem_ckpt':[], 'time_ckpt':[]} for layers in np.linspace(5,100,10): print(f'progress: {layers:.2f}%', end='\r') layers = int(layers) model_no_ckpt = GSNN(data.edge_index_dict, data.node_names_dict, share_layers=True, checkpoint=False, layers=layers, **kwargs).to(device) model_ckpt = GSNN(data.edge_index_dict, data.node_names_dict, share_layers=True, checkpoint=True, layers=layers, **kwargs).to(device) res['layers'].append(layers) res['mem_no_ckpt'].append(memory_usage(model_no_ckpt, x_train)) res['time_no_ckpt'].append(time_usage(model_no_ckpt, x_train)) res['mem_ckpt'].append(memory_usage(model_ckpt, x_train)) res['time_ckpt'].append(time_usage(model_ckpt, x_train)) res = pd.DataFrame(res) f,axes = plt.subplots(1,2, figsize=(6,3)) sbn.scatterplot(data=res, x='layers', y='mem_no_ckpt', color='red', label='no checkpointing', ax=axes[0]) sbn.scatterplot(data=res, x='layers', y='mem_ckpt', color='blue', label='checkpointing', ax=axes[0]) sbn.scatterplot(data=res, x='layers', y='time_no_ckpt', color='red', label='no checkpointing', ax=axes[1]) sbn.scatterplot(data=res, x='layers', y='time_ckpt', color='blue', label='checkpointing', ax=axes[1]) axes[0].set_ylabel('memory usage (MB)') axes[1].set_ylabel('time (s)') axes[0].set_xlabel('layers') axes[1].set_xlabel('layers') plt.tight_layout() plt.show() mem_percent_change = ((res.mem_ckpt - res.mem_no_ckpt)/res.mem_no_ckpt*100).mean() time_percent_change = ((res.time_ckpt - res.time_no_ckpt)/res.time_no_ckpt*100).mean() print('with `share_layers=True`:') print(f'\tusing checkpointing on average has a {mem_percent_change:.2f}% decrease in memory usage') print(f'\tusing checkpointing on average has a {time_percent_change:.2f}% increase in runtime')
progress: 100.00%
../_images/tutorials_06_checkpointing_and_compiling_7_1.png
with `share_layers=True`:
        using checkpointing on average has a -77.51% decrease in memory usage
        using checkpointing on average has a 36.19% increase in runtime