Commit 71b10ddd authored by Carlos GO's avatar Carlos GO
Browse files

kfold check

parent bf963188
......@@ -181,15 +181,16 @@ class Loader():
if k_fold > 1:
from sklearn.model_selection import KFold
for train_indices, test_indices in kf.split(np.array(indices)):
kf = KFold(n_splits=k_fold)
for train_indices, test_indices in kf.split(np.array(indices), np.array(indices)):
train_set = Subset(self.dataset, train_indices)
test_set = Subset(self.dataset, test_indices)
collate_block = collate_wrapper(self.dataset.node_sim_func)
train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=self.batch_size,
train_loader = DataLoader(dataset=train_set, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
test_loader = DataLoader(dataset=test_set, shuffle=True, batch_size=self.batch_size,
test_loader = DataLoader(dataset=test_set, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
yield train_loader, test_loader
......@@ -220,7 +221,8 @@ class Loader():
if __name__ == '__main__':
loader = Loader(shuffle=False,seed=99, batch_size=1, num_workers=1)
train,test = loader.get_data()
for t in train:
data = loader.get_data(k_fold=5)
for train, test in data:
print(len(train), len(test))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment