Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Initialize model and tokenizer | |
model_id = "Tesslate/WEBGEN-4B-Preview" | |
# Load model and tokenizer once during app initialization | |
tok = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
def generate_code(prompt): | |
inputs = tok(prompt, return_tensors="pt").to(model.device) | |
# Generate with streaming | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
streamer = TextIteratorStreamer(tok, skip_special_tokens=True) | |
generation_kwargs = dict( | |
**inputs, | |
max_new_tokens=10000, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tok.eos_token_id, | |
streamer=streamer | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
# Extract only the code portion (remove prompt and any non-code text) | |
if "```html" in generated_text: | |
code_start = generated_text.find("```html") + 7 | |
code_end = generated_text.find("```", code_start) | |
if code_end != -1: | |
clean_code = generated_text[code_start:code_end].strip() | |
else: | |
clean_code = generated_text[code_start:].strip() | |
elif "<html" in generated_text or "<!DOCTYPE" in generated_text: | |
# Find the start of HTML code | |
html_start = generated_text.find("<!DOCTYPE") | |
if html_start == -1: | |
html_start = generated_text.find("<html") | |
if html_start != -1: | |
clean_code = generated_text[html_start:].strip() | |
else: | |
clean_code = generated_text | |
else: | |
clean_code = generated_text | |
yield clean_code | |
thread.join() | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🧪 Text-to-Code Generator") | |
gr.Markdown("Generate HTML code from natural language prompts with WEBGEN-4B Preview model") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
value="Make a single-file landing page for 'LatticeDB'. Style: modern, generous whitespace, Tailwind, rounded-xl, soft gradients. Sections: navbar, hero (headline + 2 CTAs), features grid, pricing (3 tiers), FAQ accordion, footer. Constraints: semantic HTML, no external JS.", | |
lines=5, | |
max_lines=10 | |
) | |
generate_button = gr.Button("Generate Code") | |
code_output = gr.Code(label="Generated HTML", language="html", lines=20, interactive=False) | |
with gr.Column(): | |
html_output = gr.HTML(label="Preview") | |
# When button is clicked, generate code and update both outputs | |
generate_button.click( | |
fn=generate_code, | |
inputs=prompt_input, | |
outputs=code_output | |
).then( | |
fn=lambda code: code, | |
inputs=code_output, | |
outputs=html_output | |
) | |
demo.launch() |