Commit 2e2fef5c authored by Roman Sarrazin-Gendron's avatar Roman Sarrazin-Gendron
Browse files

added validation methods

parent 22cd100f
This diff is collapsed.
import os
import pickle
from multiprocessing import Process, Manager
import testSS
import BayesPairing
import random
from Bio import SeqIO
from random import shuffle
DATASET_NAME = "bp2_rna3dmotif"
graphs = pickle.load(open("../models/bp2_rna3dmotif_one_of_each_graph.cPickle", "rb"))
pdbs = pickle.load(open("../models/bp2_rna3dmotif_PDB_names.cPickle", "rb"))
pdb_positions = pickle.load(open("../models/bp2_rna3dmotif_PDB_positions.cPickle", "rb"))
fasta_path = "pdb_seqres.txt"
# record_dict = SeqIO.to_dict(SeqIO.parse(fasta_path, "fasta"))
listd = os.listdir("../models/all_graphs_pickled")
#print(listd)
PDBlist = set([x[0:4] for x in listd])
#print(PDBlist)
#exit()
# print(PDBlist)
def run_BP(seq, ss, modules_to_parse, dataset, left_out, leave_out_sequence = False,left_out_sequence=""):
return_dict = BayesPairing.parse_sequence(seq, modules_to_parse,ss, DATASET_NAME, left_out, leave_out_sequence=leave_out_sequence,left_out_sequence=left_out_sequence)
maxs = BayesPairing.returner(return_dict, seq)
return maxs
def shuffle_seq(seq):
seq = list(seq)
shuffle(seq)
return "".join(seq)
def get_constraints_from_BN(positions, graph):
if len(positions) > 0:
constraints = []
bps = []
ncbps = []
bp_types = []
real_bps = []
real_ncbps = []
for i in graph.edges():
# print(i,graph.get_edge_data(*i))
if (graph.get_edge_data(*i)['label'].upper() == "CWW") and i[0] < i[1]:
bps.append(i)
elif (graph.get_edge_data(*i)['label'].upper() not in ["B53","S33","S55"]) and i[0] < i[1]:
# print(graph.get_edge_data(*i))
ncbps.append((i, graph.get_edge_data(*i)['label'].upper()))
print('BASE PAIRS')
print(bps)
print(ncbps)
nodes = []
for i in graph.nodes():
nodes.append(int(i))
sortednodes = sorted(nodes)
print(sortednodes)
for j in range(len(sortednodes)):
n = sortednodes[j]
for bp in bps:
(a, b) = bp
if n == a:
pairing_node = positions[j]
partner_ind = sortednodes.index(b)
partner_node = positions[partner_ind]
real_bps.append((pairing_node, partner_node))
elif n == b:
pairing_node = positions[j]
partner_ind = sortednodes.index(a)
partner_node = positions[partner_ind]
real_bps.append((partner_node, pairing_node))
for bp in ncbps:
(a, b) = bp[0]
# print(a,b)
if n == a:
pairing_node = positions[j]
partner_ind = sortednodes.index(b)
partner_node = positions[partner_ind]
real_ncbps.append(((pairing_node, partner_node), bp[1]))
elif n == b:
pairing_node = positions[j]
partner_ind = sortednodes.index(a)
partner_node = positions[partner_ind]
real_ncbps.append(((partner_node, pairing_node), bp[1]))
return (set(real_bps), set(real_ncbps))
else:
return ([], [])
def parse_FR3D(PDB, positions, bps, ncbps,aiming_for):
print("AIMING FOR",aiming_for)
print("FOUND",positions)
PDB_name = PDB.upper() + ".nxpickled"
#chain = get_chain_from_PDB(PDB,positions[0])
max_score =len(positions)
score = 0
for i in positions:
if i in aiming_for:
score = score +1
score = score/max_score
print("SCORE :", score)
return score
def compare_to_FR3D(PDB, score, positions, module_graph,chain,aiming_for):
print("GETTING CONSTRAINTS :", positions,module_graph)
print(module_graph.edges(data=True))
#exit()
bps, ncbps = get_constraints_from_BN(positions, module_graph)
score = parse_FR3D(PDB, positions, bps, ncbps,aiming_for)
return score
def get_seq_ss(PDBid,ex):
PDB, chain = PDBid.split("_")[0:2]
# print(PDB)
# print("../all_graphs_pickled/" + PDB + ".nxpickled")
try:
g = pickle.load(open("../models/all_graphs_pickled/" + PDB + ".nxpickled", "rb"))
except FileNotFoundError:
print("PDB FILE NOT FOUND")
return ("", 0,0)
seq = ""
nodes = []
for node in g.nodes(data=True):
#print(node)
# print(node[0][0],chain)
if node[0][0] == chain:
nodecode = node[0][1]
if node[1]["nt"]!= "TU":
nodes.append((int(nodecode), node[1]["nt"]))
else:
nodes.append((int(nodecode), "U"))
sortednodes = sorted(list(nodes))
#print("FIRST NODE:",sortednodes[0])
nuc_by_node = {}
missing_nuc = False
# print("NODES")
for i in sortednodes:
nuc_by_node[i[0]] = i[1]
#print(sortednodes)
try:
for i in range(1, int(sortednodes[-1][0]) + 1):
if i not in nuc_by_node.keys() :
if ("A" in seq or "G" in seq or "C" in seq or "U" in seq):
seq = seq + "" #should be N or gap, trying not ot crash shit.
#seq = seq + "N"
else:
seq = seq + nuc_by_node[i]
if chain in g.graph["ss"]:
ss = g.graph['ss'][chain]
else:
ss = ""
# print(seq)
# print("MISSING_NUC",PDBid,missing_nuc)
if "T" in seq:
seq = seq.replace("T","U")
#exit()
#print(seq)
#exit(0)
except:
return ("","","")
return (seq, ss, chain)
def run_validation(module_to_test, module_seqs, left_out, offset,chain, aiming_for, leave_out_sequence=False,left_out_sequence=""):
results = {}
# for i in modules_to_test:
# results[i] = []
i = module_to_test
results[i] = []
for seq_obj in module_seqs[module_to_test]:
if left_out in seq_obj[0]:
PDB_name = seq_obj[0]
seq = seq_obj[1]
print("SEQUENCE:",seq)
ss = seq_obj[2]
seq_list = list(seq)
control_seq = []
for n in seq_list:
# n = random.choice(['A','C','G','U'])
control_seq.append(n)
seq = "".join(control_seq)
while "N" in seq:
seq = list(seq)
seq.remove("N")
seq = "".join(seq)
ss = seq_obj[2]
maxs = run_BP(str(seq), ss, [i], DATASET_NAME, left_out, leave_out_sequence= leave_out_sequence,left_out_sequence = left_out_sequence)
# maxs = run_BP(str(seq),ss,[i],"NONE")
print("ALL_RESULTS :",maxs)
k = modules_to_test.index(i)
el = 0
while el < (len(maxs[0])):
print("COMPARING TO FR3D")
bp_score = maxs[0][el]
score1 = compare_to_FR3D(PDB_name, maxs[0][el], [position + offset for position in maxs[0][el + 1]],
graphs[i][0],chain=chain,aiming_for=[x-1 for x in aiming_for])
print(maxs[0][el + 1],score1)
#if len(maxs)==0:
# continue
seq1 = "".join([seq[i] for i in maxs[0][el + 1]])
#ss1 = "".join([ss[i] for i in maxs[0][el + 1]])
#candidate_results = [score1, seq1, ss1]
candidate_results = [score1, seq1, bp_score]
results[i].append(candidate_results)
el = el + 3
# score1 = compare_to_FR3D(PDB_name,maxs[0][0],maxs[0][1],graphs[i][0])
# score2 = compare_to_FR3D(PDB_name, maxs[0][2], maxs[0][3], graphs[i][0])
# seq1 = "".join([seq[i] for i in maxs[0][1]])
# seq2 = "".join([seq[i] for i in maxs[0][3]])
# ss1 = "".join([ss[i] for i in maxs[0][1]])
# ss2 = "".join([ss[i] for i in maxs[0][3]])
# results[i].append((maxs[0][0],seq1,ss1,score1,maxs[0][2],seq2,ss2,score2))
return results
def cross_validation(modules_to_test):
# modules_to_test = range(0,len(graphs))
# modules_to_test = [1,3,4,7,9,23,25,29,71,82,86]
# modules_to_test = [1]
seqs = {}
print("NUMBER OF EXAMPLES")
As = 0
Cs = 0
Gs = 0
Us = 0
for i in modules_to_test:
if i in [199,225]:
continue
seqs[i] = []
print("----------------------------------")
print("MODULE #" + str(i))
# print("trained on " + str(len(graphs[i])) + "examples")
t = graphs[i]
gr = t[1]
# print("ALL SEQUENCES")
for j in t:
seqq = ""
nod = j.nodes()
sn = sorted([int(kk) for kk in nod])
nl = j.nodes(data=True)
for k in sn:
for node in nl:
if node[0] == k:
z = node
seqq = seqq + z[1]['nuc']
n = z[1]['nuc']
if n == 'A':
As += 1
if n == 'C':
Cs += 1
if n == 'G':
Gs += 1
if n == 'U':
Us += 1
print(seqq)
if seqq not in seqs[i]:
seqs[i].append(seqq)
# print(As, Cs, Gs, Us)
tot = (As + Cs + Gs + Us)
# print(As / tot, Cs / tot, Gs / tot, Us / tot)
# print(seqs)
BP_call = ""
for i in modules_to_test:
BP_call = BP_call + " " + str(i)
pdbs_in_modules = {}
positions_in_modules = {}
module_seqs = {}
# with open("list_292pdbs.txt","r") as f:
# lines = f.readlines()
# non_red = ["_".join(l[:-1].split("\t")) for l in lines]
# print(non_red)
crossval_results = {}
for i in modules_to_test:
if len(seqs[i])<=1:
print("only one sequence for module; skipping")
continue
motif_position_in_PDBs = {}
crossval_results[i] = []
pdbs[i] = [x.split(".") for x in pdbs[i][1:]]
# print(pdbs[i])
pdbs_in_modules[i] = ["_".join([x[0].upper(), x[1]]) for x in pdbs[i]]
motif_positions = pdb_positions[i][0:]
# print(pdbs[i])
# print("PDBS")
# print(pdbs_in_modules[i])
# print(PDBlist)
# print(motif_positions)
# positions_in_modules[i] = motif_position_in_PDBs
# print(motif_position_in_PDBs)
this_mod_seqs = []
bugged = ["3DEG", "1AQ4", "1H4Q", "IH4S"]
done = []
for ind, j in enumerate(pdbs_in_modules[i]):
print(motif_positions[ind])
last = sorted(motif_positions[ind][1])[-1]
first = sorted(motif_positions[ind][1])[0]
# print("PDB POSITIONS")
# print(first, last)
if j[0:4].upper() in PDBlist and j[0:4] not in bugged and j[0:4] not in done:
#print("CURRENT PDB:", j)
seq, ss, chain = get_seq_ss(j, graphs[i][ind].nodes(data=True))
motif_seq = "".join([x[1]["nuc"] for x in sorted(graphs[i][ind].nodes(data=True))])
#print("MOTIF SEQ",motif_seq)
real_motif_seq = ""
if motif_positions[ind][1][-1] < len(seq):
real_motif_seq = "".join([seq[x-1] for x in motif_positions[ind][1]])
#print("real motif seq",real_motif_seq)
#else:
#print("did not find module in sequence")
if motif_seq != real_motif_seq:
continue
#print("LENGTH OF SEQUENCE:", len(seq))
pdb_len = len(seq)
#print("pdb len", pdb_len)
#print("SEQUENCE :", seq)
if pdb_len in range(10, 200) and "-" not in seq and motif_positions[ind][1][-1]<pdb_len: # continue
#done.append(j[0:4])
# seq = shuffle_seq(seq)
this_mod_seqs = [((j, seq, ss))]
# else:i
# print(j[0:4].upper())
module_seqs[i] = this_mod_seqs
# print(module_seqs)
# print("SEQUENCE")
# print(module_seqs)
scores = run_validation(i, module_seqs, j, 0, chain=chain, aiming_for=motif_positions[ind][1], leave_out_sequence = True, left_out_sequence = motif_seq)
for k in scores:
print(scores[k])
crossval_results[i].append(scores[k])
elif pdb_len > 200:
new_seq, new_ss = [], []
offset = max(0, first - 100)
if last - first < 200:
new_seq = seq[max(0, first - 100):last + 100]
# new_seq = shuffle_seq(new_seq)
# helices = BayesPairing.find_loops(ss)[1]
# print("HELICES")
# print(helices)
# for helix in helices:
# if first > helix[0] and last < helix[1]:
# new_seq = seq[helix[0]:helix[1]]
# new_ss = ss[helix[0]:helix[1]]
# offset = helix[0]
# break
# else:
# continue
print("NEW SEQ LENGTH")
print(len(new_seq))
if len(new_seq) > 250 or len(new_seq)<10:
continue
# filler = 200-(last-first)
# split = random.randrange(0,filler)
# offset = first-split
# new_seq = seq[offset:last+(filler-split)]
# new_ss = ss[first-split:last+(filler-split)]
this_mod_seqs.append((j, new_seq, new_ss))
# else:i
# print(j[0:4].upper())
module_seqs[i] = this_mod_seqs
# print(module_seqs)
# print("SEQUENCE")
# print(module_seqs)
print("RUNNING VALIDATION")
scores = run_validation(i, module_seqs, j, offset, chain=chain,
aiming_for=motif_positions[ind][1], leave_out_sequence = True, left_out_sequence = motif_seq)
for k in scores:
print(k)
print(scores[k])
crossval_results[i].append(scores[k])
cv_fn = "bp1_crossval_loo_seq" + str(i)
pickle.dump(crossval_results[i], open(cv_fn, "wb"))
for i in crossval_results:
for j in crossval_results[i]:
print(j)
#print(crossval_results)
pickle.dump(crossval_results, open("bp1_crossval_loo_seq.cPickle", "wb"))
print('FINAL RESULTS')
for i in crossval_results:
print("RESULTS FOR MODULE :", i)
for j in crossval_results[i]:
print(j)
def test_fasta(input, modules_to_parse):
prediction_scores = {}
with open(input, "rU") as f:
r_count = 0
for record in SeqIO.parse(f, "fasta"):
r_count = r_count + 1
if r_count % 25 == 0:
pickle.dump(prediction_scores, open('partial_predictions2.cPickle', 'wb'))
id = record.id
prediction_scores[id] = {}
seq = str(record.seq).replace("T", "U")
if len(seq) > 300:
continue
maxs = run_BP(seq, "", modules_to_parse, "NONE")
print(maxs)
for ind, module in enumerate(maxs):
if len(maxs[ind]) > 0:
prediction_scores[id][modules_to_parse[ind]] = (maxs[ind][0], maxs[ind][1])
else:
prediction_scores[id][modules_to_parse[ind]] = (0, [])
pickle.dump(prediction_scores, open("prediction_score_3.cPickle", "wb"))
if __name__ == "__main__":
modules_to_test = range(0,len(graphs))
#modules_to_test = [294]
#modules_to_test = [4,5,7,20,24,30,31,32,33,34,38,40,44,54,57,60,64,75,76,78,79,98,99,100,101,104,107,108,109,111,116,119,124]
# cross_validation(modules_to_test)
# modules_to_test =[0, 5, 8, 26, 6, 18, 19, 21, 24, 25, 39, 53, 54, 58, 119, 122, 135, 146, 162, 168, 191, 194, 198, 216]
#modules_to_test = range(0, 24)
#modules_to_test = [8, 10, 13, 15, 16, 20, 29, 32, 34, 42, 49, 58, 69, 80, 88, 97, 100, 105, 109, 111, 125, 126, 128]
cross_validation(modules_to_test)
# test_fasta("cdnas.fa",modules_to_test)
import os
import pickle
from multiprocessing import Process, Manager
import testSS
import BayesPairing
import random
from Bio import SeqIO
from random import shuffle
DATASET_NAME = "bp2_rna3dmotif"
graphs = pickle.load(open("../models/bp2_rna3dmotif_one_of_each_graph.cPickle", "rb"))
pdbs = pickle.load(open("../models/bp2_rna3dmotif_PDB_names.cPickle", "rb"))
pdb_positions = pickle.load(open("../models/bp2_rna3dmotif_PDB_positions.cPickle", "rb"))
fasta_path = "pdb_seqres.txt"
# record_dict = SeqIO.to_dict(SeqIO.parse(fasta_path, "fasta"))
listd = os.listdir("../models/all_graphs_pickled")
#print(listd)
PDBlist = set([x[0:4] for x in listd])
#print(PDBlist)
#exit()
# print(PDBlist)
def run_BP(seq, ss, modules_to_parse, dataset, left_out):
return_dict = BayesPairing.parse_sequence(seq, modules_to_parse,ss, DATASET_NAME, left_out)
maxs = BayesPairing.returner(return_dict, seq, ss)
return maxs
def shuffle_seq(seq):
seq = list(seq)
shuffle(seq)
return "".join(seq)
def get_constraints_from_BN(positions, graph):
if len(positions) > 0:
constraints = []
bps = []
ncbps = []
bp_types = []
real_bps = []
real_ncbps = []
for i in graph.edges():
# print(i,graph.get_edge_data(*i))
if (graph.get_edge_data(*i)['label'].upper() == "CWW") and i[0] < i[1]:
bps.append(i)
elif (graph.get_edge_data(*i)['label'].upper() not in ["B53","S33","S55"]) and i[0] < i[1]:
# print(graph.get_edge_data(*i))
ncbps.append((i, graph.get_edge_data(*i)['label'].upper()))
print('BASE PAIRS')
print(bps)
print(ncbps)
nodes = []
for i in graph.nodes():
nodes.append(int(i))
sortednodes = sorted(nodes)
print(sortednodes)
for j in range(len(sortednodes)):
n = sortednodes[j]
for bp in bps:
(a, b) = bp
if n == a:
pairing_node = positions[j]
partner_ind = sortednodes.index(b)
partner_node = positions[partner_ind]
real_bps.append((pairing_node, partner_node))
elif n == b:
pairing_node = positions[j]
partner_ind = sortednodes.index(a)
partner_node = positions[partner_ind]
real_bps.append((partner_node, pairing_node))
for bp in ncbps:
(a, b) = bp[0]
# print(a,b)
if n == a:
pairing_node = positions[j]
partner_ind = sortednodes.index(b)
partner_node = positions[partner_ind]
real_ncbps.append(((pairing_node, partner_node), bp[1]))
elif n == b:
pairing_node = positions[j]
partner_ind = sortednodes.index(a)
partner_node = positions[partner_ind]
real_ncbps.append(((partner_node, pairing_node), bp[1]))
return (set(real_bps), set(real_ncbps))
else:
return ([], [])
def parse_FR3D(PDB, positions, bps, ncbps,aiming_for):
print("AIMING FOR",aiming_for)
print("FOUND",positions)
PDB_name = PDB.upper() + ".nxpickled"
#chain = get_chain_from_PDB(PDB,positions[0])
max_score =len(positions)
score = 0
for i in positions:
if i in aiming_for:
score = score +1
score = score/max_score
print("SCORE :", score)
return score
def compare_to_FR3D(PDB, score, positions, module_graph,chain,aiming_for):
print("GETTING CONSTRAINTS :", positions,module_graph)
print(module_graph.edges(data=True))
#exit()
bps, ncbps = get_constraints_from_BN(positions, module_graph)
score = parse_FR3D(PDB, positions, bps, ncbps,aiming_for)
return score
def get_seq_ss(PDBid,ex):
PDB, chain = PDBid.split("_")[0:2]
# print(PDB)
# print("../all_graphs_pickled/" + PDB + ".nxpickled")
try:
g = pickle.load(open("../models/all_graphs_pickled/" + PDB + ".nxpickled", "rb"))
except FileNotFoundError:
print("PDB FILE NOT FOUND")
return ("", 0,0)
seq = ""
nodes = []
for node in g.nodes(data=True):
#print(node)
# print(node[0][0],chain)
if node[0][0] == chain:
nodecode = node[0][1]
if node[1]["nt"]!= "TU":
nodes.append((int(nodecode), node[1]["nt"]))
else:
nodes.append((int(nodecode), "U"))
sortednodes = sorted(list(nodes))
#print("FIRST NODE:",sortednodes[0])
nuc_by_node = {}
missing_nuc = False
# print("NODES")
for i in sortednodes:
nuc_by_node[i[0]] = i[1]
#print(sortednodes)
try:
for i in range(1, int(sortednodes[-1][0]) + 1):
if i not in nuc_by_node.keys() :
if ("A" in seq or "G" in seq or "C" in seq or "U" in seq):
seq = seq + "" #should be N or gap, trying not ot crash shit.
#seq = seq + "N"
else:
seq = seq + nuc_by_node[i]
if chain in g.graph["ss"]:
ss = g.graph['ss'][chain]
else:
ss = ""
# print(seq)
# print("MISSING_NUC",PDBid,missing_nuc)
if "T" in seq:
seq = seq.replace("T","U")
#exit()
#print(seq)
#exit(0)
except:
return ("","","")
return (seq