ahk-d commited on
Commit
86069e9
·
verified ·
1 Parent(s): 2b6e0d1

Create gpt_min.py

Browse files
Files changed (1) hide show
  1. gpt_min.py +102 -0
gpt_min.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class CausalSelfAttention(nn.Module):
7
+ def __init__(self, n_embd, n_head, dropout=0.1):
8
+ super().__init__()
9
+ assert n_embd % n_head == 0
10
+ self.n_embd = n_embd
11
+ self.n_head = n_head
12
+ self.head_dim = n_embd // n_head
13
+ self.query = nn.Linear(n_embd, n_embd, bias=False)
14
+ self.key = nn.Linear(n_embd, n_embd, bias=False)
15
+ self.value = nn.Linear(n_embd, n_embd, bias=False)
16
+ self.output = nn.Linear(n_embd, n_embd)
17
+ self.dropout = nn.Dropout(dropout)
18
+
19
+ def forward(self, x):
20
+ B, T, C = x.shape
21
+ q = self.query(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
22
+ k = self.key(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
23
+ v = self.value(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
24
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
25
+ mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
26
+ scores = scores.masked_fill(mask, float('-inf'))
27
+ attn = F.softmax(scores, dim=-1)
28
+ attn = self.dropout(attn)
29
+ out = torch.matmul(attn, v)
30
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
31
+ return self.output(out)
32
+
33
+ class MLP(nn.Module):
34
+ def __init__(self, n_embd, dropout=0.1):
35
+ super().__init__()
36
+ self.fc1 = nn.Linear(n_embd, 4 * n_embd)
37
+ self.fc2 = nn.Linear(4 * n_embd, n_embd)
38
+ self.dropout = nn.Dropout(dropout)
39
+
40
+ def forward(self, x):
41
+ x = F.gelu(self.fc1(x))
42
+ x = self.dropout(x)
43
+ x = self.fc2(x)
44
+ return self.dropout(x)
45
+
46
+ class TransformerBlock(nn.Module):
47
+ def __init__(self, n_embd, n_head, dropout=0.1):
48
+ super().__init__()
49
+ self.attention = CausalSelfAttention(n_embd, n_head, dropout)
50
+ self.mlp = MLP(n_embd, dropout)
51
+ self.ln1 = nn.LayerNorm(n_embd)
52
+ self.ln2 = nn.LayerNorm(n_embd)
53
+
54
+ def forward(self, x):
55
+ x = x + self.attention(self.ln1(x))
56
+ x = x + self.mlp(self.ln2(x))
57
+ return x
58
+
59
+ class GPTModel(nn.Module):
60
+ def __init__(self, vocab_size, n_embd, n_head, n_layer, chunk_size, dropout=0.1):
61
+ super().__init__()
62
+ self.vocab_size = vocab_size
63
+ self.n_embd = n_embd
64
+ self.chunk_size = chunk_size
65
+ self.token_embeddings = nn.Embedding(vocab_size, n_embd)
66
+ self.position_embeddings = nn.Embedding(chunk_size, n_embd)
67
+ self.dropout = nn.Dropout(dropout)
68
+ self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head, dropout) for _ in range(n_layer)])
69
+ self.ln_f = nn.LayerNorm(n_embd)
70
+ self.output_projection = nn.Linear(n_embd, vocab_size, bias=False)
71
+
72
+ def forward(self, input_tokens):
73
+ B, T = input_tokens.shape
74
+ assert T <= self.chunk_size, f"Input length {T} > chunk_size {self.chunk_size}"
75
+ tok = self.token_embeddings(input_tokens)
76
+ pos = self.position_embeddings(torch.arange(T, device=input_tokens.device))
77
+ x = self.dropout(tok + pos)
78
+ for block in self.blocks:
79
+ x = block(x)
80
+ x = self.ln_f(x)
81
+ logits = self.output_projection(x)
82
+ return logits
83
+
84
+ def generate(self, context_ids, max_tokens, temperature=0.7, top_k=50):
85
+ self.eval()
86
+ generated = list(context_ids)
87
+ device = next(self.parameters()).device
88
+ with torch.no_grad():
89
+ for _ in range(max_tokens):
90
+ inp = torch.tensor(generated[-self.chunk_size:], dtype=torch.long, device=device).unsqueeze(0)
91
+ logits = self.forward(inp)[0, -1, :]
92
+ if temperature and temperature > 0:
93
+ logits = logits / temperature
94
+ if top_k and top_k > 0:
95
+ tk_vals, tk_idx = torch.topk(logits, min(top_k, logits.size(-1)))
96
+ filtered = torch.full_like(logits, float('-inf'))
97
+ filtered[tk_idx] = tk_vals
98
+ logits = filtered
99
+ probs = torch.softmax(logits, dim=-1)
100
+ next_id = torch.multinomial(probs, 1).item()
101
+ generated.append(next_id)
102
+ return generated