learning.py 10.8 KB
Newer Older
Carlos GO's avatar
first  
Carlos GO committed
1
2
3
4
5
6
import time
import torch
import torch.nn.functional as F
import sys
import dgl

Carlos GO's avatar
Carlos GO committed
7
8
9
10
11
12
#debug modules
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
##

Carlos GO's avatar
first  
Carlos GO committed
13
14
15
16
17
18
19
if __name__ == '__main__':
    sys.path.append('../')


def send_graph_to_device(g, device):
    """
    Send dgl graph to device
Carlos GO's avatar
Carlos GO committed
20
    :param g: :param device:
Carlos GO's avatar
first  
Carlos GO committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    :return:
    """
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)

    # nodes
    labels = g.node_attr_schemes()
    for l in labels.keys():
        g.ndata[l] = g.ndata.pop(l).to(device, non_blocking=True)

    # edges
    labels = g.edge_attr_schemes()
    for i, l in enumerate(labels.keys()):
        g.edata[l] = g.edata.pop(l).to(device, non_blocking=True)

    return g

Carlos GO's avatar
Carlos GO committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def set_gradients(model, embedding=True, attributor=True):
    """
        Set the gradients to the embedding and the attributor networks.
        If True sets requires_grad to true for network parameters.
    """
    for param in model.named_parameters():
        name, p = param
        name = name.split('.')[0]
        if name in ['embeddings', 'attributor']:
            p.requires_grad = attributor
        if name == 'layers':
            p.requires_grad = embedding
    pass

def print_gradients(model):
    """
        Set the gradients to the embedding and the attributor networks.
        If True sets requires_grad to true for network parameters.
    """
    for param in model.named_parameters():
        name, p = param
        print(name, p.grad)
    pass
61
def test(model, test_loader, device, reconstruction_lam, motif_lam):
Carlos GO's avatar
first  
Carlos GO committed
62
63
64
65
66
67
68
69
70
    """
    Compute accuracy and loss of model over given dataset
    :param model:
    :param test_loader:
    :param test_loss_fn:
    :param device:
    :return:
    """
    model.eval()
71
    test_loss,  motif_loss_tot, recons_loss_tot = (0,) * 3
Carlos GO's avatar
first  
Carlos GO committed
72
    test_size = len(test_loader)
Carlos GO's avatar
Carlos GO committed
73
    for batch_idx, (graph, K, fp) in enumerate(test_loader):
Carlos GO's avatar
first  
Carlos GO committed
74
75
        # Get data on the devices
        K = K.to(device)
76
        fp = fp.to(device)
Carlos GO's avatar
first  
Carlos GO committed
77
78
79
80
        K = torch.ones(K.shape).to(device) - K
        graph = send_graph_to_device(graph, device)

        # Do the computations for the forward pass
Carlos GO's avatar
Carlos GO committed
81
82
        with torch.no_grad():
            out, attributions = model(graph)
83
        loss, reconstruction_loss, motif_loss = compute_loss(model=model, attributions=attributions, fp=fp,
Carlos GO's avatar
first  
Carlos GO committed
84
85
                                                                         out=out, K=K, device=device,
                                                                         reconstruction_lam=reconstruction_lam,
86
                                                                         motif_lam=motif_lam)
Carlos GO's avatar
Carlos GO committed
87
88
89
        del K
        del fp
        del graph
Carlos GO's avatar
first  
Carlos GO committed
90

Carlos GO's avatar
Carlos GO committed
91
92
        recons_loss_tot += reconstruction_loss.item()
        motif_loss_tot += motif_loss.item()
Carlos GO's avatar
first  
Carlos GO committed
93
94
        test_loss += loss.item()

Carlos GO's avatar
Carlos GO committed
95
        del loss
Carlos GO's avatar
Carlos GO committed
96
97
        del reconstruction_loss
        del motif_loss
Carlos GO's avatar
Carlos GO committed
98

99
    return test_loss / test_size, motif_loss_tot / test_size, recons_loss_tot / test_size
Carlos GO's avatar
first  
Carlos GO committed
100

101
102
def compute_loss(model, attributions, out, K, fp,
        device, reconstruction_lam, motif_lam):
Carlos GO's avatar
first  
Carlos GO committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    """
    Compute the total loss and returns scalar value for each contribution of each term. Avoid overwriting loss terms
    :param model:
    :param attributions:
    :param out:
    :param K:
    :param device:
    :param reconstruction_lam:
    :param motif_lam:
    :return:
    """

    # reconstruction loss
    K_predict = torch.norm(out[:, None] - out, dim=2, p=2)
    reconstruction_loss = torch.nn.MSELoss()
    reconstruction_loss = reconstruction_loss(K_predict, K)
119
    motif_loss = torch.nn.BCELoss()
120
    motif_loss = motif_loss(attributions, fp)
Carlos GO's avatar
first  
Carlos GO committed
121
122


123
    loss = reconstruction_lam * reconstruction_loss + motif_lam * motif_loss
Carlos GO's avatar
Carlos GO committed
124
    return loss, reconstruction_loss, motif_loss
Carlos GO's avatar
first  
Carlos GO committed
125
126
127


def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path,
128
                writer=None, num_epochs=25, wall_time=None,
Carlos GO's avatar
Carlos GO committed
129
                reconstruction_lam=1, motif_lam=1, embed_only=-1):
Carlos GO's avatar
first  
Carlos GO committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    """
    Performs the entire training routine.
    :param model: (torch.nn.Module): the model to train
    :param criterion: the criterion to use (eg CrossEntropy)
    :param optimizer: the optimizer to use (eg SGD or Adam)
    :param device: the device on which to run
    :param train_loader: dataloader for training
    :param test_loader: dataloader for validation
    :param save_path: where to save the model
    :param writer: a Tensorboard object (defined in utils)
    :param num_epochs: int number of epochs
    :param wall_time: The number of hours you want the model to run
    :param reconstruction_lam: how much to enforce pariwise similarity conservation
    :param motif_lam: how much to enforce motif assignment
Carlos GO's avatar
Carlos GO committed
144
    :param embed_only: number of epochs before starting attributor training.
Carlos GO's avatar
first  
Carlos GO committed
145
146
147
148
    :return:
    """

    epochs_from_best = 0
Carlos GO's avatar
Carlos GO committed
149
    early_stop_threshold = 10
Carlos GO's avatar
first  
Carlos GO committed
150
151
152
153

    start_time = time.time()
    best_loss = sys.maxsize

Carlos GO's avatar
Carlos GO committed
154
155
156
157
158
159
160
161
    motif_lam_orig = motif_lam
    reconstruction_lam_orig = reconstruction_lam

    #if we delay attributor, start with attributor OFF
    #if <= -1, both always ON.
    if embed_only > -1:
        print("Switching attriutor OFF. Embeddings still ON.")
        set_gradients(model, attributor=False)
162
        motif_lam = 0
Carlos GO's avatar
Carlos GO committed
163

Carlos GO's avatar
first  
Carlos GO committed
164
165
166
167
168
169
170
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        # Training phase
        model.train()

Carlos GO's avatar
Carlos GO committed
171
172
173
174
175
        #switch off embedding grads, turn on attributor
        if epoch == embed_only:
            print("Switching attributor ON, embeddings OFF.")
            set_gradients(model, embedding=False, attributor=True)
            reconstruction_lam = 0
176
            motif_lam = motif_lam_orig
Carlos GO's avatar
Carlos GO committed
177

Carlos GO's avatar
first  
Carlos GO committed
178
179
180
181
182
183
        running_loss = 0.0

        time_epoch = time.perf_counter()

        num_batches = len(train_loader)

184
        for batch_idx, (graph, K, fp) in enumerate(train_loader):
Carlos GO's avatar
first  
Carlos GO committed
185
186
187
188

            # Get data on the devices
            batch_size = len(K)
            K = K.to(device)
189
            fp = fp.to(device)
Carlos GO's avatar
first  
Carlos GO committed
190
191
192
193
194
195
196
            K = torch.ones(K.shape).to(device) - K
            graph = send_graph_to_device(graph, device)

            # Do the computations for the forward pass
            out, attributions = model(graph)

            # Compute the loss with proper summation, solves the problem ?
197
            loss, reconstruction_loss, motif_loss = compute_loss(model=model, attributions=attributions, fp=fp,
Carlos GO's avatar
first  
Carlos GO committed
198
199
                                                                             out=out, K=K, device=device,
                                                                             reconstruction_lam=reconstruction_lam,
200
                                                                             motif_lam=motif_lam)
Carlos GO's avatar
Carlos GO committed
201
202
203
            del K
            del fp
            del graph
Carlos GO's avatar
first  
Carlos GO committed
204
205
206
207
208
209
210
211
            # Backward
            loss.backward()
            optimizer.step()
            model.zero_grad()

            # Metrics
            batch_loss = loss.item()
            running_loss += batch_loss
Carlos GO's avatar
Carlos GO committed
212

Carlos GO's avatar
first  
Carlos GO committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

            # running_corrects += labels.eq(target.view_as(out)).sum().item()
            if batch_idx % 20 == 0:
                time_elapsed = time.time() - start_time
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}  Time: {:.2f}'.format(
                    epoch + 1,
                    (batch_idx + 1) * batch_size,
                    num_batches * batch_size,
                    100. * (batch_idx + 1) / num_batches,
                    batch_loss,
                    time_elapsed))

                # tensorboard logging
                writer.add_scalar("Training batch loss", batch_loss,
                                  epoch * num_batches + batch_idx)
Carlos GO's avatar
Carlos GO committed
228
                writer.add_scalar("Training reconstruction loss", reconstruction_loss.item(),
Carlos GO's avatar
first  
Carlos GO committed
229
                                  epoch * num_batches + batch_idx)
Carlos GO's avatar
Carlos GO committed
230
                writer.add_scalar("Training motif loss", motif_loss.item(),
Carlos GO's avatar
first  
Carlos GO committed
231
232
                                  epoch * num_batches + batch_idx)

Carlos GO's avatar
Carlos GO committed
233
234
235
236
            del loss
            del reconstruction_loss
            del motif_loss

Carlos GO's avatar
first  
Carlos GO committed
237
238
239
240
241
242
243
244
        # Log training metrics
        train_loss = running_loss / num_batches
        writer.add_scalar("Training epoch loss", train_loss, epoch)

        # train_accuracy = running_corrects / num_batches
        # writer.log_scalar("Train accuracy during training", train_accuracy, epoch)

        # Test phase
Carlos GO's avatar
Carlos GO committed
245
        test_loss, motif_loss, reconstruction_loss = test(model, test_loader, device, reconstruction_lam, motif_lam)
Carlos GO's avatar
Carlos GO committed
246

Carlos GO's avatar
first  
Carlos GO committed
247

Carlos GO's avatar
Carlos GO committed
248
        print(f"test_loss {test_loss}, reconstruction_loss {reconstruction_loss}, fp loss {motif_loss}")
Carlos GO's avatar
first  
Carlos GO committed
249
        writer.add_scalar("Test loss during training", test_loss, epoch)
Carlos GO's avatar
Carlos GO committed
250
        writer.add_scalar("Test reconstruction loss", reconstruction_loss, epoch)
Carlos GO's avatar
first  
Carlos GO committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        writer.add_scalar("Test motif loss", motif_loss, epoch)

        # writer.log_scalar("Test accuracy during training", test_accuracy, epoch)

        # Checkpointing
        if test_loss < best_loss:
            best_loss = test_loss
            epochs_from_best = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion
            }, save_path)

        # Early stopping
        else:
            epochs_from_best += 1
            if epochs_from_best > early_stop_threshold:
                print('This model was early stopped')
                break

        # Sanity Check
        if wall_time is not None:
            # Break out of the loop if we might go beyond the wall time
            time_elapsed = time.time() - start_time
            if time_elapsed * (1 + 1 / (epoch + 1)) > .95 * wall_time * 3600:
                break
Carlos GO's avatar
Carlos GO committed
279
280
281
282
        del test_loss
        del reconstruction_loss
        del motif_loss

Carlos GO's avatar
first  
Carlos GO committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    return best_loss


def make_predictions(data_loader, model, optimizer, model_weights_path):
    """
    :param data_loader: an iterator on input data
    :param model: An empty model
    :param optimizer: An empty optimizer
    :param model_weights_path: the path of the model to load
    :return: list of predictions
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    checkpoint = torch.load(model_weights_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.eval()

    predictions = []

    for batch_idx, inputs in enumerate(data_loader):
        inputs = inputs.to(device)
        predictions.append(model(inputs))
    return predictions


if __name__ == "__main__":
    pass
# parser = argparse.ArgumentParser()
# parser.add_argument('--data_dir', default='../data/testset')
# parser.add_argument('--out_dir', default='Submissions/')
# parser.add_argument(
#     '--model_path', default='results/base_wr_lr01best_model.pth')
# args = parser.parse_args()
# make_predictions(args.data_dir, args.out_dir, args.model_path)