Commit 39d75522 authored by Carlos GO's avatar Carlos GO
Browse files

decoys in trianing almost ready

parent 49311fe1
...@@ -14,6 +14,7 @@ if __name__ == '__main__': ...@@ -14,6 +14,7 @@ if __name__ == '__main__':
sys.path.append('../') sys.path.append('../')
from learning.utils import dgl_to_nx from learning.utils import dgl_to_nx
from learning.decoy_utils import *
from post.drawing import rna_draw from post.drawing import rna_draw
def send_graph_to_device(g, device): def send_graph_to_device(g, device):
...@@ -60,7 +61,7 @@ def print_gradients(model): ...@@ -60,7 +61,7 @@ def print_gradients(model):
name, p = param name, p = param
print(name, p.grad) print(name, p.grad)
pass pass
def test(model, test_loader, device): def test(model, test_loader, device, decoys=None):
""" """
Compute accuracy and loss of model over given dataset Compute accuracy and loss of model over given dataset
:param model: :param model:
...@@ -71,8 +72,9 @@ def test(model, test_loader, device): ...@@ -71,8 +72,9 @@ def test(model, test_loader, device):
""" """
model.eval() model.eval()
test_loss, motif_loss_tot, recons_loss_tot = (0,) * 3 test_loss, motif_loss_tot, recons_loss_tot = (0,) * 3
all_graphs = test_loader.dataset.dataset.all_graphs
test_size = len(test_loader) test_size = len(test_loader)
for batch_idx, (graph, K, fp) in enumerate(test_loader): for batch_idx, (graph, K, fp, idx) in enumerate(test_loader):
# Get data on the devices # Get data on the devices
K = K.to(device) K = K.to(device)
fp = fp.to(device) fp = fp.to(device)
...@@ -83,6 +85,10 @@ def test(model, test_loader, device): ...@@ -83,6 +85,10 @@ def test(model, test_loader, device):
with torch.no_grad(): with torch.no_grad():
fp_pred, embeddings = model(graph) fp_pred, embeddings = model(graph)
loss = model.compute_loss(fp, fp_pred) loss = model.compute_loss(fp, fp_pred)
for i,f in zip(idx, fp_pred):
true_lig = all_graphs[i.item()].split(":")[2]
rank,sim = decoy_test(f, true_lig, decoys)
del K del K
del fp del fp
del graph del graph
...@@ -95,7 +101,8 @@ def test(model, test_loader, device): ...@@ -95,7 +101,8 @@ def test(model, test_loader, device):
def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path, def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path,
writer=None, num_epochs=25, wall_time=None, writer=None, num_epochs=25, wall_time=None,
reconstruction_lam=1, motif_lam=1, embed_only=-1): reconstruction_lam=1, motif_lam=1, embed_only=-1,
decoys=None):
""" """
Performs the entire training routine. Performs the entire training routine.
:param model: (torch.nn.Module): the model to train :param model: (torch.nn.Module): the model to train
...@@ -152,7 +159,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader, ...@@ -152,7 +159,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
num_batches = len(train_loader) num_batches = len(train_loader)
for batch_idx, (graph, K, fp) in enumerate(train_loader): for batch_idx, (graph, K, fp, idx) in enumerate(train_loader):
# Get data on the devices # Get data on the devices
batch_size = len(K) batch_size = len(K)
...@@ -201,7 +208,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader, ...@@ -201,7 +208,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch) # writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase # Test phase
test_loss = test(model, test_loader, device) test_loss = test(model, test_loader, device, decoys=decoys)
print(">> test loss ", test_loss) print(">> test loss ", test_loss)
writer.add_scalar("Test loss during training", test_loss, epoch) writer.add_scalar("Test loss during training", test_loss, epoch)
......
...@@ -91,10 +91,10 @@ class V1(Dataset): ...@@ -91,10 +91,10 @@ class V1(Dataset):
if self.get_sim_mat: if self.get_sim_mat:
# put the rings in same order as the dgl graph # put the rings in same order as the dgl graph
ring = dict(sorted(ring.items())) ring = dict(sorted(ring.items()))
return g_dgl, ring, fp return g_dgl, ring, fp, [idx]
else: else:
return g_dgl, fp return g_dgl, fp, [idx]
def _get_edge_data(self): def _get_edge_data(self):
""" """
...@@ -125,20 +125,21 @@ def collate_wrapper(node_sim_func, get_sim_mat=True): ...@@ -125,20 +125,21 @@ def collate_wrapper(node_sim_func, get_sim_mat=True):
# The input `samples` is a list of pairs # The input `samples` is a list of pairs
# (graph, label). # (graph, label).
# print(len(samples)) # print(len(samples))
graphs, rings, fp = map(list, zip(*samples)) graphs, rings, fp, idx = map(list, zip(*samples))
fp = np.array(fp) fp = np.array(fp)
idx = np.array(idx)
batched_graph = dgl.batch(graphs) batched_graph = dgl.batch(graphs)
K = k_block_list(rings, node_sim_func) K = k_block_list(rings, node_sim_func)
return batched_graph, torch.from_numpy(K).detach().float(), torch.from_numpy(fp).detach().float() return batched_graph, torch.from_numpy(K).detach().float(), torch.from_numpy(fp).detach().float(), torch.from_numpy(idx)
else: else:
def collate_block(samples): def collate_block(samples):
# The input `samples` is a list of pairs # The input `samples` is a list of pairs
# (graph, label). # (graph, label).
# print(len(samples)) # print(len(samples))
graphs, _, fp = map(list, zip(*samples)) graphs, _, fp, idx = map(list, zip(*samples))
fp = np.array(fp) fp = np.array(fp)
batched_graph = dgl.batch(graphs) batched_graph = dgl.batch(graphs)
return batched_graph, [1 for _ in samples], torch.from_numpy(fp) return batched_graph, [1 for _ in samples], torch.from_numpy(fp), torch.from_numpy(idx)
return collate_block return collate_block
class Loader(): class Loader():
......
...@@ -100,7 +100,6 @@ print('Created data loader') ...@@ -100,7 +100,6 @@ print('Created data loader')
Model loading Model loading
''' '''
#increase output embeddings by 1 for nuc info #increase output embeddings by 1 for nuc info
if args.nucs: if args.nucs:
dim_add = 1 dim_add = 1
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment