Commit 2b95014c authored by Carlos GO's avatar Carlos GO
Browse files

tested

parent b646cadb
......@@ -32,6 +32,18 @@ See README in each folder for details on how to use each component.
## Usage
### Extracting the annotated data
Extract the annotated data and place it in the annotations folder if it has not yet been created.
```
cd data
tar -xzvf pockets_nx_symmetric_orig.tar.gz
cd ..
mkdir data/annotated
mv data/pockets_nx_symmetric_orig data/annotated
```
### Loading a trained model
......@@ -55,10 +67,14 @@ Making predictions for every graph in a folder.
```
from tools.learning_utils import inference_on_dir
graph_folder = "data/predict_plz"
Y = inference_on_dir("models/rnamigos", graph_dir)
graph_dir = "data/annotated/pockets_nx_symmetric_orig"
fp_pred,_ = inference_on_dir("rnamigos", graph_dir)
```
`fp_pred` is a N x 166 matrix where N is the number of graphs in `graph_dir` and each column corresponds to a fingerprint index.
The raw output is a probability, so if you want a binary fingerprint, do as above and use the `>0.5` filter.
### Training your own model
A basic example is training on the annotated graphs inside `data/annotated` on default settings.
......
......@@ -146,8 +146,9 @@ def collate_wrapper(node_sim_func, get_sim_mat=True):
# The input `samples` is a list of pairs
# (graph, label).
# print(len(samples))
graphs, _, fp, idx = map(list, zip(*samples))
graphs, fp, idx = map(list, zip(*samples))
fp = np.array(fp)
idx = np.array(idx)
batched_graph = dgl.batch(graphs)
return batched_graph, [1 for _ in samples], torch.from_numpy(fp), torch.from_numpy(idx)
return collate_block
......@@ -160,7 +161,7 @@ class Loader():
sim_function="R_1",
shuffle=False,
seed=0,
get_sim_mat=True,
get_sim_mat=False,
nucs=True,
depth=3):
"""
......@@ -174,6 +175,7 @@ class Loader():
:param siamese: for the batch siamese technique
:param full_siamese for the true siamese one
"""
self.all_graphs = sorted(os.listdir(annotated_path))
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset = V1(annotated_path=annotated_path,
......@@ -226,7 +228,25 @@ class Loader():
# return train_loader, valid_loader, test_loader
yield train_loader, test_loader
#full loader
class InferenceLoader(Loader):
def __init__(self,
annotated_path,
batch_size=5,
num_workers=20):
super().__init__(
annotated_path=annotated_path,
batch_size=batch_size,
num_workers=num_workers)
self.dataset.all_graphs = sorted(os.listdir(annotated_path))
def get_data(self):
collate_block = collate_wrapper(None, get_sim_mat=False)
train_loader = DataLoader(dataset=self.dataset, shuffle=False, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
return train_loader
if __name__ == '__main__':
loader = Loader(shuffle=False,seed=99, batch_size=1, num_workers=1)
data = loader.get_data(k_fold=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment