Commit 70cdc345 authored by Carlos GO's avatar Carlos GO
Browse files

decoy testing in model done

parent 39d75522
......@@ -86,9 +86,12 @@ def test(model, test_loader, device, decoys=None):
fp_pred, embeddings = model(graph)
loss = model.compute_loss(fp, fp_pred)
enrichments = []
for i,f in zip(idx, fp_pred):
true_lig = all_graphs[i.item()].split(":")[2]
rank,sim = decoy_test(f, true_lig, decoys)
enrichments.append(rank)
decoy_ranks = np.mean(enrichments)
del K
del fp
del graph
......@@ -97,7 +100,7 @@ def test(model, test_loader, device, decoys=None):
del loss
return test_loss / test_size
return test_loss / test_size, decoy_ranks
def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path,
writer=None, num_epochs=25, wall_time=None,
......@@ -123,6 +126,8 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
edge_map = train_loader.dataset.dataset.edge_map
decoys = get_decoys(mode='pdb', annots_dir=train_loader.dataset.dataset.path)
epochs_from_best = 0
early_stop_threshold = 10
......@@ -208,8 +213,9 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss = test(model, test_loader, device, decoys=decoys)
test_loss, enrichments = test(model, test_loader, device, decoys=decoys)
print(">> test loss ", test_loss)
print(">> test enrichments", enrichments)
writer.add_scalar("Test loss during training", test_loss, epoch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment