Commit 36c49781 authored by Yixiong Sun's avatar Yixiong Sun
Browse files

Updated BN construction to use json dataset and optimized running of call_makeBN

parent 50fe5983
......@@ -23,6 +23,7 @@ import math
from Bio import AlignIO
from folding import Fold
import RNA
import json
Lambda = 0.35
md = RNA.md()
......@@ -35,10 +36,17 @@ def boltzmann(e):
def setup_BN(modules, BNs, dataset,left_out,leave_out_sequence=False,left_out_sequence="",Lambda=0.35,Theta=1, verbose=False,indexes=[]):
# Load dataset here for optimmization
with open("../models/" + dataset + ".json") as f:
modules_data = json.load(f)
# Create the BN if needed
for mod in modules:
if mod not in BNs:
print("making Bayes Net for module",mod)
BNs[mod] = makeBN.call_makeBN(mod, dataset,left_out,leave_out_sequence,left_out_sequence, Lambda, Theta,indexes=indexes)
# Pass in the module_data
BNs[mod] = makeBN.call_makeBN(mod, dataset, modules_data[str(mod)], Lambda, Theta)
if verbose:
print("Bayes Net dataset:", BNs.keys())
return BNs
......
......@@ -8,6 +8,7 @@ import BN_builder
import networkx as nx
import subprocess
from matplotlib import pyplot as plt
import json
USE_RFAM_SEQS=True
......@@ -244,10 +245,77 @@ def make_graph_from_carnaval(g, alignment):
return motif
def call_makeBN(mod,dataset,left_out, leave_out_seq = False, left_out_seq = "", Lambda=1, Theta=1,indexes=[],kfold=False,retrain=False):
def call_makeBN(mod, dataset, module_data, Lambda=1, Theta=1,retrain=False):
ok_indexes = []
current_ID = mod
excluded = left_out
#excluded = left_out
# Make the BN given the specified module
nodes = [tuple(x) for x in module_data["master_graph"]["nodes"]]
edges = [tuple(x) for x in module_data["master_graph"]["edges"]]
# build the networkx graphs from the data
# Only build the first graph or master graph?
g = nx.DiGraph()
g.add_nodes_from(nodes)
g.add_edges_from(edges)
# extract the motif sequences
motif_seqs = []
for training_set in module_data["training_set"]:
full_seq = training_set["seq"]
seq_pos = training_set["seq_pos"]
# Small error in dataset, will fix later
try:
motif_seq =[full_seq[x] for x in seq_pos]
motif_seqs.append("".join(motif_seq))
except:
pass
# Train the model if it doesn't exist or needs to be retrained
# Check if module already exists
if (os.path.isfile("../models/" + dataset + "_models.pickle")) and not retrain:
nets = pickle.load(open("../models/" + dataset + "_models.pickle", "rb"))
if mod in nets:
return nets[mod]
else:
aln = {}
for n in sorted(list(g.nodes())):
aln[n] = n
alignment = get_alignment(g, aln, motif_seqs)
else:
try:
existing_models = pickle.load(open("../models/" + dataset + "_models.pickle", 'rb'))
except:
existing_models = {}
if current_ID in existing_models:
return existing_models[current_ID]
else:
aln = {}
for n in sorted(list(g.nodes())):
aln[n] = n
alignment = get_alignment(g, aln, motif_seqs)
#TODO: test_seqs variable is outdated/not used, should be removed from the BN functions
test_seqs = []
motif = make_graph_from_carnaval(g, alignment)
pwm = BN_builder.build_pwm(sorted(list(motif.nodes())), alignment)
BN = BN_builder.build_BN(motif, pwm, alignment)
BN.from_alignment_dataset(dataset, mod, [g], alignment, test_seqs, [], Lambda, Theta)
return BN
# Below is old code with pickle dataset + CV
#TODO: Delete the old code
'''
g_list = pickle.load(open("../models/"+dataset + "_one_of_each_graph.cPickle",'rb'))
seq_list = pickle.load(open("../models/"+dataset + "_sequences.pickle",'rb'))
if type(seq_list) == list:
......@@ -338,18 +406,12 @@ def call_makeBN(mod,dataset,left_out, leave_out_seq = False, left_out_seq = "",
for n in sorted(list(g.nodes())):
aln[n] = n
alignment = get_alignment(g,aln,extra_seqs)
'''
motif = make_graph_from_carnaval(g, alignment)
pwm = BN_builder.build_pwm(sorted(list(motif.nodes())),alignment)
BN = BN_builder.build_BN(motif,pwm,alignment)
BN.from_alignment_dataset(dataset, mod, [g], alignment, test_seqs, [], Lambda, Theta)
return BN
# Temporary placeholder function for CV for building a Bayesian Network
def make_BN(module, dataset, graphs, motif_sequences, Lambda=0.35, Theta=1):
......
......@@ -12,6 +12,7 @@ from functools import reduce
from Bio import AlignIO
import os.path
from operator import itemgetter
import json
def unpick(dataset,direc,typ):
file_path = "../"+direc+"/" + dataset + "_"+typ
......@@ -57,8 +58,16 @@ def run_BP(seq, ss, modules_to_parse, dataset, left_out, aln=False, t=-5, sample
#siblings = pickle.load(open("../models/" + dataset + "_siblings.pickle", "rb"))
siblings = unpick(dataset,"models","siblings.pickle")
#siblings = unpick(dataset,"models","siblings.pickle")
# load the siblings from the json dataset, siblings should be a dict in the format {0: [1,2,3], 1:[0,2,3]}...
with open("../models/" + dataset + ".json") as f:
module_siblings = json.load(f)
# map the json dict to proper format
siblings = {int(k): v["siblings"] for k, v in module_siblings.items()}
# instead of mapping the dict to the sibling dict, pass the dataset through and process after
noSiblings = process_siblings(return_dict,siblings)
return noSiblings
......@@ -91,6 +100,7 @@ def process_siblings(results,siblings):
def parse_sequences(input,modules_to_check=[],dataset="",ss="",m=4,n=3,sm=0.3,mc=3,p=25000,sw=1,t=15.7,w=200,s=100,sscons=2):
if dataset=="":
dataset="rna3dmotif"
graphs = pickle.load(open("../models/" + dataset + "_one_of_each_graph.cPickle", "rb"))
if len(modules_to_check) == 0:
modules_to_check = range(len(graphs))
......@@ -661,7 +671,12 @@ if __name__ == "__main__":
# we load the modules from the dataset to get the number of modules available.
#graphs = pickle.load(open("../models/" + dataset + "_one_of_each_graph.cPickle", "rb"))
graphs = unpick(dataset,"models","one_of_each_graph.cPickle")
#graphs = unpick(dataset,"models","one_of_each_graph.cPickle")
# updated load from json
with open("../models/" + dataset +".json") as f:
modules = json.load(f)
if "mod" in arguments:
modules_to_check = [int(input_number) for input_number in arguments["mod"]]
else:
......@@ -671,7 +686,7 @@ if __name__ == "__main__":
#for i in excluded_modules:
# if i in modules_to_check:
# modules_to_check.remove(i)
modules_to_check = range(len(graphs))
modules_to_check = range(len(modules))
#timer.sleep(5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment