erermeev-d commited on
Commit
0d80f56
·
1 Parent(s): 75a5562

Added random seed fixing

Browse files
Files changed (2) hide show
  1. exp/gnn/train.py +5 -1
  2. exp/gnn/utils.py +12 -0
exp/gnn/train.py CHANGED
@@ -15,11 +15,14 @@ from exp.evaluate import evaluate_recsys
15
  from exp.gnn.model import GNNModel
16
  from exp.gnn.loss import nt_xent_loss
17
  from exp.gnn.utils import (
18
- prepare_graphs, LRSchedule,
19
  sample_item_batch, inference_model)
20
 
21
 
22
  def prepare_gnn_embeddings(config):
 
 
 
23
  ### Prepare graph
24
  bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
25
  bipartite_graph = bipartite_graph.to(config["device"])
@@ -127,6 +130,7 @@ if __name__ == "__main__":
127
  parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
128
 
129
  # Misc
 
130
  parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
131
  parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
132
  parser.add_argument("--wandb_name", type=str, help="WandB run name")
 
15
  from exp.gnn.model import GNNModel
16
  from exp.gnn.loss import nt_xent_loss
17
  from exp.gnn.utils import (
18
+ prepare_graphs, LRSchedule, fix_random,
19
  sample_item_batch, inference_model)
20
 
21
 
22
  def prepare_gnn_embeddings(config):
23
+ ### Fix random seed
24
+ fix_random(config["seed"])
25
+
26
  ### Prepare graph
27
  bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
28
  bipartite_graph = bipartite_graph.to(config["device"])
 
130
  parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
131
 
132
  # Misc
133
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
134
  parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
135
  parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
136
  parser.add_argument("--wandb_name", type=str, help="WandB run name")
exp/gnn/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import dgl
3
  import pandas as pd
@@ -7,6 +9,16 @@ from tqdm.auto import tqdm
7
  from exp.utils import normalize_embeddings
8
 
9
 
 
 
 
 
 
 
 
 
 
 
10
  class LRSchedule:
11
  def __init__(self, total_steps, warmup_steps, final_factor):
12
  self._total_steps = total_steps
 
1
+ import random
2
+
3
  import torch
4
  import dgl
5
  import pandas as pd
 
9
  from exp.utils import normalize_embeddings
10
 
11
 
12
+ def fix_random(seed):
13
+ dgl.seed(seed)
14
+ torch.random.manual_seed(seed)
15
+ np.random.seed(seed)
16
+ random.seed(seed)
17
+ if torch.cuda.is_available():
18
+ torch.cuda.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+
21
+
22
  class LRSchedule:
23
  def __init__(self, total_steps, warmup_steps, final_factor):
24
  self._total_steps = total_steps