Spaces:
Running
on
Zero
Running
on
Zero
Initial Commit
Browse files- .gitignore +184 -0
- README.md +100 -6
- app.py +635 -0
- requirements.txt +14 -0
- src/inference/inference.py +231 -0
- src/model/layers.py +86 -0
- src/model/transformer.py +281 -0
.gitignore
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.nox/
|
42 |
+
.coverage
|
43 |
+
.coverage.*
|
44 |
+
.cache
|
45 |
+
nosetests.xml
|
46 |
+
coverage.xml
|
47 |
+
*.cover
|
48 |
+
*.py,cover
|
49 |
+
.hypothesis/
|
50 |
+
.pytest_cache/
|
51 |
+
|
52 |
+
# Translations
|
53 |
+
*.mo
|
54 |
+
*.pot
|
55 |
+
|
56 |
+
# Django stuff:
|
57 |
+
*.log
|
58 |
+
local_settings.py
|
59 |
+
db.sqlite3
|
60 |
+
db.sqlite3-journal
|
61 |
+
|
62 |
+
# Flask stuff:
|
63 |
+
instance/
|
64 |
+
.webassets-cache
|
65 |
+
|
66 |
+
# Scrapy stuff:
|
67 |
+
.scrapy
|
68 |
+
|
69 |
+
# Sphinx documentation
|
70 |
+
docs/_build/
|
71 |
+
|
72 |
+
# PyBuilder
|
73 |
+
target/
|
74 |
+
|
75 |
+
# Jupyter Notebook
|
76 |
+
.ipynb_checkpoints
|
77 |
+
|
78 |
+
# IPython
|
79 |
+
profile_default/
|
80 |
+
ipython_config.py
|
81 |
+
|
82 |
+
# pyenv
|
83 |
+
.python-version
|
84 |
+
|
85 |
+
# pipenv
|
86 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
87 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
88 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
89 |
+
# install all needed dependencies.
|
90 |
+
#Pipfile.lock
|
91 |
+
|
92 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
93 |
+
__pypackages__/
|
94 |
+
|
95 |
+
# Celery stuff
|
96 |
+
celerybeat-schedule
|
97 |
+
celerybeat.pid
|
98 |
+
|
99 |
+
# SageMath parsed files
|
100 |
+
*.sage.py
|
101 |
+
|
102 |
+
# Environments
|
103 |
+
.env
|
104 |
+
.venv
|
105 |
+
env/
|
106 |
+
venv/
|
107 |
+
ENV/
|
108 |
+
env.bak/
|
109 |
+
venv.bak/
|
110 |
+
|
111 |
+
# Spyder project settings
|
112 |
+
.spyderproject
|
113 |
+
.spyproject
|
114 |
+
|
115 |
+
# Rope project settings
|
116 |
+
.ropeproject
|
117 |
+
|
118 |
+
# mkdocs documentation
|
119 |
+
/site
|
120 |
+
|
121 |
+
# mypy
|
122 |
+
.mypy_cache/
|
123 |
+
.dmypy.json
|
124 |
+
dmypy.json
|
125 |
+
|
126 |
+
# Pyre type checker
|
127 |
+
.pyre/
|
128 |
+
|
129 |
+
# ML/AI specific files
|
130 |
+
*.pkl
|
131 |
+
*.pickle
|
132 |
+
*.h5
|
133 |
+
*.hdf5
|
134 |
+
*.ckpt
|
135 |
+
*.pt
|
136 |
+
*.pth
|
137 |
+
*.safetensors
|
138 |
+
|
139 |
+
# Model training artifacts
|
140 |
+
wandb/
|
141 |
+
runs/
|
142 |
+
logs/
|
143 |
+
tensorboard/
|
144 |
+
|
145 |
+
# Data directories
|
146 |
+
data/
|
147 |
+
datasets/
|
148 |
+
|
149 |
+
# Model checkpoints (excluding checkpoint directory as user mentioned git LFS)
|
150 |
+
# User tracks model files in checkpoints/ with git LFS, so we won't ignore it
|
151 |
+
|
152 |
+
# Temporary files
|
153 |
+
*.tmp
|
154 |
+
*.temp
|
155 |
+
.DS_Store
|
156 |
+
Thumbs.db
|
157 |
+
|
158 |
+
# IDE files
|
159 |
+
.vscode/
|
160 |
+
.idea/
|
161 |
+
*.swp
|
162 |
+
*.swo
|
163 |
+
*~
|
164 |
+
|
165 |
+
# OS generated files
|
166 |
+
.DS_Store
|
167 |
+
.DS_Store?
|
168 |
+
._*
|
169 |
+
.Spotlight-V100
|
170 |
+
.Trashes
|
171 |
+
ehthumbs.db
|
172 |
+
Thumbs.db
|
173 |
+
|
174 |
+
# Gradio temporary files
|
175 |
+
gradio_cached_examples/
|
176 |
+
flagged/
|
177 |
+
|
178 |
+
# HuggingFace cache
|
179 |
+
.cache/
|
180 |
+
cache/
|
181 |
+
|
182 |
+
# Local configuration files
|
183 |
+
config.local.*
|
184 |
+
.secrets
|
README.md
CHANGED
@@ -1,14 +1,108 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.43.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Custom LLM - Foundational Language Model
|
3 |
+
emoji: 🤖
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.43.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
+
models:
|
12 |
+
- gpt2
|
13 |
+
datasets:
|
14 |
+
- tiiuae/falcon-refinedweb
|
15 |
+
tags:
|
16 |
+
- text-generation
|
17 |
+
- transformer
|
18 |
+
- pytorch
|
19 |
+
- custom-model
|
20 |
+
- llm
|
21 |
+
- foundational-model
|
22 |
+
short_description: A custom 2B parameter foundational language model with streaming generation
|
23 |
---
|
24 |
|
25 |
+
# 🤖 Custom LLM - Foundational Language Model
|
26 |
+
|
27 |
+
A custom-trained foundational language model with **2 billion parameters**, built with modern transformer architecture and deployed with streaming text generation capabilities.
|
28 |
+
|
29 |
+
## 🚀 Features
|
30 |
+
|
31 |
+
- **Custom Architecture**: Modern transformer with RoPE (Rotary Position Embedding), RMSNorm, and SwiGLU activation
|
32 |
+
- **Streaming Generation**: Real-time text generation with token-by-token streaming
|
33 |
+
- **Flexible Sampling**: Configurable temperature, top-p, top-k, and repetition penalty
|
34 |
+
- **ZeroGPU Integration**: Optimized for Hugging Face Spaces with GPU acceleration
|
35 |
+
- **Responsive UI**: Clean, intuitive Gradio interface
|
36 |
+
|
37 |
+
## 📊 Model Details
|
38 |
+
|
39 |
+
| Specification | Value |
|
40 |
+
|---------------|-------|
|
41 |
+
| **Parameters** | ~2 billion |
|
42 |
+
| **Architecture** | Custom Transformer |
|
43 |
+
| **Context Length** | 2,048 tokens |
|
44 |
+
| **Vocab Size** | 50,257 (GPT-2 tokenizer) |
|
45 |
+
| **Layers** | 24 |
|
46 |
+
| **Attention Heads** | 32 |
|
47 |
+
| **Hidden Size** | 2,048 |
|
48 |
+
| **Intermediate Size** | 8,192 |
|
49 |
+
|
50 |
+
## 🏗️ Architecture Components
|
51 |
+
|
52 |
+
- **RMSNorm**: Root Mean Square Layer Normalization for better training stability
|
53 |
+
- **RoPE**: Rotary Position Embeddings for better length extrapolation
|
54 |
+
- **SwiGLU**: Switch GLU activation function for improved performance
|
55 |
+
- **Causal Attention**: Standard autoregressive attention mechanism
|
56 |
+
|
57 |
+
## 🎯 Training Details
|
58 |
+
|
59 |
+
- **Dataset**: Falcon RefinedWeb (curated web text)
|
60 |
+
- **Training Steps**: 100,000 steps
|
61 |
+
- **Learning Rate**: 6e-4 with warmup and decay
|
62 |
+
- **Batch Size**: 32 (4 per device × 8 accumulation steps)
|
63 |
+
- **Optimization**: AdamW with β1=0.9, β2=0.95
|
64 |
+
- **Precision**: Mixed precision (FP16)
|
65 |
+
|
66 |
+
## 🛠️ Generation Parameters
|
67 |
+
|
68 |
+
- **Max Tokens**: Control the length of generated text (1-1024)
|
69 |
+
- **Temperature**: Sampling randomness (0.1-2.0, higher = more creative)
|
70 |
+
- **Top-p**: Nucleus sampling threshold (0.1-1.0)
|
71 |
+
- **Top-k**: Top-k sampling limit (0-200, 0 = disabled)
|
72 |
+
- **Repetition Penalty**: Reduce repetitive text (1.0-2.0)
|
73 |
+
|
74 |
+
## 💡 Usage Tips
|
75 |
+
|
76 |
+
1. **For Creative Writing**: Use higher temperature (1.0-1.5) and top-p (0.9-0.95)
|
77 |
+
2. **For Factual Content**: Use lower temperature (0.3-0.7) and top-p (0.8-0.9)
|
78 |
+
3. **For Code Generation**: Use temperature ~0.2 with top-k filtering
|
79 |
+
4. **Longer Context**: The model handles up to 2,048 tokens of context
|
80 |
+
|
81 |
+
## 🚨 Limitations
|
82 |
+
|
83 |
+
- **Knowledge Cutoff**: Training data knowledge cutoff varies by source
|
84 |
+
- **Biases**: May reflect biases present in training data
|
85 |
+
- **Factuality**: Generated content should be verified for factual accuracy
|
86 |
+
- **Context Window**: Limited to 2,048 tokens (approximately 1,500 words)
|
87 |
+
|
88 |
+
## 🔧 Technical Implementation
|
89 |
+
|
90 |
+
The model uses a custom PyTorch implementation with:
|
91 |
+
- Efficient attention mechanisms
|
92 |
+
- Memory-optimized layer implementations
|
93 |
+
- Streaming generation with proper token handling
|
94 |
+
- GPU acceleration via ZeroGPU
|
95 |
+
|
96 |
+
## 📝 License
|
97 |
+
|
98 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
99 |
+
|
100 |
+
## 🙏 Acknowledgments
|
101 |
+
|
102 |
+
- Hugging Face for the Spaces platform and ZeroGPU infrastructure
|
103 |
+
- The open-source community for transformer implementations and best practices
|
104 |
+
- TII UAE for the Falcon RefinedWeb dataset
|
105 |
+
|
106 |
+
---
|
107 |
+
|
108 |
+
**Note**: This is a foundational language model trained for research and educational purposes. Please use responsibly and be aware of potential biases and limitations.
|
app.py
ADDED
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Gradio app for the custom LLM with streaming support and ZeroGPU integration."""
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Iterator, Optional, Union, List
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
import json
|
9 |
+
import warnings
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# Add src to path
|
14 |
+
sys.path.append(str(Path(__file__).parent))
|
15 |
+
|
16 |
+
warnings.filterwarnings("ignore")
|
17 |
+
|
18 |
+
try:
|
19 |
+
import spaces
|
20 |
+
HAS_SPACES = True
|
21 |
+
except ImportError:
|
22 |
+
HAS_SPACES = False
|
23 |
+
# Mock decorator for local testing
|
24 |
+
def spaces_decorator(gpu_memory=None):
|
25 |
+
def decorator(func):
|
26 |
+
return func
|
27 |
+
return decorator
|
28 |
+
spaces = type('MockSpaces', (), {'GPU': spaces_decorator})
|
29 |
+
|
30 |
+
from src.model.transformer import TransformerForCausalLM
|
31 |
+
|
32 |
+
|
33 |
+
class StreamingTextGenerator:
|
34 |
+
"""Streaming text generation for the custom LLM."""
|
35 |
+
|
36 |
+
def __init__(self, model, tokenizer, device='cuda'):
|
37 |
+
self.model = model
|
38 |
+
self.tokenizer = tokenizer
|
39 |
+
self.device = device
|
40 |
+
self.model.to(device)
|
41 |
+
self.model.eval()
|
42 |
+
|
43 |
+
def generate_stream(
|
44 |
+
self,
|
45 |
+
prompt: str,
|
46 |
+
max_new_tokens: int = 512,
|
47 |
+
temperature: float = 0.8,
|
48 |
+
top_p: float = 0.9,
|
49 |
+
top_k: Optional[int] = 50,
|
50 |
+
repetition_penalty: float = 1.1,
|
51 |
+
do_sample: bool = True,
|
52 |
+
) -> Iterator[str]:
|
53 |
+
"""Generate text with streaming output."""
|
54 |
+
|
55 |
+
# Tokenize prompt
|
56 |
+
inputs = self.tokenizer(
|
57 |
+
prompt,
|
58 |
+
return_tensors='pt',
|
59 |
+
padding=False,
|
60 |
+
truncation=True,
|
61 |
+
max_length=1024, # Leave room for generation
|
62 |
+
).to(self.device)
|
63 |
+
|
64 |
+
input_ids = inputs['input_ids']
|
65 |
+
attention_mask = inputs['attention_mask']
|
66 |
+
|
67 |
+
# Initialize generated sequence
|
68 |
+
generated_ids = input_ids.clone()
|
69 |
+
generated_text = prompt
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
for step in range(max_new_tokens):
|
73 |
+
# Get model predictions
|
74 |
+
outputs = self.model(
|
75 |
+
input_ids=generated_ids,
|
76 |
+
attention_mask=attention_mask,
|
77 |
+
)
|
78 |
+
|
79 |
+
# Get logits for the last token
|
80 |
+
next_token_logits = outputs.logits[0, -1, :].clone()
|
81 |
+
|
82 |
+
# Apply repetition penalty
|
83 |
+
if repetition_penalty != 1.0:
|
84 |
+
for token_id in set(generated_ids[0].tolist()):
|
85 |
+
next_token_logits[token_id] /= repetition_penalty
|
86 |
+
|
87 |
+
# Apply temperature
|
88 |
+
if temperature > 0:
|
89 |
+
next_token_logits = next_token_logits / temperature
|
90 |
+
|
91 |
+
# Apply top-k filtering
|
92 |
+
if top_k is not None and top_k > 0:
|
93 |
+
top_k_logits, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
|
94 |
+
min_top_k = top_k_logits[-1]
|
95 |
+
next_token_logits = torch.where(
|
96 |
+
next_token_logits < min_top_k,
|
97 |
+
torch.full_like(next_token_logits, float('-inf')),
|
98 |
+
next_token_logits
|
99 |
+
)
|
100 |
+
|
101 |
+
# Apply top-p (nucleus) filtering
|
102 |
+
if top_p < 1.0:
|
103 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
104 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
105 |
+
|
106 |
+
# Remove tokens with cumulative probability above threshold
|
107 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
108 |
+
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
109 |
+
sorted_indices_to_remove[0] = False
|
110 |
+
|
111 |
+
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
|
112 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
113 |
+
|
114 |
+
# Sample next token
|
115 |
+
if do_sample and temperature > 0:
|
116 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
117 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
118 |
+
else:
|
119 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
120 |
+
|
121 |
+
# Check for EOS token
|
122 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
123 |
+
break
|
124 |
+
|
125 |
+
# Append to generated sequence
|
126 |
+
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=-1)
|
127 |
+
|
128 |
+
# Update attention mask
|
129 |
+
attention_mask = torch.cat([
|
130 |
+
attention_mask,
|
131 |
+
torch.ones((1, 1), device=self.device, dtype=attention_mask.dtype)
|
132 |
+
], dim=-1)
|
133 |
+
|
134 |
+
# Decode and yield new token
|
135 |
+
new_text = self.tokenizer.decode(
|
136 |
+
generated_ids[0],
|
137 |
+
skip_special_tokens=True,
|
138 |
+
clean_up_tokenization_spaces=False
|
139 |
+
)
|
140 |
+
|
141 |
+
# Only yield the new part
|
142 |
+
if len(new_text) > len(generated_text):
|
143 |
+
generated_text = new_text
|
144 |
+
yield generated_text
|
145 |
+
|
146 |
+
|
147 |
+
def download_model_from_hf():
|
148 |
+
"""Download model from HuggingFace repository."""
|
149 |
+
from huggingface_hub import hf_hub_download
|
150 |
+
import os
|
151 |
+
|
152 |
+
model_repo = "dixisouls/VelocityLM"
|
153 |
+
cache_dir = Path("model_cache")
|
154 |
+
cache_dir.mkdir(exist_ok=True)
|
155 |
+
|
156 |
+
print("📥 Downloading model from HuggingFace...")
|
157 |
+
|
158 |
+
# Download config.json
|
159 |
+
config_path = hf_hub_download(
|
160 |
+
repo_id=model_repo,
|
161 |
+
filename="config.json",
|
162 |
+
cache_dir=cache_dir,
|
163 |
+
local_files_only=False
|
164 |
+
)
|
165 |
+
|
166 |
+
# Download pytorch_model.bin
|
167 |
+
model_path = hf_hub_download(
|
168 |
+
repo_id=model_repo,
|
169 |
+
filename="pytorch_model.bin",
|
170 |
+
cache_dir=cache_dir,
|
171 |
+
local_files_only=False
|
172 |
+
)
|
173 |
+
|
174 |
+
print("✅ Model downloaded successfully!")
|
175 |
+
return config_path, model_path
|
176 |
+
|
177 |
+
|
178 |
+
def load_model_and_tokenizer():
|
179 |
+
"""Load the trained model and tokenizer."""
|
180 |
+
import os
|
181 |
+
|
182 |
+
# Check if model exists locally, if not download from HF
|
183 |
+
cache_dir = Path("model_cache")
|
184 |
+
local_config = None
|
185 |
+
local_model = None
|
186 |
+
|
187 |
+
# Try to find cached files
|
188 |
+
if cache_dir.exists():
|
189 |
+
for root, dirs, files in os.walk(cache_dir):
|
190 |
+
if "config.json" in files:
|
191 |
+
local_config = Path(root) / "config.json"
|
192 |
+
if "pytorch_model.bin" in files:
|
193 |
+
local_model = Path(root) / "pytorch_model.bin"
|
194 |
+
|
195 |
+
# Download if not found locally
|
196 |
+
if not local_config or not local_model:
|
197 |
+
config_path, model_path = download_model_from_hf()
|
198 |
+
else:
|
199 |
+
config_path = str(local_config)
|
200 |
+
model_path = str(local_model)
|
201 |
+
print("📂 Using cached model files")
|
202 |
+
|
203 |
+
# Load config
|
204 |
+
with open(config_path, 'r') as f:
|
205 |
+
config = json.load(f)
|
206 |
+
|
207 |
+
# Create model config object
|
208 |
+
class ModelConfig:
|
209 |
+
def __init__(self, config_dict):
|
210 |
+
for key, value in config_dict.items():
|
211 |
+
setattr(self, key, value)
|
212 |
+
|
213 |
+
model_config = ModelConfig(config['model'])
|
214 |
+
|
215 |
+
# Load model
|
216 |
+
print("🔧 Initializing model...")
|
217 |
+
model = TransformerForCausalLM(model_config)
|
218 |
+
|
219 |
+
# Load state dict from pytorch_model.bin
|
220 |
+
print("📦 Loading model weights...")
|
221 |
+
model_state_dict = torch.load(
|
222 |
+
model_path,
|
223 |
+
map_location='cpu'
|
224 |
+
)
|
225 |
+
|
226 |
+
model.load_state_dict(model_state_dict, strict=False)
|
227 |
+
print("✅ Model weights loaded!")
|
228 |
+
|
229 |
+
# Load tokenizer
|
230 |
+
print("🔤 Loading tokenizer...")
|
231 |
+
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name'])
|
232 |
+
if tokenizer.pad_token is None:
|
233 |
+
tokenizer.pad_token = tokenizer.eos_token
|
234 |
+
|
235 |
+
print("🎉 Model and tokenizer ready!")
|
236 |
+
return model, tokenizer
|
237 |
+
|
238 |
+
|
239 |
+
# Global variables for model and generator
|
240 |
+
model = None
|
241 |
+
tokenizer = None
|
242 |
+
generator = None
|
243 |
+
|
244 |
+
def initialize_model():
|
245 |
+
"""Initialize model and tokenizer."""
|
246 |
+
global model, tokenizer, generator
|
247 |
+
|
248 |
+
if model is None:
|
249 |
+
print("Loading model and tokenizer...")
|
250 |
+
model, tokenizer = load_model_and_tokenizer()
|
251 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
252 |
+
generator = StreamingTextGenerator(model, tokenizer, device=device)
|
253 |
+
print(f"Model loaded on {device}")
|
254 |
+
|
255 |
+
|
256 |
+
@spaces.GPU(duration=120) if HAS_SPACES else lambda x: x
|
257 |
+
def generate_response(
|
258 |
+
prompt: str,
|
259 |
+
max_new_tokens: int = 512,
|
260 |
+
temperature: float = 0.8,
|
261 |
+
top_p: float = 0.9,
|
262 |
+
top_k: int = 50,
|
263 |
+
repetition_penalty: float = 1.1,
|
264 |
+
) -> Iterator[str]:
|
265 |
+
"""Generate streaming response."""
|
266 |
+
|
267 |
+
# Initialize model if needed
|
268 |
+
initialize_model()
|
269 |
+
|
270 |
+
if not prompt.strip():
|
271 |
+
yield "Please enter a prompt."
|
272 |
+
return
|
273 |
+
|
274 |
+
try:
|
275 |
+
# Generate with streaming
|
276 |
+
for partial_text in generator.generate_stream(
|
277 |
+
prompt=prompt,
|
278 |
+
max_new_tokens=max_new_tokens,
|
279 |
+
temperature=temperature,
|
280 |
+
top_p=top_p,
|
281 |
+
top_k=top_k if top_k > 0 else None,
|
282 |
+
repetition_penalty=repetition_penalty,
|
283 |
+
do_sample=temperature > 0,
|
284 |
+
):
|
285 |
+
yield partial_text
|
286 |
+
|
287 |
+
except Exception as e:
|
288 |
+
yield f"Error generating text: {str(e)}"
|
289 |
+
|
290 |
+
|
291 |
+
# Create Gradio interface
|
292 |
+
def create_interface():
|
293 |
+
"""Create the Gradio interface."""
|
294 |
+
|
295 |
+
# Custom CSS for enhanced UI
|
296 |
+
custom_css = """
|
297 |
+
.gradio-container {
|
298 |
+
max-width: 1200px !important;
|
299 |
+
margin: 0 auto !important;
|
300 |
+
}
|
301 |
+
|
302 |
+
.header-text {
|
303 |
+
text-align: center;
|
304 |
+
background: linear-gradient(45deg, #667eea 0%, #764ba2 100%);
|
305 |
+
-webkit-background-clip: text;
|
306 |
+
-webkit-text-fill-color: transparent;
|
307 |
+
background-clip: text;
|
308 |
+
font-size: 2.5em !important;
|
309 |
+
font-weight: bold !important;
|
310 |
+
margin-bottom: 0.5em !important;
|
311 |
+
}
|
312 |
+
|
313 |
+
.subtitle-text {
|
314 |
+
text-align: center;
|
315 |
+
color: #666;
|
316 |
+
font-size: 1.2em !important;
|
317 |
+
margin-bottom: 2em !important;
|
318 |
+
}
|
319 |
+
|
320 |
+
.parameter-box {
|
321 |
+
background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%) !important;
|
322 |
+
border-radius: 15px !important;
|
323 |
+
padding: 20px !important;
|
324 |
+
border: 1px solid #4a5568 !important;
|
325 |
+
}
|
326 |
+
|
327 |
+
.parameter-box summary {
|
328 |
+
color: #ffffff !important;
|
329 |
+
font-weight: bold !important;
|
330 |
+
background: rgba(255, 255, 255, 0.1) !important;
|
331 |
+
padding: 10px !important;
|
332 |
+
border-radius: 10px !important;
|
333 |
+
}
|
334 |
+
|
335 |
+
.parameter-box details summary {
|
336 |
+
color: #ffffff !important;
|
337 |
+
font-weight: bold !important;
|
338 |
+
}
|
339 |
+
|
340 |
+
/* Make ALL text white in the parameter box */
|
341 |
+
.parameter-box,
|
342 |
+
.parameter-box *,
|
343 |
+
.parameter-box label,
|
344 |
+
.parameter-box span,
|
345 |
+
.parameter-box p,
|
346 |
+
.parameter-box div,
|
347 |
+
.parameter-box small {
|
348 |
+
color: #ffffff !important;
|
349 |
+
}
|
350 |
+
|
351 |
+
/* Ensure input values are also white */
|
352 |
+
.parameter-box input[type="number"],
|
353 |
+
.parameter-box .gr-textbox input {
|
354 |
+
color: #ffffff !important;
|
355 |
+
background: rgba(255, 255, 255, 0.1) !important;
|
356 |
+
border: 1px solid #4a5568 !important;
|
357 |
+
}
|
358 |
+
|
359 |
+
/* Make the centered description text white too */
|
360 |
+
.parameter-box > p {
|
361 |
+
color: #ffffff !important;
|
362 |
+
text-align: center !important;
|
363 |
+
}
|
364 |
+
|
365 |
+
.output-box {
|
366 |
+
border-radius: 15px !important;
|
367 |
+
border: 1px solid #e1e5e9 !important;
|
368 |
+
}
|
369 |
+
|
370 |
+
.generate-btn {
|
371 |
+
background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
|
372 |
+
border: none !important;
|
373 |
+
color: white !important;
|
374 |
+
font-weight: bold !important;
|
375 |
+
font-size: 1.1em !important;
|
376 |
+
padding: 15px 30px !important;
|
377 |
+
border-radius: 25px !important;
|
378 |
+
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
|
379 |
+
transition: all 0.3s ease !important;
|
380 |
+
}
|
381 |
+
|
382 |
+
.generate-btn:hover {
|
383 |
+
transform: translateY(-2px) !important;
|
384 |
+
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
|
385 |
+
}
|
386 |
+
|
387 |
+
.clear-btn {
|
388 |
+
background: linear-gradient(45deg, #ff6b6b 0%, #ee5a24 100%) !important;
|
389 |
+
border: none !important;
|
390 |
+
color: white !important;
|
391 |
+
font-weight: bold !important;
|
392 |
+
border-radius: 20px !important;
|
393 |
+
padding: 10px 20px !important;
|
394 |
+
box-shadow: 0 2px 10px rgba(255, 107, 107, 0.3) !important;
|
395 |
+
}
|
396 |
+
|
397 |
+
.info-box {
|
398 |
+
background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%) !important;
|
399 |
+
border-radius: 15px !important;
|
400 |
+
padding: 20px !important;
|
401 |
+
border: 1px solid #f0c27b !important;
|
402 |
+
margin-top: 20px !important;
|
403 |
+
}
|
404 |
+
|
405 |
+
.example-box {
|
406 |
+
background: linear-gradient(135def, #e8f5e8 0%, #d4edda 100%) !important;
|
407 |
+
border-radius: 15px !important;
|
408 |
+
padding: 15px !important;
|
409 |
+
border: 1px solid #c3e6cb !important;
|
410 |
+
}
|
411 |
+
|
412 |
+
.metric-card {
|
413 |
+
background: white !important;
|
414 |
+
border-radius: 10px !important;
|
415 |
+
padding: 15px !important;
|
416 |
+
text-align: center !important;
|
417 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1) !important;
|
418 |
+
border-left: 4px solid #667eea !important;
|
419 |
+
}
|
420 |
+
|
421 |
+
.progress-bar {
|
422 |
+
background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
|
423 |
+
}
|
424 |
+
"""
|
425 |
+
|
426 |
+
with gr.Blocks(
|
427 |
+
title="VelocityLM - Fast Text Generation",
|
428 |
+
theme=gr.themes.Soft(
|
429 |
+
primary_hue="blue",
|
430 |
+
secondary_hue="purple",
|
431 |
+
neutral_hue="gray"
|
432 |
+
),
|
433 |
+
css=custom_css
|
434 |
+
) as demo:
|
435 |
+
|
436 |
+
# Header with gradient text
|
437 |
+
gr.HTML("""
|
438 |
+
<div style="text-align: center; margin-bottom: 2rem;">
|
439 |
+
<h1 class="header-text">VelocityLM</h1>
|
440 |
+
<p class="subtitle-text">Advanced 2B Parameter Foundational Language Model</p>
|
441 |
+
<div style="display: flex; justify-content: center; gap: 2rem; margin: 1.5rem 0;">
|
442 |
+
<div class="metric-card">
|
443 |
+
<h3 style="margin: 0; color: #667eea;">2B+</h3>
|
444 |
+
<p style="margin: 5px 0 0 0; color: #666; font-size: 0.9em;">Parameters</p>
|
445 |
+
</div>
|
446 |
+
<div class="metric-card">
|
447 |
+
<h3 style="margin: 0; color: #667eea;">2048</h3>
|
448 |
+
<p style="margin: 5px 0 0 0; color: #666; font-size: 0.9em;">Context Length</p>
|
449 |
+
</div>
|
450 |
+
</div>
|
451 |
+
</div>
|
452 |
+
""")
|
453 |
+
|
454 |
+
gr.Markdown(
|
455 |
+
"""
|
456 |
+
<div style="text-align: center; background: linear-gradient(135deg, #f8f9ff 0%, #e8f0ff 100%);
|
457 |
+
padding: 20px; border-radius: 15px; margin-bottom: 2rem; border: 1px solid #e1e8f7;">
|
458 |
+
<p style="margin: 0; font-size: 1.1em; color: #4a5568;">
|
459 |
+
🎯 <strong>Modern Architecture:</strong> RoPE • RMSNorm • SwiGLU • Multi-Head Attention<br>
|
460 |
+
✨ <strong>Features:</strong> Text Generation • Configurable Sampling • GPU Accelerated
|
461 |
+
</p>
|
462 |
+
</div>
|
463 |
+
""",
|
464 |
+
elem_classes=["info-box"]
|
465 |
+
)
|
466 |
+
|
467 |
+
with gr.Row(equal_height=True):
|
468 |
+
# Input Column
|
469 |
+
with gr.Column(scale=2, min_width=400):
|
470 |
+
gr.HTML("<div style='margin-bottom: 1rem;'><h3 style='color: #667eea; margin: 0;'>💬 Input Prompt</h3></div>")
|
471 |
+
|
472 |
+
prompt_input = gr.Textbox(
|
473 |
+
lines=6,
|
474 |
+
placeholder="✨ Enter your creative prompt here...\n\nExample: Write a story about a future where AI and humans collaborate to solve climate change...",
|
475 |
+
label="Your Prompt",
|
476 |
+
show_copy_button=True,
|
477 |
+
container=True,
|
478 |
+
elem_classes=["input-box"]
|
479 |
+
)
|
480 |
+
|
481 |
+
# Advanced Parameters Section
|
482 |
+
with gr.Accordion("🎛️ Advanced Generation Parameters", open=False, elem_classes=["parameter-box"]):
|
483 |
+
gr.HTML("<p style='text-align: center; color: #333; margin-bottom: 1rem;'>Fine-tune your generation settings</p>")
|
484 |
+
|
485 |
+
with gr.Row():
|
486 |
+
max_new_tokens = gr.Slider(
|
487 |
+
minimum=1,
|
488 |
+
maximum=1024,
|
489 |
+
value=512,
|
490 |
+
step=1,
|
491 |
+
label="🔢 Max New Tokens",
|
492 |
+
info="Maximum number of tokens to generate"
|
493 |
+
)
|
494 |
+
temperature = gr.Slider(
|
495 |
+
minimum=0.1,
|
496 |
+
maximum=2.0,
|
497 |
+
value=0.8,
|
498 |
+
step=0.1,
|
499 |
+
label="🌡️ Temperature",
|
500 |
+
info="Higher = more creative, lower = more focused"
|
501 |
+
)
|
502 |
+
|
503 |
+
with gr.Row():
|
504 |
+
top_p = gr.Slider(
|
505 |
+
minimum=0.1,
|
506 |
+
maximum=1.0,
|
507 |
+
value=0.9,
|
508 |
+
step=0.05,
|
509 |
+
label="🎯 Top-p",
|
510 |
+
info="Nucleus sampling threshold"
|
511 |
+
)
|
512 |
+
top_k = gr.Slider(
|
513 |
+
minimum=0,
|
514 |
+
maximum=200,
|
515 |
+
value=50,
|
516 |
+
step=5,
|
517 |
+
label="📊 Top-k",
|
518 |
+
info="Top-k sampling limit (0 = disabled)"
|
519 |
+
)
|
520 |
+
|
521 |
+
repetition_penalty = gr.Slider(
|
522 |
+
minimum=1.0,
|
523 |
+
maximum=2.0,
|
524 |
+
value=1.1,
|
525 |
+
step=0.05,
|
526 |
+
label="🔄 Repetition Penalty",
|
527 |
+
info="Reduce repetitive text (higher = less repetition)"
|
528 |
+
)
|
529 |
+
|
530 |
+
# Generate Button with enhanced styling
|
531 |
+
gr.HTML("<div style='margin: 1.5rem 0;'>")
|
532 |
+
generate_btn = gr.Button(
|
533 |
+
"🚀 Generate Text",
|
534 |
+
variant="primary",
|
535 |
+
size="lg",
|
536 |
+
elem_classes=["generate-btn"],
|
537 |
+
scale=1
|
538 |
+
)
|
539 |
+
gr.HTML("</div>")
|
540 |
+
|
541 |
+
# Quick Settings Presets
|
542 |
+
gr.HTML("<div style='margin-top: 1rem;'><h4 style='color: #667eea; margin-bottom: 0.5rem;'>⚡ Quick Presets</h4></div>")
|
543 |
+
with gr.Row():
|
544 |
+
creative_btn = gr.Button("🎨 Creative", size="sm", variant="secondary")
|
545 |
+
balanced_btn = gr.Button("⚖️ Balanced", size="sm", variant="secondary")
|
546 |
+
precise_btn = gr.Button("🎯 Precise", size="sm", variant="secondary")
|
547 |
+
|
548 |
+
# Output Column
|
549 |
+
with gr.Column(scale=3, min_width=500):
|
550 |
+
gr.HTML("<div style='margin-bottom: 1rem; display: flex; justify-content: space-between; align-items: center;'><h3 style='color: #667eea; margin: 0;'>📝 Generated Output</h3></div>")
|
551 |
+
|
552 |
+
output_text = gr.Textbox(
|
553 |
+
lines=22,
|
554 |
+
label="Generated Text",
|
555 |
+
show_copy_button=True,
|
556 |
+
interactive=False,
|
557 |
+
placeholder="Your generated text will appear here...\n\n✨ Streaming in real-time\n🚀 Powered by custom 2B parameter model",
|
558 |
+
elem_classes=["output-box"],
|
559 |
+
container=True
|
560 |
+
)
|
561 |
+
|
562 |
+
# Action buttons
|
563 |
+
with gr.Row():
|
564 |
+
clear_btn = gr.Button("🗑️ Clear All", variant="secondary", elem_classes=["clear-btn"])
|
565 |
+
|
566 |
+
# Enhanced Examples Section
|
567 |
+
gr.HTML("<div style='margin: 2rem 0;'><h3 style='color: #667eea; text-align: center; margin-bottom: 1rem;'>🎯 Example Prompts</h3></div>")
|
568 |
+
|
569 |
+
with gr.Accordion("📚 Prompt Examples", open=True, elem_classes=["example-box"]):
|
570 |
+
gr.Examples(
|
571 |
+
examples=[
|
572 |
+
["Once upon a time in a distant galaxy, there lived a civilization that had never seen the stars."],
|
573 |
+
["The old lighthouse keeper noticed something strange about the fog that night."],
|
574 |
+
["In the depths of the Amazon rainforest, Dr. Martinez made a discovery that would change everything."],
|
575 |
+
["The last bookstore on Earth was about to close its doors forever when"],
|
576 |
+
["As the spaceship approached the mysterious planet, the crew realized"],
|
577 |
+
["The clockmaker's shop had been abandoned for fifty years, but every morning at precisely 9 AM"],
|
578 |
+
["Deep beneath the city, in tunnels forgotten by time, archaeologist Elena found"],
|
579 |
+
["The message in a bottle had traveled across three oceans before washing ashore"],
|
580 |
+
],
|
581 |
+
inputs=[prompt_input],
|
582 |
+
label="Click any example to get started!",
|
583 |
+
examples_per_page=4
|
584 |
+
)
|
585 |
+
|
586 |
+
# Event handlers for main functionality
|
587 |
+
generate_btn.click(
|
588 |
+
fn=generate_response,
|
589 |
+
inputs=[
|
590 |
+
prompt_input,
|
591 |
+
max_new_tokens,
|
592 |
+
temperature,
|
593 |
+
top_p,
|
594 |
+
top_k,
|
595 |
+
repetition_penalty,
|
596 |
+
],
|
597 |
+
outputs=[output_text],
|
598 |
+
show_progress=True,
|
599 |
+
)
|
600 |
+
|
601 |
+
# Preset button handlers
|
602 |
+
creative_btn.click(
|
603 |
+
fn=lambda: (1.2, 0.95, 40, 1.05),
|
604 |
+
outputs=[temperature, top_p, top_k, repetition_penalty]
|
605 |
+
)
|
606 |
+
|
607 |
+
balanced_btn.click(
|
608 |
+
fn=lambda: (0.8, 0.9, 50, 1.1),
|
609 |
+
outputs=[temperature, top_p, top_k, repetition_penalty]
|
610 |
+
)
|
611 |
+
|
612 |
+
precise_btn.click(
|
613 |
+
fn=lambda: (0.3, 0.8, 20, 1.2),
|
614 |
+
outputs=[temperature, top_p, top_k, repetition_penalty]
|
615 |
+
)
|
616 |
+
|
617 |
+
# Utility button handlers
|
618 |
+
clear_btn.click(
|
619 |
+
fn=lambda: ("", ""),
|
620 |
+
outputs=[prompt_input, output_text]
|
621 |
+
)
|
622 |
+
|
623 |
+
|
624 |
+
return demo
|
625 |
+
|
626 |
+
|
627 |
+
if __name__ == "__main__":
|
628 |
+
# Initialize for local testing
|
629 |
+
demo = create_interface()
|
630 |
+
demo.launch(
|
631 |
+
server_name="127.0.0.1",
|
632 |
+
server_port=7860,
|
633 |
+
share=False,
|
634 |
+
debug=False,
|
635 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gradio app requirements for HuggingFace Spaces
|
2 |
+
gradio==4.44.0
|
3 |
+
spaces==0.29.4
|
4 |
+
|
5 |
+
# Core ML dependencies
|
6 |
+
torch==2.2.0
|
7 |
+
transformers==4.36.0
|
8 |
+
tokenizers==0.15.0
|
9 |
+
|
10 |
+
# Numerical computing
|
11 |
+
numpy==1.26.4
|
12 |
+
|
13 |
+
# Utilities
|
14 |
+
tqdm==4.66.1
|
src/inference/inference.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Text generation utilities for the trained model."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class TextGenerator:
|
13 |
+
"""Text generation with various decoding strategies."""
|
14 |
+
|
15 |
+
def __init__(self, model, tokenizer, device='cuda'):
|
16 |
+
self.model = model
|
17 |
+
self.tokenizer = tokenizer
|
18 |
+
self.device = device
|
19 |
+
self.model.to(device)
|
20 |
+
self.model.eval()
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def generate(
|
24 |
+
self,
|
25 |
+
prompt: Union[str, List[str]],
|
26 |
+
max_length: int = 100,
|
27 |
+
temperature: float = 1.0,
|
28 |
+
top_k: Optional[int] = 50,
|
29 |
+
top_p: Optional[float] = 0.9,
|
30 |
+
num_return_sequences: int = 1,
|
31 |
+
do_sample: bool = True,
|
32 |
+
repetition_penalty: float = 1.0,
|
33 |
+
) -> List[str]:
|
34 |
+
"""Generate text from prompt(s)."""
|
35 |
+
|
36 |
+
# Handle single string input
|
37 |
+
if isinstance(prompt, str):
|
38 |
+
prompts = [prompt]
|
39 |
+
else:
|
40 |
+
prompts = prompt
|
41 |
+
|
42 |
+
# Tokenize prompts
|
43 |
+
inputs = self.tokenizer(
|
44 |
+
prompts,
|
45 |
+
return_tensors='pt',
|
46 |
+
padding=True,
|
47 |
+
truncation=True,
|
48 |
+
max_length=max_length,
|
49 |
+
).to(self.device)
|
50 |
+
|
51 |
+
input_ids = inputs['input_ids']
|
52 |
+
attention_mask = inputs['attention_mask']
|
53 |
+
|
54 |
+
# Generate
|
55 |
+
batch_size = input_ids.shape[0]
|
56 |
+
generated_ids = input_ids.clone()
|
57 |
+
|
58 |
+
for _ in range(max_length - input_ids.shape[1]):
|
59 |
+
# Get model predictions
|
60 |
+
outputs = self.model(
|
61 |
+
input_ids=generated_ids,
|
62 |
+
attention_mask=attention_mask,
|
63 |
+
)
|
64 |
+
|
65 |
+
# Get logits for the last token
|
66 |
+
next_token_logits = outputs.logits[:, -1, :]
|
67 |
+
|
68 |
+
# Apply repetition penalty
|
69 |
+
if repetition_penalty != 1.0:
|
70 |
+
for i in range(batch_size):
|
71 |
+
for token_id in set(generated_ids[i].tolist()):
|
72 |
+
next_token_logits[i, token_id] /= repetition_penalty
|
73 |
+
|
74 |
+
# Apply temperature
|
75 |
+
if temperature != 1.0:
|
76 |
+
next_token_logits = next_token_logits / temperature
|
77 |
+
|
78 |
+
# Apply top-k filtering
|
79 |
+
if top_k is not None:
|
80 |
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
81 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
82 |
+
|
83 |
+
# Apply top-p (nucleus) filtering
|
84 |
+
if top_p is not None:
|
85 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
86 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
87 |
+
|
88 |
+
# Remove tokens with cumulative probability above the threshold
|
89 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
90 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
91 |
+
sorted_indices_to_remove[..., 0] = 0
|
92 |
+
|
93 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
94 |
+
1, sorted_indices, sorted_indices_to_remove
|
95 |
+
)
|
96 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
97 |
+
|
98 |
+
# Sample from the distribution
|
99 |
+
if do_sample:
|
100 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
101 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
102 |
+
else:
|
103 |
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
104 |
+
|
105 |
+
# Append to generated sequence
|
106 |
+
generated_ids = torch.cat([generated_ids, next_tokens.unsqueeze(1)], dim=1)
|
107 |
+
|
108 |
+
# Update attention mask
|
109 |
+
attention_mask = torch.cat([
|
110 |
+
attention_mask,
|
111 |
+
torch.ones((batch_size, 1), device=self.device)
|
112 |
+
], dim=1)
|
113 |
+
|
114 |
+
# Check for EOS token
|
115 |
+
if (next_tokens == self.tokenizer.eos_token_id).all():
|
116 |
+
break
|
117 |
+
|
118 |
+
# Decode generated sequences
|
119 |
+
generated_texts = []
|
120 |
+
for i in range(batch_size):
|
121 |
+
generated_text = self.tokenizer.decode(
|
122 |
+
generated_ids[i],
|
123 |
+
skip_special_tokens=True,
|
124 |
+
clean_up_tokenization_spaces=True
|
125 |
+
)
|
126 |
+
generated_texts.append(generated_text)
|
127 |
+
|
128 |
+
return generated_texts
|
129 |
+
|
130 |
+
def beam_search(
|
131 |
+
self,
|
132 |
+
prompt: str,
|
133 |
+
max_length: int = 100,
|
134 |
+
num_beams: int = 4,
|
135 |
+
length_penalty: float = 1.0,
|
136 |
+
early_stopping: bool = True,
|
137 |
+
) -> str:
|
138 |
+
"""Generate text using beam search."""
|
139 |
+
# Implementation of beam search
|
140 |
+
# This is a simplified version - full implementation would be more complex
|
141 |
+
|
142 |
+
inputs = self.tokenizer(
|
143 |
+
prompt,
|
144 |
+
return_tensors='pt',
|
145 |
+
truncation=True,
|
146 |
+
max_length=max_length,
|
147 |
+
).to(self.device)
|
148 |
+
|
149 |
+
# For now, fallback to greedy decoding
|
150 |
+
return self.generate(
|
151 |
+
prompt,
|
152 |
+
max_length=max_length,
|
153 |
+
do_sample=False,
|
154 |
+
num_return_sequences=1
|
155 |
+
)[0]
|
156 |
+
|
157 |
+
|
158 |
+
def load_generator(checkpoint_path: str, device: str = 'cuda'):
|
159 |
+
"""Load model and create generator."""
|
160 |
+
import yaml
|
161 |
+
from pathlib import Path
|
162 |
+
import sys
|
163 |
+
sys.path.append(str(Path(__file__).parent.parent.parent))
|
164 |
+
|
165 |
+
from src.model.transformer import TransformerForCausalLM
|
166 |
+
|
167 |
+
# Load config
|
168 |
+
config_path = Path(checkpoint_path) / 'config.json'
|
169 |
+
with open(config_path, 'r') as f:
|
170 |
+
import json
|
171 |
+
config = json.load(f)
|
172 |
+
|
173 |
+
# Create model config
|
174 |
+
class ModelConfig:
|
175 |
+
def __init__(self, config_dict):
|
176 |
+
for key, value in config_dict.items():
|
177 |
+
setattr(self, key, value)
|
178 |
+
|
179 |
+
model_config = ModelConfig(config['model'])
|
180 |
+
|
181 |
+
# Load model
|
182 |
+
model = TransformerForCausalLM(model_config)
|
183 |
+
state_dict = torch.load(
|
184 |
+
Path(checkpoint_path) / 'pytorch_model.bin',
|
185 |
+
map_location=device
|
186 |
+
)
|
187 |
+
model.load_state_dict(state_dict)
|
188 |
+
|
189 |
+
# Load tokenizer
|
190 |
+
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name'])
|
191 |
+
if tokenizer.pad_token is None:
|
192 |
+
tokenizer.pad_token = tokenizer.eos_token
|
193 |
+
|
194 |
+
# Create generator
|
195 |
+
generator = TextGenerator(model, tokenizer, device)
|
196 |
+
|
197 |
+
return generator
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == '__main__':
|
201 |
+
"""Example usage."""
|
202 |
+
import argparse
|
203 |
+
|
204 |
+
parser = argparse.ArgumentParser()
|
205 |
+
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
|
206 |
+
parser.add_argument('--prompt', type=str, required=True, help='Input prompt')
|
207 |
+
parser.add_argument('--max-length', type=int, default=100, help='Maximum generation length')
|
208 |
+
parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature')
|
209 |
+
parser.add_argument('--top-k', type=int, default=50, help='Top-k filtering')
|
210 |
+
parser.add_argument('--top-p', type=float, default=0.9, help='Top-p (nucleus) filtering')
|
211 |
+
parser.add_argument('--device', type=str, default='cuda', help='Device to use')
|
212 |
+
|
213 |
+
args = parser.parse_args()
|
214 |
+
|
215 |
+
# Load generator
|
216 |
+
print("Loading model...")
|
217 |
+
generator = load_generator(args.checkpoint, args.device)
|
218 |
+
|
219 |
+
# Generate text
|
220 |
+
print(f"Prompt: {args.prompt}")
|
221 |
+
print("Generating...")
|
222 |
+
|
223 |
+
generated = generator.generate(
|
224 |
+
args.prompt,
|
225 |
+
max_length=args.max_length,
|
226 |
+
temperature=args.temperature,
|
227 |
+
top_k=args.top_k,
|
228 |
+
top_p=args.top_p,
|
229 |
+
)
|
230 |
+
|
231 |
+
print(f"Generated: {generated[0]}")
|
src/model/layers.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom layers for the transformer model."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
class RMSNorm(nn.Module):
|
12 |
+
"""Root Mean Square Layer Normalization."""
|
13 |
+
|
14 |
+
def __init__(self, hidden_size, eps=1e-6):
|
15 |
+
super().__init__()
|
16 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
17 |
+
self.eps = eps
|
18 |
+
|
19 |
+
def forward(self, hidden_states):
|
20 |
+
input_dtype = hidden_states.dtype
|
21 |
+
hidden_states = hidden_states.to(torch.float32)
|
22 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
23 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
24 |
+
return self.weight * hidden_states.to(input_dtype)
|
25 |
+
|
26 |
+
|
27 |
+
class RotaryEmbedding(nn.Module):
|
28 |
+
"""Rotary Position Embedding."""
|
29 |
+
|
30 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
31 |
+
super().__init__()
|
32 |
+
self.dim = dim
|
33 |
+
self.max_position_embeddings = max_position_embeddings
|
34 |
+
self.base = base
|
35 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
36 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
37 |
+
|
38 |
+
# Build cached cos/sin
|
39 |
+
self._set_cos_sin_cache(
|
40 |
+
seq_len=max_position_embeddings,
|
41 |
+
device=self.inv_freq.device,
|
42 |
+
dtype=torch.get_default_dtype()
|
43 |
+
)
|
44 |
+
|
45 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
46 |
+
self.max_seq_len_cached = seq_len
|
47 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
48 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
49 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
50 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
51 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
52 |
+
|
53 |
+
def forward(self, x, seq_len=None):
|
54 |
+
if seq_len > self.max_seq_len_cached:
|
55 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
56 |
+
return (
|
57 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
58 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
59 |
+
)
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def rotate_half(x):
|
63 |
+
x1 = x[..., : x.shape[-1] // 2]
|
64 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
65 |
+
return torch.cat((-x2, x1), dim=-1)
|
66 |
+
|
67 |
+
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
|
68 |
+
cos = cos[position_ids].unsqueeze(1)
|
69 |
+
sin = sin[position_ids].unsqueeze(1)
|
70 |
+
q_embed = (q * cos) + (self.rotate_half(q) * sin)
|
71 |
+
k_embed = (k * cos) + (self.rotate_half(k) * sin)
|
72 |
+
return q_embed, k_embed
|
73 |
+
|
74 |
+
|
75 |
+
class SwiGLU(nn.Module):
|
76 |
+
"""SwiGLU activation function."""
|
77 |
+
|
78 |
+
def __init__(self, hidden_size, intermediate_size, hidden_act="silu"):
|
79 |
+
super().__init__()
|
80 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
81 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
82 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
83 |
+
self.act_fn = F.silu if hidden_act == "silu" else F.gelu
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
src/model/transformer.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""State-of-the-art Transformer model implementation."""
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.nn import CrossEntropyLoss
|
9 |
+
from dataclasses import dataclass
|
10 |
+
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
from .layers import RMSNorm, RotaryEmbedding, SwiGLU
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class ModelOutput:
|
19 |
+
"""Model output container."""
|
20 |
+
loss: Optional[torch.Tensor] = None
|
21 |
+
logits: Optional[torch.Tensor] = None
|
22 |
+
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
23 |
+
attentions: Optional[Tuple[torch.Tensor]] = None
|
24 |
+
|
25 |
+
|
26 |
+
class CausalSelfAttention(nn.Module):
|
27 |
+
"""Multi-head self-attention with causal mask and RoPE."""
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
super().__init__()
|
31 |
+
assert config.hidden_size % config.num_attention_heads == 0
|
32 |
+
|
33 |
+
self.num_attention_heads = config.num_attention_heads
|
34 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
35 |
+
self.hidden_size = config.hidden_size
|
36 |
+
|
37 |
+
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
38 |
+
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
39 |
+
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
40 |
+
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
41 |
+
|
42 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
43 |
+
self.rotary_emb = RotaryEmbedding(
|
44 |
+
self.head_dim,
|
45 |
+
max_position_embeddings=config.max_position_embeddings,
|
46 |
+
base=config.rope_theta,
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(
|
50 |
+
self,
|
51 |
+
hidden_states: torch.Tensor,
|
52 |
+
attention_mask: Optional[torch.Tensor] = None,
|
53 |
+
position_ids: Optional[torch.Tensor] = None,
|
54 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
55 |
+
use_cache: bool = False,
|
56 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
57 |
+
bsz, q_len, _ = hidden_states.size()
|
58 |
+
|
59 |
+
q = self.q_proj(hidden_states)
|
60 |
+
k = self.k_proj(hidden_states)
|
61 |
+
v = self.v_proj(hidden_states)
|
62 |
+
|
63 |
+
q = q.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
64 |
+
k = k.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
65 |
+
v = v.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
66 |
+
|
67 |
+
# Apply rotary embeddings
|
68 |
+
cos, sin = self.rotary_emb(v, seq_len=q_len)
|
69 |
+
q, k = self.rotary_emb.apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
70 |
+
|
71 |
+
# Flash attention or standard attention
|
72 |
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
73 |
+
|
74 |
+
if attention_mask is not None:
|
75 |
+
attn_weights = attn_weights + attention_mask
|
76 |
+
|
77 |
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
78 |
+
attn_weights = self.attention_dropout(attn_weights)
|
79 |
+
|
80 |
+
attn_output = torch.matmul(attn_weights, v)
|
81 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
82 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
83 |
+
attn_output = self.o_proj(attn_output)
|
84 |
+
|
85 |
+
return attn_output, None
|
86 |
+
|
87 |
+
|
88 |
+
class TransformerBlock(nn.Module):
|
89 |
+
"""Transformer block with RMSNorm and SwiGLU."""
|
90 |
+
|
91 |
+
def __init__(self, config):
|
92 |
+
super().__init__()
|
93 |
+
self.hidden_size = config.hidden_size
|
94 |
+
self.self_attn = CausalSelfAttention(config)
|
95 |
+
self.mlp = SwiGLU(
|
96 |
+
hidden_size=config.hidden_size,
|
97 |
+
intermediate_size=config.intermediate_size,
|
98 |
+
hidden_act=config.hidden_act,
|
99 |
+
)
|
100 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
|
101 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
|
102 |
+
self.hidden_dropout = nn.Dropout(config.hidden_dropout)
|
103 |
+
|
104 |
+
def forward(
|
105 |
+
self,
|
106 |
+
hidden_states: torch.Tensor,
|
107 |
+
attention_mask: Optional[torch.Tensor] = None,
|
108 |
+
position_ids: Optional[torch.Tensor] = None,
|
109 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
110 |
+
use_cache: bool = False,
|
111 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
112 |
+
residual = hidden_states
|
113 |
+
hidden_states = self.input_layernorm(hidden_states)
|
114 |
+
|
115 |
+
# Self Attention
|
116 |
+
hidden_states, present_key_value = self.self_attn(
|
117 |
+
hidden_states=hidden_states,
|
118 |
+
attention_mask=attention_mask,
|
119 |
+
position_ids=position_ids,
|
120 |
+
past_key_value=past_key_value,
|
121 |
+
use_cache=use_cache,
|
122 |
+
)
|
123 |
+
hidden_states = self.hidden_dropout(hidden_states)
|
124 |
+
hidden_states = residual + hidden_states
|
125 |
+
|
126 |
+
# MLP
|
127 |
+
residual = hidden_states
|
128 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
129 |
+
hidden_states = self.mlp(hidden_states)
|
130 |
+
hidden_states = self.hidden_dropout(hidden_states)
|
131 |
+
hidden_states = residual + hidden_states
|
132 |
+
|
133 |
+
return hidden_states, present_key_value
|
134 |
+
|
135 |
+
|
136 |
+
class TransformerModel(nn.Module):
|
137 |
+
"""Main transformer model."""
|
138 |
+
|
139 |
+
def __init__(self, config):
|
140 |
+
super().__init__()
|
141 |
+
self.config = config
|
142 |
+
self.vocab_size = config.vocab_size
|
143 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
144 |
+
self.layers = nn.ModuleList(
|
145 |
+
[TransformerBlock(config) for _ in range(config.num_hidden_layers)]
|
146 |
+
)
|
147 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
|
148 |
+
self.gradient_checkpointing = False
|
149 |
+
|
150 |
+
# Initialize weights
|
151 |
+
self.apply(self._init_weights)
|
152 |
+
|
153 |
+
def _init_weights(self, module):
|
154 |
+
if isinstance(module, nn.Linear):
|
155 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
156 |
+
if module.bias is not None:
|
157 |
+
torch.nn.init.zeros_(module.bias)
|
158 |
+
elif isinstance(module, nn.Embedding):
|
159 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
160 |
+
|
161 |
+
def get_input_embeddings(self):
|
162 |
+
return self.embed_tokens
|
163 |
+
|
164 |
+
def set_input_embeddings(self, value):
|
165 |
+
self.embed_tokens = value
|
166 |
+
|
167 |
+
def forward(
|
168 |
+
self,
|
169 |
+
input_ids: torch.LongTensor,
|
170 |
+
attention_mask: Optional[torch.Tensor] = None,
|
171 |
+
position_ids: Optional[torch.LongTensor] = None,
|
172 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
173 |
+
use_cache: Optional[bool] = None,
|
174 |
+
output_attentions: bool = False,
|
175 |
+
output_hidden_states: bool = False,
|
176 |
+
return_dict: bool = True,
|
177 |
+
) -> torch.Tensor:
|
178 |
+
batch_size, seq_length = input_ids.shape
|
179 |
+
|
180 |
+
# Embed tokens
|
181 |
+
hidden_states = self.embed_tokens(input_ids)
|
182 |
+
|
183 |
+
# Create position IDs
|
184 |
+
if position_ids is None:
|
185 |
+
position_ids = torch.arange(
|
186 |
+
seq_length, dtype=torch.long, device=input_ids.device
|
187 |
+
).unsqueeze(0)
|
188 |
+
|
189 |
+
# Create causal mask
|
190 |
+
causal_mask = torch.triu(
|
191 |
+
torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device),
|
192 |
+
diagonal=1
|
193 |
+
).unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
|
194 |
+
|
195 |
+
if attention_mask is not None:
|
196 |
+
# Convert padding mask [batch, seq_len] to 4D [batch, 1, 1, seq_len]
|
197 |
+
# and combine with causal mask
|
198 |
+
expanded_mask = attention_mask[:, None, None, :] # [batch, 1, 1, seq_len]
|
199 |
+
expanded_mask = (1.0 - expanded_mask) * -10000.0 # Convert 0s to -inf
|
200 |
+
attention_mask = expanded_mask + causal_mask.expand(input_ids.shape[0], -1, -1, -1)
|
201 |
+
else:
|
202 |
+
attention_mask = causal_mask.expand(input_ids.shape[0], -1, -1, -1)
|
203 |
+
|
204 |
+
# Forward through layers
|
205 |
+
for layer in self.layers:
|
206 |
+
if self.gradient_checkpointing and self.training:
|
207 |
+
hidden_states, _ = torch.utils.checkpoint.checkpoint(
|
208 |
+
layer,
|
209 |
+
hidden_states,
|
210 |
+
attention_mask,
|
211 |
+
position_ids,
|
212 |
+
None,
|
213 |
+
False,
|
214 |
+
use_reentrant=False,
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
hidden_states, _ = layer(
|
218 |
+
hidden_states,
|
219 |
+
attention_mask=attention_mask,
|
220 |
+
position_ids=position_ids,
|
221 |
+
past_key_value=None,
|
222 |
+
use_cache=False,
|
223 |
+
)
|
224 |
+
|
225 |
+
hidden_states = self.norm(hidden_states)
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
class TransformerForCausalLM(nn.Module):
|
230 |
+
"""Transformer model with language modeling head."""
|
231 |
+
|
232 |
+
def __init__(self, config):
|
233 |
+
super().__init__()
|
234 |
+
self.model = TransformerModel(config)
|
235 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
236 |
+
|
237 |
+
# Tie weights
|
238 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
239 |
+
|
240 |
+
def forward(
|
241 |
+
self,
|
242 |
+
input_ids: torch.LongTensor,
|
243 |
+
attention_mask: Optional[torch.Tensor] = None,
|
244 |
+
position_ids: Optional[torch.LongTensor] = None,
|
245 |
+
labels: Optional[torch.LongTensor] = None,
|
246 |
+
use_cache: Optional[bool] = None,
|
247 |
+
output_attentions: bool = False,
|
248 |
+
output_hidden_states: bool = False,
|
249 |
+
return_dict: bool = True,
|
250 |
+
) -> ModelOutput:
|
251 |
+
hidden_states = self.model(
|
252 |
+
input_ids=input_ids,
|
253 |
+
attention_mask=attention_mask,
|
254 |
+
position_ids=position_ids,
|
255 |
+
use_cache=use_cache,
|
256 |
+
output_attentions=output_attentions,
|
257 |
+
output_hidden_states=output_hidden_states,
|
258 |
+
return_dict=return_dict,
|
259 |
+
)
|
260 |
+
|
261 |
+
logits = self.lm_head(hidden_states)
|
262 |
+
|
263 |
+
loss = None
|
264 |
+
if labels is not None:
|
265 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
266 |
+
shift_labels = labels[..., 1:].contiguous()
|
267 |
+
loss_fct = CrossEntropyLoss()
|
268 |
+
loss = loss_fct(
|
269 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
270 |
+
shift_labels.view(-1)
|
271 |
+
)
|
272 |
+
|
273 |
+
return ModelOutput(
|
274 |
+
loss=loss,
|
275 |
+
logits=logits,
|
276 |
+
hidden_states=hidden_states,
|
277 |
+
attentions=None,
|
278 |
+
)
|
279 |
+
|
280 |
+
def gradient_checkpointing_enable(self):
|
281 |
+
self.model.gradient_checkpointing = True
|