dixisouls commited on
Commit
27b9282
·
1 Parent(s): a12c1bc

Initial Commit

Browse files
Files changed (7) hide show
  1. .gitignore +184 -0
  2. README.md +100 -6
  3. app.py +635 -0
  4. requirements.txt +14 -0
  5. src/inference/inference.py +231 -0
  6. src/model/layers.py +86 -0
  7. src/model/transformer.py +281 -0
.gitignore ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ *.py,cover
49
+ .hypothesis/
50
+ .pytest_cache/
51
+
52
+ # Translations
53
+ *.mo
54
+ *.pot
55
+
56
+ # Django stuff:
57
+ *.log
58
+ local_settings.py
59
+ db.sqlite3
60
+ db.sqlite3-journal
61
+
62
+ # Flask stuff:
63
+ instance/
64
+ .webassets-cache
65
+
66
+ # Scrapy stuff:
67
+ .scrapy
68
+
69
+ # Sphinx documentation
70
+ docs/_build/
71
+
72
+ # PyBuilder
73
+ target/
74
+
75
+ # Jupyter Notebook
76
+ .ipynb_checkpoints
77
+
78
+ # IPython
79
+ profile_default/
80
+ ipython_config.py
81
+
82
+ # pyenv
83
+ .python-version
84
+
85
+ # pipenv
86
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
87
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
88
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
89
+ # install all needed dependencies.
90
+ #Pipfile.lock
91
+
92
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
93
+ __pypackages__/
94
+
95
+ # Celery stuff
96
+ celerybeat-schedule
97
+ celerybeat.pid
98
+
99
+ # SageMath parsed files
100
+ *.sage.py
101
+
102
+ # Environments
103
+ .env
104
+ .venv
105
+ env/
106
+ venv/
107
+ ENV/
108
+ env.bak/
109
+ venv.bak/
110
+
111
+ # Spyder project settings
112
+ .spyderproject
113
+ .spyproject
114
+
115
+ # Rope project settings
116
+ .ropeproject
117
+
118
+ # mkdocs documentation
119
+ /site
120
+
121
+ # mypy
122
+ .mypy_cache/
123
+ .dmypy.json
124
+ dmypy.json
125
+
126
+ # Pyre type checker
127
+ .pyre/
128
+
129
+ # ML/AI specific files
130
+ *.pkl
131
+ *.pickle
132
+ *.h5
133
+ *.hdf5
134
+ *.ckpt
135
+ *.pt
136
+ *.pth
137
+ *.safetensors
138
+
139
+ # Model training artifacts
140
+ wandb/
141
+ runs/
142
+ logs/
143
+ tensorboard/
144
+
145
+ # Data directories
146
+ data/
147
+ datasets/
148
+
149
+ # Model checkpoints (excluding checkpoint directory as user mentioned git LFS)
150
+ # User tracks model files in checkpoints/ with git LFS, so we won't ignore it
151
+
152
+ # Temporary files
153
+ *.tmp
154
+ *.temp
155
+ .DS_Store
156
+ Thumbs.db
157
+
158
+ # IDE files
159
+ .vscode/
160
+ .idea/
161
+ *.swp
162
+ *.swo
163
+ *~
164
+
165
+ # OS generated files
166
+ .DS_Store
167
+ .DS_Store?
168
+ ._*
169
+ .Spotlight-V100
170
+ .Trashes
171
+ ehthumbs.db
172
+ Thumbs.db
173
+
174
+ # Gradio temporary files
175
+ gradio_cached_examples/
176
+ flagged/
177
+
178
+ # HuggingFace cache
179
+ .cache/
180
+ cache/
181
+
182
+ # Local configuration files
183
+ config.local.*
184
+ .secrets
README.md CHANGED
@@ -1,14 +1,108 @@
1
  ---
2
- title: VelocityLM
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.43.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: FoundationalLM for fast text-generation
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Custom LLM - Foundational Language Model
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.43.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ models:
12
+ - gpt2
13
+ datasets:
14
+ - tiiuae/falcon-refinedweb
15
+ tags:
16
+ - text-generation
17
+ - transformer
18
+ - pytorch
19
+ - custom-model
20
+ - llm
21
+ - foundational-model
22
+ short_description: A custom 2B parameter foundational language model with streaming generation
23
  ---
24
 
25
+ # 🤖 Custom LLM - Foundational Language Model
26
+
27
+ A custom-trained foundational language model with **2 billion parameters**, built with modern transformer architecture and deployed with streaming text generation capabilities.
28
+
29
+ ## 🚀 Features
30
+
31
+ - **Custom Architecture**: Modern transformer with RoPE (Rotary Position Embedding), RMSNorm, and SwiGLU activation
32
+ - **Streaming Generation**: Real-time text generation with token-by-token streaming
33
+ - **Flexible Sampling**: Configurable temperature, top-p, top-k, and repetition penalty
34
+ - **ZeroGPU Integration**: Optimized for Hugging Face Spaces with GPU acceleration
35
+ - **Responsive UI**: Clean, intuitive Gradio interface
36
+
37
+ ## 📊 Model Details
38
+
39
+ | Specification | Value |
40
+ |---------------|-------|
41
+ | **Parameters** | ~2 billion |
42
+ | **Architecture** | Custom Transformer |
43
+ | **Context Length** | 2,048 tokens |
44
+ | **Vocab Size** | 50,257 (GPT-2 tokenizer) |
45
+ | **Layers** | 24 |
46
+ | **Attention Heads** | 32 |
47
+ | **Hidden Size** | 2,048 |
48
+ | **Intermediate Size** | 8,192 |
49
+
50
+ ## 🏗️ Architecture Components
51
+
52
+ - **RMSNorm**: Root Mean Square Layer Normalization for better training stability
53
+ - **RoPE**: Rotary Position Embeddings for better length extrapolation
54
+ - **SwiGLU**: Switch GLU activation function for improved performance
55
+ - **Causal Attention**: Standard autoregressive attention mechanism
56
+
57
+ ## 🎯 Training Details
58
+
59
+ - **Dataset**: Falcon RefinedWeb (curated web text)
60
+ - **Training Steps**: 100,000 steps
61
+ - **Learning Rate**: 6e-4 with warmup and decay
62
+ - **Batch Size**: 32 (4 per device × 8 accumulation steps)
63
+ - **Optimization**: AdamW with β1=0.9, β2=0.95
64
+ - **Precision**: Mixed precision (FP16)
65
+
66
+ ## 🛠️ Generation Parameters
67
+
68
+ - **Max Tokens**: Control the length of generated text (1-1024)
69
+ - **Temperature**: Sampling randomness (0.1-2.0, higher = more creative)
70
+ - **Top-p**: Nucleus sampling threshold (0.1-1.0)
71
+ - **Top-k**: Top-k sampling limit (0-200, 0 = disabled)
72
+ - **Repetition Penalty**: Reduce repetitive text (1.0-2.0)
73
+
74
+ ## 💡 Usage Tips
75
+
76
+ 1. **For Creative Writing**: Use higher temperature (1.0-1.5) and top-p (0.9-0.95)
77
+ 2. **For Factual Content**: Use lower temperature (0.3-0.7) and top-p (0.8-0.9)
78
+ 3. **For Code Generation**: Use temperature ~0.2 with top-k filtering
79
+ 4. **Longer Context**: The model handles up to 2,048 tokens of context
80
+
81
+ ## 🚨 Limitations
82
+
83
+ - **Knowledge Cutoff**: Training data knowledge cutoff varies by source
84
+ - **Biases**: May reflect biases present in training data
85
+ - **Factuality**: Generated content should be verified for factual accuracy
86
+ - **Context Window**: Limited to 2,048 tokens (approximately 1,500 words)
87
+
88
+ ## 🔧 Technical Implementation
89
+
90
+ The model uses a custom PyTorch implementation with:
91
+ - Efficient attention mechanisms
92
+ - Memory-optimized layer implementations
93
+ - Streaming generation with proper token handling
94
+ - GPU acceleration via ZeroGPU
95
+
96
+ ## 📝 License
97
+
98
+ This project is licensed under the MIT License - see the LICENSE file for details.
99
+
100
+ ## 🙏 Acknowledgments
101
+
102
+ - Hugging Face for the Spaces platform and ZeroGPU infrastructure
103
+ - The open-source community for transformer implementations and best practices
104
+ - TII UAE for the Falcon RefinedWeb dataset
105
+
106
+ ---
107
+
108
+ **Note**: This is a foundational language model trained for research and educational purposes. Please use responsibly and be aware of potential biases and limitations.
app.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio app for the custom LLM with streaming support and ZeroGPU integration."""
2
+
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from typing import Iterator, Optional, Union, List
7
+ from transformers import AutoTokenizer
8
+ import json
9
+ import warnings
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # Add src to path
14
+ sys.path.append(str(Path(__file__).parent))
15
+
16
+ warnings.filterwarnings("ignore")
17
+
18
+ try:
19
+ import spaces
20
+ HAS_SPACES = True
21
+ except ImportError:
22
+ HAS_SPACES = False
23
+ # Mock decorator for local testing
24
+ def spaces_decorator(gpu_memory=None):
25
+ def decorator(func):
26
+ return func
27
+ return decorator
28
+ spaces = type('MockSpaces', (), {'GPU': spaces_decorator})
29
+
30
+ from src.model.transformer import TransformerForCausalLM
31
+
32
+
33
+ class StreamingTextGenerator:
34
+ """Streaming text generation for the custom LLM."""
35
+
36
+ def __init__(self, model, tokenizer, device='cuda'):
37
+ self.model = model
38
+ self.tokenizer = tokenizer
39
+ self.device = device
40
+ self.model.to(device)
41
+ self.model.eval()
42
+
43
+ def generate_stream(
44
+ self,
45
+ prompt: str,
46
+ max_new_tokens: int = 512,
47
+ temperature: float = 0.8,
48
+ top_p: float = 0.9,
49
+ top_k: Optional[int] = 50,
50
+ repetition_penalty: float = 1.1,
51
+ do_sample: bool = True,
52
+ ) -> Iterator[str]:
53
+ """Generate text with streaming output."""
54
+
55
+ # Tokenize prompt
56
+ inputs = self.tokenizer(
57
+ prompt,
58
+ return_tensors='pt',
59
+ padding=False,
60
+ truncation=True,
61
+ max_length=1024, # Leave room for generation
62
+ ).to(self.device)
63
+
64
+ input_ids = inputs['input_ids']
65
+ attention_mask = inputs['attention_mask']
66
+
67
+ # Initialize generated sequence
68
+ generated_ids = input_ids.clone()
69
+ generated_text = prompt
70
+
71
+ with torch.no_grad():
72
+ for step in range(max_new_tokens):
73
+ # Get model predictions
74
+ outputs = self.model(
75
+ input_ids=generated_ids,
76
+ attention_mask=attention_mask,
77
+ )
78
+
79
+ # Get logits for the last token
80
+ next_token_logits = outputs.logits[0, -1, :].clone()
81
+
82
+ # Apply repetition penalty
83
+ if repetition_penalty != 1.0:
84
+ for token_id in set(generated_ids[0].tolist()):
85
+ next_token_logits[token_id] /= repetition_penalty
86
+
87
+ # Apply temperature
88
+ if temperature > 0:
89
+ next_token_logits = next_token_logits / temperature
90
+
91
+ # Apply top-k filtering
92
+ if top_k is not None and top_k > 0:
93
+ top_k_logits, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
94
+ min_top_k = top_k_logits[-1]
95
+ next_token_logits = torch.where(
96
+ next_token_logits < min_top_k,
97
+ torch.full_like(next_token_logits, float('-inf')),
98
+ next_token_logits
99
+ )
100
+
101
+ # Apply top-p (nucleus) filtering
102
+ if top_p < 1.0:
103
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
104
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
105
+
106
+ # Remove tokens with cumulative probability above threshold
107
+ sorted_indices_to_remove = cumulative_probs > top_p
108
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
109
+ sorted_indices_to_remove[0] = False
110
+
111
+ indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
112
+ next_token_logits[indices_to_remove] = float('-inf')
113
+
114
+ # Sample next token
115
+ if do_sample and temperature > 0:
116
+ probs = F.softmax(next_token_logits, dim=-1)
117
+ next_token = torch.multinomial(probs, num_samples=1)
118
+ else:
119
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
120
+
121
+ # Check for EOS token
122
+ if next_token.item() == self.tokenizer.eos_token_id:
123
+ break
124
+
125
+ # Append to generated sequence
126
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=-1)
127
+
128
+ # Update attention mask
129
+ attention_mask = torch.cat([
130
+ attention_mask,
131
+ torch.ones((1, 1), device=self.device, dtype=attention_mask.dtype)
132
+ ], dim=-1)
133
+
134
+ # Decode and yield new token
135
+ new_text = self.tokenizer.decode(
136
+ generated_ids[0],
137
+ skip_special_tokens=True,
138
+ clean_up_tokenization_spaces=False
139
+ )
140
+
141
+ # Only yield the new part
142
+ if len(new_text) > len(generated_text):
143
+ generated_text = new_text
144
+ yield generated_text
145
+
146
+
147
+ def download_model_from_hf():
148
+ """Download model from HuggingFace repository."""
149
+ from huggingface_hub import hf_hub_download
150
+ import os
151
+
152
+ model_repo = "dixisouls/VelocityLM"
153
+ cache_dir = Path("model_cache")
154
+ cache_dir.mkdir(exist_ok=True)
155
+
156
+ print("📥 Downloading model from HuggingFace...")
157
+
158
+ # Download config.json
159
+ config_path = hf_hub_download(
160
+ repo_id=model_repo,
161
+ filename="config.json",
162
+ cache_dir=cache_dir,
163
+ local_files_only=False
164
+ )
165
+
166
+ # Download pytorch_model.bin
167
+ model_path = hf_hub_download(
168
+ repo_id=model_repo,
169
+ filename="pytorch_model.bin",
170
+ cache_dir=cache_dir,
171
+ local_files_only=False
172
+ )
173
+
174
+ print("✅ Model downloaded successfully!")
175
+ return config_path, model_path
176
+
177
+
178
+ def load_model_and_tokenizer():
179
+ """Load the trained model and tokenizer."""
180
+ import os
181
+
182
+ # Check if model exists locally, if not download from HF
183
+ cache_dir = Path("model_cache")
184
+ local_config = None
185
+ local_model = None
186
+
187
+ # Try to find cached files
188
+ if cache_dir.exists():
189
+ for root, dirs, files in os.walk(cache_dir):
190
+ if "config.json" in files:
191
+ local_config = Path(root) / "config.json"
192
+ if "pytorch_model.bin" in files:
193
+ local_model = Path(root) / "pytorch_model.bin"
194
+
195
+ # Download if not found locally
196
+ if not local_config or not local_model:
197
+ config_path, model_path = download_model_from_hf()
198
+ else:
199
+ config_path = str(local_config)
200
+ model_path = str(local_model)
201
+ print("📂 Using cached model files")
202
+
203
+ # Load config
204
+ with open(config_path, 'r') as f:
205
+ config = json.load(f)
206
+
207
+ # Create model config object
208
+ class ModelConfig:
209
+ def __init__(self, config_dict):
210
+ for key, value in config_dict.items():
211
+ setattr(self, key, value)
212
+
213
+ model_config = ModelConfig(config['model'])
214
+
215
+ # Load model
216
+ print("🔧 Initializing model...")
217
+ model = TransformerForCausalLM(model_config)
218
+
219
+ # Load state dict from pytorch_model.bin
220
+ print("📦 Loading model weights...")
221
+ model_state_dict = torch.load(
222
+ model_path,
223
+ map_location='cpu'
224
+ )
225
+
226
+ model.load_state_dict(model_state_dict, strict=False)
227
+ print("✅ Model weights loaded!")
228
+
229
+ # Load tokenizer
230
+ print("🔤 Loading tokenizer...")
231
+ tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name'])
232
+ if tokenizer.pad_token is None:
233
+ tokenizer.pad_token = tokenizer.eos_token
234
+
235
+ print("🎉 Model and tokenizer ready!")
236
+ return model, tokenizer
237
+
238
+
239
+ # Global variables for model and generator
240
+ model = None
241
+ tokenizer = None
242
+ generator = None
243
+
244
+ def initialize_model():
245
+ """Initialize model and tokenizer."""
246
+ global model, tokenizer, generator
247
+
248
+ if model is None:
249
+ print("Loading model and tokenizer...")
250
+ model, tokenizer = load_model_and_tokenizer()
251
+ device = "cuda" if torch.cuda.is_available() else "cpu"
252
+ generator = StreamingTextGenerator(model, tokenizer, device=device)
253
+ print(f"Model loaded on {device}")
254
+
255
+
256
+ @spaces.GPU(duration=120) if HAS_SPACES else lambda x: x
257
+ def generate_response(
258
+ prompt: str,
259
+ max_new_tokens: int = 512,
260
+ temperature: float = 0.8,
261
+ top_p: float = 0.9,
262
+ top_k: int = 50,
263
+ repetition_penalty: float = 1.1,
264
+ ) -> Iterator[str]:
265
+ """Generate streaming response."""
266
+
267
+ # Initialize model if needed
268
+ initialize_model()
269
+
270
+ if not prompt.strip():
271
+ yield "Please enter a prompt."
272
+ return
273
+
274
+ try:
275
+ # Generate with streaming
276
+ for partial_text in generator.generate_stream(
277
+ prompt=prompt,
278
+ max_new_tokens=max_new_tokens,
279
+ temperature=temperature,
280
+ top_p=top_p,
281
+ top_k=top_k if top_k > 0 else None,
282
+ repetition_penalty=repetition_penalty,
283
+ do_sample=temperature > 0,
284
+ ):
285
+ yield partial_text
286
+
287
+ except Exception as e:
288
+ yield f"Error generating text: {str(e)}"
289
+
290
+
291
+ # Create Gradio interface
292
+ def create_interface():
293
+ """Create the Gradio interface."""
294
+
295
+ # Custom CSS for enhanced UI
296
+ custom_css = """
297
+ .gradio-container {
298
+ max-width: 1200px !important;
299
+ margin: 0 auto !important;
300
+ }
301
+
302
+ .header-text {
303
+ text-align: center;
304
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%);
305
+ -webkit-background-clip: text;
306
+ -webkit-text-fill-color: transparent;
307
+ background-clip: text;
308
+ font-size: 2.5em !important;
309
+ font-weight: bold !important;
310
+ margin-bottom: 0.5em !important;
311
+ }
312
+
313
+ .subtitle-text {
314
+ text-align: center;
315
+ color: #666;
316
+ font-size: 1.2em !important;
317
+ margin-bottom: 2em !important;
318
+ }
319
+
320
+ .parameter-box {
321
+ background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%) !important;
322
+ border-radius: 15px !important;
323
+ padding: 20px !important;
324
+ border: 1px solid #4a5568 !important;
325
+ }
326
+
327
+ .parameter-box summary {
328
+ color: #ffffff !important;
329
+ font-weight: bold !important;
330
+ background: rgba(255, 255, 255, 0.1) !important;
331
+ padding: 10px !important;
332
+ border-radius: 10px !important;
333
+ }
334
+
335
+ .parameter-box details summary {
336
+ color: #ffffff !important;
337
+ font-weight: bold !important;
338
+ }
339
+
340
+ /* Make ALL text white in the parameter box */
341
+ .parameter-box,
342
+ .parameter-box *,
343
+ .parameter-box label,
344
+ .parameter-box span,
345
+ .parameter-box p,
346
+ .parameter-box div,
347
+ .parameter-box small {
348
+ color: #ffffff !important;
349
+ }
350
+
351
+ /* Ensure input values are also white */
352
+ .parameter-box input[type="number"],
353
+ .parameter-box .gr-textbox input {
354
+ color: #ffffff !important;
355
+ background: rgba(255, 255, 255, 0.1) !important;
356
+ border: 1px solid #4a5568 !important;
357
+ }
358
+
359
+ /* Make the centered description text white too */
360
+ .parameter-box > p {
361
+ color: #ffffff !important;
362
+ text-align: center !important;
363
+ }
364
+
365
+ .output-box {
366
+ border-radius: 15px !important;
367
+ border: 1px solid #e1e5e9 !important;
368
+ }
369
+
370
+ .generate-btn {
371
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
372
+ border: none !important;
373
+ color: white !important;
374
+ font-weight: bold !important;
375
+ font-size: 1.1em !important;
376
+ padding: 15px 30px !important;
377
+ border-radius: 25px !important;
378
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
379
+ transition: all 0.3s ease !important;
380
+ }
381
+
382
+ .generate-btn:hover {
383
+ transform: translateY(-2px) !important;
384
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
385
+ }
386
+
387
+ .clear-btn {
388
+ background: linear-gradient(45deg, #ff6b6b 0%, #ee5a24 100%) !important;
389
+ border: none !important;
390
+ color: white !important;
391
+ font-weight: bold !important;
392
+ border-radius: 20px !important;
393
+ padding: 10px 20px !important;
394
+ box-shadow: 0 2px 10px rgba(255, 107, 107, 0.3) !important;
395
+ }
396
+
397
+ .info-box {
398
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%) !important;
399
+ border-radius: 15px !important;
400
+ padding: 20px !important;
401
+ border: 1px solid #f0c27b !important;
402
+ margin-top: 20px !important;
403
+ }
404
+
405
+ .example-box {
406
+ background: linear-gradient(135def, #e8f5e8 0%, #d4edda 100%) !important;
407
+ border-radius: 15px !important;
408
+ padding: 15px !important;
409
+ border: 1px solid #c3e6cb !important;
410
+ }
411
+
412
+ .metric-card {
413
+ background: white !important;
414
+ border-radius: 10px !important;
415
+ padding: 15px !important;
416
+ text-align: center !important;
417
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1) !important;
418
+ border-left: 4px solid #667eea !important;
419
+ }
420
+
421
+ .progress-bar {
422
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
423
+ }
424
+ """
425
+
426
+ with gr.Blocks(
427
+ title="VelocityLM - Fast Text Generation",
428
+ theme=gr.themes.Soft(
429
+ primary_hue="blue",
430
+ secondary_hue="purple",
431
+ neutral_hue="gray"
432
+ ),
433
+ css=custom_css
434
+ ) as demo:
435
+
436
+ # Header with gradient text
437
+ gr.HTML("""
438
+ <div style="text-align: center; margin-bottom: 2rem;">
439
+ <h1 class="header-text">VelocityLM</h1>
440
+ <p class="subtitle-text">Advanced 2B Parameter Foundational Language Model</p>
441
+ <div style="display: flex; justify-content: center; gap: 2rem; margin: 1.5rem 0;">
442
+ <div class="metric-card">
443
+ <h3 style="margin: 0; color: #667eea;">2B+</h3>
444
+ <p style="margin: 5px 0 0 0; color: #666; font-size: 0.9em;">Parameters</p>
445
+ </div>
446
+ <div class="metric-card">
447
+ <h3 style="margin: 0; color: #667eea;">2048</h3>
448
+ <p style="margin: 5px 0 0 0; color: #666; font-size: 0.9em;">Context Length</p>
449
+ </div>
450
+ </div>
451
+ </div>
452
+ """)
453
+
454
+ gr.Markdown(
455
+ """
456
+ <div style="text-align: center; background: linear-gradient(135deg, #f8f9ff 0%, #e8f0ff 100%);
457
+ padding: 20px; border-radius: 15px; margin-bottom: 2rem; border: 1px solid #e1e8f7;">
458
+ <p style="margin: 0; font-size: 1.1em; color: #4a5568;">
459
+ 🎯 <strong>Modern Architecture:</strong> RoPE • RMSNorm • SwiGLU • Multi-Head Attention<br>
460
+ ✨ <strong>Features:</strong> Text Generation • Configurable Sampling • GPU Accelerated
461
+ </p>
462
+ </div>
463
+ """,
464
+ elem_classes=["info-box"]
465
+ )
466
+
467
+ with gr.Row(equal_height=True):
468
+ # Input Column
469
+ with gr.Column(scale=2, min_width=400):
470
+ gr.HTML("<div style='margin-bottom: 1rem;'><h3 style='color: #667eea; margin: 0;'>💬 Input Prompt</h3></div>")
471
+
472
+ prompt_input = gr.Textbox(
473
+ lines=6,
474
+ placeholder="✨ Enter your creative prompt here...\n\nExample: Write a story about a future where AI and humans collaborate to solve climate change...",
475
+ label="Your Prompt",
476
+ show_copy_button=True,
477
+ container=True,
478
+ elem_classes=["input-box"]
479
+ )
480
+
481
+ # Advanced Parameters Section
482
+ with gr.Accordion("🎛️ Advanced Generation Parameters", open=False, elem_classes=["parameter-box"]):
483
+ gr.HTML("<p style='text-align: center; color: #333; margin-bottom: 1rem;'>Fine-tune your generation settings</p>")
484
+
485
+ with gr.Row():
486
+ max_new_tokens = gr.Slider(
487
+ minimum=1,
488
+ maximum=1024,
489
+ value=512,
490
+ step=1,
491
+ label="🔢 Max New Tokens",
492
+ info="Maximum number of tokens to generate"
493
+ )
494
+ temperature = gr.Slider(
495
+ minimum=0.1,
496
+ maximum=2.0,
497
+ value=0.8,
498
+ step=0.1,
499
+ label="🌡️ Temperature",
500
+ info="Higher = more creative, lower = more focused"
501
+ )
502
+
503
+ with gr.Row():
504
+ top_p = gr.Slider(
505
+ minimum=0.1,
506
+ maximum=1.0,
507
+ value=0.9,
508
+ step=0.05,
509
+ label="🎯 Top-p",
510
+ info="Nucleus sampling threshold"
511
+ )
512
+ top_k = gr.Slider(
513
+ minimum=0,
514
+ maximum=200,
515
+ value=50,
516
+ step=5,
517
+ label="📊 Top-k",
518
+ info="Top-k sampling limit (0 = disabled)"
519
+ )
520
+
521
+ repetition_penalty = gr.Slider(
522
+ minimum=1.0,
523
+ maximum=2.0,
524
+ value=1.1,
525
+ step=0.05,
526
+ label="🔄 Repetition Penalty",
527
+ info="Reduce repetitive text (higher = less repetition)"
528
+ )
529
+
530
+ # Generate Button with enhanced styling
531
+ gr.HTML("<div style='margin: 1.5rem 0;'>")
532
+ generate_btn = gr.Button(
533
+ "🚀 Generate Text",
534
+ variant="primary",
535
+ size="lg",
536
+ elem_classes=["generate-btn"],
537
+ scale=1
538
+ )
539
+ gr.HTML("</div>")
540
+
541
+ # Quick Settings Presets
542
+ gr.HTML("<div style='margin-top: 1rem;'><h4 style='color: #667eea; margin-bottom: 0.5rem;'>⚡ Quick Presets</h4></div>")
543
+ with gr.Row():
544
+ creative_btn = gr.Button("🎨 Creative", size="sm", variant="secondary")
545
+ balanced_btn = gr.Button("⚖️ Balanced", size="sm", variant="secondary")
546
+ precise_btn = gr.Button("🎯 Precise", size="sm", variant="secondary")
547
+
548
+ # Output Column
549
+ with gr.Column(scale=3, min_width=500):
550
+ gr.HTML("<div style='margin-bottom: 1rem; display: flex; justify-content: space-between; align-items: center;'><h3 style='color: #667eea; margin: 0;'>📝 Generated Output</h3></div>")
551
+
552
+ output_text = gr.Textbox(
553
+ lines=22,
554
+ label="Generated Text",
555
+ show_copy_button=True,
556
+ interactive=False,
557
+ placeholder="Your generated text will appear here...\n\n✨ Streaming in real-time\n🚀 Powered by custom 2B parameter model",
558
+ elem_classes=["output-box"],
559
+ container=True
560
+ )
561
+
562
+ # Action buttons
563
+ with gr.Row():
564
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary", elem_classes=["clear-btn"])
565
+
566
+ # Enhanced Examples Section
567
+ gr.HTML("<div style='margin: 2rem 0;'><h3 style='color: #667eea; text-align: center; margin-bottom: 1rem;'>🎯 Example Prompts</h3></div>")
568
+
569
+ with gr.Accordion("📚 Prompt Examples", open=True, elem_classes=["example-box"]):
570
+ gr.Examples(
571
+ examples=[
572
+ ["Once upon a time in a distant galaxy, there lived a civilization that had never seen the stars."],
573
+ ["The old lighthouse keeper noticed something strange about the fog that night."],
574
+ ["In the depths of the Amazon rainforest, Dr. Martinez made a discovery that would change everything."],
575
+ ["The last bookstore on Earth was about to close its doors forever when"],
576
+ ["As the spaceship approached the mysterious planet, the crew realized"],
577
+ ["The clockmaker's shop had been abandoned for fifty years, but every morning at precisely 9 AM"],
578
+ ["Deep beneath the city, in tunnels forgotten by time, archaeologist Elena found"],
579
+ ["The message in a bottle had traveled across three oceans before washing ashore"],
580
+ ],
581
+ inputs=[prompt_input],
582
+ label="Click any example to get started!",
583
+ examples_per_page=4
584
+ )
585
+
586
+ # Event handlers for main functionality
587
+ generate_btn.click(
588
+ fn=generate_response,
589
+ inputs=[
590
+ prompt_input,
591
+ max_new_tokens,
592
+ temperature,
593
+ top_p,
594
+ top_k,
595
+ repetition_penalty,
596
+ ],
597
+ outputs=[output_text],
598
+ show_progress=True,
599
+ )
600
+
601
+ # Preset button handlers
602
+ creative_btn.click(
603
+ fn=lambda: (1.2, 0.95, 40, 1.05),
604
+ outputs=[temperature, top_p, top_k, repetition_penalty]
605
+ )
606
+
607
+ balanced_btn.click(
608
+ fn=lambda: (0.8, 0.9, 50, 1.1),
609
+ outputs=[temperature, top_p, top_k, repetition_penalty]
610
+ )
611
+
612
+ precise_btn.click(
613
+ fn=lambda: (0.3, 0.8, 20, 1.2),
614
+ outputs=[temperature, top_p, top_k, repetition_penalty]
615
+ )
616
+
617
+ # Utility button handlers
618
+ clear_btn.click(
619
+ fn=lambda: ("", ""),
620
+ outputs=[prompt_input, output_text]
621
+ )
622
+
623
+
624
+ return demo
625
+
626
+
627
+ if __name__ == "__main__":
628
+ # Initialize for local testing
629
+ demo = create_interface()
630
+ demo.launch(
631
+ server_name="127.0.0.1",
632
+ server_port=7860,
633
+ share=False,
634
+ debug=False,
635
+ )
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio app requirements for HuggingFace Spaces
2
+ gradio==4.44.0
3
+ spaces==0.29.4
4
+
5
+ # Core ML dependencies
6
+ torch==2.2.0
7
+ transformers==4.36.0
8
+ tokenizers==0.15.0
9
+
10
+ # Numerical computing
11
+ numpy==1.26.4
12
+
13
+ # Utilities
14
+ tqdm==4.66.1
src/inference/inference.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text generation utilities for the trained model."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import List, Optional, Union
6
+ from transformers import AutoTokenizer
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class TextGenerator:
13
+ """Text generation with various decoding strategies."""
14
+
15
+ def __init__(self, model, tokenizer, device='cuda'):
16
+ self.model = model
17
+ self.tokenizer = tokenizer
18
+ self.device = device
19
+ self.model.to(device)
20
+ self.model.eval()
21
+
22
+ @torch.no_grad()
23
+ def generate(
24
+ self,
25
+ prompt: Union[str, List[str]],
26
+ max_length: int = 100,
27
+ temperature: float = 1.0,
28
+ top_k: Optional[int] = 50,
29
+ top_p: Optional[float] = 0.9,
30
+ num_return_sequences: int = 1,
31
+ do_sample: bool = True,
32
+ repetition_penalty: float = 1.0,
33
+ ) -> List[str]:
34
+ """Generate text from prompt(s)."""
35
+
36
+ # Handle single string input
37
+ if isinstance(prompt, str):
38
+ prompts = [prompt]
39
+ else:
40
+ prompts = prompt
41
+
42
+ # Tokenize prompts
43
+ inputs = self.tokenizer(
44
+ prompts,
45
+ return_tensors='pt',
46
+ padding=True,
47
+ truncation=True,
48
+ max_length=max_length,
49
+ ).to(self.device)
50
+
51
+ input_ids = inputs['input_ids']
52
+ attention_mask = inputs['attention_mask']
53
+
54
+ # Generate
55
+ batch_size = input_ids.shape[0]
56
+ generated_ids = input_ids.clone()
57
+
58
+ for _ in range(max_length - input_ids.shape[1]):
59
+ # Get model predictions
60
+ outputs = self.model(
61
+ input_ids=generated_ids,
62
+ attention_mask=attention_mask,
63
+ )
64
+
65
+ # Get logits for the last token
66
+ next_token_logits = outputs.logits[:, -1, :]
67
+
68
+ # Apply repetition penalty
69
+ if repetition_penalty != 1.0:
70
+ for i in range(batch_size):
71
+ for token_id in set(generated_ids[i].tolist()):
72
+ next_token_logits[i, token_id] /= repetition_penalty
73
+
74
+ # Apply temperature
75
+ if temperature != 1.0:
76
+ next_token_logits = next_token_logits / temperature
77
+
78
+ # Apply top-k filtering
79
+ if top_k is not None:
80
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
81
+ next_token_logits[indices_to_remove] = float('-inf')
82
+
83
+ # Apply top-p (nucleus) filtering
84
+ if top_p is not None:
85
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
86
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
87
+
88
+ # Remove tokens with cumulative probability above the threshold
89
+ sorted_indices_to_remove = cumulative_probs > top_p
90
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
91
+ sorted_indices_to_remove[..., 0] = 0
92
+
93
+ indices_to_remove = sorted_indices_to_remove.scatter(
94
+ 1, sorted_indices, sorted_indices_to_remove
95
+ )
96
+ next_token_logits[indices_to_remove] = float('-inf')
97
+
98
+ # Sample from the distribution
99
+ if do_sample:
100
+ probs = F.softmax(next_token_logits, dim=-1)
101
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
102
+ else:
103
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
104
+
105
+ # Append to generated sequence
106
+ generated_ids = torch.cat([generated_ids, next_tokens.unsqueeze(1)], dim=1)
107
+
108
+ # Update attention mask
109
+ attention_mask = torch.cat([
110
+ attention_mask,
111
+ torch.ones((batch_size, 1), device=self.device)
112
+ ], dim=1)
113
+
114
+ # Check for EOS token
115
+ if (next_tokens == self.tokenizer.eos_token_id).all():
116
+ break
117
+
118
+ # Decode generated sequences
119
+ generated_texts = []
120
+ for i in range(batch_size):
121
+ generated_text = self.tokenizer.decode(
122
+ generated_ids[i],
123
+ skip_special_tokens=True,
124
+ clean_up_tokenization_spaces=True
125
+ )
126
+ generated_texts.append(generated_text)
127
+
128
+ return generated_texts
129
+
130
+ def beam_search(
131
+ self,
132
+ prompt: str,
133
+ max_length: int = 100,
134
+ num_beams: int = 4,
135
+ length_penalty: float = 1.0,
136
+ early_stopping: bool = True,
137
+ ) -> str:
138
+ """Generate text using beam search."""
139
+ # Implementation of beam search
140
+ # This is a simplified version - full implementation would be more complex
141
+
142
+ inputs = self.tokenizer(
143
+ prompt,
144
+ return_tensors='pt',
145
+ truncation=True,
146
+ max_length=max_length,
147
+ ).to(self.device)
148
+
149
+ # For now, fallback to greedy decoding
150
+ return self.generate(
151
+ prompt,
152
+ max_length=max_length,
153
+ do_sample=False,
154
+ num_return_sequences=1
155
+ )[0]
156
+
157
+
158
+ def load_generator(checkpoint_path: str, device: str = 'cuda'):
159
+ """Load model and create generator."""
160
+ import yaml
161
+ from pathlib import Path
162
+ import sys
163
+ sys.path.append(str(Path(__file__).parent.parent.parent))
164
+
165
+ from src.model.transformer import TransformerForCausalLM
166
+
167
+ # Load config
168
+ config_path = Path(checkpoint_path) / 'config.json'
169
+ with open(config_path, 'r') as f:
170
+ import json
171
+ config = json.load(f)
172
+
173
+ # Create model config
174
+ class ModelConfig:
175
+ def __init__(self, config_dict):
176
+ for key, value in config_dict.items():
177
+ setattr(self, key, value)
178
+
179
+ model_config = ModelConfig(config['model'])
180
+
181
+ # Load model
182
+ model = TransformerForCausalLM(model_config)
183
+ state_dict = torch.load(
184
+ Path(checkpoint_path) / 'pytorch_model.bin',
185
+ map_location=device
186
+ )
187
+ model.load_state_dict(state_dict)
188
+
189
+ # Load tokenizer
190
+ tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name'])
191
+ if tokenizer.pad_token is None:
192
+ tokenizer.pad_token = tokenizer.eos_token
193
+
194
+ # Create generator
195
+ generator = TextGenerator(model, tokenizer, device)
196
+
197
+ return generator
198
+
199
+
200
+ if __name__ == '__main__':
201
+ """Example usage."""
202
+ import argparse
203
+
204
+ parser = argparse.ArgumentParser()
205
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
206
+ parser.add_argument('--prompt', type=str, required=True, help='Input prompt')
207
+ parser.add_argument('--max-length', type=int, default=100, help='Maximum generation length')
208
+ parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature')
209
+ parser.add_argument('--top-k', type=int, default=50, help='Top-k filtering')
210
+ parser.add_argument('--top-p', type=float, default=0.9, help='Top-p (nucleus) filtering')
211
+ parser.add_argument('--device', type=str, default='cuda', help='Device to use')
212
+
213
+ args = parser.parse_args()
214
+
215
+ # Load generator
216
+ print("Loading model...")
217
+ generator = load_generator(args.checkpoint, args.device)
218
+
219
+ # Generate text
220
+ print(f"Prompt: {args.prompt}")
221
+ print("Generating...")
222
+
223
+ generated = generator.generate(
224
+ args.prompt,
225
+ max_length=args.max_length,
226
+ temperature=args.temperature,
227
+ top_k=args.top_k,
228
+ top_p=args.top_p,
229
+ )
230
+
231
+ print(f"Generated: {generated[0]}")
src/model/layers.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom layers for the transformer model."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Tuple
7
+
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ class RMSNorm(nn.Module):
12
+ """Root Mean Square Layer Normalization."""
13
+
14
+ def __init__(self, hidden_size, eps=1e-6):
15
+ super().__init__()
16
+ self.weight = nn.Parameter(torch.ones(hidden_size))
17
+ self.eps = eps
18
+
19
+ def forward(self, hidden_states):
20
+ input_dtype = hidden_states.dtype
21
+ hidden_states = hidden_states.to(torch.float32)
22
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
23
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
24
+ return self.weight * hidden_states.to(input_dtype)
25
+
26
+
27
+ class RotaryEmbedding(nn.Module):
28
+ """Rotary Position Embedding."""
29
+
30
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
31
+ super().__init__()
32
+ self.dim = dim
33
+ self.max_position_embeddings = max_position_embeddings
34
+ self.base = base
35
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
36
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
37
+
38
+ # Build cached cos/sin
39
+ self._set_cos_sin_cache(
40
+ seq_len=max_position_embeddings,
41
+ device=self.inv_freq.device,
42
+ dtype=torch.get_default_dtype()
43
+ )
44
+
45
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
46
+ self.max_seq_len_cached = seq_len
47
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
48
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
49
+ emb = torch.cat((freqs, freqs), dim=-1)
50
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
51
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
52
+
53
+ def forward(self, x, seq_len=None):
54
+ if seq_len > self.max_seq_len_cached:
55
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
56
+ return (
57
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
58
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
59
+ )
60
+
61
+ @staticmethod
62
+ def rotate_half(x):
63
+ x1 = x[..., : x.shape[-1] // 2]
64
+ x2 = x[..., x.shape[-1] // 2 :]
65
+ return torch.cat((-x2, x1), dim=-1)
66
+
67
+ def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
68
+ cos = cos[position_ids].unsqueeze(1)
69
+ sin = sin[position_ids].unsqueeze(1)
70
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
71
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
72
+ return q_embed, k_embed
73
+
74
+
75
+ class SwiGLU(nn.Module):
76
+ """SwiGLU activation function."""
77
+
78
+ def __init__(self, hidden_size, intermediate_size, hidden_act="silu"):
79
+ super().__init__()
80
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
81
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
82
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
83
+ self.act_fn = F.silu if hidden_act == "silu" else F.gelu
84
+
85
+ def forward(self, x):
86
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
src/model/transformer.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """State-of-the-art Transformer model implementation."""
2
+
3
+ import math
4
+ from typing import Optional, Tuple
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import CrossEntropyLoss
9
+ from dataclasses import dataclass
10
+
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ from .layers import RMSNorm, RotaryEmbedding, SwiGLU
15
+
16
+
17
+ @dataclass
18
+ class ModelOutput:
19
+ """Model output container."""
20
+ loss: Optional[torch.Tensor] = None
21
+ logits: Optional[torch.Tensor] = None
22
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
23
+ attentions: Optional[Tuple[torch.Tensor]] = None
24
+
25
+
26
+ class CausalSelfAttention(nn.Module):
27
+ """Multi-head self-attention with causal mask and RoPE."""
28
+
29
+ def __init__(self, config):
30
+ super().__init__()
31
+ assert config.hidden_size % config.num_attention_heads == 0
32
+
33
+ self.num_attention_heads = config.num_attention_heads
34
+ self.head_dim = config.hidden_size // config.num_attention_heads
35
+ self.hidden_size = config.hidden_size
36
+
37
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
38
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
39
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
40
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
41
+
42
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
43
+ self.rotary_emb = RotaryEmbedding(
44
+ self.head_dim,
45
+ max_position_embeddings=config.max_position_embeddings,
46
+ base=config.rope_theta,
47
+ )
48
+
49
+ def forward(
50
+ self,
51
+ hidden_states: torch.Tensor,
52
+ attention_mask: Optional[torch.Tensor] = None,
53
+ position_ids: Optional[torch.Tensor] = None,
54
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
55
+ use_cache: bool = False,
56
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
57
+ bsz, q_len, _ = hidden_states.size()
58
+
59
+ q = self.q_proj(hidden_states)
60
+ k = self.k_proj(hidden_states)
61
+ v = self.v_proj(hidden_states)
62
+
63
+ q = q.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
64
+ k = k.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
65
+ v = v.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
66
+
67
+ # Apply rotary embeddings
68
+ cos, sin = self.rotary_emb(v, seq_len=q_len)
69
+ q, k = self.rotary_emb.apply_rotary_pos_emb(q, k, cos, sin, position_ids)
70
+
71
+ # Flash attention or standard attention
72
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
73
+
74
+ if attention_mask is not None:
75
+ attn_weights = attn_weights + attention_mask
76
+
77
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
78
+ attn_weights = self.attention_dropout(attn_weights)
79
+
80
+ attn_output = torch.matmul(attn_weights, v)
81
+ attn_output = attn_output.transpose(1, 2).contiguous()
82
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
83
+ attn_output = self.o_proj(attn_output)
84
+
85
+ return attn_output, None
86
+
87
+
88
+ class TransformerBlock(nn.Module):
89
+ """Transformer block with RMSNorm and SwiGLU."""
90
+
91
+ def __init__(self, config):
92
+ super().__init__()
93
+ self.hidden_size = config.hidden_size
94
+ self.self_attn = CausalSelfAttention(config)
95
+ self.mlp = SwiGLU(
96
+ hidden_size=config.hidden_size,
97
+ intermediate_size=config.intermediate_size,
98
+ hidden_act=config.hidden_act,
99
+ )
100
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
101
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
102
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout)
103
+
104
+ def forward(
105
+ self,
106
+ hidden_states: torch.Tensor,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ position_ids: Optional[torch.Tensor] = None,
109
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
110
+ use_cache: bool = False,
111
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
112
+ residual = hidden_states
113
+ hidden_states = self.input_layernorm(hidden_states)
114
+
115
+ # Self Attention
116
+ hidden_states, present_key_value = self.self_attn(
117
+ hidden_states=hidden_states,
118
+ attention_mask=attention_mask,
119
+ position_ids=position_ids,
120
+ past_key_value=past_key_value,
121
+ use_cache=use_cache,
122
+ )
123
+ hidden_states = self.hidden_dropout(hidden_states)
124
+ hidden_states = residual + hidden_states
125
+
126
+ # MLP
127
+ residual = hidden_states
128
+ hidden_states = self.post_attention_layernorm(hidden_states)
129
+ hidden_states = self.mlp(hidden_states)
130
+ hidden_states = self.hidden_dropout(hidden_states)
131
+ hidden_states = residual + hidden_states
132
+
133
+ return hidden_states, present_key_value
134
+
135
+
136
+ class TransformerModel(nn.Module):
137
+ """Main transformer model."""
138
+
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.config = config
142
+ self.vocab_size = config.vocab_size
143
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
144
+ self.layers = nn.ModuleList(
145
+ [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
146
+ )
147
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
148
+ self.gradient_checkpointing = False
149
+
150
+ # Initialize weights
151
+ self.apply(self._init_weights)
152
+
153
+ def _init_weights(self, module):
154
+ if isinstance(module, nn.Linear):
155
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
156
+ if module.bias is not None:
157
+ torch.nn.init.zeros_(module.bias)
158
+ elif isinstance(module, nn.Embedding):
159
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+
161
+ def get_input_embeddings(self):
162
+ return self.embed_tokens
163
+
164
+ def set_input_embeddings(self, value):
165
+ self.embed_tokens = value
166
+
167
+ def forward(
168
+ self,
169
+ input_ids: torch.LongTensor,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ position_ids: Optional[torch.LongTensor] = None,
172
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
173
+ use_cache: Optional[bool] = None,
174
+ output_attentions: bool = False,
175
+ output_hidden_states: bool = False,
176
+ return_dict: bool = True,
177
+ ) -> torch.Tensor:
178
+ batch_size, seq_length = input_ids.shape
179
+
180
+ # Embed tokens
181
+ hidden_states = self.embed_tokens(input_ids)
182
+
183
+ # Create position IDs
184
+ if position_ids is None:
185
+ position_ids = torch.arange(
186
+ seq_length, dtype=torch.long, device=input_ids.device
187
+ ).unsqueeze(0)
188
+
189
+ # Create causal mask
190
+ causal_mask = torch.triu(
191
+ torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device),
192
+ diagonal=1
193
+ ).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
194
+
195
+ if attention_mask is not None:
196
+ # Convert padding mask [batch, seq_len] to 4D [batch, 1, 1, seq_len]
197
+ # and combine with causal mask
198
+ expanded_mask = attention_mask[:, None, None, :] # [batch, 1, 1, seq_len]
199
+ expanded_mask = (1.0 - expanded_mask) * -10000.0 # Convert 0s to -inf
200
+ attention_mask = expanded_mask + causal_mask.expand(input_ids.shape[0], -1, -1, -1)
201
+ else:
202
+ attention_mask = causal_mask.expand(input_ids.shape[0], -1, -1, -1)
203
+
204
+ # Forward through layers
205
+ for layer in self.layers:
206
+ if self.gradient_checkpointing and self.training:
207
+ hidden_states, _ = torch.utils.checkpoint.checkpoint(
208
+ layer,
209
+ hidden_states,
210
+ attention_mask,
211
+ position_ids,
212
+ None,
213
+ False,
214
+ use_reentrant=False,
215
+ )
216
+ else:
217
+ hidden_states, _ = layer(
218
+ hidden_states,
219
+ attention_mask=attention_mask,
220
+ position_ids=position_ids,
221
+ past_key_value=None,
222
+ use_cache=False,
223
+ )
224
+
225
+ hidden_states = self.norm(hidden_states)
226
+ return hidden_states
227
+
228
+
229
+ class TransformerForCausalLM(nn.Module):
230
+ """Transformer model with language modeling head."""
231
+
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.model = TransformerModel(config)
235
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
236
+
237
+ # Tie weights
238
+ self.lm_head.weight = self.model.embed_tokens.weight
239
+
240
+ def forward(
241
+ self,
242
+ input_ids: torch.LongTensor,
243
+ attention_mask: Optional[torch.Tensor] = None,
244
+ position_ids: Optional[torch.LongTensor] = None,
245
+ labels: Optional[torch.LongTensor] = None,
246
+ use_cache: Optional[bool] = None,
247
+ output_attentions: bool = False,
248
+ output_hidden_states: bool = False,
249
+ return_dict: bool = True,
250
+ ) -> ModelOutput:
251
+ hidden_states = self.model(
252
+ input_ids=input_ids,
253
+ attention_mask=attention_mask,
254
+ position_ids=position_ids,
255
+ use_cache=use_cache,
256
+ output_attentions=output_attentions,
257
+ output_hidden_states=output_hidden_states,
258
+ return_dict=return_dict,
259
+ )
260
+
261
+ logits = self.lm_head(hidden_states)
262
+
263
+ loss = None
264
+ if labels is not None:
265
+ shift_logits = logits[..., :-1, :].contiguous()
266
+ shift_labels = labels[..., 1:].contiguous()
267
+ loss_fct = CrossEntropyLoss()
268
+ loss = loss_fct(
269
+ shift_logits.view(-1, shift_logits.size(-1)),
270
+ shift_labels.view(-1)
271
+ )
272
+
273
+ return ModelOutput(
274
+ loss=loss,
275
+ logits=logits,
276
+ hidden_states=hidden_states,
277
+ attentions=None,
278
+ )
279
+
280
+ def gradient_checkpointing_enable(self):
281
+ self.model.gradient_checkpointing = True