Commit 49311fe1 authored by Carlos GO's avatar Carlos GO
Browse files

kfold in learning main

parent 71b10ddd
......@@ -2,6 +2,8 @@ import argparse
import os, sys
import pickle
import numpy as np
cwd = os.getcwd()
if cwd.endswith('learn'):
sys.path.append('../')
......@@ -13,7 +15,6 @@ parser.add_argument("-p", "--parallel", default=True, help="If we don't want to
parser.add_argument("-da", "--annotated_data", default='pockets_nx_symmetric')
parser.add_argument("-bs", "--batch_size", type=int, default=8, help="choose the batch size")
parser.add_argument("-nw", "--workers", type=int, default=20, help="Number of workers to load data")
parser.add_argument("-wt", "--wall_time", type=int, default=None, help="Max time to run the model")
parser.add_argument("-n", "--name", type=str, default='default_name', help="Name for the logs")
parser.add_argument("-t", "--timed", help="to use timed learn", action='store_true')
parser.add_argument("-ep", "--num_epochs", type=int, help="number of epochs to train", default=3)
......@@ -28,6 +29,7 @@ parser.add_argument('-pw', '--pos_weight', type=int, default=0, help='Weight for
parser.add_argument('-po', '--pool', type=str, default='sum', help='Pooling function to use.')
parser.add_argument("-nu", "--nucs", default=True, help="Use nucleotide IDs for learn", action='store_false')
parser.add_argument('-rs', '--seed', type=int, default=0, help='Random seed to use (if > 0, else no seed is set).')
parser.add_argument('-kf', '--kfold', type=int, default=0, help='Do k-fold crossval and do decoys on each fold..')
args = parser.parse_args()
......@@ -89,12 +91,10 @@ loader = Loader(annotated_path=annotated_path,
sim_function=args.sim_function,
nucs=args.nucs)
train_loader, _, test_loader = loader.get_data()
print('Created data loader')
if len(train_loader) == 0 & len(test_loader) == 0:
raise ValueError('there are not enough points compared to the BS')
print('Created data loader')
'''
Model loading
......@@ -114,89 +114,75 @@ if dims[-1] != attributor_dims[0] - dim_add:
motif_lam = args.motif_lam
reconstruction_lam = args.reconstruction_lam
model = Model(dims, device, attributor_dims=attributor_dims,
num_rels=loader.num_edge_types,
num_bases=-1, pool=args.pool,
pos_weight=args.pos_weight,
nucs=nucs)
#if pre-trained initialize matching layers
if args.warm_start:
print("warm starting")
m = torch.load(args.warm_start, map_location='cpu')['model_state_dict']
#remove keys not related to embeddings
for k in list(m.keys()):
if 'embedder' not in k:
print("killing ", k)
del m[k]
missing = model.load_state_dict(m, strict=False)
print(missing)
model = model.to(device)
# for param_tensor in model.state_dict():
# if 'embedder' in param_tensor:
# print(param_tensor, "\t", model.state_dict()[param_tensor])
print(f'Using {model.__class__} as model')
if used_gpus_count > 1:
model = torch.nn.DataParallel(model)
'''
Optimizer instanciation
'''
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters())
# print(list(model.named_parameters()))
# raise ValueError
# optimizer.add_param_group({'param': model.embeddings})
# optimizer = optim.SGD(model.parameters(), lr=1)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
'''
Experiment Setup
'''
name = args.name
result_folder, save_path = mkdirs(name)
print(save_path)
writer = SummaryWriter(result_folder)
print(f'Saving result in {result_folder}/{name}')
meta = {k:getattr(args, k) for k in dir(args) if not k.startswith("_")}
meta['edge_map'] = train_loader.dataset.dataset.edge_map
#save metainfo
pickle.dump(meta, open(os.path.join(result_folder, 'meta.p'), 'wb'))
import numpy as np
all_graphs = np.array(test_loader.dataset.dataset.all_graphs)
test_inds = test_loader.dataset.indices
train_inds = train_loader.dataset.indices
pickle.dump(({'test': all_graphs[test_inds], 'train': all_graphs[train_inds]}),
open(os.path.join(result_folder, 'splits.p'), 'wb'))
wall_time = args.wall_time
'''
Run
'''
num_epochs = args.num_epochs
learn.train_model(model=model,
criterion=criterion,
optimizer=optimizer,
device=device,
train_loader=train_loader,
test_loader=test_loader,
save_path=save_path,
writer=writer,
num_epochs=num_epochs,
wall_time=wall_time,
reconstruction_lam=reconstruction_lam,
motif_lam=motif_lam,
embed_only=args.embed_only)
data = loader.get_data(k_fold=args.kfold)
for k, (train_loader, test_loader) in enumerate(data):
model = Model(dims, device, attributor_dims=attributor_dims,
num_rels=loader.num_edge_types,
num_bases=-1, pool=args.pool,
pos_weight=args.pos_weight,
nucs=args.nucs)
#if pre-trained initialize matching layers
if args.warm_start:
print("warm starting")
m = torch.load(args.warm_start, map_location='cpu')['model_state_dict']
#remove keys not related to embeddings
for k in list(m.keys()):
if 'embedder' not in k:
print("killing ", k)
del m[k]
missing = model.load_state_dict(m, strict=False)
print(missing)
model = model.to(device)
print(f'Using {model.__class__} as model')
'''
Optimizer instanciation
'''
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters())
'''
Experiment Setup
'''
name = f"{args.name}_{k}"
result_folder, save_path = mkdirs(name)
print(save_path)
writer = SummaryWriter(result_folder)
print(f'Saving result in {result_folder}/{name}')
meta = {k:getattr(args, k) for k in dir(args) if not k.startswith("_")}
meta['edge_map'] = train_loader.dataset.dataset.edge_map
#save metainfo
pickle.dump(meta, open(os.path.join(result_folder, 'meta.p'), 'wb'))
all_graphs = np.array(test_loader.dataset.dataset.all_graphs)
test_inds = test_loader.dataset.indices
train_inds = train_loader.dataset.indices
pickle.dump(({'test': all_graphs[test_inds], 'train': all_graphs[train_inds]}),
open(os.path.join(result_folder, f'splits_{k}.p'), 'wb'))
'''
Run
'''
num_epochs = args.num_epochs
learn.train_model(model=model,
criterion=criterion,
optimizer=optimizer,
device=device,
train_loader=train_loader,
test_loader=test_loader,
save_path=save_path,
writer=writer,
num_epochs=num_epochs,
reconstruction_lam=reconstruction_lam,
motif_lam=motif_lam,
embed_only=args.embed_only)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment