Commit ceff9cea authored by Roman Sarrazin-Gendron's avatar Roman Sarrazin-Gendron
Browse files

updated validation tools for comparison with v2

parent 8a85d6e3
......@@ -195,9 +195,9 @@ def call_makeBN(mod,dataset,left_out):
excluded = left_out
bayesname = dataset+"_BN_"+str(current_ID)
if excluded != "NONE":
removecall = "rm "+ +"../models/"+bayesname
subprocess.call(removecall, shell=True)
#if excluded != "NONE":
# removecall = "rm " +"../models/"+bayesname
# subprocess.call(removecall, shell=True)
#if not os.path.isfile(bayesname), changed for cross-validation; always remake the model
if not os.path.isfile("../models/"+bayesname):
aln_list = pickle.load(open("../models/"+dataset + "_aligned_modulegraphs.cPickle",'rb'))
......
import os
import pickle
from multiprocessing import Process, Manager
import testSS
import BayesPairing
import random
from Bio import SeqIO
from random import shuffle
DATASET_NAME = "yash"
graphs = pickle.load(open("../models/yash_one_of_each_graph.cPickle", "rb"))
pdbs = pickle.load(open("../models/yash_PDB_names.cPickle", "rb"))
pdb_positions = pickle.load(open("../models/yash_PDB_positions.cPickle", "rb"))
fasta_path = "pdb_seqres.txt"
# record_dict = SeqIO.to_dict(SeqIO.parse(fasta_path, "fasta"))
listd = os.listdir("../models/all_graphs_pickled")
#print(listd)
PDBlist = set([x[0:4] for x in listd])
#print(PDBlist)
#exit()
# print(PDBlist)
def run_BP(seq, ss, modules_to_parse, dataset, left_out):
return_dict = BayesPairing.parse_sequence(seq, modules_to_parse,ss, DATASET_NAME, "NONE")
maxs = BayesPairing.returner(return_dict, seq)
return maxs
def shuffle_seq(seq):
seq = list(seq)
shuffle(seq)
return "".join(seq)
def get_constraints_from_BN(positions, graph):
if len(positions) > 0:
constraints = []
bps = []
ncbps = []
bp_types = []
real_bps = []
real_ncbps = []
for i in graph.edges():
# print(i,graph.get_edge_data(*i))
if (graph.get_edge_data(*i)['label'].upper() == "CWW") and i[0] < i[1]:
bps.append(i)
elif (graph.get_edge_data(*i)['label'].upper() not in ["B53","S33","S55"]) and i[0] < i[1]:
# print(graph.get_edge_data(*i))
ncbps.append((i, graph.get_edge_data(*i)['label'].upper()))
print('BASE PAIRS')
print(bps)
print(ncbps)
nodes = []
for i in graph.nodes():
nodes.append(int(i))
sortednodes = sorted(nodes)
print(sortednodes)
for j in range(len(sortednodes)):
n = sortednodes[j]
for bp in bps:
(a, b) = bp
if n == a:
pairing_node = positions[j]
partner_ind = sortednodes.index(b)
partner_node = positions[partner_ind]
real_bps.append((pairing_node, partner_node))
elif n == b:
pairing_node = positions[j]
partner_ind = sortednodes.index(a)
partner_node = positions[partner_ind]
real_bps.append((partner_node, pairing_node))
for bp in ncbps:
(a, b) = bp[0]
# print(a,b)
if n == a:
pairing_node = positions[j]
partner_ind = sortednodes.index(b)
partner_node = positions[partner_ind]
real_ncbps.append(((pairing_node, partner_node), bp[1]))
elif n == b:
pairing_node = positions[j]
partner_ind = sortednodes.index(a)
partner_node = positions[partner_ind]
real_ncbps.append(((partner_node, pairing_node), bp[1]))
return (set(real_bps), set(real_ncbps))
else:
return ([], [])
def parse_FR3D(PDB, positions, bps, ncbps,aiming_for):
print("AIMING FOR",aiming_for)
print("FOUND",positions)
PDB_name = PDB.upper() + ".nxpickled"
#chain = get_chain_from_PDB(PDB,positions[0])
max_score =len(positions)
score = 0
for i in positions:
if i in aiming_for:
score = score +1
score = score/max_score
print("SCORE :", score)
return score
def compare_to_FR3D(PDB, score, positions, module_graph,chain,aiming_for):
print("GETTING CONSTRAINTS :", positions,module_graph)
print(module_graph.edges(data=True))
#exit()
bps, ncbps = get_constraints_from_BN(positions, module_graph)
score = parse_FR3D(PDB, positions, bps, ncbps,aiming_for)
return score
def get_seq_ss(PDBid,ex):
PDB, chain = PDBid.split("_")[0:2]
# print(PDB)
# print("../all_graphs_pickled/" + PDB + ".nxpickled")
try:
g = pickle.load(open("../models/all_graphs_pickled/" + PDB + ".nxpickled", "rb"))
except FileNotFoundError:
print("PDB FILE NOT FOUND")
return ("", 0,0)
seq = ""
nodes = []
for node in g.nodes(data=True):
# print(node)
# print(node[0][0],chain)
if node[0][0] == chain:
nodecode = (node[0][1])
nodes.append((int(nodecode), node[1]["nt"]))
sortednodes = sorted(list(nodes))
nuc_by_node = {}
missing_nuc = False
# print("NODES")
numericals = [x[0] for x in sortednodes]
decalage = 0
if 1 not in numericals:
decalage = decalage + 1
sortednodes.append((1, "N"))
sortednodes = sorted(sortednodes)
# missing_nuc=True
# decalage = decalage +1
newnodes = []
# for i in sortednodes:
# newnodes.append((i[0],i[1]))
# sortednodes = sorted(list(newnodes))
numericals = [x[0] for x in sortednodes]
# print("MISSING 1", PDBid)
# print(numericals)
#
# sortednodes=sorted(sortednodes)
# numericals = [x[0]-1 for x in sortednodes]
# numericals.insert(0,0)
# else:
# print("NOT MISSING", PDBid)
for i in sortednodes:
nuc_by_node[i[0]] = i[1]
# print(sortednodes)
for i in range(1, int(sortednodes[-1][0]) + 1):
if i not in numericals:
"NOT IN NODES"
seq = seq + "-"
else:
seq = seq + nuc_by_node[i]
ss = g.graph['ss']
# print(seq)
# print("MISSING_NUC",PDBid,missing_nuc)
if "T" in seq:
seq = seq.replace("T","U")
return (seq, ss, chain)
def run_validation(module_to_test, module_seqs, left_out, offset,chain, aiming_for):
results = {}
# for i in modules_to_test:
# results[i] = []
i = module_to_test
results[i] = []
for seq_obj in module_seqs[module_to_test]:
if left_out in seq_obj[0]:
PDB_name = seq_obj[0]
seq = seq_obj[1]
print("SEQUENCE:",seq)
ss = seq_obj[2]
seq_list = list(seq)
control_seq = []
for n in seq_list:
# n = random.choice(['A','C','G','U'])
control_seq.append(n)
seq = "".join(control_seq)
while "N" in seq:
seq = list(seq)
seq.remove("N")
seq = "".join(seq)
ss = seq_obj[2]
maxs = run_BP(str(seq), ss, [i], DATASET_NAME, left_out)
# maxs = run_BP(str(seq),ss,[i],"NONE")
print("ALL_RESULTS :",maxs)
k = modules_to_test.index(i)
el = 0
while el < (len(maxs[0])):
print("COMPARING TO FR3D")
bp_score = maxs[0][el]
score1 = compare_to_FR3D(PDB_name, maxs[0][el], [position + offset for position in maxs[0][el + 1]],
graphs[i][0],chain=chain,aiming_for=[x-1 for x in aiming_for])
print(maxs[0][el + 1],score1)
#if len(maxs)==0:
# continue
seq1 = "".join([seq[i] for i in maxs[0][el + 1]])
#ss1 = "".join([ss[i] for i in maxs[0][el + 1]])
#candidate_results = [score1, seq1, ss1]
candidate_results = [score1, seq1, bp_score]
results[i].append(candidate_results)
el = el + 3
# score1 = compare_to_FR3D(PDB_name,maxs[0][0],maxs[0][1],graphs[i][0])
# score2 = compare_to_FR3D(PDB_name, maxs[0][2], maxs[0][3], graphs[i][0])
# seq1 = "".join([seq[i] for i in maxs[0][1]])
# seq2 = "".join([seq[i] for i in maxs[0][3]])
# ss1 = "".join([ss[i] for i in maxs[0][1]])
# ss2 = "".join([ss[i] for i in maxs[0][3]])
# results[i].append((maxs[0][0],seq1,ss1,score1,maxs[0][2],seq2,ss2,score2))
return results
def cross_validation(modules_to_test):
# modules_to_test = range(0,len(graphs))
# modules_to_test = [1,3,4,7,9,23,25,29,71,82,86]
# modules_to_test = [1]
seqs = {}
print("NUMBER OF EXAMPLES")
As = 0
Cs = 0
Gs = 0
Us = 0
for i in modules_to_test:
if i in [199,225]:
continue
seqs[i] = []
print("----------------------------------")
print("MODULE #" + str(i))
# print("trained on " + str(len(graphs[i])) + "examples")
t = graphs[i]
gr = t[1]
# print("ALL SEQUENCES")
for j in t:
seqq = ""
nod = j.nodes()
sn = sorted([int(kk) for kk in nod])
nl = j.nodes(data=True)
for k in sn:
for node in nl:
if node[0] == k:
z = node
seqq = seqq + z[1]['nuc']
n = z[1]['nuc']
if n == 'A':
As += 1
if n == 'C':
Cs += 1
if n == 'G':
Gs += 1
if n == 'U':
Us += 1
print(seqq)
if seqq not in seqs[i]:
seqs[i].append(seqq)
# print(As, Cs, Gs, Us)
tot = (As + Cs + Gs + Us)
# print(As / tot, Cs / tot, Gs / tot, Us / tot)
# print(seqs)
BP_call = ""
for i in modules_to_test:
BP_call = BP_call + " " + str(i)
pdbs_in_modules = {}
positions_in_modules = {}
module_seqs = {}
# with open("list_292pdbs.txt","r") as f:
# lines = f.readlines()
# non_red = ["_".join(l[:-1].split("\t")) for l in lines]
# print(non_red)
crossval_results = {}
for i in modules_to_test:
if i in [199,225]:
continue
motif_position_in_PDBs = {}
crossval_results[i] = []
pdbs[i] = [x.split(".") for x in pdbs[i][1:]]
# print(pdbs[i])
pdbs_in_modules[i] = ["_".join([x[0].upper(), x[1]]) for x in pdbs[i]]
motif_positions = pdb_positions[i][1:]
# print(pdbs[i])
# print("PDBS")
# print(pdbs_in_modules[i])
# print(PDBlist)
# print(motif_positions)
# positions_in_modules[i] = motif_position_in_PDBs
# print(motif_position_in_PDBs)
this_mod_seqs = []
bugged = ["3DEG", "1AQ4", "1H4Q", "IH4S"]
done = []
for ind, j in enumerate(pdbs_in_modules[i]):
print(motif_positions[ind])
last = sorted(motif_positions[ind][1])[-1]
first = sorted(motif_positions[ind][1])[0]
# print("PDB POSITIONS")
# print(first, last)
if j[0:4].upper() in PDBlist and j[0:4] not in bugged and j[0:4] not in done:
#print("CURRENT PDB:", j)
seq, ss, chain = get_seq_ss(j, graphs[i][ind].nodes(data=True))
#print("LENGTH OF SEQUENCE:", len(seq))
pdb_len = len(seq)
#print("SEQUENCE :", seq)
if pdb_len in range(10, 150) and "-" not in seq and motif_positions[ind][1][-1]<pdb_len: # continue
done.append(j[0:4])
# seq = shuffle_seq(seq)
this_mod_seqs = [((j, seq, ss))]
# else:i
# print(j[0:4].upper())
module_seqs[i] = this_mod_seqs
# print(module_seqs)
# print("SEQUENCE")
# print(module_seqs)
scores = run_validation(i, module_seqs, j, 0, chain=chain, aiming_for=motif_positions[ind][1])
for k in scores:
print(scores[k])
crossval_results[i].append(scores[k])
elif pdb_len > 200:
continue
new_seq, new_ss = [], []
offset = max(0, first - 100)
if last - first < 200:
new_seq = seq[max(0, first - 100):last + 100]
# new_seq = shuffle_seq(new_seq)
# helices = BayesPairing.find_loops(ss)[1]
# print("HELICES")
# print(helices)
# for helix in helices:
# if first > helix[0] and last < helix[1]:
# new_seq = seq[helix[0]:helix[1]]
# new_ss = ss[helix[0]:helix[1]]
# offset = helix[0]
# break
# else:
# continue
print("NEW SEQ LENGTH")
print(len(new_seq))
if len(new_seq) > 300:
continue
# filler = 200-(last-first)
# split = random.randrange(0,filler)
# offset = first-split
# new_seq = seq[offset:last+(filler-split)]
# new_ss = ss[first-split:last+(filler-split)]
this_mod_seqs.append((j, new_seq, new_ss))
# else:i
# print(j[0:4].upper())
module_seqs[i] = this_mod_seqs
# print(module_seqs)
# print("SEQUENCE")
# print(module_seqs)
print("RUNNING VALIDATION")
scores = run_validation(i, module_seqs, j, offset, chain=chain,
aiming_for=motif_positions[ind][1])
for k in scores:
print(k)
print(scores[k])
crossval_results[i].append(scores[k])
cv_fn = "yash_rna3dmotif_positions_crossval_" + str(i)
pickle.dump(crossval_results[i], open(cv_fn, "wb"))
for i in crossval_results:
for j in crossval_results[i]:
print(j)
#print(crossval_results)
pickle.dump(crossval_results, open("yash_3dmotif_pos_cv.cPickle", "wb"))
print('FINAL RESULTS')
for i in crossval_results:
print("RESULTS FOR MODULE :", i)
for j in crossval_results[i]:
print(j)
def test_fasta(input, modules_to_parse):
prediction_scores = {}
with open(input, "rU") as f:
r_count = 0
for record in SeqIO.parse(f, "fasta"):
r_count = r_count + 1
if r_count % 25 == 0:
pickle.dump(prediction_scores, open('partial_predictions2.cPickle', 'wb'))
id = record.id
prediction_scores[id] = {}
seq = str(record.seq).replace("T", "U")
if len(seq) > 300:
continue
maxs = run_BP(seq, "", modules_to_parse, "NONE")
print(maxs)
for ind, module in enumerate(maxs):
if len(maxs[ind]) > 0:
prediction_scores[id][modules_to_parse[ind]] = (maxs[ind][0], maxs[ind][1])
else:
prediction_scores[id][modules_to_parse[ind]] = (0, [])
pickle.dump(prediction_scores, open("prediction_score_3.cPickle", "wb"))
if __name__ == "__main__":
modules_to_test = range(0,len(graphs))
#modules_to_test = [8, 13, 16, 15, 20, 42, 80, 97, 100, 109]
#modules_to_test = [4,5,7,20,24,30,31,32,33,34,38,40,44,54,57,60,64,75,76,78,79,98,99,100,101,104,107,108,109,111,116,119,124]
# cross_validation(modules_to_test)
# modules_to_test =[0, 5, 8, 26, 6, 18, 19, 21, 24, 25, 39, 53, 54, 58, 119, 122, 135, 146, 162, 168, 191, 194, 198, 216]
#modules_to_test = range(0, 24)
#modules_to_test = [8, 10, 13, 15, 16, 20, 29, 32, 34, 42, 49, 58, 69, 80, 88, 97, 100, 105, 109, 111, 125, 126, 128]
cross_validation(modules_to_test)
# test_fasta("cdnas.fa",modules_to_test)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment