Commit 205c394e authored by Carlos GO's avatar Carlos GO
Browse files

get pair shuffle violins

parent e3f2dbb0
......@@ -3,7 +3,6 @@ import os
import pickle
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--parallel", default=True, help="If we don't want to run thing in parallel", action='store_false')
parser.add_argument("-da", "--annotated_data", default='samples')
parser.add_argument("-bs", "--batch_size", type=int, default=128, help="choose the batch size")
parser.add_argument("-nw", "--workers", type=int, default=20, help="Number of workers to load data")
......@@ -54,10 +53,7 @@ Hardware settings
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# This is to create an appropriate number of workers, but works too with cpu
if args.parallel:
used_gpus_count = torch.cuda.device_count()
else:
used_gpus_count = 1
used_gpus_count = 1
print(f'Using {used_gpus_count} GPUs')
......
......@@ -58,7 +58,7 @@ def distance_rank(active, pred, decoys, dist_func=jaccard):
rank += 1
return 1- (rank / (len(decoys) + 1))
def decoy_test(model, decoys, edge_map, embed_dim, test_graphlist=None, test_graph_path="../data/annotated/pockets_nx"):
def decoy_test(model, decoys, edge_map, embed_dim, shuffle=False, test_graphlist=None, test_graph_path="../data/annotated/pockets_nx"):
"""
Check performance against decoy set.
decoys --> {'ligand_id', ('expected_FP', [decoy_fps])}
......@@ -79,6 +79,8 @@ def decoy_test(model, decoys, edge_map, embed_dim, test_graphlist=None, test_gra
g,_,_,_ = pickle.load(open(os.path.join(test_graph_path, g_path), 'rb'))
try:
true_id = g_path.split(":")[2]
if shuffle:
true_id = random.choice(list(decoys.keys()))
except:
print(f">> failed on {g_path}")
continue
......@@ -93,12 +95,12 @@ def decoy_test(model, decoys, edge_map, embed_dim, test_graphlist=None, test_gra
return np.mean(ranks), ranks
def ablation_results():
modes = ['', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle']
modes = ['', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
decoys = get_decoys()
ranks, methods = [], []
for m in modes:
if m == '':
if m in ['', 'pair-shuffle']:
graph_dir = "../data/annotated/pockets_nx"
run = 'small_no_rec_2'
else:
......@@ -107,6 +109,9 @@ def ablation_results():
edge_map = get_edge_map(graph_dir)
num_edge_types = len(edge_map)
shuffle = False
if m == 'pair-shuffle':
shuffle = True
dims = [32] * 3
# dims = [32]*6
......@@ -117,7 +122,10 @@ def ablation_results():
graph_ids = pickle.load(open(f'../results/{run}/splits.p', 'rb'))
acc, ranks_this = decoy_test(model, decoys, edge_map, 32, test_graphlist=graph_ids['test'], test_graph_path=graph_dir)
acc, ranks_this = decoy_test(model, decoys, edge_map, 32,
test_graphlist=graph_ids['test'],
test_graph_path=graph_dir,
shuffle=shuffle)
ranks.extend(ranks_this)
methods.extend([m]*len(ranks_this))
print("test", 1-acc)
......@@ -130,7 +138,7 @@ def ablation_results():
for artist in ax.findobj(PathCollection):
artist.set_zorder(11)
sns.stripplot(data=df, x='method', y='rank', jitter=True, alpha=0.6)
plt.savefig("../tex/Figs/violins_gcn.pdf", format="pdf")
# plt.savefig("../tex/Figs/violins_gcn.pdf", format="pdf")
plt.show()
if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment