Files changed (1) hide show
  1. app.py +108 -38
app.py CHANGED
@@ -1,51 +1,121 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
 
 
 
 
 
 
 
 
4
 
5
- def train_model(model_name, dataset_name, text_column="text"):
6
- try:
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- dataset = load_dataset(dataset_name)
 
 
 
 
 
 
11
 
12
- def tokenize(batch):
13
- return tokenizer(batch[text_column], padding="max_length", truncation=True)
 
 
14
 
15
- tokenized = dataset.map(tokenize, batched=True)
 
 
 
 
16
 
17
- training_args = TrainingArguments(
18
- output_dir="./results",
19
- num_train_epochs=1,
20
- per_device_train_batch_size=2,
21
- save_steps=10,
22
- save_total_limit=1
23
- )
24
 
25
- trainer = Trainer(
26
- model=model,
27
- args=training_args,
28
- train_dataset=tokenized["train"]
29
- )
30
 
31
- trainer.train()
32
- model.save_pretrained("./custom_model")
33
 
34
- return f"βœ… Training complete! Model saved at ./custom_model"
35
- except Exception as e:
36
- return f"❌ Error: {str(e)}"
37
 
38
- with gr.Blocks() as demo:
39
- gr.Markdown("# πŸ€– AI Model Builder (Hugging Face Space)")
40
- model_name = gr.Textbox(value="gpt2", label="Base model")
41
- dataset_name = gr.Textbox(value="wikitext", label="Dataset name (from HF Datasets)")
42
- train_button = gr.Button("Train Model")
43
- output = gr.Textbox(label="Status")
44
-
45
- train_button.click(
46
- fn=train_model,
47
- inputs=[model_name, dataset_name],
48
- outputs=output
 
 
 
 
 
 
 
 
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ Trainer,
6
+ TrainingArguments,
7
+ DataCollatorForLanguageModeling,
8
+ )
9
+ from datasets import load_dataset, Dataset
10
+ import torch
11
+ import os
12
 
13
+ # Default model path
14
+ MODEL_DIR = "./custom_model"
 
 
15
 
16
+ # ---------- Dataset Handling ----------
17
+ def get_dataset(dataset_name, config_name=None, user_file=None):
18
+ if user_file is not None:
19
+ # Load user-uploaded text dataset
20
+ with open(user_file.name, "r", encoding="utf-8") as f:
21
+ text_data = f.read().splitlines()
22
+ return Dataset.from_dict({"text": text_data})
23
 
24
+ if config_name:
25
+ return load_dataset(dataset_name, config_name)
26
+ else:
27
+ return load_dataset(dataset_name)
28
 
29
+ # ---------- Training ----------
30
+ def train_model(model_name, dataset_name, config_name, user_file, epochs, output_dir):
31
+ # Load tokenizer & model
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+ model = AutoModelForCausalLM.from_pretrained(model_name)
34
 
35
+ # Load dataset
36
+ dataset = get_dataset(dataset_name, config_name, user_file)
 
 
 
 
 
37
 
38
+ # Tokenize
39
+ def tokenize_function(examples):
40
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
 
 
41
 
42
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
 
43
 
44
+ # Data collator
45
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
 
46
 
47
+ # Training args
48
+ training_args = TrainingArguments(
49
+ output_dir=output_dir,
50
+ evaluation_strategy="no",
51
+ learning_rate=2e-5,
52
+ per_device_train_batch_size=2,
53
+ num_train_epochs=int(epochs),
54
+ weight_decay=0.01,
55
+ save_total_limit=1,
56
+ logging_steps=5,
57
+ )
58
+
59
+ # Trainer
60
+ trainer = Trainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=tokenized_dataset["train"] if "train" in tokenized_dataset else tokenized_dataset,
64
+ tokenizer=tokenizer,
65
+ data_collator=data_collator,
66
  )
67
 
68
+ trainer.train()
69
+ trainer.save_model(output_dir)
70
+ tokenizer.save_pretrained(output_dir)
71
+
72
+ return f"βœ… Training complete! Model saved to {output_dir}"
73
+
74
+ # ---------- Chat ----------
75
+ def chat_with_model(prompt, history):
76
+ if not os.path.exists(MODEL_DIR):
77
+ return "⚠️ No trained model found yet. Train one first!"
78
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
79
+ model = AutoModelForCausalLM.from_pretrained(MODEL_DIR)
80
+
81
+ inputs = tokenizer(prompt, return_tensors="pt")
82
+ with torch.no_grad():
83
+ outputs = model.generate(**inputs, max_length=200, pad_token_id=tokenizer.eos_token_id)
84
+
85
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+ return response
87
+
88
+ # ---------- UI ----------
89
+ with gr.Blocks() as demo:
90
+ gr.Markdown("# 🧠 Custom AI Model Builder")
91
+ gr.Markdown("Train and chat with your **own model** directly in Hugging Face.")
92
+
93
+ with gr.Tab("Train Model"):
94
+ model_name = gr.Textbox(label="Base Model (e.g. gpt2, distilgpt2, codeparrot-small)", value="distilgpt2")
95
+ dataset_name = gr.Textbox(label="Dataset Name (HuggingFace hub, e.g. wikitext, imdb)", value="wikitext")
96
+ config_name = gr.Textbox(label="Config (optional, e.g. wikitext-2-raw-v1)", value="wikitext-2-raw-v1")
97
+ user_file = gr.File(label="Or Upload Your Own TXT Dataset", file_types=[".txt"], type="file")
98
+ epochs = gr.Number(label="Epochs", value=1, precision=0)
99
+ output_dir = gr.Textbox(label="Output Directory", value=MODEL_DIR)
100
+ train_button = gr.Button("πŸš€ Start Training")
101
+ train_output = gr.Textbox(label="Training Logs")
102
+
103
+ train_button.click(
104
+ train_model,
105
+ inputs=[model_name, dataset_name, config_name, user_file, epochs, output_dir],
106
+ outputs=train_output,
107
+ )
108
+
109
+ with gr.Tab("Chat with Model"):
110
+ chatbot = gr.Chatbot()
111
+ msg = gr.Textbox(label="Message")
112
+ send = gr.Button("Send")
113
+
114
+ def respond(message, chat_history):
115
+ response = chat_with_model(message, chat_history)
116
+ chat_history.append((message, response))
117
+ return "", chat_history
118
+
119
+ send.click(respond, [msg, chatbot], [msg, chatbot])
120
+
121
  demo.launch()