Files changed (1) hide show
  1. app.py +106 -79
app.py CHANGED
@@ -1,111 +1,138 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
  from datasets import load_dataset, Dataset
4
-
5
- # --------------------------
6
- # Dataset Loader
7
- # --------------------------
 
 
 
 
 
 
 
 
8
  def get_dataset(dataset_name, config_name=None, user_file=None):
9
- if user_file is not None:
10
- with open(user_file, "r", encoding="utf-8") as f:
11
- text_data = f.read().splitlines()
12
- return Dataset.from_dict({"text": text_data})
13
-
14
- if config_name:
15
- return load_dataset(dataset_name, config_name, split="train")
16
- else:
17
- return load_dataset(dataset_name, split="train")
18
-
19
- # --------------------------
20
- # Training Function
21
- # --------------------------
22
- def train_model(model_name, dataset_name, config_name, user_file, output_dir, epochs=1):
23
  try:
24
- dataset = get_dataset(dataset_name, config_name, user_file)
25
- tokenizer = AutoTokenizer.from_pretrained(model_name)
26
-
27
- def tokenize_function(examples):
28
- return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
29
-
30
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
31
- model = AutoModelForCausalLM.from_pretrained(model_name)
32
-
33
- training_args = TrainingArguments(
34
- output_dir=output_dir,
35
- overwrite_output_dir=True,
36
- per_device_train_batch_size=2,
37
- num_train_epochs=epochs,
38
- save_strategy="epoch",
39
- logging_dir="./logs",
40
- logging_steps=10,
41
- )
42
-
43
- trainer = Trainer(
44
- model=model,
45
- args=training_args,
46
- train_dataset=tokenized_dataset,
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
49
  trainer.train()
50
  model.save_pretrained(output_dir)
51
  tokenizer.save_pretrained(output_dir)
 
 
 
52
 
53
- return f"βœ… Training complete! Model saved to {output_dir}"
 
 
 
54
 
55
- except Exception as e:
56
- return f"❌ Error: {str(e)}"
 
 
57
 
58
- # --------------------------
59
- # Chatbot with trained model
60
- # --------------------------
61
- def chat_with_model(user_input, model_name="custom_model"):
62
  try:
63
- tokenizer = AutoTokenizer.from_pretrained(model_name)
64
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
65
 
66
- inputs = tokenizer(user_input, return_tensors="pt")
67
- outputs = model.generate(**inputs, max_length=200)
68
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
69
 
 
 
70
  except Exception as e:
71
- return f"⚠️ Model not ready yet. Error: {str(e)}"
72
 
73
- # --------------------------
74
  # Gradio UI
75
- # --------------------------
76
  with gr.Blocks() as demo:
77
- gr.Markdown("## 🧠 Custom AI Model Trainer + Chatbot")
78
 
79
- with gr.Tab("Train a Model"):
80
  model_name = gr.Textbox(label="Base Model (e.g. gpt2, distilgpt2)", value="gpt2")
81
- dataset_name = gr.Textbox(label="Dataset (e.g. wikitext)", value="wikitext")
82
- config_name = gr.Dropdown(
83
- label="Dataset Config (if needed)",
84
- choices=["", "wikitext-103-raw-v1", "wikitext-103-v1", "wikitext-2-raw-v1", "wikitext-2-v1"],
85
- value=""
86
- )
87
- user_file = gr.File(label="Upload TXT Dataset", file_types=[".txt"], type="filepath")
88
- output_dir = gr.Textbox(label="Output Directory", value="custom_model")
89
- epochs = gr.Slider(1, 5, value=1, step=1, label="Epochs")
90
- train_button = gr.Button("πŸš€ Train Model")
91
  train_output = gr.Textbox(label="Training Logs / Status")
92
 
93
  train_button.click(
94
  fn=train_model,
95
- inputs=[model_name, dataset_name, config_name, user_file, output_dir, epochs],
96
- outputs=train_output
97
  )
98
 
99
- with gr.Tab("Chat with Your Model"):
100
- user_input = gr.Textbox(label="Your Message")
101
- chat_output = gr.Textbox(label="Model Response")
102
- chat_button = gr.Button("πŸ’¬ Chat")
103
 
104
  chat_button.click(
105
  fn=chat_with_model,
106
- inputs=[user_input],
107
  outputs=chat_output
108
  )
109
 
110
- if __name__ == "__main__":
111
- demo.launch()
 
1
  import gradio as gr
 
2
  from datasets import load_dataset, Dataset
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ Trainer,
7
+ TrainingArguments,
8
+ pipeline
9
+ )
10
+ import os
11
+
12
+ # -------------------------
13
+ # Helpers
14
+ # -------------------------
15
  def get_dataset(dataset_name, config_name=None, user_file=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
+ if user_file is not None:
18
+ with open(user_file, "r", encoding="utf-8") as f:
19
+ text_data = f.read().splitlines()
20
+ return Dataset.from_dict({"text": text_data})
21
+ elif dataset_name:
22
+ return load_dataset(dataset_name, config_name, split="train")
23
+ except Exception as e:
24
+ return None
25
+ return None
26
+
27
+ def train_model(model_name, dataset_name, config_name, user_file, output_dir, epochs, lr):
28
+ dataset = get_dataset(dataset_name, config_name, user_file)
29
+ if dataset is None:
30
+ return "❌ Error: Could not load dataset. Check name or file.", None
31
+
32
+ # Tokenizer
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+
35
+ # Fix GPT-2 style models (no pad token)
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
38
+
39
+ def tokenize_function(examples):
40
+ text_key = "text" if "text" in examples else list(examples.keys())[0]
41
+ return tokenizer(examples[text_key],
42
+ truncation=True,
43
+ padding="max_length",
44
+ max_length=128)
45
+
46
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
47
+
48
+ # Model
49
+ model = AutoModelForCausalLM.from_pretrained(model_name)
50
+ model.resize_token_embeddings(len(tokenizer))
51
+
52
+ # Training args
53
+ training_args = TrainingArguments(
54
+ output_dir=output_dir,
55
+ evaluation_strategy="no",
56
+ learning_rate=float(lr),
57
+ per_device_train_batch_size=2,
58
+ num_train_epochs=int(epochs),
59
+ weight_decay=0.01,
60
+ save_strategy="epoch",
61
+ logging_dir=os.path.join(output_dir, "logs"),
62
+ push_to_hub=False
63
+ )
64
+
65
+ trainer = Trainer(
66
+ model=model,
67
+ args=training_args,
68
+ train_dataset=tokenized_dataset
69
+ )
70
 
71
+ try:
72
  trainer.train()
73
  model.save_pretrained(output_dir)
74
  tokenizer.save_pretrained(output_dir)
75
+ return f"βœ… Training complete! Model saved to `{output_dir}`", output_dir
76
+ except Exception as e:
77
+ return f"❌ Training failed: {str(e)}", None
78
 
79
+ # -------------------------
80
+ # Chat interface
81
+ # -------------------------
82
+ chat_history = []
83
 
84
+ def chat_with_model(user_input, model_dir):
85
+ global chat_history
86
+ if not model_dir or not os.path.exists(model_dir):
87
+ return "⚠️ No trained model found. Please train first."
88
 
 
 
 
 
89
  try:
90
+ generator = pipeline("text-generation", model=model_dir)
91
+ conversation = " ".join([f"User: {u}\nAI: {a}" for u, a in chat_history])
92
+ prompt = f"{conversation}\nUser: {user_input}\nAI:"
93
 
94
+ response = generator(prompt, max_length=200, num_return_sequences=1)[0]["generated_text"]
95
+ # Extract AI response after last "AI:"
96
+ ai_reply = response.split("AI:")[-1].strip()
97
 
98
+ chat_history.append((user_input, ai_reply))
99
+ return ai_reply
100
  except Exception as e:
101
+ return f"❌ Chat error: {str(e)}"
102
 
103
+ # -------------------------
104
  # Gradio UI
105
+ # -------------------------
106
  with gr.Blocks() as demo:
107
+ gr.Markdown("# 🧠 Personal AI Model Builder\nTrain + Chat with your own AI assistant.")
108
 
109
+ with gr.Tab("1️⃣ Train Model"):
110
  model_name = gr.Textbox(label="Base Model (e.g. gpt2, distilgpt2)", value="gpt2")
111
+ dataset_name = gr.Textbox(label="HuggingFace Dataset (optional, e.g. wikitext)", value="wikitext")
112
+ config_name = gr.Textbox(label="Dataset Config (e.g. wikitext-2-raw-v1)", value="wikitext-2-raw-v1")
113
+ user_file = gr.File(label="Or Upload Your Own TXT Dataset", file_types=[".txt"], type="filepath")
114
+ output_dir = gr.Textbox(label="Output Directory", value="./custom_model")
115
+ epochs = gr.Number(label="Epochs", value=1, precision=0)
116
+ lr = gr.Textbox(label="Learning Rate", value="5e-5")
117
+
118
+ train_button = gr.Button("πŸš€ Train My Model")
 
 
119
  train_output = gr.Textbox(label="Training Logs / Status")
120
 
121
  train_button.click(
122
  fn=train_model,
123
+ inputs=[model_name, dataset_name, config_name, user_file, output_dir, epochs, lr],
124
+ outputs=[train_output, output_dir]
125
  )
126
 
127
+ with gr.Tab("2️⃣ Chat With Model"):
128
+ chat_input = gr.Textbox(label="Your Message")
129
+ chat_button = gr.Button("πŸ’¬ Send")
130
+ chat_output = gr.Textbox(label="AI Reply")
131
 
132
  chat_button.click(
133
  fn=chat_with_model,
134
+ inputs=[chat_input, output_dir],
135
  outputs=chat_output
136
  )
137
 
138
+ demo.launch()