Spaces:
Running
Running
Create gpt_min.py
Browse files- gpt_min.py +102 -0
gpt_min.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class CausalSelfAttention(nn.Module):
|
7 |
+
def __init__(self, n_embd, n_head, dropout=0.1):
|
8 |
+
super().__init__()
|
9 |
+
assert n_embd % n_head == 0
|
10 |
+
self.n_embd = n_embd
|
11 |
+
self.n_head = n_head
|
12 |
+
self.head_dim = n_embd // n_head
|
13 |
+
self.query = nn.Linear(n_embd, n_embd, bias=False)
|
14 |
+
self.key = nn.Linear(n_embd, n_embd, bias=False)
|
15 |
+
self.value = nn.Linear(n_embd, n_embd, bias=False)
|
16 |
+
self.output = nn.Linear(n_embd, n_embd)
|
17 |
+
self.dropout = nn.Dropout(dropout)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
B, T, C = x.shape
|
21 |
+
q = self.query(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
22 |
+
k = self.key(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
23 |
+
v = self.value(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
24 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
25 |
+
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
|
26 |
+
scores = scores.masked_fill(mask, float('-inf'))
|
27 |
+
attn = F.softmax(scores, dim=-1)
|
28 |
+
attn = self.dropout(attn)
|
29 |
+
out = torch.matmul(attn, v)
|
30 |
+
out = out.transpose(1, 2).contiguous().view(B, T, C)
|
31 |
+
return self.output(out)
|
32 |
+
|
33 |
+
class MLP(nn.Module):
|
34 |
+
def __init__(self, n_embd, dropout=0.1):
|
35 |
+
super().__init__()
|
36 |
+
self.fc1 = nn.Linear(n_embd, 4 * n_embd)
|
37 |
+
self.fc2 = nn.Linear(4 * n_embd, n_embd)
|
38 |
+
self.dropout = nn.Dropout(dropout)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
x = F.gelu(self.fc1(x))
|
42 |
+
x = self.dropout(x)
|
43 |
+
x = self.fc2(x)
|
44 |
+
return self.dropout(x)
|
45 |
+
|
46 |
+
class TransformerBlock(nn.Module):
|
47 |
+
def __init__(self, n_embd, n_head, dropout=0.1):
|
48 |
+
super().__init__()
|
49 |
+
self.attention = CausalSelfAttention(n_embd, n_head, dropout)
|
50 |
+
self.mlp = MLP(n_embd, dropout)
|
51 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
52 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = x + self.attention(self.ln1(x))
|
56 |
+
x = x + self.mlp(self.ln2(x))
|
57 |
+
return x
|
58 |
+
|
59 |
+
class GPTModel(nn.Module):
|
60 |
+
def __init__(self, vocab_size, n_embd, n_head, n_layer, chunk_size, dropout=0.1):
|
61 |
+
super().__init__()
|
62 |
+
self.vocab_size = vocab_size
|
63 |
+
self.n_embd = n_embd
|
64 |
+
self.chunk_size = chunk_size
|
65 |
+
self.token_embeddings = nn.Embedding(vocab_size, n_embd)
|
66 |
+
self.position_embeddings = nn.Embedding(chunk_size, n_embd)
|
67 |
+
self.dropout = nn.Dropout(dropout)
|
68 |
+
self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head, dropout) for _ in range(n_layer)])
|
69 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
70 |
+
self.output_projection = nn.Linear(n_embd, vocab_size, bias=False)
|
71 |
+
|
72 |
+
def forward(self, input_tokens):
|
73 |
+
B, T = input_tokens.shape
|
74 |
+
assert T <= self.chunk_size, f"Input length {T} > chunk_size {self.chunk_size}"
|
75 |
+
tok = self.token_embeddings(input_tokens)
|
76 |
+
pos = self.position_embeddings(torch.arange(T, device=input_tokens.device))
|
77 |
+
x = self.dropout(tok + pos)
|
78 |
+
for block in self.blocks:
|
79 |
+
x = block(x)
|
80 |
+
x = self.ln_f(x)
|
81 |
+
logits = self.output_projection(x)
|
82 |
+
return logits
|
83 |
+
|
84 |
+
def generate(self, context_ids, max_tokens, temperature=0.7, top_k=50):
|
85 |
+
self.eval()
|
86 |
+
generated = list(context_ids)
|
87 |
+
device = next(self.parameters()).device
|
88 |
+
with torch.no_grad():
|
89 |
+
for _ in range(max_tokens):
|
90 |
+
inp = torch.tensor(generated[-self.chunk_size:], dtype=torch.long, device=device).unsqueeze(0)
|
91 |
+
logits = self.forward(inp)[0, -1, :]
|
92 |
+
if temperature and temperature > 0:
|
93 |
+
logits = logits / temperature
|
94 |
+
if top_k and top_k > 0:
|
95 |
+
tk_vals, tk_idx = torch.topk(logits, min(top_k, logits.size(-1)))
|
96 |
+
filtered = torch.full_like(logits, float('-inf'))
|
97 |
+
filtered[tk_idx] = tk_vals
|
98 |
+
logits = filtered
|
99 |
+
probs = torch.softmax(logits, dim=-1)
|
100 |
+
next_id = torch.multinomial(probs, 1).item()
|
101 |
+
generated.append(next_id)
|
102 |
+
return generated
|