Commit 39d75522 authored by Carlos GO's avatar Carlos GO
Browse files

decoys in trianing almost ready

parent 49311fe1
......@@ -14,6 +14,7 @@ if __name__ == '__main__':
sys.path.append('../')
from learning.utils import dgl_to_nx
from learning.decoy_utils import *
from post.drawing import rna_draw
def send_graph_to_device(g, device):
......@@ -60,7 +61,7 @@ def print_gradients(model):
name, p = param
print(name, p.grad)
pass
def test(model, test_loader, device):
def test(model, test_loader, device, decoys=None):
"""
Compute accuracy and loss of model over given dataset
:param model:
......@@ -71,8 +72,9 @@ def test(model, test_loader, device):
"""
model.eval()
test_loss, motif_loss_tot, recons_loss_tot = (0,) * 3
all_graphs = test_loader.dataset.dataset.all_graphs
test_size = len(test_loader)
for batch_idx, (graph, K, fp) in enumerate(test_loader):
for batch_idx, (graph, K, fp, idx) in enumerate(test_loader):
# Get data on the devices
K = K.to(device)
fp = fp.to(device)
......@@ -83,6 +85,10 @@ def test(model, test_loader, device):
with torch.no_grad():
fp_pred, embeddings = model(graph)
loss = model.compute_loss(fp, fp_pred)
for i,f in zip(idx, fp_pred):
true_lig = all_graphs[i.item()].split(":")[2]
rank,sim = decoy_test(f, true_lig, decoys)
del K
del fp
del graph
......@@ -95,7 +101,8 @@ def test(model, test_loader, device):
def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path,
writer=None, num_epochs=25, wall_time=None,
reconstruction_lam=1, motif_lam=1, embed_only=-1):
reconstruction_lam=1, motif_lam=1, embed_only=-1,
decoys=None):
"""
Performs the entire training routine.
:param model: (torch.nn.Module): the model to train
......@@ -152,7 +159,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
num_batches = len(train_loader)
for batch_idx, (graph, K, fp) in enumerate(train_loader):
for batch_idx, (graph, K, fp, idx) in enumerate(train_loader):
# Get data on the devices
batch_size = len(K)
......@@ -201,7 +208,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss = test(model, test_loader, device)
test_loss = test(model, test_loader, device, decoys=decoys)
print(">> test loss ", test_loss)
writer.add_scalar("Test loss during training", test_loss, epoch)
......
......@@ -91,10 +91,10 @@ class V1(Dataset):
if self.get_sim_mat:
# put the rings in same order as the dgl graph
ring = dict(sorted(ring.items()))
return g_dgl, ring, fp
return g_dgl, ring, fp, [idx]
else:
return g_dgl, fp
return g_dgl, fp, [idx]
def _get_edge_data(self):
"""
......@@ -125,20 +125,21 @@ def collate_wrapper(node_sim_func, get_sim_mat=True):
# The input `samples` is a list of pairs
# (graph, label).
# print(len(samples))
graphs, rings, fp = map(list, zip(*samples))
graphs, rings, fp, idx = map(list, zip(*samples))
fp = np.array(fp)
idx = np.array(idx)
batched_graph = dgl.batch(graphs)
K = k_block_list(rings, node_sim_func)
return batched_graph, torch.from_numpy(K).detach().float(), torch.from_numpy(fp).detach().float()
return batched_graph, torch.from_numpy(K).detach().float(), torch.from_numpy(fp).detach().float(), torch.from_numpy(idx)
else:
def collate_block(samples):
# The input `samples` is a list of pairs
# (graph, label).
# print(len(samples))
graphs, _, fp = map(list, zip(*samples))
graphs, _, fp, idx = map(list, zip(*samples))
fp = np.array(fp)
batched_graph = dgl.batch(graphs)
return batched_graph, [1 for _ in samples], torch.from_numpy(fp)
return batched_graph, [1 for _ in samples], torch.from_numpy(fp), torch.from_numpy(idx)
return collate_block
class Loader():
......
......@@ -100,7 +100,6 @@ print('Created data loader')
Model loading
'''
#increase output embeddings by 1 for nuc info
if args.nucs:
dim_add = 1
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment