# GNNSubNet.py
# Authors: Bastian Pfeifer <https://github.com/pievos101>, Marcus D. Bloice <https://github.com/mdbloice>
from urllib.parse import _NetlocResultMixinStr
import numpy as np
import random
#from scipy.sparse.extract import find
from scipy.sparse import find
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from torch.nn.modules import conv
from torch_geometric import data
from torch_geometric.data import DataLoader, Batch
from pathlib import Path
import copy
from tqdm import tqdm
import os
import requests
import pandas as pd
import io
#from collections.abc import Mapping
from torch_geometric.data.data import Data
from torch_geometric.loader import DataLoader
from .gnn_training_utils import check_if_graph_is_connected, pass_data_iteratively
from .dataset import generate, load_OMICS_dataset, convert_to_s2vgraph
from .gnn_explainer import GNNExplainer
from .graphcnn import GraphCNN
from .graphcheb import GraphCheb, ChebConvNet, test_model_acc, test_model
from .community_detection import find_communities
from .edge_importance import calc_edge_importance
from torch_geometric.nn.conv.cheb_conv import ChebConv
[docs]class GNNSubNet(object):
"""
The class GNNSubSet represents the main user API for the
GNN-SubNet package.
"""
def __init__(self, location=None, ppi=None, features=None, target=None, cutoff=950, normalize=True) -> None:
self.location = location
self.ppi = ppi
self.features = features
self.target = target
self.dataset = None
self.model_status = None
self.model = None
self.gene_names = None
self.accuracy = None
self.confusion_matrix = None
self.test_loss = None
# Flags for internal use (hidden from user)
self._explainer_run = False
if ppi == None:
return None
dataset, gene_names = load_OMICS_dataset(self.ppi, self.features, self.target, True, cutoff, normalize)
# Check whether graph is connected
check = check_if_graph_is_connected(dataset[0].edge_index)
print("Graph is connected ", check)
if check == False:
print("Calculate subgraph ...")
dataset, gene_names = load_OMICS_dataset(self.ppi, self.features, self.target, False, cutoff, normalize)
check = check_if_graph_is_connected(dataset[0].edge_index)
print("Graph is connected ", check)
#print('\n')
print('##################')
print("# DATASET LOADED #")
print('##################')
#print('\n')
self.dataset = dataset
self.true_class = None
self.gene_names = gene_names
self.s2v_test_dataset = None
self.edges = np.transpose(np.array(dataset[0].edge_index))
self.edge_mask = None
self.node_mask = None
self.node_mask_matrix = None
self.modules = None
self.module_importances = None
[docs] def summary(self):
"""
Print a summary for the GNNSubSet object's current state.
"""
print("")
print("Number of nodes:", len(self.dataset[0].x))
print("Number of edges:", self.edges.shape[0])
print("Number of modalities:",self.dataset[0].x.shape[1])
[docs] def train(self, epoch_nr = 20, method="graphcnn", learning_rate=0.01):
if method=="chebconv":
print("chebconv for training ...")
self.train_chebconv(epoch_nr = epoch_nr)
self.classifier="chebconv"
if method=="graphcnn":
print("graphcnn for training ...")
self.train_graphcnn(epoch_nr = epoch_nr, learning_rate=learning_rate)
self.classifier="graphcnn"
if method=="graphcheb":
print("graphcheb for training ...")
self.train_graphcheb(epoch_nr = epoch_nr)
self.classifier="graphcheb"
if method=="chebnet":
print("chebnet for training ...")
self.train_chebnet(epoch_nr = epoch_nr)
self.classifier="chebnet"
[docs] def explain(self, n_runs=1, classifier="graphcnn", communities=True):
if self.classifier=="chebconv":
self.explain_chebconv(n_runs=n_runs, communities=communities)
if self.classifier=="graphcnn":
self.explain_graphcnn(n_runs=n_runs, communities=communities)
if self.classifier=="graphcheb":
self.explain_graphcheb(n_runs=n_runs, communities=communities)
if self.classifier=="chebnet":
self.explain_graphcheb(n_runs=n_runs, communities=communities)
[docs] def predict(self, gnnsubnet_test, classifier="graphcnn"):
if self.classifier=="chebconv":
pred = self.predict_chebconv(gnnsubnet_test=gnnsubnet_test)
if self.classifier=="graphcnn":
pred = self.predict_graphcnn(gnnsubnet_test=gnnsubnet_test)
if self.classifier=="graphcheb":
pred = self.predict_graphcheb(gnnsubnet_test=gnnsubnet_test)
if self.classifier=="chebnet":
pred = self.predict_graphcheb(gnnsubnet_test=gnnsubnet_test)
pred = np.array(pred)
pred = pred.reshape(1, pred.size)
return pred
[docs] def train_chebnet(self, epoch_nr=25, shuffle=True, weights=False,
hidden_channels=10,
K=10,
layers_nr=1,
num_classes=2):
"""
---
"""
use_weights = False
dataset = self.dataset
gene_names = self.gene_names
graphs_class_0_list = []
graphs_class_1_list = []
for graph in dataset:
if graph.y.detach().cpu().numpy() == 0:
graphs_class_0_list.append(graph)
else:
graphs_class_1_list.append(graph)
graphs_class_0_len = len(graphs_class_0_list)
graphs_class_1_len = len(graphs_class_1_list)
print(f"Graphs class 0: {graphs_class_0_len}, Graphs class 1: {graphs_class_1_len}")
########################################################################################################################
# Downsampling of the class that contains more elements ===========================================================
# ########################################################################################################################
if graphs_class_0_len >= graphs_class_1_len:
random_graphs_class_0_list = random.sample(graphs_class_0_list, graphs_class_1_len)
balanced_dataset_list = graphs_class_1_list + random_graphs_class_0_list
if graphs_class_0_len < graphs_class_1_len:
random_graphs_class_1_list = random.sample(graphs_class_1_list, graphs_class_0_len)
balanced_dataset_list = graphs_class_0_list + random_graphs_class_1_list
# print(len(random_graphs_class_0_list))
# print(len(random_graphs_class_1_list))
random.shuffle(balanced_dataset_list)
print(f"Length of balanced dataset list: {len(balanced_dataset_list)}")
list_len = len(balanced_dataset_list)
# print(list_len)
train_set_len = int(list_len * 4 / 5)
train_dataset_list = balanced_dataset_list[:train_set_len]
test_dataset_list = balanced_dataset_list[train_set_len:]
train_graph_class_0_nr = 0
train_graph_class_1_nr = 0
for graph in train_dataset_list:
if graph.y.detach().cpu().numpy() == 0:
train_graph_class_0_nr += 1
else:
train_graph_class_1_nr += 1
print(f"Train graph class 0: {train_graph_class_0_nr}, train graph class 1: {train_graph_class_1_nr}")
test_graph_class_0_nr = 0
test_graph_class_1_nr = 0
for graph in test_dataset_list:
if graph.y.detach().cpu().numpy() == 0:
test_graph_class_0_nr += 1
else:
test_graph_class_1_nr += 1
print(f"Validation graph class 0: {test_graph_class_0_nr}, validation graph class 1: {test_graph_class_1_nr}")
# s2v_train_dataset = convert_to_s2vgraph(train_dataset_list)
# s2v_test_dataset = convert_to_s2vgraph(test_dataset_list)
s2v_train_dataset = train_dataset_list
s2v_test_dataset = test_dataset_list
model_path = 'omics_model.pth'
no_of_features = dataset[0].x.shape[1]
nodes_per_graph_nr = dataset[0].x.shape[0]
print("\tnodes_per_graph_nr", nodes_per_graph_nr)
input_dim = no_of_features
n_classes = 2
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ChebConvNet(input_channels=1, n_features=nodes_per_graph_nr, n_channels=2, n_classes=2, K=8, n_layers=1)
#print(model)
#model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6)
# lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9,
# last_epoch=-1)
criterion = torch.nn.CrossEntropyLoss()
model.train()
min_loss = 50
best_model = ChebConvNet(input_channels=1, n_features=nodes_per_graph_nr, n_channels=2, n_classes=2, K=8, n_layers=1)
#best_model.to(device)
# best_model = ChebConv(in_channels=1, out_channels=2, K=10)
min_val_loss = 1000000.0
n_epochs_stop = 25
epochs_no_improve = 0
batch_size = 100
train_loader = DataLoader(s2v_train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(s2v_test_dataset, batch_size=batch_size, shuffle=False)
for epoch in range(epoch_nr):
running_loss = 0.0
steps = 0
model.train()
# data_pbar_loader = tqdm(train_loader, unit='batch')
for data in train_loader:
out = model(x=data.x, edge_index=data.edge_index, batch=data.batch)
loss = criterion(out, data.y) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
running_loss += loss.item()
steps += 1
epoch_loss = running_loss / steps
model.eval()
acc_train = test_model_acc(train_loader, model)
val_loss, val_acc, _ , _ = test_model(test_loader, model, criterion)
print()
print(f'Epoch: {epoch:03d}, Train Loss: {epoch_loss:.4f}, Train Acc: {acc_train:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}', end='\t')
# print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
# print(f"Train Acc {acc_train:.4f}")
# data_pbar_loader.set_description('epoch: %d' % (epoch))
# val_loss = 0
#
# tr = DataLoader(s2v_test_dataset, batch_size=len(s2v_test_dataset), shuffle=False)
# for vv in tr:
# # print("\toutput test")
# output = model(vv.x, vv.edge_index, vv.batch)
# # print("\toutput", output)
#
# # output = pass_data_iteratively(model, s2v_test_dataset)
#
# pred = output.max(1, keepdim=True)[1]
# labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
# if use_weights:
# loss = nn.CrossEntropyLoss(weight=weight)(output, labels)
# else:
# loss = nn.CrossEntropyLoss()(output, labels)
# val_loss += loss
# print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))
if val_loss < min_val_loss and epoch > 2: # to go through at least 2 epochs
print(f"Saving best model with validation loss {val_loss:.4f}", end="")
best_model = copy.deepcopy(model)
epochs_no_improve = 0
min_val_loss = val_loss
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!')
# model.load_state_dict(best_model.state_dict())
break
#
# confusion_array = []
# true_class_array = []
# predicted_class_array = []
# model.eval()
# correct = 0
# true_class_array = []
# predicted_class_array = []
# loading the parameters of the best model
model.load_state_dict(best_model.state_dict())
_, _, true_labels, predicted_labels = test_model(test_loader, model, criterion)
# test_loss = 0
#
# model.load_state_dict(best_model.state_dict())
#
# tr = DataLoader(s2v_test_dataset, batch_size=len(s2v_test_dataset), shuffle=False)
# for vv in tr:
# output = model(vv.x, vv.edge_index, vv.batch)
#
# # output = pass_data_iteratively(model, s2v_test_dataset)
# output = np.array(output.detach())
# predicted_class = output.argmax(1, keepdims=True)
#
# predicted_class = list(predicted_class)
#
# labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
# correct = torch.tensor(np.array(predicted_class)).eq(
# labels.view_as(torch.tensor(np.array(predicted_class)))).sum().item()
confusion_matrix_gnn = confusion_matrix(true_labels, predicted_labels)
print("\nConfusion matrix (Validation set):\n")
print(confusion_matrix_gnn)
from sklearn.metrics import balanced_accuracy_score
acc_bal = balanced_accuracy_score(true_labels, predicted_labels)
print("Validation balanced accuracy: {}".format(acc_bal))
model.train()
self.model_status = 'Trained'
self.model = copy.deepcopy(model)
self.accuracy = acc_bal
self.confusion_matrix = confusion_matrix_gnn
# self.test_loss = test_loss
self.s2v_test_dataset = s2v_test_dataset
self.predictions = predicted_labels
self.true_class = true_labels
[docs] def train_graphcheb(self, epoch_nr = 20, shuffle=True, weights=False,
hidden_channels=7,
K=5,
layers_nr=2,
num_classes=2):
"""
---
"""
use_weights = False
dataset = self.dataset
gene_names = self.gene_names
graphs_class_0_list = []
graphs_class_1_list = []
for graph in dataset:
if graph.y.numpy() == 0:
graphs_class_0_list.append(graph)
else:
graphs_class_1_list.append(graph)
graphs_class_0_len = len(graphs_class_0_list)
graphs_class_1_len = len(graphs_class_1_list)
print(f"Graphs class 0: {graphs_class_0_len}, Graphs class 1: {graphs_class_1_len}")
########################################################################################################################
# Downsampling of the class that contains more elements ===========================================================
# ########################################################################################################################
if graphs_class_0_len >= graphs_class_1_len:
random_graphs_class_0_list = random.sample(graphs_class_0_list, graphs_class_1_len)
balanced_dataset_list = graphs_class_1_list + random_graphs_class_0_list
if graphs_class_0_len < graphs_class_1_len:
random_graphs_class_1_list = random.sample(graphs_class_1_list, graphs_class_0_len)
balanced_dataset_list = graphs_class_0_list + random_graphs_class_1_list
#print(len(random_graphs_class_0_list))
#print(len(random_graphs_class_1_list))
random.shuffle(balanced_dataset_list)
print(f"Length of balanced dataset list: {len(balanced_dataset_list)}")
list_len = len(balanced_dataset_list)
#print(list_len)
train_set_len = int(list_len * 4 / 5)
train_dataset_list = balanced_dataset_list[:train_set_len]
test_dataset_list = balanced_dataset_list[train_set_len:]
train_graph_class_0_nr = 0
train_graph_class_1_nr = 0
for graph in train_dataset_list:
if graph.y.numpy() == 0:
train_graph_class_0_nr += 1
else:
train_graph_class_1_nr += 1
print(f"Train graph class 0: {train_graph_class_0_nr}, train graph class 1: {train_graph_class_1_nr}")
test_graph_class_0_nr = 0
test_graph_class_1_nr = 0
for graph in test_dataset_list:
if graph.y.numpy() == 0:
test_graph_class_0_nr += 1
else:
test_graph_class_1_nr += 1
print(f"Validation graph class 0: {test_graph_class_0_nr}, validation graph class 1: {test_graph_class_1_nr}")
#s2v_train_dataset = convert_to_s2vgraph(train_dataset_list)
#s2v_test_dataset = convert_to_s2vgraph(test_dataset_list)
s2v_train_dataset = train_dataset_list
s2v_test_dataset = test_dataset_list
model_path = 'omics_model.pth'
no_of_features = dataset[0].x.shape[1]
nodes_per_graph_nr = dataset[0].x.shape[0]
input_dim = no_of_features
n_classes = 2
model = GraphCheb(
num_node_features=input_dim,
hidden_channels=hidden_channels,
K=K,
layers_nr=layers_nr,
num_classes=2)
opt = torch.optim.Adam(model.parameters(), lr = 0.1)
load_model = False
if load_model:
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['state_dict'])
opt = checkpoint['optimizer']
model.train()
min_loss = 50
best_model = GraphCheb(
num_node_features=input_dim,
hidden_channels=hidden_channels,
K=K,
layers_nr=1,
num_classes=2)
min_val_loss = 1000000
n_epochs_stop = 10
epochs_no_improve = 0
steps_per_epoch = 35
for epoch in range(epoch_nr):
model.train()
pbar = tqdm(range(steps_per_epoch), unit='batch')
epoch_loss = 0
for pos in pbar:
selected_idx = np.random.permutation(len(s2v_train_dataset))[:32]
batch_graph_x = [s2v_train_dataset[idx] for idx in selected_idx]
batch_graph = DataLoader(batch_graph_x, batch_size=32, shuffle=False)
for batch_graph_y in batch_graph:
logits = model(batch_graph_y.x, batch_graph_y.edge_index, batch_graph_y.batch)
labels = torch.LongTensor([graph.y for graph in batch_graph_x])
if use_weights:
loss = nn.CrossEntropyLoss(weight=weight)(logits,labels)
else:
loss = nn.CrossEntropyLoss()(logits,labels)
opt.zero_grad()
loss.backward()
opt.step()
epoch_loss += loss.detach().item()
epoch_loss /= steps_per_epoch
model.eval()
tr = DataLoader(s2v_train_dataset, batch_size=len(s2v_train_dataset), shuffle=False)
for vv in tr:
output = model(vv.x, vv.edge_index, vv.batch)
#output = pass_data_iteratively(model, s2v_train_dataset)
predicted_class = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.y for graph in s2v_train_dataset])
correct = predicted_class.eq(labels.view_as(predicted_class)).sum().item()
acc_train = correct / float(len(s2v_train_dataset))
print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
print(f"Train Acc {acc_train:.4f}")
pbar.set_description('epoch: %d' % (epoch))
val_loss = 0
tr = DataLoader(s2v_test_dataset, batch_size=len(s2v_test_dataset), shuffle=False)
for vv in tr:
output = model(vv.x, vv.edge_index, vv.batch)
#output = pass_data_iteratively(model, s2v_test_dataset)
pred = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
if use_weights:
loss = nn.CrossEntropyLoss(weight=weight)(output,labels)
else:
loss = nn.CrossEntropyLoss()(output,labels)
val_loss += loss
print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))
if val_loss < min_val_loss:
print(f"Saving best model with validation loss {val_loss}")
best_model = copy.deepcopy(model)
epochs_no_improve = 0
min_val_loss = val_loss
#if acc_train > 0.75:
# opt = torch.optim.Adam(model.parameters(), lr = 0.01)
#if acc_train > 0.85:
# opt = torch.optim.Adam(model.parameters(), lr = 0.001)
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!')
model.load_state_dict(best_model.state_dict())
break
confusion_array = []
true_class_array = []
predicted_class_array = []
model.eval()
correct = 0
true_class_array = []
predicted_class_array = []
test_loss = 0
model.load_state_dict(best_model.state_dict())
tr = DataLoader(s2v_test_dataset, batch_size=len(s2v_test_dataset), shuffle=False)
for vv in tr:
output = model(vv.x, vv.edge_index, vv.batch)
#output = pass_data_iteratively(model, s2v_test_dataset)
output = np.array(output.detach())
predicted_class = output.argmax(1, keepdims=True)
predicted_class = list(predicted_class)
labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
correct = torch.tensor(np.array(predicted_class)).eq(labels.view_as(torch.tensor(np.array(predicted_class)))).sum().item()
confusion_matrix_gnn = confusion_matrix(labels, predicted_class)
print("\nConfusion matrix (Validation set):\n")
print(confusion_matrix_gnn)
from sklearn.metrics import balanced_accuracy_score
acc_bal = balanced_accuracy_score(labels, predicted_class)
print("Validation accuracy: {}".format(acc_bal))
model.train()
self.model_status = 'Trained'
self.model = copy.deepcopy(model)
self.accuracy = acc_bal
self.confusion_matrix = confusion_matrix_gnn
#self.test_loss = test_loss
self.s2v_test_dataset = s2v_test_dataset
self.predictions = predicted_class_array
self.true_class = labels
[docs] def train_chebconv(self, epoch_nr = 20, shuffle=True, weights=False):
"""
Train the GNN model on the data provided during initialisation.
"""
use_weights = False
dataset = self.dataset
gene_names = self.gene_names
graphs_class_0_list = []
graphs_class_1_list = []
for graph in dataset:
if graph.y.numpy() == 0:
graphs_class_0_list.append(graph)
else:
graphs_class_1_list.append(graph)
graphs_class_0_len = len(graphs_class_0_list)
graphs_class_1_len = len(graphs_class_1_list)
print(f"Graphs class 0: {graphs_class_0_len}, Graphs class 1: {graphs_class_1_len}")
########################################################################################################################
# Downsampling of the class that contains more elements ===========================================================
# ########################################################################################################################
if graphs_class_0_len >= graphs_class_1_len:
random_graphs_class_0_list = random.sample(graphs_class_0_list, graphs_class_1_len)
balanced_dataset_list = graphs_class_1_list + random_graphs_class_0_list
if graphs_class_0_len < graphs_class_1_len:
random_graphs_class_1_list = random.sample(graphs_class_1_list, graphs_class_0_len)
balanced_dataset_list = graphs_class_0_list + random_graphs_class_1_list
random.shuffle(balanced_dataset_list)
print(f"Length of balanced dataset list: {len(balanced_dataset_list)}")
list_len = len(balanced_dataset_list)
#print(list_len)
train_set_len = int(list_len * 4 / 5)
train_dataset_list = balanced_dataset_list[:train_set_len]
test_dataset_list = balanced_dataset_list[train_set_len:]
train_graph_class_0_nr = 0
train_graph_class_1_nr = 0
for graph in train_dataset_list:
if graph.y.numpy() == 0:
train_graph_class_0_nr += 1
else:
train_graph_class_1_nr += 1
print(f"Train graph class 0: {train_graph_class_0_nr}, train graph class 1: {train_graph_class_1_nr}")
test_graph_class_0_nr = 0
test_graph_class_1_nr = 0
for graph in test_dataset_list:
if graph.y.numpy() == 0:
test_graph_class_0_nr += 1
else:
test_graph_class_1_nr += 1
print(f"Validation graph class 0: {test_graph_class_0_nr}, validation graph class 1: {test_graph_class_1_nr}")
# for ChebConv()
s2v_train_dataset = train_dataset_list
s2v_test_dataset = test_dataset_list
model_path = 'omics_model.pth'
no_of_features = dataset[0].x.shape[1]
nodes_per_graph_nr = dataset[0].x.shape[0]
input_dim = no_of_features
n_classes = 2
#model = GraphCNN(num_layers, num_mlp_layers, input_dim, 32, n_classes, 0.5, True, graph_pooling_type, neighbor_pooling_type, 0)
model = ChebConv(input_dim, n_classes, 10)
opt = torch.optim.Adam(model.parameters(), lr = 0.1)
load_model = False
if load_model:
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['state_dict'])
opt = checkpoint['optimizer']
model.train()
#min_loss = 50000
#best_model = GraphCNN(num_layers, num_mlp_layers, input_dim, 32, n_classes, 0.5, True, graph_pooling_type, neighbor_pooling_type, 0)
best_model = ChebConv(input_dim, n_classes, 10)
min_val_loss = 1000000
n_epochs_stop = 7
epochs_no_improve = 0
steps_per_epoch = 35
for epoch in range(epoch_nr):
model.train()
pbar = tqdm(range(steps_per_epoch), unit='batch')
epoch_loss = 0
for pos in pbar:
selected_idx = np.random.permutation(len(s2v_train_dataset))[:30]
batch_graph = [s2v_train_dataset[idx] for idx in selected_idx]
logits=[]
for g in batch_graph:
logits.append(model(x=g.x, edge_index=g.edge_index).max(0)[0])
#logits.append(model(x=g.x, edge_index=g.edge_index).mean(0))
logits = torch.reshape(torch.cat(logits,0),(30,2))
labels = torch.LongTensor([graph.y for graph in batch_graph])
if use_weights:
loss = nn.CrossEntropyLoss(weight=weight)(logits,labels)
else:
loss = nn.CrossEntropyLoss()(logits,labels)
opt.zero_grad()
loss.backward()
opt.step()
epoch_loss += loss.detach().item()
epoch_loss /= steps_per_epoch
model.eval()
output = []
for graphs in s2v_train_dataset:
output.append(model(x=graphs.x, edge_index=graphs.edge_index).max(0)[0])
#output.append(model(x=graphs.x, edge_index=graphs.edge_index).mean(0))
output = torch.reshape(torch.cat(output,0),(len(output),2))
output = np.array(output.detach())
predicted_class = output.argmax(1, keepdims=True)
predicted_class = list(predicted_class)
labels = torch.LongTensor([graph.y for graph in s2v_train_dataset])
correct = torch.tensor(np.array(predicted_class)).eq(labels.view_as(torch.tensor(np.array(predicted_class)))).sum().item()
acc_train = correct / len(s2v_train_dataset)
print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
print(f"Train Acc {acc_train:.4f}")
pbar.set_description('epoch: %d' % (epoch))
val_loss = 0
output = []
for graphs in s2v_test_dataset:
output.append(model(x=graphs.x, edge_index=graphs.edge_index).max(0)[0])
#output.append(model(x=graphs.x, edge_index=graphs.edge_index).mean(0))
output = torch.reshape(torch.cat(output,0),(len(output),2))
labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
if use_weights:
loss = nn.CrossEntropyLoss(weight=weight)(output,labels)
else:
loss = nn.CrossEntropyLoss()(output,labels)
val_loss += loss
print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))
if val_loss < min_val_loss:
print(f"Saving best model with validation loss {val_loss}")
best_model = copy.deepcopy(model)
epochs_no_improve = 0
min_val_loss = val_loss
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!')
model.load_state_dict(best_model.state_dict())
break
confusion_array = []
true_class_array = []
predicted_class_array = []
model.eval()
correct = 0
true_class_array = []
predicted_class_array = []
test_loss = 0
model.load_state_dict(best_model.state_dict())
output = []
for graphs in s2v_test_dataset:
output.append(model(x=graphs.x, edge_index=graphs.edge_index).max(0)[0])
#output.append(model(x=graphs.x, edge_index=graphs.edge_index).mean(0))
output = torch.reshape(torch.cat(output,0),(len(output),2))
output = np.array(output.detach())
predicted_class = output.argmax(1, keepdims=True)
predicted_class = list(predicted_class)
labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
correct = torch.tensor(np.array(predicted_class)).eq(labels.view_as(torch.tensor(np.array(predicted_class)))).sum().item()
confusion_matrix_gnn = confusion_matrix(labels, predicted_class)
print("\nConfusion matrix (Validation set):\n")
print(confusion_matrix_gnn)
from sklearn.metrics import balanced_accuracy_score
acc_bal = balanced_accuracy_score(labels, predicted_class)
print("Validation accuracy: {}".format(acc_bal))
model.train()
self.model_status = 'Trained'
self.model = copy.deepcopy(model)
self.accuracy = acc_bal
self.confusion_matrix = confusion_matrix_gnn
#self.test_loss = test_loss
self.s2v_test_dataset = s2v_test_dataset
self.predictions = predicted_class_array
self.true_class = labels
#model = GraphCNN(5, 2, input_dim, 32, n_classes, 0.5, True, 'sum1', 'sum', 0)
[docs] def train_graphcnn(self, num_layers=2, num_mlp_layers=2, epoch_nr = 20, shuffle=True, weights=False, graph_pooling_type='sum1', neighbor_pooling_type ='sum', learning_rate=0.1):
"""
Train the GNN model on the data provided during initialisation.
num_layers: number of layers in the neural networks (INCLUDING the input layer)
num_mlp_layers: number of layers in mlps (EXCLUDING the input layer)
graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
neighbor_pooling_type: *sum*! how to aggregate neighbors (mean, average, or max)
"""
use_weights = False
dataset = self.dataset
gene_names = self.gene_names
graphs_class_0_list = []
graphs_class_1_list = []
for graph in dataset:
if graph.y.numpy() == 0:
graphs_class_0_list.append(graph)
else:
graphs_class_1_list.append(graph)
graphs_class_0_len = len(graphs_class_0_list)
graphs_class_1_len = len(graphs_class_1_list)
print(f"Graphs class 0: {graphs_class_0_len}, Graphs class 1: {graphs_class_1_len}")
########################################################################################################################
# Downsampling of the class that contains more elements ===========================================================
# ########################################################################################################################
if graphs_class_0_len >= graphs_class_1_len:
random_graphs_class_0_list = random.sample(graphs_class_0_list, graphs_class_1_len)
balanced_dataset_list = graphs_class_1_list + random_graphs_class_0_list
if graphs_class_0_len < graphs_class_1_len:
random_graphs_class_1_list = random.sample(graphs_class_1_list, graphs_class_0_len)
balanced_dataset_list = graphs_class_0_list + random_graphs_class_1_list
#print(len(random_graphs_class_0_list))
#print(len(random_graphs_class_1_list))
random.shuffle(balanced_dataset_list)
print(f"Length of balanced dataset list: {len(balanced_dataset_list)}")
list_len = len(balanced_dataset_list)
#print(list_len)
train_set_len = int(list_len * 4 / 5)
train_dataset_list = balanced_dataset_list[:train_set_len]
test_dataset_list = balanced_dataset_list[train_set_len:]
train_graph_class_0_nr = 0
train_graph_class_1_nr = 0
for graph in train_dataset_list:
if graph.y.numpy() == 0:
train_graph_class_0_nr += 1
else:
train_graph_class_1_nr += 1
print(f"Train graph class 0: {train_graph_class_0_nr}, train graph class 1: {train_graph_class_1_nr}")
test_graph_class_0_nr = 0
test_graph_class_1_nr = 0
for graph in test_dataset_list:
if graph.y.numpy() == 0:
test_graph_class_0_nr += 1
else:
test_graph_class_1_nr += 1
print(f"Validation graph class 0: {test_graph_class_0_nr}, validation graph class 1: {test_graph_class_1_nr}")
s2v_train_dataset = convert_to_s2vgraph(train_dataset_list)
s2v_test_dataset = convert_to_s2vgraph(test_dataset_list)
# TRAIN GNN -------------------------------------------------- #
#count = 0
#for item in dataset:
# count += item.y.item()
#weight = torch.tensor([count/len(dataset), 1-count/len(dataset)])
#print(count/len(dataset), 1-count/len(dataset))
model_path = 'omics_model.pth'
no_of_features = dataset[0].x.shape[1]
nodes_per_graph_nr = dataset[0].x.shape[0]
#print(len(dataset), len(dataset)*0.2)
#s2v_dataset = convert_to_s2vgraph(dataset)
#train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=123)
#s2v_train_dataset = convert_to_s2vgraph(train_dataset)
#s2v_test_dataset = convert_to_s2vgraph(test_dataset)
#s2v_train_dataset, s2v_test_dataset = train_test_split(s2v_dataset, test_size=0.2, random_state=123)
input_dim = no_of_features
n_classes = 2
model = GraphCNN(num_layers, num_mlp_layers, input_dim, 32, n_classes, 0.5, True, graph_pooling_type, neighbor_pooling_type, 0)
opt = torch.optim.Adam(model.parameters(), lr = learning_rate)
load_model = False
if load_model:
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['state_dict'])
opt = checkpoint['optimizer']
model.train()
min_loss = 50
best_model = GraphCNN(num_layers, num_mlp_layers, input_dim, 32, n_classes, 0.5, True, graph_pooling_type, neighbor_pooling_type, 0)
min_val_loss = 1000000
n_epochs_stop = 10
epochs_no_improve = 0
steps_per_epoch = 35
for epoch in range(epoch_nr):
model.train()
pbar = tqdm(range(steps_per_epoch), unit='batch')
epoch_loss = 0
for pos in pbar:
selected_idx = np.random.permutation(len(s2v_train_dataset))[:32]
batch_graph = [s2v_train_dataset[idx] for idx in selected_idx]
logits = model(batch_graph)
labels = torch.LongTensor([graph.label for graph in batch_graph])
if use_weights:
loss = nn.CrossEntropyLoss(weight=weight)(logits,labels)
else:
loss = nn.CrossEntropyLoss()(logits,labels)
opt.zero_grad()
loss.backward()
opt.step()
epoch_loss += loss.detach().item()
epoch_loss /= steps_per_epoch
model.eval()
output = pass_data_iteratively(model, s2v_train_dataset)
predicted_class = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in s2v_train_dataset])
correct = predicted_class.eq(labels.view_as(predicted_class)).sum().item()
acc_train = correct / float(len(s2v_train_dataset))
print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
print(f"Train Acc {acc_train:.4f}")
pbar.set_description('epoch: %d' % (epoch))
val_loss = 0
output = pass_data_iteratively(model, s2v_test_dataset)
pred = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in s2v_test_dataset])
if use_weights:
loss = nn.CrossEntropyLoss(weight=weight)(output,labels)
else:
loss = nn.CrossEntropyLoss()(output,labels)
val_loss += loss
print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))
if val_loss < min_val_loss:
print(f"Saving best model with validation loss {val_loss}")
best_model = copy.deepcopy(model)
epochs_no_improve = 0
min_val_loss = val_loss
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!')
model.load_state_dict(best_model.state_dict())
break
confusion_array = []
true_class_array = []
predicted_class_array = []
model.eval()
correct = 0
true_class_array = []
predicted_class_array = []
test_loss = 0
model.load_state_dict(best_model.state_dict())
output = pass_data_iteratively(model, s2v_test_dataset)
predicted_class = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in s2v_test_dataset])
correct = predicted_class.eq(labels.view_as(predicted_class)).sum().item()
acc_test = correct / float(len(s2v_test_dataset))
if use_weights:
loss = nn.CrossEntropyLoss(weight=weight)(output,labels)
else:
loss = nn.CrossEntropyLoss()(output,labels)
test_loss = loss
predicted_class_array = np.append(predicted_class_array, predicted_class)
true_class_array = np.append(true_class_array, labels)
confusion_matrix_gnn = confusion_matrix(true_class_array, predicted_class_array)
print("\nConfusion matrix (Validation set):\n")
print(confusion_matrix_gnn)
counter = 0
for it, i in zip(predicted_class_array, range(len(predicted_class_array))):
if it == true_class_array[i]:
counter += 1
accuracy = counter/len(true_class_array) * 100
print("Validation accuracy: {}%".format(accuracy))
print("Validation loss {}".format(test_loss))
checkpoint = {
'state_dict': best_model.state_dict(),
'optimizer': opt.state_dict()
}
torch.save(checkpoint, model_path)
model.train()
self.model_status = 'Trained'
self.model = copy.deepcopy(model)
self.accuracy = accuracy
self.confusion_matrix = confusion_matrix_gnn
self.test_loss = test_loss
self.s2v_test_dataset = s2v_test_dataset
self.predictions = predicted_class_array
self.true_class = true_class_array
[docs] def explain_graphcheb(self, n_runs=10, explainer_lambda=0.8, communities=True, save_to_disk=False):
"""
Explain the model's results.
"""
############################################
# Run the Explainer
############################################
LOC = self.location
model = self.model
s2v_test_dataset = self.s2v_test_dataset
dataset = self.dataset
gene_names = self.gene_names
print("")
print("------- Run the Explainer -------")
print("")
no_of_runs = n_runs
lamda = 0.8 # not used!
ems = []
NODE_MASK = list()
for idx in range(no_of_runs):
print(f'Explainer::Iteration {idx+1} of {no_of_runs}')
exp = GNNExplainer(model, epochs=300)
em = exp.explain_graph_modified_cheb2(s2v_test_dataset, lamda)
#Path(f"{path}/{sigma}/modified_gnn").mkdir(parents=True, exist_ok=True)
gnn_feature_masks = np.reshape(em, (len(em), -1))
NODE_MASK.append(np.array(gnn_feature_masks.sigmoid()))
np.savetxt(f'{LOC}/gnn_feature_masks{idx}.csv', gnn_feature_masks.sigmoid(), delimiter=',', fmt='%.3f')
#np.savetxt(f'{path}/{sigma}/modified_gnn/gnn_feature_masks{idx}.csv', gnn_feature_masks.sigmoid(), delimiter=',', fmt='%.3f')
gnn_edge_masks = calc_edge_importance(gnn_feature_masks, dataset[0].edge_index)
np.savetxt(f'{LOC}/gnn_edge_masks{idx}.csv', gnn_edge_masks.sigmoid(), delimiter=',', fmt='%.3f')
#np.savetxt(f'{path}/{sigma}/modified_gnn/gnn_edge_masks{idx}.csv', gnn_edge_masks.sigmoid(), delimiter=',', fmt='%.3f')
ems.append(gnn_edge_masks.sigmoid().numpy())
ems = np.array(ems)
mean_em = ems.mean(0)
# OUTPUT -- Save Edge Masks
np.savetxt(f'{LOC}/edge_masks.txt', mean_em, delimiter=',', fmt='%.5f')
self.edge_mask = mean_em
self.node_mask_matrix = np.concatenate(NODE_MASK,1)
self.node_mask = np.concatenate(NODE_MASK,1).mean(1)
self._explainer_run = True
###############################################
# Perform Community Detection
###############################################
if communities:
avg_mask, coms = find_communities(f'{LOC}/edge_index.txt', f'{LOC}/edge_masks.txt')
self.modules = coms
self.module_importances = avg_mask
np.savetxt(f'{LOC}/communities_scores.txt', avg_mask, delimiter=',', fmt='%.3f')
filePath = f'{LOC}/communities.txt'
if os.path.exists(filePath):
os.remove(filePath)
f = open(f'{LOC}/communities.txt', "a")
for idx in range(len(avg_mask)):
s_com = ','.join(str(e) for e in coms[idx])
f.write(s_com + '\n')
f.close()
# Write gene_names to file
textfile = open(f'{LOC}/gene_names.txt', "w")
for element in gene_names:
listToStr = ''.join(map(str, element))
textfile.write(listToStr + "\n")
textfile.close()
self._explainer_run = True
[docs] def explain_chebconv(self, n_runs=10, explainer_lambda=0.8, communities=True, save_to_disk=False):
"""
Explain the model's results.
"""
############################################
# Run the Explainer
############################################
LOC = self.location
model = self.model
s2v_test_dataset = self.s2v_test_dataset
dataset = self.dataset
gene_names = self.gene_names
print("")
print("------- Run the Explainer -------")
print("")
no_of_runs = n_runs
lamda = 0.8 # not used!
ems = []
NODE_MASK = list()
for idx in range(no_of_runs):
print(f'Explainer::Iteration {idx+1} of {no_of_runs}')
exp = GNNExplainer(model, epochs=300)
em = exp.explain_graph_modified_cheb(s2v_test_dataset, lamda)
#Path(f"{path}/{sigma}/modified_gnn").mkdir(parents=True, exist_ok=True)
gnn_feature_masks = np.reshape(em, (len(em), -1))
NODE_MASK.append(np.array(gnn_feature_masks.sigmoid()))
np.savetxt(f'{LOC}/gnn_feature_masks{idx}.csv', gnn_feature_masks.sigmoid(), delimiter=',', fmt='%.3f')
#np.savetxt(f'{path}/{sigma}/modified_gnn/gnn_feature_masks{idx}.csv', gnn_feature_masks.sigmoid(), delimiter=',', fmt='%.3f')
gnn_edge_masks = calc_edge_importance(gnn_feature_masks, dataset[0].edge_index)
np.savetxt(f'{LOC}/gnn_edge_masks{idx}.csv', gnn_edge_masks.sigmoid(), delimiter=',', fmt='%.3f')
#np.savetxt(f'{path}/{sigma}/modified_gnn/gnn_edge_masks{idx}.csv', gnn_edge_masks.sigmoid(), delimiter=',', fmt='%.3f')
ems.append(gnn_edge_masks.sigmoid().numpy())
ems = np.array(ems)
mean_em = ems.mean(0)
# OUTPUT -- Save Edge Masks
np.savetxt(f'{LOC}/edge_masks.txt', mean_em, delimiter=',', fmt='%.5f')
self.edge_mask = mean_em
self.node_mask_matrix = np.concatenate(NODE_MASK,1)
self.node_mask = np.concatenate(NODE_MASK,1).mean(1)
self._explainer_run = True
###############################################
# Perform Community Detection
###############################################
if communities:
avg_mask, coms = find_communities(f'{LOC}/edge_index.txt', f'{LOC}/edge_masks.txt')
self.modules = coms
self.module_importances = avg_mask
np.savetxt(f'{LOC}/communities_scores.txt', avg_mask, delimiter=',', fmt='%.3f')
filePath = f'{LOC}/communities.txt'
if os.path.exists(filePath):
os.remove(filePath)
f = open(f'{LOC}/communities.txt', "a")
for idx in range(len(avg_mask)):
s_com = ','.join(str(e) for e in coms[idx])
f.write(s_com + '\n')
f.close()
# Write gene_names to file
textfile = open(f'{LOC}/gene_names.txt', "w")
for element in gene_names:
listToStr = ''.join(map(str, element))
textfile.write(listToStr + "\n")
textfile.close()
self._explainer_run = True
[docs] def explain_graphcnn(self, n_runs=10, explainer_lambda=0.8, communities=True, save_to_disk=False):
"""
Explain the model's results.
"""
############################################
# Run the Explainer
############################################
LOC = self.location
model = self.model
s2v_test_dataset = self.s2v_test_dataset
dataset = self.dataset
gene_names = self.gene_names
print("")
print("------- Run the Explainer -------")
print("")
no_of_runs = n_runs
lamda = 0.8 # not used!
ems = []
NODE_MASK = list()
for idx in range(no_of_runs):
print(f'Explainer::Iteration {idx+1} of {no_of_runs}')
exp = GNNExplainer(model, epochs=300)
em = exp.explain_graph_modified_s2v(s2v_test_dataset, lamda)
#Path(f"{path}/{sigma}/modified_gnn").mkdir(parents=True, exist_ok=True)
gnn_feature_masks = np.reshape(em, (len(em), -1))
NODE_MASK.append(np.array(gnn_feature_masks.sigmoid()))
np.savetxt(f'{LOC}/gnn_feature_masks{idx}.csv', gnn_feature_masks.sigmoid(), delimiter=',', fmt='%.3f')
#np.savetxt(f'{path}/{sigma}/modified_gnn/gnn_feature_masks{idx}.csv', gnn_feature_masks.sigmoid(), delimiter=',', fmt='%.3f')
gnn_edge_masks = calc_edge_importance(gnn_feature_masks, dataset[0].edge_index)
np.savetxt(f'{LOC}/gnn_edge_masks{idx}.csv', gnn_edge_masks.sigmoid(), delimiter=',', fmt='%.3f')
#np.savetxt(f'{path}/{sigma}/modified_gnn/gnn_edge_masks{idx}.csv', gnn_edge_masks.sigmoid(), delimiter=',', fmt='%.3f')
ems.append(gnn_edge_masks.sigmoid().numpy())
ems = np.array(ems)
mean_em = ems.mean(0)
# OUTPUT -- Save Edge Masks
np.savetxt(f'{LOC}/edge_masks.txt', mean_em, delimiter=',', fmt='%.5f')
self.edge_mask = mean_em
self.node_mask_matrix = np.concatenate(NODE_MASK,1)
self.node_mask = np.concatenate(NODE_MASK,1).mean(1)
self._explainer_run = True
###############################################
# Perform Community Detection
###############################################
if communities:
avg_mask, coms = find_communities(f'{LOC}/edge_index.txt', f'{LOC}/edge_masks.txt')
self.modules = coms
self.module_importances = avg_mask
np.savetxt(f'{LOC}/communities_scores.txt', avg_mask, delimiter=',', fmt='%.3f')
filePath = f'{LOC}/communities.txt'
if os.path.exists(filePath):
os.remove(filePath)
f = open(f'{LOC}/communities.txt', "a")
for idx in range(len(avg_mask)):
s_com = ','.join(str(e) for e in coms[idx])
f.write(s_com + '\n')
f.close()
# Write gene_names to file
textfile = open(f'{LOC}/gene_names.txt', "w")
for element in gene_names:
listToStr = ''.join(map(str, element))
textfile.write(listToStr + "\n")
textfile.close()
self._explainer_run = True
[docs] def predict_graphcheb(self, gnnsubnet_test):
confusion_array = []
true_class_array = []
predicted_class_array = []
s2v_test_dataset = gnnsubnet_test.dataset
model = self.model
model.eval()
tr = DataLoader(s2v_test_dataset, batch_size=len(s2v_test_dataset), shuffle=False)
for vv in tr:
output = model(vv.x, vv.edge_index, vv.batch)
output = np.array(output.detach())
predicted_class = output.argmax(1, keepdims=True)
predicted_class = list(predicted_class)
labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
correct = torch.tensor(np.array(predicted_class)).eq(labels.view_as(torch.tensor(np.array(predicted_class)))).sum().item()
confusion_matrix_gnn = confusion_matrix(labels, predicted_class)
print("\nConfusion matrix (Validation set):\n")
print(confusion_matrix_gnn)
from sklearn.metrics import balanced_accuracy_score
acc_bal = balanced_accuracy_score(labels, predicted_class)
print("Validation accuracy: {}".format(acc_bal))
self.predictions_test = predicted_class
self.true_class_test = labels
self.accuracy_test = acc_bal
self.confusion_matrix_test = confusion_matrix_gnn
return predicted_class
[docs] def predict_chebconv(self, gnnsubnet_test):
confusion_array = []
true_class_array = []
predicted_class_array = []
s2v_test_dataset = gnnsubnet_test.dataset
model = self.model
model.eval()
output = []
for graphs in s2v_test_dataset:
output.append(model(x=graphs.x, edge_index=graphs.edge_index).max(0)[0])
#output.append(model(x=graphs.x, edge_index=graphs.edge_index).mean(0))
output = torch.reshape(torch.cat(output,0),(len(output),2))
output = np.array(output.detach())
predicted_class = output.argmax(1, keepdims=True)
predicted_class = list(predicted_class)
labels = torch.LongTensor([graph.y for graph in s2v_test_dataset])
correct = torch.tensor(np.array(predicted_class)).eq(labels.view_as(torch.tensor(np.array(predicted_class)))).sum().item()
confusion_matrix_gnn = confusion_matrix(labels, predicted_class)
print("\nConfusion matrix (Validation set):\n")
print(confusion_matrix_gnn)
from sklearn.metrics import balanced_accuracy_score
acc_bal = balanced_accuracy_score(labels, predicted_class)
print("Validation accuracy: {}".format(acc_bal))
self.predictions_test = predicted_class
self.true_class_test = labels
self.accuracy_test = acc_bal
self.confusion_matrix_test = confusion_matrix_gnn
return predicted_class
[docs] def predict_graphcnn(self, gnnsubnet_test):
confusion_array = []
true_class_array = []
predicted_class_array = []
s2v_test_dataset = convert_to_s2vgraph(gnnsubnet_test.dataset)
model = self.model
model.eval()
output = pass_data_iteratively(model, s2v_test_dataset)
predicted_class = output.max(1, keepdim=True)[1]
labels = torch.LongTensor([graph.label for graph in s2v_test_dataset])
correct = predicted_class.eq(labels.view_as(predicted_class)).sum().item()
acc_test = correct / float(len(s2v_test_dataset))
#if use_weights:
# loss = nn.CrossEntropyLoss(weight=weight)(output,labels)
#else:
# loss = nn.CrossEntropyLoss()(output,labels)
#test_loss = loss
predicted_class_array = np.append(predicted_class_array, predicted_class)
true_class_array = np.append(true_class_array, labels)
confusion_matrix_gnn = confusion_matrix(true_class_array, predicted_class_array)
print("\nConfusion matrix:\n")
print(confusion_matrix_gnn)
counter = 0
for it, i in zip(predicted_class_array, range(len(predicted_class_array))):
if it == true_class_array[i]:
counter += 1
accuracy = counter/len(true_class_array) * 100
print("Accuracy: {}%".format(accuracy))
self.predictions_test = predicted_class_array
self.true_class_test = true_class_array
self.accuracy_test = accuracy
self.confusion_matrix_test = confusion_matrix_gnn
return predicted_class_array
[docs] def download_TCGA(self, save_to_disk=False) -> None:
"""
Warning: Currently not implemented!
Download some sample TCGA data. Running this function will download
approximately 100MB of data.
"""
base_url = 'https://raw.githubusercontent.com/pievos101/GNN-SubNet/python-package/TCGA/' # CHANGE THIS URL WHEN BRANCH MERGES TO MAIN
KIDNEY_RANDOM_Methy_FEATURES_filename = 'KIDNEY_RANDOM_Methy_FEATURES.txt'
KIDNEY_RANDOM_PPI_filename = 'KIDNEY_RANDOM_PPI.txt'
KIDNEY_RANDOM_TARGET_filename = 'KIDNEY_RANDOM_TARGET.txt'
KIDNEY_RANDOM_mRNA_FEATURES_filename = 'KIDNEY_RANDOM_mRNA_FEATURES.txt'
# For testing let's use KIDNEY_RANDOM_Methy_FEATURES and store in memory.
raw = requests.get(base_url + KIDNEY_RANDOM_Methy_FEATURES_filename, stream=True)
self.KIDNEY_RANDOM_Methy_FEATURES = np.asarray(pd.read_csv(io.BytesIO(raw.content), delimiter=' '))
# Clear some memory
raw = None
return None