erermeev-d
commited on
Commit
·
0d80f56
1
Parent(s):
75a5562
Added random seed fixing
Browse files- exp/gnn/train.py +5 -1
- exp/gnn/utils.py +12 -0
exp/gnn/train.py
CHANGED
@@ -15,11 +15,14 @@ from exp.evaluate import evaluate_recsys
|
|
15 |
from exp.gnn.model import GNNModel
|
16 |
from exp.gnn.loss import nt_xent_loss
|
17 |
from exp.gnn.utils import (
|
18 |
-
prepare_graphs, LRSchedule,
|
19 |
sample_item_batch, inference_model)
|
20 |
|
21 |
|
22 |
def prepare_gnn_embeddings(config):
|
|
|
|
|
|
|
23 |
### Prepare graph
|
24 |
bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
|
25 |
bipartite_graph = bipartite_graph.to(config["device"])
|
@@ -127,6 +130,7 @@ if __name__ == "__main__":
|
|
127 |
parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
|
128 |
|
129 |
# Misc
|
|
|
130 |
parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
|
131 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
132 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
|
|
15 |
from exp.gnn.model import GNNModel
|
16 |
from exp.gnn.loss import nt_xent_loss
|
17 |
from exp.gnn.utils import (
|
18 |
+
prepare_graphs, LRSchedule, fix_random,
|
19 |
sample_item_batch, inference_model)
|
20 |
|
21 |
|
22 |
def prepare_gnn_embeddings(config):
|
23 |
+
### Fix random seed
|
24 |
+
fix_random(config["seed"])
|
25 |
+
|
26 |
### Prepare graph
|
27 |
bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
|
28 |
bipartite_graph = bipartite_graph.to(config["device"])
|
|
|
130 |
parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
|
131 |
|
132 |
# Misc
|
133 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
134 |
parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
|
135 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
136 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
exp/gnn/utils.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
import dgl
|
3 |
import pandas as pd
|
@@ -7,6 +9,16 @@ from tqdm.auto import tqdm
|
|
7 |
from exp.utils import normalize_embeddings
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class LRSchedule:
|
11 |
def __init__(self, total_steps, warmup_steps, final_factor):
|
12 |
self._total_steps = total_steps
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
import torch
|
4 |
import dgl
|
5 |
import pandas as pd
|
|
|
9 |
from exp.utils import normalize_embeddings
|
10 |
|
11 |
|
12 |
+
def fix_random(seed):
|
13 |
+
dgl.seed(seed)
|
14 |
+
torch.random.manual_seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
random.seed(seed)
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
torch.cuda.manual_seed(seed)
|
19 |
+
torch.cuda.manual_seed_all(seed)
|
20 |
+
|
21 |
+
|
22 |
class LRSchedule:
|
23 |
def __init__(self, total_steps, warmup_steps, final_factor):
|
24 |
self._total_steps = total_steps
|