Commit 29235f10 authored by Carlos GO's avatar Carlos GO
Browse files

tools dir

parent b83befd4
import os, sys
import pickle
import networkx as nx
import matplotlib
matplotlib.rcParams['text.usetex'] = True
import matplotlib.pyplot as plt
import seaborn as sns
if __name__ == "__main__":
sys.path.append("..")
from tools.rna_layout import circular_layout
params = {'text.latex.preamble': [r'\usepackage{fdsymbol}\usepackage{xspace}']}
plt.rc('font', family='serif')
plt.rcParams.update(params)
labels = {
'CW': r"$\medblackcircle$\xspace",
'CS': r"$\medblacktriangleright$\xspace",
'CH': r"$\medblacksquare$\xspace",
'TW': r"$\medcircle$\xspace",
'TS': r"$\medtriangleright$\xspace",
'TH': r"$\medsquare$\xspace"
}
make_label = lambda s: labels[s[:2]] + labels[s[0::2]] if len(set(s[1:])) == 2 \
else labels[s[:2]]
def rna_draw(nx_g, title="", highlight_edges=None, nt_info=False, node_colors=None, num_clusters=None):
"""
Draw an RNA with the edge labels used by Leontis Westhof
:param nx_g:
:param title:
:param highlight_edges:
:param node_colors:
:param num_clusters:
:return:
"""
# pos = circular_layout(nx_g)
pos = nx.spring_layout(nx_g)
if node_colors is None:
nodes = nx.draw_networkx_nodes(nx_g, pos, node_size=150, node_color='white', linewidths=2)
else:
nodes = nx.draw_networkx_nodes(nx_g, pos, node_size=150, node_color=node_colors, linewidths=2)
nodes.set_edgecolor('black')
if nt_info:
nx.draw_networkx_labels(nx_g, pos, font_color='black')
# plt.title(r"{0}".format(title))
edge_labels = {}
for n1, n2, d in nx_g.edges(data=True):
try:
symbol = make_label(d['label'])
edge_labels[(n1, n2)] = symbol
except:
if d['label'] == 'B53':
edge_labels[(n1, n2)] = ''
else:
edge_labels[(n1, n2)] = r"{0}".format(d['label'])
continue
non_bb_edges = [(n1, n2) for n1, n2, d in nx_g.edges(data=True) if d['label'] != 'B53']
bb_edges = [(n1, n2) for n1, n2, d in nx_g.edges(data=True) if d['label'] == 'B53']
nx.draw_networkx_edges(nx_g, pos, edgelist=non_bb_edges)
nx.draw_networkx_edges(nx_g, pos, edgelist=bb_edges, width=2)
if not highlight_edges is None:
nx.draw_networkx_edges(nx_g, pos, edgelist=highlight_edges, edge_color='y', width=8, alpha=0.5)
nx.draw_networkx_edge_labels(nx_g, pos, font_size=16,
edge_labels=edge_labels)
plt.axis('off')
# plt.savefig('fmn_' + title + '.png', format='png')
# plt.clf()
plt.show()
def rna_draw_pair(graphs, estimated_value=None, highlight_edges=None, node_colors=None, num_clusters=None,
similarity=False,
true_value=None):
fig, ax = plt.subplots(1, len(graphs), num=1)
for i, g in enumerate(graphs):
pos = nx.spring_layout(g)
if not node_colors is None:
nodes = nx.draw_networkx_nodes(g, pos, node_size=150, node_color=node_colors[i], linewidths=2, ax=ax[i])
else:
nodes = nx.draw_networkx_nodes(g, pos, node_size=150, node_color='grey', linewidths=2, ax=ax[i])
nodes.set_edgecolor('black')
# plt.title(r"{0}".format(title))
edge_labels = {}
for n1, n2, d in g.edges(data=True):
try:
symbol = make_label(d['label'])
edge_labels[(n1, n2)] = symbol
except:
if d['label'] == 'B53':
edge_labels[(n1, n2)] = ''
else:
edge_labels[(n1, n2)] = r"{0}".format(d['label'])
continue
non_bb_edges = [(n1, n2) for n1, n2, d in g.edges(data=True) if d['label'] != 'B53']
bb_edges = [(n1, n2) for n1, n2, d in g.edges(data=True) if d['label'] == 'B53']
nx.draw_networkx_edges(g, pos, edgelist=non_bb_edges, ax=ax[i])
nx.draw_networkx_edges(g, pos, edgelist=bb_edges, width=2, ax=ax[i])
if not highlight_edges is None:
nx.draw_networkx_edges(g, pos, edgelist=highlight_edges, edge_color='y', width=8, alpha=0.5, ax=ax[i])
nx.draw_networkx_edge_labels(g, pos, font_size=16,
edge_labels=edge_labels, ax=ax[i])
ax[i].set_axis_off()
plt.axis('off')
title = 'similarity : ' if similarity else 'distance : ' + str(estimated_value)
if true_value is not None:
title = title + f' true : {true_value}'
plt.title(title)
plt.show()
def generic_draw_pair(graphs, title="", highlight_edges=None, node_colors=None, num_clusters=None):
fig, ax = plt.subplots(1, len(graphs), num=1)
for i, g in enumerate(graphs):
pos = nx.spring_layout(g)
if not node_colors is None:
nodes = nx.draw_networkx_nodes(g, pos, node_size=150, node_color=node_colors[i], linewidths=2, ax=ax[i])
else:
nodes = nx.draw_networkx_nodes(g, pos, node_size=150, node_color='grey', linewidths=2, ax=ax[i])
nodes.set_edgecolor('black')
# plt.title(r"{0}".format(title))
edge_labels = {}
for n1, n2, d in g.edges(data=True):
edge_labels[(n1, n2)] = str(d['label'])
if not highlight_edges is None:
nx.draw_networkx_edges(g, pos, edgelist=highlight_edges, edge_color='y', width=8, alpha=0.5, ax=ax[i])
nx.draw_networkx_edge_labels(g, pos, font_size=16,
edge_labels=edge_labels, ax=ax[i])
ax[i].set_axis_off()
plt.axis('off')
plt.title(f"distance {title}")
plt.show()
def generic_draw(graph, title="", highlight_edges=None, node_colors=None):
fig, ax = plt.subplots(1, 2, num=1)
pos = nx.spring_layout(graph)
if not node_colors is None:
nodes = nx.draw_networkx_nodes(graph, pos, node_size=150, cmap=plt.cm.Blues, node_color=node_colors,
linewidths=2, ax=ax[0])
else:
nodes = nx.draw_networkx_nodes(graph, pos, node_size=150, node_color='grey', linewidths=2, ax=ax[0])
nodes.set_edgecolor('black')
# plt.title(r"{0}".format(title))
edge_labels = {}
for n1, n2, d in graph.edges(data=True):
edge_labels[(n1, n2)] = str(d['label'])
if not highlight_edges is None:
nx.draw_networkx_edges(graph, pos, edgelist=highlight_edges, edge_color='y', width=8, alpha=0.5, ax=ax[0])
nx.draw_networkx_edges(graph, pos, ax=ax[0])
nx.draw_networkx_edge_labels(graph, pos, font_size=16,
edge_labels=edge_labels, ax=ax[0])
ax[0].set_axis_off()
plt.axis('off')
plt.title(f"motif {title}")
plt.show()
def ablation_draw():
g_name = "1fmn_#0.1:A:FMN:36.nx_annot.p"
modes = ['', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle']
for m in modes:
g_dir = "../data/annotated/pockets_nx" + m
g, _, _, _ = pickle.load(open(os.path.join(g_dir, g_name), 'rb'))
rna_draw(g, title=m)
pass
if __name__ == "__main__":
ablation_draw()
import pickle
import os
import itertools
from tqdm import tqdm
import networkx as nx
import torch
import dgl
def get_edge_map(graphs_dir):
edge_labels = set()
print("Collecting edge labels.")
for g in tqdm(os.listdir(graphs_dir)):
graph, _, _ = pickle.load(open(os.path.join(graphs_dir, g), 'rb'))
edges = {e_dict['label'] for _, _, e_dict in graph.edges(data=True)}
edge_labels = edge_labels.union(edges)
return {label: i for i, label in enumerate(sorted(edge_labels))}
def nx_to_dgl_jacques(graph, edge_map):
"""
Returns one training item at index `idx`.
"""
#adding the self edges
# graph.add_edges_from([(n, n, {'label': 'X'}) for n in graph.nodes()])
graph = nx.to_undirected(graph)
one_hot = {edge: torch.tensor(edge_map[label]) for edge, label in
(nx.get_edge_attributes(graph, 'label')).items()}
nx.set_edge_attributes(graph, name='one_hot', values=one_hot)
g_dgl = dgl.DGLGraph()
# g_dgl.from_networkx(nx_graph=graph, edge_attrs=['one_hot'], node_attrs=['one_hot'])
g_dgl.from_networkx(nx_graph=graph, edge_attrs=['one_hot'], node_attrs=['angles', 'identity'])
#JACQUES
# Init node embeddings with nodes features
floatid = g_dgl.ndata['identity'].float()
g_dgl.ndata['h'] = torch.cat([g_dgl.ndata['angles'], floatid], dim = 1)
print("HII")
return graph, g_dgl
def nx_to_dgl_(graph, edge_map, embed_dim):
"""
Networkx graph to DGL.
"""
import torch
import dgl
graph, _, ring = pickle.load(open(graph, 'rb'))
one_hot = {edge: edge_map[label] for edge, label in (nx.get_edge_attributes(graph, 'label')).items()}
nx.set_edge_attributes(graph, name='one_hot', values=one_hot)
one_hot = {edge: torch.tensor(edge_map[label]) for edge, label in (nx.get_edge_attributes(graph, 'label')).items()}
g_dgl = dgl.DGLGraph()
g_dgl.from_networkx(nx_graph=graph, edge_attrs=['one_hot'])
n_nodes = len(g_dgl.nodes())
g_dgl.ndata['h'] = torch.ones((n_nodes, embed_dim))
return graph, g_dgl
def dgl_to_nx(graph, edge_map):
g = dgl.to_networkx(graph, edge_attrs=['one_hot'])
edge_map_r = {v: k for k, v in edge_map.items()}
nx.set_edge_attributes(g, {(n1, n2): edge_map_r[d['one_hot'].item()] for n1, n2, d in g.edges(data=True)}, 'label')
return g
def bfs_expand(G, initial_nodes, depth=2):
"""
Extend motif graph starting with motif_nodes.
Returns list of nodes.
"""
total_nodes = [list(initial_nodes)]
for d in range(depth):
depth_ring = []
for n in total_nodes[d]:
for nei in G.neighbors(n):
depth_ring.append(nei)
total_nodes.append(depth_ring)
return set(itertools.chain(*total_nodes))
def bfs(G, initial_node, depth=2):
"""
Generator for bfs given graph and initial node.
Yields nodes at next hop at each call.
"""
total_nodes = [[initial_node]]
visited = []
for d in range(depth):
depth_ring = []
for n in total_nodes[d]:
visited.append(n)
for nei in G.neighbors(n):
if nei not in visited:
depth_ring.append(nei)
total_nodes.append(depth_ring)
yield depth_ring
def graph_ablations(G, mode):
"""
Remove edges with certain labels depending on the mode.
:params
:G Binding Site Graph
:mode how to remove edges ('bb-only', 'wc-bb', 'wc-bb-nc', 'no-label')
:returns: Copy of original graph with edges removed/relabeled.
"""
H = nx.Graph()
if mode == 'label-shuffle':
# assign a random label from the same graph to each edge.
labels = [d['label'] for _, _, d in G.edges(data=True)]
shuffle(labels)
for n1, n2, d in G.edges(data=True):
H.add_edge(n1, n2, label=labels.pop())
return H
if mode == 'no-label':
for n1, n2, d in G.edges(data=True):
H.add_edge(n1, n2, label='X')
return H
if mode == 'wc-bb-nc':
for n1, n2, d in G.edges(data=True):
label = d['label']
if d['label'] not in ['CWW', 'B53']:
label = 'NC'
H.add_edge(n1, n2, label=label)
return H
if mode == 'bb-only':
valid_edges = ['B53']
if mode == 'wc-bb':
valid_edges = ['B53', 'CWW']
for n1, n2, d in G.edges(data=True):
if d['label'] in valid_edges:
H.add_edge(n1, n2, label=d['label'])
return H
import os
import configparser
from ast import literal_eval
import pickle
from tqdm import tqdm
import torch
import numpy as np
import networkx as nx
from learning.loader import Loader, InferenceLoader
from learning.learn import send_graph_to_device
from learning.rgcn import Model
def remove(name):
"""
delete an experiment results
:param name:
:return:
"""
import shutil
script_dir = os.path.dirname(__file__)
logdir = os.path.join(script_dir, f'../results/logs/{name}')
weights_dir = os.path.join(script_dir, f'../results/trained_models/{name}')
experiment = os.path.join(script_dir, f'../results/experiments/{name}.exp')
shutil.rmtree(logdir)
shutil.rmtree(weights_dir)
os.remove(experiment)
return True
def setup():
"""
Create all relevant directories to setup the learning procedure
:return:
"""
script_dir = os.path.dirname(__file__)
resdir = os.path.join(script_dir, f'../results/')
logdir = os.path.join(script_dir, f'../results/logs/')
weights_dir = os.path.join(script_dir, f'../results/trained_models/')
experiment = os.path.join(script_dir, f'../results/experiments/')
os.mkdir(resdir)
os.mkdir(logdir)
os.mkdir(weights_dir)
os.mkdir(experiment)
def mkdirs_learning(name, permissive=True):
"""
Try to make the logs folder for each experiment
:param name:
:param permissive: If True will overwrite existing files (good for debugging)
:return:
"""
from tools.utils import makedir
log_path = os.path.join('results/logs', name)
save_path = os.path.join('results/trained_models', name)
makedir(log_path, permissive)
makedir(save_path, permissive)
save_name = os.path.join(save_path, name + '.pth')
return log_path, save_name
def load_model(run):
"""
Load full trained model with id `run`
"""
meta = pickle.load(open(f'../results/trained_models/{run}/meta.p', 'rb'))
edge_map = meta['edge_map']
num_edge_types = len(edge_map)
model_dict = torch.load(f'../results/trained_models/{run}/{run}.pth', map_location='cpu')
model = Model(dims=meta['embedding_dims'], attributor_dims=meta['attributor_dims'], num_rels=num_edge_types,
num_bases=-1,
device='cpu',
pool=meta['pool'])
model.load_state_dict(model_dict['model_state_dict'])
return model, meta
def load_data(annotated_path, meta, get_sim_mat=True):
"""
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
"""
loader = Loader(
annotated_path=annotated_path,
batch_size=1, num_workers=1,
sim_function=meta['sim_function'],
get_sim_mat=get_sim_mat)
train_loader, _, test_loader = loader.get_data()
return train_loader, test_loader
def predict(model, loader, max_graphs=10, device='cpu'):
all_graphs = loader.dataset.all_graphs
Z = []
fps = []
g_inds = []
model = model.to(device)
with torch.no_grad():
for i, (graph, K, fp, graph_index) in tqdm(enumerate(loader), total=len(loader)):
graph = send_graph_to_device(graph, device)
fp, z = model(graph)
Z.append(z.cpu().numpy())
fps.append(fp.cpu().numpy())
Z = np.concatenate(Z)
fps = np.array(fps)
return fps, Z
def inference_on_dir(run, graph_dir, ini=True, max_graphs=10, get_sim_mat=False,
split_mode='test', attributions=False, device='cpu'):
"""
Load model and get node embeddings.
The results then need to be parsed as the order of the graphs is random and that the order of
each node in the graph is the messed up one (sorted)
Returns : embeddings and attributions, as well as 'g_inds':
a dict (graph name, node_id in sorted g_nodes) : index in the embedding matrix
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
:max_graphs max number of graphs to get embeddings for
"""
model, meta = meta_load_model(run)
loader = InferenceLoader(graph_dir).get_data()
return predict(model, loader, max_graphs=max_graphs,
device=device)
def meta_load_model(run):
"""
Load full trained model with id `run`
"""
meta = pickle.load(open(f'models/{run}/meta.p', 'rb'))
print(meta)
edge_map = meta['edge_map']
num_edge_types = len(edge_map)
model_dict = torch.load(f'models/{run}/{run}.pth', map_location='cpu')
model = Model(dims=meta['embedding_dims'], attributor_dims=meta['attributor_dims'], num_rels=num_edge_types,
num_bases=-1, device='cpu')
model.load_state_dict(model_dict['model_state_dict'])
return model, meta
def model_from_hparams(hparams):
"""
Load full trained model with id `run`
"""
edge_map = hparams.get('edges', 'edge_map')
num_edge_types = len(edge_map)
run = hparams.get('argparse', 'name')
model_dict = torch.load(f'../results/trained_models/{run}/{run}.pth', map_location='cpu')
model = Model(dims=hparams.get('argparse', 'embedding_dims'),
attributor_dims=hparams.get('argparse', 'attributor_dims'),
num_rels=num_edge_types,
num_bases=-1,
hard_embed=hparams.get('argparse', 'hard_embed'))
model.load_state_dict(model_dict['model_state_dict'])
return model
def data_from_hparams(annotated_path, hparams, get_sim_mat=True):
"""
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
"""
dims = hparams.get('argparse', 'embedding_dims')
loader = Loader(annotated_path=annotated_path,
batch_size=hparams.get('argparse', 'batch_size'),
num_workers=1,
sim_function=hparams.get('argparse', 'sim_function'),
depth=hparams.get('argparse', 'kernel_depth'),
hard_embed=hparams.get('argparse', 'hard_embed'),
hparams=hparams,
get_sim_mat=get_sim_mat)
train_loader, _, test_loader = loader.get_data()
return train_loader, test_loader
def get_rgcn_outputs(run, graph_dir, ini=False, max_graphs=100, nc_only=False, get_sim_mat=True):
"""
Load model and get node embeddings.
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
:max_graphs max number of graphs to get embeddings for
"""
from tools.graph_utils import dgl_to_nx
if ini:
hparams = ConfParser(default_path=os.path.join('../results/experiments', f'{run}.exp'))
model = model_from_hparams(hparams)
train_loader, test_loader = data_from_hparams(graph_dir, hparams, get_sim_mat=get_sim_mat)
edge_map = hparams.get('edges', 'edge_map')
similarity = hparams.get('argparse', 'similarity')
else:
model, meta = load_model(run)
train_loader, test_loader = load_data(graph_dir, meta, get_sim_mat=get_sim_mat)
edge_map = meta['edge_map']
similarity = False
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor])
Z = []
fp_mat = []
nx_graphs = []
KS = []
# maps full nodeset index to graph and node index inside graph
node_map = {}
ind = 0
offset = 0
for i, (graph, K, graph_sizes) in enumerate(train_loader):
if i > max_graphs - 1:
break
fp, z = model(graph)
KS.append(K)
fp_mat.append(np.array(fp.detach().numpy()))
for j, emb in enumerate(z.detach().numpy()):
Z.append(np.array(emb))
node_map[ind] = (i, j)
ind += 1
# nx_graphs.append(nx_graph)
nx_g = dgl_to_nx(graph, edge_map)
#assign unique id to graph nodes
nx_g = nx.relabel_nodes(nx_g,{node:offset+k for k,node in enumerate(nx_g.nodes())})
offset += len(nx_g.nodes())
# print(z)
# rna_draw(nx_g)
nx_graphs.append(nx_g)