res = {'layers':[], 'mem_no_ckpt':[], 'time_no_ckpt':[], 'mem_ckpt':[], 'time_ckpt':[]}
for layers in np.linspace(5,100,10):
print(f'progress: {layers:.2f}%', end='\r')
layers = int(layers)
model_no_ckpt = GSNN(data.edge_index_dict,
data.node_names_dict,
share_layers=False,
checkpoint=False,
layers=layers,
**kwargs).to(device)
model_ckpt = GSNN(data.edge_index_dict,
data.node_names_dict,
share_layers=False,
checkpoint=True,
layers=layers,
**kwargs).to(device)
res['layers'].append(layers)
res['mem_no_ckpt'].append(memory_usage(model_no_ckpt, x_train))
res['time_no_ckpt'].append(time_usage(model_no_ckpt, x_train))
res['mem_ckpt'].append(memory_usage(model_ckpt, x_train))
res['time_ckpt'].append(time_usage(model_ckpt, x_train))
res = pd.DataFrame(res)
f,axes = plt.subplots(1,2, figsize=(6,3))
sbn.scatterplot(data=res, x='layers', y='mem_no_ckpt', color='red', label='no checkpointing', ax=axes[0])
sbn.scatterplot(data=res, x='layers', y='mem_ckpt', color='blue', label='checkpointing', ax=axes[0])
sbn.scatterplot(data=res, x='layers', y='time_no_ckpt', color='red', label='no checkpointing', ax=axes[1])
sbn.scatterplot(data=res, x='layers', y='time_ckpt', color='blue', label='checkpointing', ax=axes[1])
axes[0].set_ylabel('memory usage (MB)')
axes[1].set_ylabel('time (s)')
axes[0].set_xlabel('layers')
axes[1].set_xlabel('layers')
plt.tight_layout()
plt.show()
mem_percent_change = ((res.mem_ckpt - res.mem_no_ckpt)/res.mem_no_ckpt*100).mean()
time_percent_change = ((res.time_ckpt - res.time_no_ckpt)/res.time_no_ckpt*100).mean()
print('with `share_layers=False`:')
print(f'\tusing checkpointing on average has a {mem_percent_change:.2f}% decrease in memory usage')
print(f'\tusing checkpointing on average has a {time_percent_change:.2f}% increase in runtime')