Commit a49dbf51 authored by Carlos GO's avatar Carlos GO
Browse files

readme

parent 64203b12
# RNAmigos: RNA 3D Small Molecule Ligand Prediction
# RNAmigos: RNA Small Molecule Ligand Prediction
This repository is an implementation of ligand prediction from an RNA base pairing network.
> Augmented base pairing networks encode RNA-small molecule binding preferences.
> Oliver C., Mallet V., Sarrazin Gendron, R., Reinharz V., Hamilton L W., Moitessier N., Waldispuhl J.
> BiorXiv, 2020.
> [[Paper]](https://www.biorxiv.org/content/10.1101/701326v3)
This repository is a GCN implementation of ligand prediction from an RNA 3D site.
We build a graph representation of an RNA site from a 3D structure.
This is fed to a GCN which outputs a vector representation (fingerprint) of a likely ligand.
......@@ -255,7 +255,7 @@ def annotate_all(fp_file="../data/all_ligs_maccs.p", dump_path='../data/annotate
if __name__ == '__main__':
# annotate_all(parallel=False, graph_path="../data/pockets_nx_2", dump_path="../data/annotated/pockets_nx_2", ablate="")
annotate_all(parallel=False, graph_path="../data/pockets_nx_symmetric", dump_path="../data/annotated/pockets_nx_symmetric_wc-bb",
ablate="wc-bb", mode='fp')
annotate_all(parallel=False, graph_path="../data/pockets_nx_symmetric", dump_path="../data/annotated/pockets_nx_symmetric_no-label",
ablate="no-label", mode='fp')
# annotate_all(parallel=False, graph_path="../data/pockets_nx_symmetric", dump_path="../data/annotated/pockets_nx_symmetric_clust",
# ablate="", mode='fp', fp_file='../data/fp_dict_8clusters.p')
......@@ -26,6 +26,7 @@ from Bio.PDB import MMCIFParser, NeighborSearch
from learning.rgcn import Model
from rna_classes import *
from post.utils import *
from post.tree_grid_vincent import compute_clustering
from learning.attn import get_attention_map
from learning.utils import dgl_to_nx
......@@ -94,10 +95,14 @@ def decoy_test(model, decoys, edge_map, embed_dim,
ligs = list(decoys.keys())
if majority:
generic = generic_fp("../data/annotated/pockets_nx_symmetric_orig")
true_ids = []
fp_dict = {}
for g_path in test_graphlist:
g,_,_,true_fp = pickle.load(open(os.path.join(test_graph_path, g_path), 'rb'))
try:
true_id = g_path.split(":")[2]
fp_dict[true_id] = true_fp
decoys[true_id]
except:
print(f">> failed on {g_path}")
continue
......@@ -105,6 +110,7 @@ def decoy_test(model, decoys, edge_map, embed_dim,
with torch.no_grad():
fp_pred, _ = model(dgl_graph)
# fp_pred = fp_pred.detach().numpy()
fp_pred = fp_pred.detach().numpy() > 0.5
fp_pred = fp_pred.astype(int)
if majority:
......@@ -113,20 +119,23 @@ def decoy_test(model, decoys, edge_map, embed_dim,
active = decoys[true_id][0]
decs = decoys[true_id][1]
rank = distance_rank(active, fp_pred, decs, dist_func=mse)
sim = jaccard(true_fp, fp_pred)
sim = mse(true_fp, fp_pred)
true_ids.append(true_id)
ranks.append(rank)
sims.append(sim)
return ranks, sims
return ranks, sims, true_ids, fp_dict
def wilcoxon_all_pairs(df):
def wilcoxon_all_pairs(df, methods):
"""
Compute pairwise wilcoxon on all runs.
"""
from scipy.stats import wilcoxon
wilcoxons = {'method_1': [], 'method_2':[], 'p-value': []}
for method_1, df1 in df.groupby('method'):
for method_2, df2 in df.groupby('method'):
p_val = wilcoxon(df1['rank'], df2['rank'])
for method_1 in methods:
for method_2 in methods:
vals1 = df.loc[df['method'] == method_1]
vals2 = df.loc[df['method'] == method_2]
p_val = wilcoxon(vals1['rank'], vals2['rank'], correction=True)
wilcoxons['method_1'].append(method_1)
wilcoxons['method_2'].append(method_2)
......@@ -136,11 +145,12 @@ def wilcoxon_all_pairs(df):
wil_df.fillna(0)
pvals = wil_df.pivot("method_1", "method_2", "p-value")
pvals.fillna(0)
mask = np.zeros_like(pvals)
mask[np.triu_indices_from(mask)] = True
g = sns.heatmap(pvals, cmap="Reds_r", annot=True, mask=mask, cbar=True)
g.set_facecolor('grey')
plt.show()
print(pvals.to_latex())
# mask = np.zeros_like(pvals)
# mask[np.triu_indices_from(mask)] = True
# g = sns.heatmap(pvals, cmap="Reds_r", annot=True, mask=mask, cbar=True)
# g.set_facecolor('grey')
# plt.show()
pass
def generic_fp(annot_dir):
"""
......@@ -197,42 +207,74 @@ def make_ridge(df, x='method', y='rank', save=None, show=True):
g.set_titles("")
g.set(yticks=[])
g.despine(bottom=True, left=True)
plt.show()
if save:
plt.savefig(save)
if show:
plt.show()
def make_tree_grid(df, fp_dict, method='htune'):
lig_dict = {}
df_tree = df.loc[df['method'] == method]
means = df_tree.groupby('lig').mean()
for row in means.itertuples():
lig_dict[row.Index] = (fp_dict[row.Index], row.rank)
compute_clustering(lig_dict)
pass
def ablation_results():
# modes = h'', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
# modes = ['raw', 'bb', 'wc-bb', 'pair-shuffle']
# modes = h'', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle'] # modes = ['raw', 'bb', 'wc-bb', 'pair-shuffle']
# modes = ['raw', 'warm', 'wc-bb', 'bb', 'majority', 'swap', 'random']
modes = ['raw', 'wc-bb', 'bb', 'majority', 'swap', 'random']
decoys = get_decoys(mode='pdb')
ranks, methods, jaccards = [], [], []
# modes = ['raw', 'tune', 'wc-bb', 'no-label', 'bb', 'majority', 'swap', 'random']
modes = ['atune', 'braw', 'cwc-bb', 'dbb', 'eno-label', 'fmajority', 'gswap', 'hrandom']
# modes = ['tune', 'wc-bb', 'swap']
title = {'raw': 'ABPN',
'tune': 'ABPN + unsup.',
'wc-bb': 'sec.struc.',
'bb': 'primary. struc.',
'majority': 'majority',
'swap': 'swap',
'no-label': 'no-label',
'random': 'random'
}
ranks, methods, jaccards, ligs = [], [], [], []
graph_dir = '../data/annotated/pockets_nx_symmetric_orig'
decoys = get_decoys(mode='dude', annots_dir='../data/annotated/pockets_nx_symmetric_orig')
# graph_dir = '../data/annotated/pockets_nx_2'
run = 'ismb'
num_folds = 10
num_folds = 10
majority = False
fp_dict = {}
for m in modes:
print(m)
real = m[1:]
if m in ['raw', 'pair-shuffle']:
if real in ['raw', 'pair-shuffle']:
graph_dir = "../data/annotated/pockets_nx_symmetric_orig"
run = 'ismb-raw'
elif m == 'swap':
elif real == 'swap':
graph_dir = '../data/annotated/pockets_nx_symmetric_scramble_orig'
run = 'ismb-' + m
elif m == 'majority':
run = 'ismb-' + real
elif real == 'majority':
run = 'ismb-raw'
graph_dir = '../data/annotated/pockets_nx_symmetric_orig'
majority = True
elif m == 'random':
elif real == 'random':
graph_dir = '../data/annotated/pockets_nx_symmetric_random_orig'
run = 'random'
elif m == 'warm':
elif real == 'warm':
graph_dir = '../data/annotated/pockets_nx_symmetric_orig'
run = 'ismb-warm'
elif real == 'tune':
graph_dir == '../data/annotated/pockets_nx_symmetric_orig'
run = 'ismb-tune'
elif real == 'tune-2':
graph_dir == '../data/annotated/pockets_nx_symmetric_orig'
run = 'ismb-tune-2'
else:
graph_dir = "../data/annotated/pockets_nx_symmetric_" + m + "_orig"
run = 'ismb-' + m
graph_dir = "../data/annotated/pockets_nx_symmetric_" + real + "_orig"
run = 'ismb-' + real
print(majority, run, graph_dir)
for fold in range(num_folds):
model, meta = load_model(run +"_" + str(fold))
......@@ -242,64 +284,31 @@ def ablation_results():
num_edge_types = len(edge_map)
graph_ids = pickle.load(open(f'../results/trained_models/{run}_{fold}/splits_{fold}.p', 'rb'))
# graph_ids = pickle.load(open(f'../results/trained_models/{run}/splits.p', 'rb'))
ranks_this,sims_this = decoy_test(model, decoys, edge_map, embed_dim,
shuffle=shuffle,
ranks_this,sims_this, lig_ids, fp_dict_this = decoy_test(model, decoys, edge_map, embed_dim,
nucs=meta['nucs'],
test_graphlist=graph_ids['test'],
test_graph_path=graph_dir,
majority=majority)
test_ligs = []
if real == 'tune':
fp_dict.update(fp_dict_this)
ranks.extend(ranks_this)
jaccards.extend(sims_this)
ligs.extend(lig_ids)
methods.extend([m]*len(ranks_this))
# decoy distance distribution
# dists = []
# for _,(active, decs) in decoys.items():
# for d in decs:
# dists.append(jaccard(active, d))
# plt.scatter(ranks_this, sims_this)
# plt.xlabel("ranks")
# plt.ylabel("distance")
# plt.show()
# sns.distplot(dists, label='decoy distance')
# sns.distplot(sims_this, label='pred distance')
# plt.xlabel("distance")
# plt.legend()
# plt.show()
# # rank_cut = 0.9
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.4 and r > 0.8]
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.3]
# print(cool)
# print(cool, len(ranks_this))
#4v6q_#0:BB:FME:3001.nx_annot.p
# test_ligs = set([f.split(":")[2] for f in graph_ids['test']])
# train_ligs = set([f.split(":")[2] for f in graph_ids['train']])
# print("ligands not in train set", test_ligs - train_ligs)
# points = []
# tot = len([x for x in ranks_this if x >= rank_cut])
# for sim_cut in np.arange(0,1.1,0.1):
# pos = 0
# for s,r in zip(sims_this, ranks_this):
# if s < sim_cut and r > rank_cut:
# pos += 1
# points.append(pos / tot)
# from sklearn.metrics import auc
# plt.title(f"Top 20% Accuracy {auc(np.arange(0, 1.1, 0.1), points)}, {m}")
# plt.plot(points, label=m)
# plt.plot([x for x in np.arange(0,1.1, 0.1)], '--')
# plt.ylabel("Positives")
# plt.xlabel("Distance threshold")
# plt.xticks(np.arange(10), [0, 0.1, 0.2, 0.3, 0.4, 0.5,0.6, 0.7, 0.9, 1.0])
# plt.legend()
# plt.show()
df = pd.DataFrame({'rank': ranks, 'jaccard': jaccards, 'method':methods})
wilcoxon_all_pairs(df)
# make_ridge(df, x='method', y='rank')
# make_ridge(df, x='method', y='jaccard')
majority = False
df = pd.DataFrame({'rank': ranks, 'jaccard': jaccards, 'method':methods, 'lig': ligs})
wilcoxon_all_pairs(df, modes)
# make_tree_grid(df, fp_dict)
# sys.exit()
means = df.groupby('method').std()
print(means.to_latex())
# make_ridge(df, x='method', y='rank', save='../tex/Figs/waves_rank_pdb.pdf')
# make_ridge(df, x='method', y='jaccard', save='../tex/Figs/waves_dist_pdb.pdf')
# make_ridge(df, x='method', y='rank', save=None)
# make_ridge(df, x='method', y='jaccard', save=None)
# make_violins(df, x='method', y='jaccard')
# make_violins(df, x='method', y='rank')
......@@ -376,6 +385,79 @@ def scanning_analyze():
plt.show()
pass
def structure_scanning(pdb, ligname, graph, model, edge_map, embed_dim):
"""
Given a PDB structure make a prediction for each residue in the structure:
- chop the structure into candidate sites (for each residue get a sphere..)
- convert residue neighbourhood into graph
- get prediction from model for each
- compare prediction to native ligand.
:returns: `residue_preds` dictionary with residue id as key and fingerprint prediction as value.
"""
from data_processor.build_dataset import get_pocket_graph
parser = MMCIFParser(QUIET=True)
structure = parser.get_structure("", pdb)[0]
residue_preds = {}
residues = list(structure.get_residues())
for residue in tqdm(residues):
if residue.resname in ['A', 'U', 'C', 'G', ligname]:
res_info = ":".join(["_",residue.get_parent().id, residue.resname, str(residue.id[1])])
pocket_graph = get_pocket_graph(pdb, res_info, graph)
_,dgl_graph = nx_to_dgl(pocket_graph, edge_map, embed_dim)
_,fp_pred= model(dgl_graph)
fp_pred = fp_pred.detach().numpy() > 0.5
residue_preds[(residue.get_parent().id, residue.id[1])] = fp_pred
else:
continue
return residue_preds
def scanning_analyze():
"""
Visualize results of scanning on PDB.
Color residues by prediction score.
1fmn_#0.1:A:FMN:36.nx_annot.p
"""
from data_processor.build_dataset import find_residue,lig_center
model, edge_map, embed_dim = load_model('small_no_rec_2', '../data/annotated/pockets_nx')
for f in os.listdir("../data/annotated/pockets_nx"):
pdbid = f.split("_")[0]
_,chain,ligname,pos = f.replace(".nx_annot.p", "").split(":")
pos = int(pos)
print(chain,ligname, pos)
graph = pickle.load(open(f'../data/RNA_Graphs/{pdbid}.pickle', 'rb'))
if len(graph.nodes()) > 100:
continue
try:
fp_preds = structure_scanning(f'../data/all_rna_prot_lig_2019/{pdbid}.cif', ligname, graph, model, edge_map, embed_dim)
except Exception as e:
print(e)
continue
parser = MMCIFParser(QUIET=True)
structure = parser.get_structure("", f"../data/all_rna_prot_lig_2019/{pdbid}.cif")[0]
lig_res = find_residue(structure[chain], pos)
lig_c = lig_center(lig_res.get_atoms())
fp_dict = pickle.load(open("../data/all_ligs_maccs.p", 'rb'))
true_fp = fp_dict[ligname]
dists = []
jaccards = []
decoys = get_decoys()
for res, fp in fp_preds.items():
chain, pos = res
r = find_residue(structure[chain], pos)
r_center = lig_center(r.get_atoms())
dists.append(euclidean(r_center, lig_c))
jaccards.append(mse(true_fp, fp))
plt.title(f)
plt.distplot(dists, jaccards)
plt.xlabel("dist to binding site")
plt.ylabel("dist to fp")
plt.show()
pass
if __name__ == "__main__":
# scanning_analyze()
ablation_results()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment