Files changed (1) hide show
  1. app.py +103 -117
app.py CHANGED
@@ -1,138 +1,124 @@
1
  import gradio as gr
2
- from datasets import load_dataset, Dataset
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
6
  Trainer,
7
  TrainingArguments,
8
- pipeline
9
  )
10
- import os
11
-
12
- # -------------------------
13
- # Helpers
14
- # -------------------------
15
- def get_dataset(dataset_name, config_name=None, user_file=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
- if user_file is not None:
18
- with open(user_file, "r", encoding="utf-8") as f:
19
- text_data = f.read().splitlines()
20
- return Dataset.from_dict({"text": text_data})
21
- elif dataset_name:
22
- return load_dataset(dataset_name, config_name, split="train")
23
- except Exception as e:
24
- return None
25
- return None
26
-
27
- def train_model(model_name, dataset_name, config_name, user_file, output_dir, epochs, lr):
28
- dataset = get_dataset(dataset_name, config_name, user_file)
29
- if dataset is None:
30
- return "❌ Error: Could not load dataset. Check name or file.", None
31
-
32
- # Tokenizer
33
- tokenizer = AutoTokenizer.from_pretrained(model_name)
34
-
35
- # Fix GPT-2 style models (no pad token)
36
- if tokenizer.pad_token is None:
37
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
38
-
39
- def tokenize_function(examples):
40
- text_key = "text" if "text" in examples else list(examples.keys())[0]
41
- return tokenizer(examples[text_key],
42
- truncation=True,
43
- padding="max_length",
44
- max_length=128)
45
-
46
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
47
-
48
- # Model
49
- model = AutoModelForCausalLM.from_pretrained(model_name)
50
- model.resize_token_embeddings(len(tokenizer))
51
-
52
- # Training args
53
- training_args = TrainingArguments(
54
- output_dir=output_dir,
55
- evaluation_strategy="no",
56
- learning_rate=float(lr),
57
- per_device_train_batch_size=2,
58
- num_train_epochs=int(epochs),
59
- weight_decay=0.01,
60
- save_strategy="epoch",
61
- logging_dir=os.path.join(output_dir, "logs"),
62
- push_to_hub=False
63
- )
64
-
65
- trainer = Trainer(
66
- model=model,
67
- args=training_args,
68
- train_dataset=tokenized_dataset
69
- )
70
 
71
- try:
72
  trainer.train()
73
- model.save_pretrained(output_dir)
74
- tokenizer.save_pretrained(output_dir)
75
- return f"βœ… Training complete! Model saved to `{output_dir}`", output_dir
76
- except Exception as e:
77
- return f"❌ Training failed: {str(e)}", None
78
 
79
- # -------------------------
80
- # Chat interface
81
- # -------------------------
82
- chat_history = []
83
 
84
- def chat_with_model(user_input, model_dir):
85
- global chat_history
86
- if not model_dir or not os.path.exists(model_dir):
87
- return "⚠️ No trained model found. Please train first."
88
 
 
 
89
  try:
90
- generator = pipeline("text-generation", model=model_dir)
91
- conversation = " ".join([f"User: {u}\nAI: {a}" for u, a in chat_history])
92
- prompt = f"{conversation}\nUser: {user_input}\nAI:"
93
-
94
- response = generator(prompt, max_length=200, num_return_sequences=1)[0]["generated_text"]
95
- # Extract AI response after last "AI:"
96
- ai_reply = response.split("AI:")[-1].strip()
97
-
98
- chat_history.append((user_input, ai_reply))
99
- return ai_reply
100
  except Exception as e:
101
  return f"❌ Chat error: {str(e)}"
102
 
103
- # -------------------------
104
  # Gradio UI
105
- # -------------------------
106
  with gr.Blocks() as demo:
107
- gr.Markdown("# 🧠 Personal AI Model Builder\nTrain + Chat with your own AI assistant.")
108
-
109
- with gr.Tab("1️⃣ Train Model"):
110
- model_name = gr.Textbox(label="Base Model (e.g. gpt2, distilgpt2)", value="gpt2")
111
- dataset_name = gr.Textbox(label="HuggingFace Dataset (optional, e.g. wikitext)", value="wikitext")
112
- config_name = gr.Textbox(label="Dataset Config (e.g. wikitext-2-raw-v1)", value="wikitext-2-raw-v1")
113
- user_file = gr.File(label="Or Upload Your Own TXT Dataset", file_types=[".txt"], type="filepath")
114
- output_dir = gr.Textbox(label="Output Directory", value="./custom_model")
115
- epochs = gr.Number(label="Epochs", value=1, precision=0)
116
- lr = gr.Textbox(label="Learning Rate", value="5e-5")
117
-
118
- train_button = gr.Button("πŸš€ Train My Model")
119
- train_output = gr.Textbox(label="Training Logs / Status")
120
-
121
- train_button.click(
122
- fn=train_model,
123
- inputs=[model_name, dataset_name, config_name, user_file, output_dir, epochs, lr],
124
- outputs=[train_output, output_dir]
125
- )
126
-
127
- with gr.Tab("2️⃣ Chat With Model"):
128
- chat_input = gr.Textbox(label="Your Message")
129
- chat_button = gr.Button("πŸ’¬ Send")
130
- chat_output = gr.Textbox(label="AI Reply")
131
 
132
- chat_button.click(
133
- fn=chat_with_model,
134
- inputs=[chat_input, output_dir],
135
- outputs=chat_output
 
136
  )
137
-
138
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from datasets import load_dataset
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
6
  Trainer,
7
  TrainingArguments,
8
+ DataCollatorForLanguageModeling,
9
  )
10
+ import torch
11
+
12
+
13
+ # Map specialization β†’ dataset + base model
14
+ SPECIALIZATIONS = {
15
+ "Coding Assistant": {
16
+ "dataset": "codeparrot/github-code",
17
+ "model": "EleutherAI/gpt-neo-125M",
18
+ },
19
+ "Cybersecurity Helper": {
20
+ "dataset": "wikitext",
21
+ "model": "distilgpt2", # placeholder dataset, replace with cybersecurity text later
22
+ },
23
+ "App/Web Developer": {
24
+ "dataset": "wikitext",
25
+ "model": "gpt2",
26
+ },
27
+ "General Problem Solver": {
28
+ "dataset": "wikitext",
29
+ "model": "gpt2",
30
+ },
31
+ }
32
+
33
+
34
+ def train_model(specialization, epochs, lr):
35
  try:
36
+ spec = SPECIALIZATIONS.get(specialization, SPECIALIZATIONS["General Problem Solver"])
37
+
38
+ dataset_name = spec["dataset"]
39
+ model_name = spec["model"]
40
+
41
+ # Load dataset
42
+ dataset = load_dataset(dataset_name)
43
+
44
+ # Load tokenizer & model
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ model = AutoModelForCausalLM.from_pretrained(model_name)
47
+
48
+ def tokenize_function(examples):
49
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
50
+
51
+ tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
52
+
53
+ # Data collator
54
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
55
+
56
+ # Training args
57
+ training_args = TrainingArguments(
58
+ output_dir="./results",
59
+ eval_strategy="epoch",
60
+ learning_rate=lr,
61
+ per_device_train_batch_size=2,
62
+ per_device_eval_batch_size=2,
63
+ num_train_epochs=epochs,
64
+ weight_decay=0.01,
65
+ save_strategy="no",
66
+ logging_dir="./logs",
67
+ logging_steps=10,
68
+ )
69
+
70
+ trainer = Trainer(
71
+ model=model,
72
+ args=training_args,
73
+ train_dataset=tokenized_datasets["train"],
74
+ eval_dataset=tokenized_datasets["validation"],
75
+ tokenizer=tokenizer,
76
+ data_collator=data_collator,
77
+ )
 
 
 
 
 
 
 
 
 
 
 
78
 
 
79
  trainer.train()
 
 
 
 
 
80
 
81
+ return f"βœ… Training complete for {specialization} model ({model_name}) with {epochs} epochs, lr={lr}"
82
+ except Exception as e:
83
+ return f"❌ Error: {str(e)}"
 
84
 
 
 
 
 
85
 
86
+ # Inference / Chat Function
87
+ def chat_fn(prompt, specialization):
88
  try:
89
+ spec = SPECIALIZATIONS.get(specialization, SPECIALIZATIONS["General Problem Solver"])
90
+ model_name = spec["model"]
91
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
92
+ model = AutoModelForCausalLM.from_pretrained(model_name)
93
+
94
+ inputs = tokenizer(prompt, return_tensors="pt")
95
+ outputs = model.generate(**inputs, max_length=200)
96
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
97
  except Exception as e:
98
  return f"❌ Chat error: {str(e)}"
99
 
100
+
101
  # Gradio UI
 
102
  with gr.Blocks() as demo:
103
+ gr.Markdown("# πŸš€ Custom AI Model Builder & Assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ with gr.Tab("1️⃣ Train Custom Model"):
106
+ specialization = gr.Radio(
107
+ list(SPECIALIZATIONS.keys()),
108
+ label="What do you want your AI to specialize in?",
109
+ value="General Problem Solver",
110
  )
111
+ epochs = gr.Slider(1, 10, value=1, step=1, label="Training Epochs")
112
+ lr = gr.Slider(1e-6, 5e-4, value=5e-5, step=1e-6, label="Learning Rate")
113
+ train_button = gr.Button("πŸš€ Start Training")
114
+ output_log = gr.Textbox(label="Training Log")
115
+ train_button.click(train_model, inputs=[specialization, epochs, lr], outputs=output_log)
116
+
117
+ with gr.Tab("2️⃣ Chat with Your Model"):
118
+ chat_specialization = gr.Dropdown(list(SPECIALIZATIONS.keys()), value="General Problem Solver", label="Model Type")
119
+ prompt = gr.Textbox(label="Ask me anything", placeholder="Type your question here...")
120
+ chat_button = gr.Button("πŸ’¬ Generate Response")
121
+ chat_output = gr.Textbox(label="Response")
122
+ chat_button.click(chat_fn, inputs=[prompt, chat_specialization], outputs=chat_output)
123
+
124
+ demo.launch(server_name="0.0.0.0", server_port=7860)