Commit d8468622 authored by Yixiong Sun's avatar Yixiong Sun
Browse files

Changed output to save after every module number for continuous validation

parent 14a10753
......@@ -460,56 +460,69 @@ def two_fold_CV(modules_to_test, output_file, shuffle, ss):
:return:
"""
cv_results = {}
TO_SKIP = []
for module_number in modules_to_test:
module_results = []
if module_number in TO_SKIP: # we have done a sibling
print("Module", module_number, "skipped.")
continue
module_sequences = mod_sequences[module_number]
target_sequences = test_sequences[module_number]
# Idea - open pickle results, load previous and continue
try:
cv_results = pickle.load(open(output_file, "rb"))
except IOError:
cv_results = {}
# If key doesn't exists, validate, otherwise skip it
if module_number not in cv_results:
module_results = []
module_sequences = mod_sequences[module_number]
target_sequences = test_sequences[module_number]
num_sequences = len(module_sequences)
if num_sequences < 2:
print("Module", module_number, "does not have enough data for Cross Validation, skipped.")
num_sequences = len(module_sequences)
if num_sequences < 2:
print("Module", module_number, "does not have enough data for Cross Validation, skipped.")
indexes = list(range(0, num_sequences))
indexes = list(range(0, num_sequences))
# Shuffle and split indexes in half
random.shuffle(indexes)
# Shuffle and split indexes in half
random.shuffle(indexes)
# TODO: Fix this so it's more logical with higher K-folds
split = [indexes[:num_sequences//2], indexes[num_sequences//2:]]
# TODO: Fix this so it's more logical with higher K-folds
split = [indexes[:num_sequences//2], indexes[num_sequences//2:]]
for j in [0,1]:
for j in [0,1]:
test_indexes = split[1 - j]
train_indexes = split[0 + 1]
modules_train = [module_sequences[i] for i in train_indexes]
modules_test = [module_sequences[i] for i in test_indexes]
target_test = [target_sequences[i] for i in test_indexes]
test_indexes = split[1 - j]
train_indexes = split[0 + 1]
modules_train = [module_sequences[i] for i in train_indexes]
modules_test = [module_sequences[i] for i in test_indexes]
target_test = [target_sequences[i] for i in test_indexes]
# Train the BN
BNs = train_BN(module_number, motif_sequences=modules_train)
# Train the BN
BNs = train_BN(module_number, motif_sequences=modules_train)
# Results = list of scores
module_results = module_results + run_validation(module_number, BNs=BNs, module_sequences=modules_test, target_sequences=target_test,test_indexes=test_indexes, shuffle=shuffle, ss=ss)
# Results = list of scores
module_results = module_results + run_validation(module_number, BNs=BNs, module_sequences=modules_test, target_sequences=target_test,test_indexes=test_indexes, shuffle=shuffle, ss=ss)
# Set results in dict
cv_results[module_number] = module_results
# Save to output file
# Set results in dict
cv_results[module_number] = module_results
pickle.dump(cv_results, open(output_file, "wb"))
# If module is in siblings, skip the rest
if module_number in siblings:
for sib in siblings[module_number]:
TO_SKIP.append(sib)
# Save to output file
pickle.dump(cv_results, open(output_file, "wb"))
......@@ -540,7 +553,7 @@ if __name__ == "__main__":
# Test all the modules
modules_to_test = list(range(0, len(graphs)))
#modules_to_test = [0]
#modules_to_test = [0,1,2,3,4]
# Output file name
file_name = "2fold_cv_" + DATASET_NAME
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment