jasvir-singh1021 commited on
Commit
e915946
·
verified ·
1 Parent(s): 591185f

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +68 -0
train_model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
3
+ import torch
4
+ import numpy as np
5
+
6
+ # Load dataset
7
+ dataset = load_dataset("imranraad/github-emotion-love")
8
+
9
+ # Multi-label setup
10
+ emotions = ["Anger", "Love", "Fear", "Joy", "Sadness", "Surprise"]
11
+
12
+ # Tokenizer
13
+ model_name = "distilbert-base-uncased"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ def tokenize(batch):
17
+ return tokenizer(batch['modified_comment'], padding='max_length', truncation=True, max_length=128)
18
+
19
+ dataset = dataset.map(tokenize, batched=True)
20
+
21
+ # Convert labels to list of floats for multi-label
22
+ def format_labels(batch):
23
+ batch["labels"] = [[batch[emo][i] for emo in emotions] for i in range(len(batch[emotions[0]]))]
24
+ return batch
25
+
26
+ dataset = dataset.map(format_labels, batched=True)
27
+
28
+ # Load model
29
+ model = AutoModelForSequenceClassification.from_pretrained(
30
+ model_name,
31
+ num_labels=len(emotions),
32
+ problem_type="multi_label_classification"
33
+ )
34
+
35
+ # Training arguments
36
+ training_args = TrainingArguments(
37
+ output_dir="./model",
38
+ evaluation_strategy="epoch",
39
+ learning_rate=2e-5,
40
+ per_device_train_batch_size=16,
41
+ per_device_eval_batch_size=16,
42
+ num_train_epochs=3,
43
+ weight_decay=0.01,
44
+ logging_dir="./logs",
45
+ save_strategy="epoch"
46
+ )
47
+
48
+ # Metrics
49
+ def compute_metrics(pred):
50
+ logits, labels = pred
51
+ sigmoid = torch.nn.Sigmoid()
52
+ probs = sigmoid(torch.tensor(logits))
53
+ preds = (probs > 0.5).float()
54
+ accuracy = (preds == torch.tensor(labels)).float().mean()
55
+ return {"accuracy": accuracy.item()}
56
+
57
+ # Trainer
58
+ trainer = Trainer(
59
+ model=model,
60
+ args=training_args,
61
+ train_dataset=dataset["train"],
62
+ eval_dataset=dataset["test"],
63
+ tokenizer=tokenizer,
64
+ compute_metrics=compute_metrics
65
+ )
66
+
67
+ trainer.train()
68
+ trainer.save_model("./model")