Files changed (1) hide show
  1. app.py +133 -101
app.py CHANGED
@@ -1,124 +1,156 @@
 
 
1
  import gradio as gr
 
 
2
  from datasets import load_dataset
3
- from transformers import (
4
- AutoTokenizer,
5
- AutoModelForCausalLM,
6
- Trainer,
7
- TrainingArguments,
8
- DataCollatorForLanguageModeling,
9
- )
10
- import torch
11
-
12
-
13
- # Map specialization β†’ dataset + base model
14
- SPECIALIZATIONS = {
15
- "Coding Assistant": {
16
- "dataset": "codeparrot/github-code",
17
- "model": "EleutherAI/gpt-neo-125M",
18
- },
19
- "Cybersecurity Helper": {
20
- "dataset": "wikitext",
21
- "model": "distilgpt2", # placeholder dataset, replace with cybersecurity text later
22
- },
23
- "App/Web Developer": {
24
- "dataset": "wikitext",
25
- "model": "gpt2",
26
- },
27
- "General Problem Solver": {
28
- "dataset": "wikitext",
29
- "model": "gpt2",
30
- },
31
- }
32
-
33
-
34
- def train_model(specialization, epochs, lr):
35
- try:
36
- spec = SPECIALIZATIONS.get(specialization, SPECIALIZATIONS["General Problem Solver"])
37
-
38
- dataset_name = spec["dataset"]
39
- model_name = spec["model"]
40
 
41
- # Load dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  dataset = load_dataset(dataset_name)
43
-
44
- # Load tokenizer & model
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
  model = AutoModelForCausalLM.from_pretrained(model_name)
47
 
48
- def tokenize_function(examples):
49
- return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
50
-
51
- tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
52
-
53
- # Data collator
54
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
55
 
56
- # Training args
57
  training_args = TrainingArguments(
58
- output_dir="./results",
59
- eval_strategy="epoch",
60
- learning_rate=lr,
61
  per_device_train_batch_size=2,
62
- per_device_eval_batch_size=2,
63
- num_train_epochs=epochs,
64
- weight_decay=0.01,
65
- save_strategy="no",
66
- logging_dir="./logs",
67
- logging_steps=10,
68
- )
69
-
70
- trainer = Trainer(
71
- model=model,
72
- args=training_args,
73
- train_dataset=tokenized_datasets["train"],
74
- eval_dataset=tokenized_datasets["validation"],
75
- tokenizer=tokenizer,
76
- data_collator=data_collator,
77
  )
78
 
 
79
  trainer.train()
80
-
81
- return f"βœ… Training complete for {specialization} model ({model_name}) with {epochs} epochs, lr={lr}"
 
82
  except Exception as e:
83
- return f"❌ Error: {str(e)}"
84
 
85
-
86
- # Inference / Chat Function
87
- def chat_fn(prompt, specialization):
88
  try:
89
- spec = SPECIALIZATIONS.get(specialization, SPECIALIZATIONS["General Problem Solver"])
90
- model_name = spec["model"]
91
- tokenizer = AutoTokenizer.from_pretrained(model_name)
92
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
93
 
94
  inputs = tokenizer(prompt, return_tensors="pt")
95
- outputs = model.generate(**inputs, max_length=200)
96
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
97
- except Exception as e:
98
- return f"❌ Chat error: {str(e)}"
99
 
 
 
100
 
101
- # Gradio UI
102
- with gr.Blocks() as demo:
103
- gr.Markdown("# πŸš€ Custom AI Model Builder & Assistant")
104
 
105
- with gr.Tab("1️⃣ Train Custom Model"):
106
- specialization = gr.Radio(
107
- list(SPECIALIZATIONS.keys()),
108
- label="What do you want your AI to specialize in?",
109
- value="General Problem Solver",
110
- )
111
- epochs = gr.Slider(1, 10, value=1, step=1, label="Training Epochs")
112
- lr = gr.Slider(1e-6, 5e-4, value=5e-5, step=1e-6, label="Learning Rate")
113
- train_button = gr.Button("πŸš€ Start Training")
114
- output_log = gr.Textbox(label="Training Log")
115
- train_button.click(train_model, inputs=[specialization, epochs, lr], outputs=output_log)
116
-
117
- with gr.Tab("2️⃣ Chat with Your Model"):
118
- chat_specialization = gr.Dropdown(list(SPECIALIZATIONS.keys()), value="General Problem Solver", label="Model Type")
119
- prompt = gr.Textbox(label="Ask me anything", placeholder="Type your question here...")
120
- chat_button = gr.Button("πŸ’¬ Generate Response")
 
121
  chat_output = gr.Textbox(label="Response")
122
- chat_button.click(chat_fn, inputs=[prompt, chat_specialization], outputs=chat_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os
2
+ import json
3
  import gradio as gr
4
+ from datetime import datetime
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
6
  from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # ========= MEMORY MANAGEMENT =========
9
+ MEMORY_DIR = "memories"
10
+ MODEL_DIR = "models"
11
+
12
+ os.makedirs(MEMORY_DIR, exist_ok=True)
13
+ os.makedirs(MODEL_DIR, exist_ok=True)
14
+
15
+ def get_memory_file(model_name):
16
+ safe_name = model_name.replace("/", "_")
17
+ return os.path.join(MEMORY_DIR, f"{safe_name}_memory.json")
18
+
19
+ def load_memory(model_name):
20
+ filepath = get_memory_file(model_name)
21
+ if os.path.exists(filepath):
22
+ with open(filepath, "r") as f:
23
+ return json.load(f)
24
+ return []
25
+
26
+ def save_memory(model_name, memory_data):
27
+ filepath = get_memory_file(model_name)
28
+ with open(filepath, "w") as f:
29
+ json.dump(memory_data, f, indent=2)
30
+
31
+ def append_memory(model_name, role, content):
32
+ memory = load_memory(model_name)
33
+ memory.append({
34
+ "timestamp": datetime.now().isoformat(),
35
+ "role": role,
36
+ "content": content
37
+ })
38
+ save_memory(model_name, memory)
39
+
40
+ def clear_memory(model_name):
41
+ filepath = get_memory_file(model_name)
42
+ if os.path.exists(filepath):
43
+ os.remove(filepath)
44
+ return f"Memory cleared for {model_name}."
45
+
46
+ def download_memory(model_name):
47
+ filepath = get_memory_file(model_name)
48
+ if os.path.exists(filepath):
49
+ return filepath
50
+ return None
51
+
52
+ def upload_memory(model_name, file_obj):
53
+ if file_obj is None:
54
+ return "No file uploaded."
55
+ new_data = json.load(open(file_obj.name))
56
+ save_memory(model_name, new_data)
57
+ return f"Memory replaced for {model_name}."
58
+
59
+ def merge_memory(model_name, file_obj):
60
+ if file_obj is None:
61
+ return "No file uploaded."
62
+ current = load_memory(model_name)
63
+ new_data = json.load(open(file_obj.name))
64
+ merged = current + new_data
65
+ save_memory(model_name, merged)
66
+ return f"Memory merged for {model_name}."
67
+
68
+ # ========= MODEL MANAGEMENT =========
69
+ def train_model(model_name, dataset_name, epochs, output_dir):
70
+ try:
71
  dataset = load_dataset(dataset_name)
 
 
72
  tokenizer = AutoTokenizer.from_pretrained(model_name)
73
  model = AutoModelForCausalLM.from_pretrained(model_name)
74
 
75
+ def tokenize(batch):
76
+ return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=128)
 
 
 
 
 
77
 
78
+ dataset = dataset.map(tokenize, batched=True)
79
  training_args = TrainingArguments(
80
+ output_dir=output_dir,
81
+ overwrite_output_dir=True,
 
82
  per_device_train_batch_size=2,
83
+ num_train_epochs=int(epochs),
84
+ save_strategy="epoch",
85
+ logging_dir=f"{output_dir}/logs"
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
 
88
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset["train"])
89
  trainer.train()
90
+ model.save_pretrained(output_dir)
91
+ tokenizer.save_pretrained(output_dir)
92
+ return f"Training complete. Model saved to {output_dir}"
93
  except Exception as e:
94
+ return f"Error: {str(e)}"
95
 
96
+ def chat_with_model(model_name, prompt):
 
 
97
  try:
98
+ model_path = os.path.join(MODEL_DIR, model_name.replace("/", "_"))
99
+ if os.path.exists(model_path):
100
+ model = AutoModelForCausalLM.from_pretrained(model_path)
101
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
102
+ else:
103
+ model = AutoModelForCausalLM.from_pretrained(model_name)
104
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
105
 
106
  inputs = tokenizer(prompt, return_tensors="pt")
107
+ outputs = model.generate(**inputs, max_length=256, do_sample=True, temperature=0.7)
108
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
109
 
110
+ append_memory(model_name, "user", prompt)
111
+ append_memory(model_name, "assistant", response)
112
 
113
+ return response
114
+ except Exception as e:
115
+ return f"Error: {str(e)}"
116
 
117
+ # ========= INTERFACE =========
118
+ with gr.Blocks() as demo:
119
+ gr.Markdown("# πŸ€– My AI Model Builder\nTrain, fine-tune, test, and manage AI models with memory.")
120
+
121
+ with gr.Tab("Train Model"):
122
+ model_name = gr.Textbox(label="Base Model (Hugging Face Hub ID)", value="gpt2")
123
+ dataset_name = gr.Textbox(label="Dataset Name (Hugging Face Dataset ID)", value="wikitext")
124
+ epochs = gr.Number(label="Epochs", value=1, precision=0)
125
+ output_dir = gr.Textbox(label="Output Directory", value="models/custom_model")
126
+ train_btn = gr.Button("Train Model")
127
+ train_output = gr.Textbox(label="Training Status")
128
+ train_btn.click(train_model, inputs=[model_name, dataset_name, epochs, output_dir], outputs=train_output)
129
+
130
+ with gr.Tab("Test Models / Chat"):
131
+ chat_model = gr.Textbox(label="Model Name", value="gpt2")
132
+ user_prompt = gr.Textbox(label="Enter Prompt")
133
+ chat_btn = gr.Button("Chat")
134
  chat_output = gr.Textbox(label="Response")
135
+ chat_btn.click(chat_with_model, inputs=[chat_model, user_prompt], outputs=chat_output)
136
+
137
+ with gr.Tab("Memory Management"):
138
+ mem_model = gr.Textbox(label="Model Name", value="gpt2")
139
+ view_btn = gr.Button("View Memory")
140
+ memory_output = gr.JSON(label="Memory Log")
141
+ view_btn.click(load_memory, inputs=[mem_model], outputs=memory_output)
142
+
143
+ with gr.Row():
144
+ dl_btn = gr.Button("Download Memory")
145
+ up_btn = gr.File(label="Upload Memory JSON")
146
+ merge_btn = gr.File(label="Merge Memory JSON")
147
+
148
+ dl_file = gr.File()
149
+ dl_btn.click(download_memory, inputs=[mem_model], outputs=dl_file)
150
+ up_btn.upload(upload_memory, inputs=[mem_model, up_btn], outputs=memory_output)
151
+ merge_btn.upload(merge_memory, inputs=[mem_model, merge_btn], outputs=memory_output)
152
+
153
+ clear_btn = gr.Button("Clear Memory")
154
+ clear_btn.click(clear_memory, inputs=[mem_model], outputs=memory_output)
155
 
156
+ demo.launch()