README.md CHANGED
@@ -14,8 +14,6 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
14
 
15
  commands:
16
 
17
- download images: python download.py -i 1 -r 2 -o /home/user/app/image_tmp -z
18
-
19
  pip install git+https://github.com/huggingface/diffusers
20
 
21
  accelerate launch \
@@ -45,39 +43,26 @@ fine tune a trained model: --pretrained_model_name_or_path="./nyc-ad-model/check
45
 
46
  export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
47
 
 
 
 
48
 
49
- pipeline:
50
- # 1 Fully Fine‑tune image model with ZeRO
51
  accelerate launch --deepspeed_config_file=ds_config_zero3.json train_lora.py
52
- fully_fine_tine_stablediffusion
53
 
54
- # 2 SFT 120B OSS 语言模型 with QLoRA
55
- lauguage_model_fine_tuning
56
 
57
- # 3 RLHF PPO 120B OSS 语言模型 with QLoRA : 训练 reward model
58
- lauguage_model_fine_tuning
59
 
60
- # 4 distill 120B OSS模型给20B OSS模型
61
- lauguage_model_fine_tuning
62
- 用 Teacher 生成 Response,student模型用LoRA fine tuning
63
 
64
- # 5 Build RAG index embedding table
65
- retrieval_augmented_generation
66
 
67
  # 6 Inference with RAG
68
- inference.py
69
-
70
-
71
- system flow:
72
- input: business or product description text
73
- 1. 根据input用RAG取embedding
74
- 1. GPT‑OSS 生成 4 个广告文案 + 标题 + 口号(可选语气:专业/活泼/极简)
75
- 2. GPT‑OSS 基于选中文案生成 扩展视觉提示词(主体、配色、镜头、艺术风格)
76
- 3. stablediffusion model 生成 4 张草图(可选 ControlNet-Layout/Logo 插入)
77
- 4. 返回4张海报+后处理
78
- output: an advertisement sentence and post image
79
-
80
-
81
- design details:
82
- LoRA fine tune teacher OSS 120B model using smangrul/ad-copy-generation (广告文案生成)
83
- LoRA distill knowledge to OSS 20B model
 
14
 
15
  commands:
16
 
 
 
17
  pip install git+https://github.com/huggingface/diffusers
18
 
19
  accelerate launch \
 
43
 
44
  export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
45
 
46
+ import torch
47
+ torch.cuda.empty_cache()
48
+ torch.cuda.reset_peak_memory_stats()
49
 
50
+ 7/12
51
+ # 1 Fine‑tune image model LoRA+QLoRA
52
  accelerate launch --deepspeed_config_file=ds_config_zero3.json train_lora.py
53
+ python train_lora.py
54
 
55
+ # 2 SFT 语言模型
56
+ python sft_train.py
57
 
58
+ # 3 Build RAG index
59
+ python build_embeddings.py
60
 
61
+ # 4 (可选) 收集偏好 训练 reward model
62
+ python reward_model.py
 
63
 
64
+ # 5 PPO RLHF 微调
65
+ python ppo_tune.py
66
 
67
  # 6 Inference with RAG
68
+ python rag_infer.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build_embeddings.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss, json, glob, os, numpy as np
3
+
4
+ model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
5
+ texts=[]; vecs=[]
6
+ for f in glob.glob("nyc_ads_dataset/*.json"):
7
+ cap=json.load(open(f))["caption"]
8
+ texts.append(cap); vecs.append(model.encode(cap,normalize_embeddings=True))
9
+ vecs=np.vstack(vecs).astype("float32")
10
+ index=faiss.IndexFlatIP(vecs.shape[1]); index.add(vecs)
11
+ faiss.write_index(index,"prompt.index"); json.dump(texts,open("prompt.txt","w"))
data_loader/download.py DELETED
@@ -1,209 +0,0 @@
1
- # Author: Marco Lustri 2022 - https://github.com/TheLustriVA
2
- # MIT License
3
-
4
- """A script to make downloading the DiffusionDB dataset easier."""
5
- from urllib.error import HTTPError
6
- from urllib.request import urlretrieve
7
- from alive_progress import alive_bar
8
- from os.path import exists
9
-
10
- import shutil
11
- import os
12
- import time
13
- import argparse
14
-
15
- index = None # initiate main arguments as None
16
- range_max = None
17
- output = None
18
- unzip = None
19
- large = None
20
-
21
- parser = argparse.ArgumentParser(description="Download a file from a URL") #
22
-
23
- # It's adding arguments to the parser.
24
- parser.add_argument(
25
- "-i",
26
- "--index",
27
- type=int,
28
- default=1,
29
- help="File to download or lower bound of range if -r is set",
30
- )
31
- parser.add_argument(
32
- "-r",
33
- "--range",
34
- type=int,
35
- default=None,
36
- help="Upper bound of range if -i is provided",
37
- )
38
- parser.add_argument(
39
- "-o", "--output", type=str, default="images", help="Output directory name"
40
- )
41
- parser.add_argument(
42
- "-z",
43
- "--unzip",
44
- default=False,
45
- help="Unzip the file after downloading",
46
- # It's setting the argument to True if it's provided.
47
- action="store_true",
48
- )
49
- parser.add_argument(
50
- "-l",
51
- "--large",
52
- default=False,
53
- help="Download from DiffusionDB Large (14 million images)",
54
- action="store_true",
55
- )
56
-
57
- args = parser.parse_args() # parse the arguments
58
-
59
- # It's checking if the user has provided any arguments, and if they have, it
60
- # sets the variables to the arguments.
61
- if args.index:
62
- index = args.index
63
- if args.range:
64
- range_max = args.range
65
- if args.output:
66
- output = args.output
67
- if args.unzip:
68
- unzip = args.unzip
69
- if args.large:
70
- large = args.large
71
-
72
-
73
-
74
- def download(index=1, range_index=0, output="", large=False):
75
- """
76
- Download a file from a URL and save it to a local file
77
-
78
- :param index: The index of the file to download, defaults to 1 (optional)
79
- :param range_index: The number of files to download. If you want to download
80
- all files, set this to the number of files you want to download,
81
- defaults to 0 (optional)
82
- :param output: The directory to download the files to :return: A list of
83
- files to unzip
84
- :param large: If downloading from DiffusionDB Large (14 million images)
85
- instead of DiffusionDB 2M (2 million images)
86
- """
87
- baseurl = "https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/"
88
- files_to_unzip = []
89
-
90
- if large:
91
- if index <= 10000:
92
- url = f"{baseurl}diffusiondb-large-part-1/part-{index:06}.zip"
93
- else:
94
- url = f"{baseurl}diffusiondb-large-part-2/part-{index:06}.zip"
95
- else:
96
- url = f"{baseurl}images/part-{index:06}.zip"
97
-
98
- if output != "":
99
- output = f"{output}/"
100
-
101
- if not exists(output):
102
- os.makedirs(output)
103
-
104
- if range_index == 0:
105
- print("Downloading file: ", url)
106
- file_path = f"{output}part-{index:06}.zip"
107
- try:
108
- urlretrieve(url, file_path)
109
- except HTTPError as e:
110
- print(f"Encountered an HTTPError downloading file: {url} - {e}")
111
- if unzip:
112
- unzip(file_path)
113
- else:
114
- # It's downloading the files numbered from index to range_index.
115
- with alive_bar(range_index - index, title="Downloading files") as bar:
116
- for idx in range(index, range_index):
117
- if large:
118
- if idx <= 10000:
119
- url = f"{baseurl}diffusiondb-large-part-1/part-{idx:06}.zip"
120
- else:
121
- url = f"{baseurl}diffusiondb-large-part-2/part-{idx:06}.zip"
122
- else:
123
- url = f"{baseurl}images/part-{idx:06}.zip"
124
-
125
- loop_file_path = f"{output}part-{idx:06}.zip"
126
- # It's trying to download the file, and if it encounters an
127
- # HTTPError, it prints the error.
128
- try:
129
- urlretrieve(url, loop_file_path)
130
- except HTTPError as e:
131
- print(f"HTTPError downloading file: {url} - {e}")
132
- files_to_unzip.append(loop_file_path)
133
- # It's writing the url of the file to a manifest file.
134
- with open("manifest.txt", "a") as f:
135
- f.write(url + "\n")
136
- time.sleep(0.1)
137
- bar()
138
-
139
- # It's checking if the user wants to unzip the files, and if they do, it
140
- # returns a list of files to unzip. It would be a bad idea to put these
141
- # together as the process is already lengthy.
142
- if unzip and len(files_to_unzip) > 0:
143
- return files_to_unzip
144
-
145
-
146
- def unzip_file(file: str, extract_to: str = None):
147
- """
148
- > This function takes a zip file and unpacks it to specified directory
149
-
150
- :param file: str - path to zip file
151
- :param extract_to: str - directory to extract to (default: same name as zip file)
152
- :return: The extraction directory path
153
- """
154
- if extract_to is None:
155
- extract_to = file.replace('.zip', '')
156
-
157
- shutil.unpack_archive(file, extract_to)
158
- return f"File: {file} has been unzipped to {extract_to}"
159
-
160
-
161
- def unzip_all(files: list):
162
- """
163
- > Unzip all files in a list of files
164
-
165
- :param files: list
166
- :type files: list
167
- """
168
- with alive_bar(len(files), title="Unzipping files") as bar:
169
- for file in files:
170
- unzip_file(file, '/home/user/app/images')
171
- time.sleep(0.1)
172
- bar()
173
-
174
-
175
- def main(index=None, range_max=None, output=None, unzip=None, large=None):
176
- """
177
- `main` is a function that takes in an index, a range_max, an output, and an
178
- unzip, and if the user confirms that they have enough space, it downloads
179
- the files from the index to the output, and if unzip is true, it unzips them
180
-
181
- :param index: The index of the file you want to download
182
- :param range_max: The number of files to download
183
- :param output: The directory to download the files to
184
- :param unzip: If you want to unzip the files after downloading them, set
185
- this to True
186
- :param large: If you want to download from DiffusionDB Large (14 million
187
- images) instead of DiffusionDB 2M (2 million images)
188
- :return: A list of files that have been downloaded
189
- """
190
- if index and range_max:
191
- if range_max - index >= 1999:
192
- confirmation = input("Do you have at least 1.7Tb free: (y/n)")
193
- if confirmation != "y":
194
- return
195
- files = download(index, range_max, output, large)
196
- if unzip:
197
- unzip_all(files)
198
- elif index:
199
- download(index, output=output, large=large)
200
- else:
201
- print("No index provided")
202
-
203
-
204
- # This is a common pattern in Python. It allows you to run the main function of
205
- # your script by running the script through the interpreter. It also allows you
206
- # to import the script into the interpreter without automatically running the
207
- # main function.
208
- if __name__ == "__main__":
209
- main(index, range_max, output, unzip, large)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_loader/download_dataset.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- import json
3
- import pandas as pd
4
- from datasets import load_dataset
5
- from PIL import Image
6
- import shutil
7
- from tqdm import tqdm
8
-
9
- def load_and_process():
10
- dataset = load_dataset("poloclub/diffusiondb", split="train[:1000]")
11
-
12
- os.makedirs("processed/images", exist_ok=True)
13
- processed_data = []
14
-
15
- for idx, sample in enumerate(tqdm(dataset)):
16
- image_id = f"{idx:06d}.png"
17
-
18
- if sample.get('image'):
19
- sample['image'].save(f"processed/images/{image_id}")
20
-
21
- data_entry = {
22
- "id": idx,
23
- "image_file": image_id,
24
- "prompt": sample.get('p', ''),
25
- "seed": sample.get('se', 0),
26
- "cfg_scale": sample.get('c', 0.0),
27
- "steps": sample.get('st', 0),
28
- "sampler": sample.get('sa', '')
29
- }
30
- processed_data.append(data_entry)
31
-
32
- return processed_data
33
-
34
- def save_data(data):
35
- with open("processed/data.json", "w") as f:
36
- json.dump(data, f)
37
-
38
- df = pd.DataFrame(data)
39
- df.to_csv("processed/data.csv", index=False)
40
- df.to_parquet("processed/data.parquet", index=False)
41
-
42
- def main():
43
- data = load_and_process()
44
- save_data(data)
45
- print(f"Processed {len(data)} samples")
46
-
47
- if __name__ == "__main__":
48
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deprecated/image_download.py → image_download.py RENAMED
File without changes
deprecated/image_gen.py → image_gen.py RENAMED
File without changes
lauguage_model_fine_tuning/accelerate_config.yaml DELETED
@@ -1,23 +0,0 @@
1
- # accelerate_config.yaml - 多GPU训练配置
2
-
3
- compute_environment: LOCAL_MACHINE
4
- distributed_type: MULTI_GPU
5
- downcast_bf16: 'no'
6
- gpu_ids: all
7
- machine_rank: 0
8
- main_training_function: main
9
- mixed_precision: fp16
10
- num_machines: 1
11
- num_processes: 4 # 根据GPU数量调整
12
- rdzv_backend: static
13
- same_network: true
14
- tpu_env: []
15
- tpu_use_cluster: false
16
- tpu_use_sudo: false
17
- use_cpu: false
18
-
19
- # RLHF特定设置
20
- gradient_accumulation_steps: 8
21
- gradient_clipping: 1.0
22
- learning_rate: 1e-5
23
- dataloader_drop_last: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/distillation/distill_llm.py DELETED
@@ -1,485 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Teacher-Student知识蒸馏脚本
4
- 将经过SFT+PPO RLHF的Teacher模型蒸馏到更小的Student模型
5
- """
6
-
7
- import os
8
- import torch
9
- import torch.nn.functional as F
10
- from torch.utils.data import DataLoader, Dataset
11
- from transformers import (
12
- AutoModelForCausalLM,
13
- AutoTokenizer,
14
- TrainingArguments,
15
- Trainer,
16
- DataCollatorForLanguageModeling,
17
- logging,
18
- )
19
- from datasets import load_dataset, Dataset as HFDataset
20
- from peft import LoraConfig, get_peft_model, TaskType
21
- import numpy as np
22
- import wandb
23
- from typing import Dict, List, Any, Optional
24
- import json
25
- from tqdm import tqdm
26
- import warnings
27
-
28
- warnings.filterwarnings("ignore")
29
- logging.set_verbosity(logging.CRITICAL)
30
-
31
- class DistillationConfig:
32
- """蒸馏训练配置"""
33
- # 模型路径
34
- teacher_model_path = "./rlhf_teacher_model" # RLHF后的Teacher模型
35
- student_model_name = "microsoft/DialoGPT-medium" # 替换为实际的OpenAI OSS 20B模型
36
-
37
- # 蒸馏参数
38
- temperature = 4.0 # 蒸馏温度
39
- alpha = 0.7 # 蒸馏损失权重
40
- beta = 0.3 # 学生损失权重
41
- gamma = 0.1 # 特征匹配损失权重
42
-
43
- # 训练参数
44
- learning_rate = 1e-4
45
- num_train_epochs = 3
46
- per_device_train_batch_size = 2
47
- per_device_eval_batch_size = 4
48
- gradient_accumulation_steps = 8
49
- warmup_ratio = 0.1
50
- weight_decay = 0.01
51
- logging_steps = 50
52
- eval_steps = 500
53
- save_steps = 1000
54
-
55
- # LoRA配置(为Student模型添加LoRA以提高训练效率)
56
- use_lora = True
57
- lora_r = 32
58
- lora_alpha = 64
59
- lora_dropout = 0.1
60
-
61
- # 数据配置
62
- max_length = 512
63
- num_distill_samples = 10000 # 用于蒸馏的样本数量
64
-
65
- # 输出配置
66
- output_dir = "./distilled_student_model"
67
- run_name = "teacher-student-distillation"
68
-
69
- class DistillationDataset(Dataset):
70
- """蒸馏数据集类"""
71
-
72
- def __init__(self, teacher_outputs: List[Dict], tokenizer, max_length: int = 512):
73
- self.data = teacher_outputs
74
- self.tokenizer = tokenizer
75
- self.max_length = max_length
76
-
77
- def __len__(self):
78
- return len(self.data)
79
-
80
- def __getitem__(self, idx):
81
- item = self.data[idx]
82
-
83
- # 构建完整的输入-输出序列
84
- full_text = f"### Human: {item['prompt']}\n### Assistant: {item['response']}"
85
-
86
- # Tokenize
87
- encoded = self.tokenizer(
88
- full_text,
89
- truncation=True,
90
- padding="max_length",
91
- max_length=self.max_length,
92
- return_tensors="pt"
93
- )
94
-
95
- return {
96
- "input_ids": encoded["input_ids"].squeeze(),
97
- "attention_mask": encoded["attention_mask"].squeeze(),
98
- "teacher_logits": torch.tensor(item["teacher_logits"], dtype=torch.float),
99
- "labels": encoded["input_ids"].squeeze()
100
- }
101
-
102
- class KnowledgeDistillationTrainer(Trainer):
103
- """知识蒸馏训练器"""
104
-
105
- def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7, beta=0.3, gamma=0.1, **kwargs):
106
- super().__init__(model=student_model, **kwargs)
107
- self.teacher_model = teacher_model
108
- self.teacher_model.eval() # 冻结Teacher模型
109
-
110
- self.temperature = temperature
111
- self.alpha = alpha # 蒸馏损失权重
112
- self.beta = beta # 学生损失权重
113
- self.gamma = gamma # 特征匹配损失权重
114
-
115
- def compute_loss(self, model, inputs, return_outputs=False):
116
- """计算蒸馏损失"""
117
-
118
- labels = inputs.get("labels")
119
- teacher_logits = inputs.get("teacher_logits").to(model.device)
120
-
121
- # Student模型前向传播
122
- student_outputs = model(**{k: v for k, v in inputs.items() if k not in ["teacher_logits"]})
123
- student_logits = student_outputs.logits
124
-
125
- # 计算各种损失
126
- losses = {}
127
-
128
- # 1. 标准语言模型损失 (学生模型自己的损失)
129
- if labels is not None:
130
- shift_logits = student_logits[..., :-1, :].contiguous()
131
- shift_labels = labels[..., 1:].contiguous()
132
- loss_fct = torch.nn.CrossEntropyLoss()
133
- student_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
134
- losses["student_loss"] = student_loss
135
-
136
- # 2. 蒸馏损失 (KL散度)
137
- if teacher_logits is not None:
138
- # 确保维度匹配
139
- if teacher_logits.shape != student_logits.shape:
140
- min_seq_len = min(teacher_logits.shape[1], student_logits.shape[1])
141
- teacher_logits = teacher_logits[:, :min_seq_len, :]
142
- student_logits_for_distill = student_logits[:, :min_seq_len, :]
143
- else:
144
- student_logits_for_distill = student_logits
145
-
146
- # 计算软标签概率
147
- teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
148
- student_log_probs = F.log_softmax(student_logits_for_distill / self.temperature, dim=-1)
149
-
150
- # KL散度损失
151
- distill_loss = F.kl_div(
152
- student_log_probs,
153
- teacher_probs,
154
- reduction="batchmean"
155
- ) * (self.temperature ** 2)
156
-
157
- losses["distill_loss"] = distill_loss
158
-
159
- # 3. 组合总损失
160
- total_loss = 0
161
- if "student_loss" in losses:
162
- total_loss += self.beta * losses["student_loss"]
163
- if "distill_loss" in losses:
164
- total_loss += self.alpha * losses["distill_loss"]
165
-
166
- # 记录各项损失
167
- self.log({
168
- "train/total_loss": total_loss.item(),
169
- "train/student_loss": losses.get("student_loss", 0).item() if "student_loss" in losses else 0,
170
- "train/distill_loss": losses.get("distill_loss", 0).item() if "distill_loss" in losses else 0,
171
- })
172
-
173
- return (total_loss, student_outputs) if return_outputs else total_loss
174
-
175
- def prepare_student_model(config: DistillationConfig):
176
- """准备Student模型"""
177
- print("🎓 Preparing student model...")
178
-
179
- # 加载Student基础模型
180
- student_model = AutoModelForCausalLM.from_pretrained(
181
- config.student_model_name,
182
- torch_dtype=torch.float16,
183
- device_map="auto",
184
- trust_remote_code=True,
185
- )
186
-
187
- # 添加LoRA(可选,用于高效训练)
188
- if config.use_lora:
189
- print("🔧 Adding LoRA to student model...")
190
- lora_config = LoraConfig(
191
- task_type=TaskType.CAUSAL_LM,
192
- inference_mode=False,
193
- r=config.lora_r,
194
- lora_alpha=config.lora_alpha,
195
- lora_dropout=config.lora_dropout,
196
- target_modules=[
197
- "q_proj", "k_proj", "v_proj", "o_proj",
198
- "gate_proj", "up_proj", "down_proj",
199
- ]
200
- )
201
- student_model = get_peft_model(student_model, lora_config)
202
- student_model.print_trainable_parameters()
203
-
204
- return student_model
205
-
206
- def load_teacher_model(config: DistillationConfig):
207
- """加载Teacher模型"""
208
- print("👨‍🏫 Loading teacher model...")
209
-
210
- teacher_model = AutoModelForCausalLM.from_pretrained(
211
- config.teacher_model_path,
212
- torch_dtype=torch.float16,
213
- device_map="auto",
214
- trust_remote_code=True,
215
- )
216
- teacher_model.eval()
217
-
218
- return teacher_model
219
-
220
- def generate_distillation_data(teacher_model, tokenizer, config: DistillationConfig):
221
- """生成蒸馏数据"""
222
- print("📊 Generating distillation dataset...")
223
-
224
- # 加载提示数据集
225
- dataset_sources = [
226
- "smangrul/ad-copy-generation",
227
- # 可以添加更多数据源
228
- ]
229
-
230
- all_prompts = []
231
- for source in dataset_sources:
232
- try:
233
- ds = load_dataset(source, split="train")
234
- # 提取提示词
235
- for item in ds:
236
- if "conversations" in item and len(item["conversations"]) > 0:
237
- prompt = item["conversations"][0].get("value", "")
238
- if len(prompt.strip()) > 10:
239
- all_prompts.append(prompt.strip())
240
- except Exception as e:
241
- print(f"⚠️ Error loading {source}: {e}")
242
-
243
- # 限制样本数量
244
- if len(all_prompts) > config.num_distill_samples:
245
- all_prompts = all_prompts[:config.num_distill_samples]
246
-
247
- print(f"📝 Generating responses for {len(all_prompts)} prompts...")
248
-
249
- distillation_data = []
250
- teacher_model.eval()
251
-
252
- with torch.no_grad():
253
- for i, prompt in enumerate(tqdm(all_prompts, desc="Generating teacher responses")):
254
- try:
255
- # 格式化输入
256
- formatted_prompt = f"### Human: {prompt}\n### Assistant:"
257
- inputs = tokenizer(
258
- formatted_prompt,
259
- return_tensors="pt",
260
- truncation=True,
261
- max_length=config.max_length // 2
262
- ).to(teacher_model.device)
263
-
264
- # 生成响应
265
- outputs = teacher_model.generate(
266
- **inputs,
267
- max_new_tokens=200,
268
- temperature=0.7,
269
- top_p=0.9,
270
- do_sample=True,
271
- pad_token_id=tokenizer.eos_token_id,
272
- return_dict_in_generate=True,
273
- output_scores=True
274
- )
275
-
276
- # 解码响应
277
- generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:]
278
- response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
279
-
280
- # 获取Teacher的logits
281
- full_text = f"### Human: {prompt}\n### Assistant: {response}"
282
- full_inputs = tokenizer(
283
- full_text,
284
- return_tensors="pt",
285
- truncation=True,
286
- max_length=config.max_length
287
- ).to(teacher_model.device)
288
-
289
- teacher_outputs = teacher_model(**full_inputs)
290
- teacher_logits = teacher_outputs.logits.cpu().numpy()
291
-
292
- distillation_data.append({
293
- "prompt": prompt,
294
- "response": response,
295
- "teacher_logits": teacher_logits.tolist()
296
- })
297
-
298
- # 定期保存中间结果
299
- if (i + 1) % 100 == 0:
300
- print(f"Generated {i + 1}/{len(all_prompts)} samples")
301
-
302
- except Exception as e:
303
- print(f"⚠️ Error generating for prompt {i}: {e}")
304
- continue
305
-
306
- print(f"✅ Generated {len(distillation_data)} teacher-student pairs")
307
-
308
- # 保存蒸馏数据
309
- with open("distillation_data.json", "w", encoding="utf-8") as f:
310
- json.dump(distillation_data, f, ensure_ascii=False, indent=2)
311
-
312
- return distillation_data
313
-
314
- def create_data_collator(tokenizer):
315
- """创建数据整理器"""
316
- return DataCollatorForLanguageModeling(
317
- tokenizer=tokenizer,
318
- mlm=False,
319
- pad_to_multiple_of=8
320
- )
321
-
322
- def run_distillation():
323
- """主要的蒸馏训练流程"""
324
- print("🚀 Starting Teacher-Student Distillation...")
325
-
326
- config = DistillationConfig()
327
-
328
- # 初始化wandb
329
- wandb.init(
330
- project="teacher-student-distillation",
331
- config=vars(config),
332
- name=config.run_name
333
- )
334
-
335
- # 加载tokenizer
336
- tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
337
- if tokenizer.pad_token is None:
338
- tokenizer.pad_token = tokenizer.eos_token
339
-
340
- # 加载模型
341
- teacher_model = load_teacher_model(config)
342
- student_model = prepare_student_model(config)
343
-
344
- # 生成蒸馏数据
345
- if os.path.exists("distillation_data.json"):
346
- print("📂 Loading existing distillation data...")
347
- with open("distillation_data.json", "r", encoding="utf-8") as f:
348
- distillation_data = json.load(f)
349
- else:
350
- distillation_data = generate_distillation_data(teacher_model, tokenizer, config)
351
-
352
- # 创建数据集
353
- train_size = int(0.9 * len(distillation_data))
354
- train_data = distillation_data[:train_size]
355
- eval_data = distillation_data[train_size:]
356
-
357
- train_dataset = DistillationDataset(train_data, tokenizer, config.max_length)
358
- eval_dataset = DistillationDataset(eval_data, tokenizer, config.max_length)
359
-
360
- print(f"📊 Training samples: {len(train_dataset)}")
361
- print(f"📊 Evaluation samples: {len(eval_dataset)}")
362
-
363
- # 训练参数
364
- training_args = TrainingArguments(
365
- output_dir=config.output_dir,
366
- overwrite_output_dir=True,
367
- num_train_epochs=config.num_train_epochs,
368
- per_device_train_batch_size=config.per_device_train_batch_size,
369
- per_device_eval_batch_size=config.per_device_eval_batch_size,
370
- gradient_accumulation_steps=config.gradient_accumulation_steps,
371
- learning_rate=config.learning_rate,
372
- weight_decay=config.weight_decay,
373
- warmup_ratio=config.warmup_ratio,
374
- logging_steps=config.logging_steps,
375
- eval_steps=config.eval_steps,
376
- save_steps=config.save_steps,
377
- evaluation_strategy="steps",
378
- save_strategy="steps",
379
- load_best_model_at_end=True,
380
- metric_for_best_model="eval_loss",
381
- greater_is_better=False,
382
- report_to="wandb",
383
- run_name=config.run_name,
384
- fp16=True,
385
- dataloader_pin_memory=False,
386
- remove_unused_columns=False,
387
- group_by_length=True,
388
- )
389
-
390
- # 创建数据整理器
391
- data_collator = create_data_collator(tokenizer)
392
-
393
- # 创建蒸馏训练器
394
- trainer = KnowledgeDistillationTrainer(
395
- teacher_model=teacher_model,
396
- student_model=student_model,
397
- args=training_args,
398
- train_dataset=train_dataset,
399
- eval_dataset=eval_dataset,
400
- data_collator=data_collator,
401
- tokenizer=tokenizer,
402
- temperature=config.temperature,
403
- alpha=config.alpha,
404
- beta=config.beta,
405
- gamma=config.gamma,
406
- )
407
-
408
- # 开始训练
409
- print("🔥 Starting distillation training...")
410
- trainer.train()
411
-
412
- # 保存最终模型
413
- print("💾 Saving distilled student model...")
414
- trainer.save_model()
415
- tokenizer.save_pretrained(config.output_dir)
416
-
417
- # 评估模型
418
- print("🧪 Evaluating distilled model...")
419
- evaluate_distilled_model(trainer.model, tokenizer, config)
420
-
421
- wandb.finish()
422
- print("✅ Distillation training completed!")
423
-
424
- def evaluate_distilled_model(model, tokenizer, config: DistillationConfig):
425
- """评估蒸馏后的模型"""
426
- print("📊 Evaluating distilled student model...")
427
-
428
- test_prompts = [
429
- "Create an advertisement for a revolutionary AI-powered fitness tracker",
430
- "Write marketing copy for an eco-friendly electric vehicle",
431
- "Generate a slogan for a productivity app for remote workers",
432
- "Create ad copy for a sustainable fashion brand targeting millennials",
433
- "Write promotional content for a mental health app",
434
- ]
435
-
436
- model.eval()
437
- results = []
438
-
439
- for prompt in test_prompts:
440
- formatted_prompt = f"### Human: {prompt}\n### Assistant:"
441
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
442
-
443
- with torch.no_grad():
444
- outputs = model.generate(
445
- **inputs,
446
- max_new_tokens=150,
447
- temperature=0.7,
448
- top_p=0.9,
449
- do_sample=True,
450
- pad_token_id=tokenizer.eos_token_id,
451
- )
452
-
453
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
454
- generated_text = response[len(formatted_prompt):].strip()
455
-
456
- results.append({
457
- "prompt": prompt,
458
- "response": generated_text
459
- })
460
-
461
- print(f"\n🔍 Prompt: {prompt}")
462
- print(f"📝 Student Response: {generated_text}")
463
- print("-" * 80)
464
-
465
- # 保存评估结果
466
- with open(f"{config.output_dir}/evaluation_results.json", "w", encoding="utf-8") as f:
467
- json.dump(results, f, ensure_ascii=False, indent=2)
468
-
469
- return results
470
-
471
- if __name__ == "__main__":
472
- # 设置环境变量
473
- os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
474
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
475
-
476
- # 检查GPU
477
- if torch.cuda.is_available():
478
- print(f"🔥 Using {torch.cuda.device_count()} GPUs")
479
- for i in range(torch.cuda.device_count()):
480
- print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
481
- else:
482
- print("⚠️ Warning: No GPU available, using CPU (very slow)")
483
-
484
- # 开始蒸馏训练
485
- run_distillation()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/distillation/eval_compare_teacher_student.py DELETED
@@ -1,168 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Teacher-Student模型性能比较脚本
4
- 比较RLHF Teacher模型和蒸馏后的Student模型的性能
5
- """
6
-
7
- import torch
8
- import argparse
9
- import json
10
- import time
11
- from transformers import AutoModelForCausalLM, AutoTokenizer
12
- from typing import List, Dict, Any
13
- import numpy as np
14
- from datetime import datetime
15
-
16
- class ModelComparator:
17
- def __init__(self, teacher_path: str, student_path: str):
18
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- print("📥 Loading Teacher model...")
21
- self.teacher_model = AutoModelForCausalLM.from_pretrained(
22
- teacher_path,
23
- torch_dtype=torch.float16,
24
- device_map="auto"
25
- )
26
- self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_path)
27
-
28
- print("📥 Loading Student model...")
29
- self.student_model = AutoModelForCausalLM.from_pretrained(
30
- student_path,
31
- torch_dtype=torch.float16,
32
- device_map="auto"
33
- )
34
- self.student_tokenizer = AutoTokenizer.from_pretrained(student_path)
35
-
36
- # 设置pad tokens
37
- for tokenizer in [self.teacher_tokenizer, self.student_tokenizer]:
38
- if tokenizer.pad_token is None:
39
- tokenizer.pad_token = tokenizer.eos_token
40
-
41
- def generate_response(self, model, tokenizer, prompt: str, **kwargs) -> Dict[str, Any]:
42
- """生成响应并记录性能指标"""
43
- formatted_prompt = f"### Human: {prompt}\n### Assistant:"
44
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
45
-
46
- generation_config = {
47
- "max_new_tokens": 200,
48
- "temperature": 0.7,
49
- "top_p": 0.9,
50
- "do_sample": True,
51
- "pad_token_id": tokenizer.eos_token_id,
52
- **kwargs
53
- }
54
-
55
- # 测量生成时间
56
- start_time = time.time()
57
-
58
- with torch.no_grad():
59
- outputs = model.generate(**inputs, **generation_config)
60
-
61
- generation_time = time.time() - start_time
62
-
63
- # 解码响应
64
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
- generated_text = response[len(formatted_prompt):].strip()
66
-
67
- # 计算tokens数量
68
- generated_tokens = len(tokenizer.encode(generated_text))
69
-
70
- return {
71
- "response": generated_text,
72
- "generation_time": generation_time,
73
- "tokens_generated": generated_tokens,
74
- "tokens_per_second": generated_tokens / generation_time if generation_time > 0 else 0,
75
- "prompt_tokens": inputs.input_ids.shape[1],
76
- "total_tokens": outputs.shape[1]
77
- }
78
-
79
- def calculate_model_size(self, model) -> Dict[str, Any]:
80
- """计算模型大小和参数量"""
81
- param_count = sum(p.numel() for p in model.parameters())
82
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
83
-
84
- # 估算模型大小(bytes)
85
- model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
86
- model_size_mb = model_size_bytes / (1024 * 1024)
87
- model_size_gb = model_size_mb / 1024
88
-
89
- return {
90
- "total_parameters": param_count,
91
- "trainable_parameters": trainable_params,
92
- "model_size_mb": model_size_mb,
93
- "model_size_gb": model_size_gb,
94
- "compression_ratio": None # 将在比较时计算
95
- }
96
-
97
- def evaluate_quality_metrics(self, responses: List[str]) -> Dict[str, float]:
98
- """评估生成质量指标"""
99
- metrics = {}
100
-
101
- # 平均响应长度
102
- avg_length = np.mean([len(resp.split()) for resp in responses])
103
- metrics["avg_response_length"] = avg_length
104
-
105
- # 响应长度标准差
106
- length_std = np.std([len(resp.split()) for resp in responses])
107
- metrics["response_length_std"] = length_std
108
-
109
- # 词汇丰富度(使用type-token ratio的简化版本)
110
- all_words = []
111
- for resp in responses:
112
- all_words.extend(resp.lower().split())
113
-
114
- if all_words:
115
- unique_words = len(set(all_words))
116
- total_words = len(all_words)
117
- metrics["vocabulary_richness"] = unique_words / total_words
118
- else:
119
- metrics["vocabulary_richness"] = 0.0
120
-
121
- # 平均句子数量
122
- avg_sentences = np.mean([resp.count('.') + resp.count('!') + resp.count('?') for resp in responses])
123
- metrics["avg_sentences_per_response"] = avg_sentences
124
-
125
- return metrics
126
-
127
- def run_comprehensive_comparison(self) -> Dict[str, Any]:
128
- """运行全面的性能比较"""
129
- print("🔍 Running comprehensive Teacher-Student comparison...")
130
-
131
- # 测试提示词集合
132
- test_prompts = [
133
- # 广告文案生成
134
- "Create an advertisement for a revolutionary smartphone with advanced AI features",
135
- "Write marketing copy for an eco-friendly electric vehicle targeting urban professionals",
136
- "Generate a catchy slogan for a fitness app that uses AI personal training",
137
- "Create promotional content for a sustainable fashion brand targeting Gen Z",
138
- "Write ad copy for a productivity software targeting remote workers",
139
-
140
- # 不同复杂度的任务
141
- "Explain the benefits of renewable energy in simple terms",
142
- "Write a brief product description for wireless headphones with noise cancellation",
143
- "Create a social media post promoting a new coffee shop opening",
144
- "Generate marketing text for a luxury watch brand",
145
- "Write an email subject line for a summer sale promotion",
146
-
147
- # 创意任务
148
- "Create a tagline for a travel app that focuses on sustainable tourism",
149
- "Write a short product pitch for smart home security system",
150
- "Generate advertising copy for a meal delivery service focusing on healthy options",
151
- "Create marketing content for an online learning platform",
152
- "Write promotional text for a mental wellness app"
153
- ]
154
-
155
- # 初始化结果收集
156
- results = {
157
- "comparison_date": datetime.now().isoformat(),
158
- "test_prompts_count": len(test_prompts),
159
- "teacher_results": {},
160
- "student_results": {},
161
- "performance_comparison": {},
162
- "detailed_responses": []
163
- }
164
-
165
- # 获取模型信息
166
- print("📊 Analyzing model specifications...")
167
- teacher_info = self.calculate_model_size(self.teacher_model)
168
- student_info = self.calculate_model_size(self.student_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/distillation/launch_distill.sh DELETED
@@ -1,60 +0,0 @@
1
- #!/bin/bash
2
- # launch_distillation.sh - 启动Teacher-Student蒸馏训练
3
-
4
- echo "🎓 Starting Teacher-Student Distillation Training..."
5
-
6
- # 检查前置条件
7
- echo "📋 Checking prerequisites..."
8
-
9
- # 检查Teacher模型
10
- if [ ! -d "./rlhf_teacher_model" ]; then
11
- echo "❌ Error: RLHF Teacher model not found at ./rlhf_teacher_model"
12
- echo " Please complete SFT and RLHF training first"
13
- exit 1
14
- fi
15
-
16
- # 检查GPU资源
17
- echo "📊 GPU Resources:"
18
- nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv
19
-
20
- # 检查可用显存
21
- AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
22
- echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"
23
-
24
- if [ "$AVAILABLE_MEMORY" -lt 40000 ]; then
25
- echo "⚠️ Warning: Distillation training requires significant GPU memory (>40GB recommended)"
26
- echo " Consider using gradient checkpointing or smaller batch sizes"
27
- fi
28
-
29
- # 设置环境变量
30
- export CUDA_VISIBLE_DEVICES=0,1 # 根据可用GPU调整
31
- export TOKENIZERS_PARALLELISM=false
32
- export WANDB_PROJECT="teacher-student-distillation"
33
- export WANDB_RUN_NAME="distillation-$(date +%Y%m%d_%H%M%S)"
34
-
35
- # 创建输出目录
36
- mkdir -p ./distilled_student_model
37
- mkdir -p ./distillation_logs
38
-
39
- # 检查是否有现有的蒸馏数据
40
- if [ -f "./distillation_data.json" ]; then
41
- echo "📂 Found existing distillation data, will reuse it"
42
- else
43
- echo "📊 Will generate new distillation data from teacher model"
44
- fi
45
-
46
- echo "🔥 Starting distillation training..."
47
-
48
- # 启动训练
49
- python teacher_student_distillation.py 2>&1 | tee ./distillation_logs/distillation_$(date +%Y%m%d_%H%M%S).log
50
-
51
- echo "✅ Distillation training completed!"
52
-
53
- # 训练后比较
54
- echo "⚖️ Comparing Teacher vs Student performance..."
55
- python compare_teacher_student.py \
56
- --teacher_path ./rlhf_teacher_model \
57
- --student_path ./distilled_student_model \
58
- --output_file ./comparison_results.json
59
-
60
- echo "📊 Results saved to comparison_results.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/eval_ppo_teacher.py DELETED
@@ -1,170 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- RLHF模型评估脚本
4
- 评估训练后模型的对齐效果和生成质量
5
- """
6
-
7
- import torch
8
- import argparse
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
- from datasets import Dataset
11
- import numpy as np
12
- from typing import List, Dict
13
- import json
14
-
15
- class RLHFEvaluator:
16
- def __init__(self, model_path: str, baseline_path: str = None):
17
- """
18
- 初始化评估器
19
-
20
- Args:
21
- model_path: RLHF训练后的模型路径
22
- baseline_path: 基线模型路径(SFT模型)
23
- """
24
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
-
26
- # 加载RLHF模型
27
- print(f"📥 Loading RLHF model from {model_path}...")
28
- self.rlhf_model = AutoModelForCausalLM.from_pretrained(
29
- model_path,
30
- torch_dtype=torch.float16,
31
- device_map="auto"
32
- )
33
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
34
-
35
- # 加载基线模型(可选)
36
- self.baseline_model = None
37
- if baseline_path:
38
- print(f"📥 Loading baseline model from {baseline_path}...")
39
- self.baseline_model = AutoModelForCausalLM.from_pretrained(
40
- baseline_path,
41
- torch_dtype=torch.float16,
42
- device_map="auto"
43
- )
44
-
45
- # 设置pad token
46
- if self.tokenizer.pad_token is None:
47
- self.tokenizer.pad_token = self.tokenizer.eos_token
48
-
49
- def generate_response(self, prompt: str, model=None, **kwargs) -> str:
50
- """生成响应"""
51
- if model is None:
52
- model = self.rlhf_model
53
-
54
- formatted_prompt = f"### Human: {prompt}\n### Assistant:"
55
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
56
-
57
- generation_config = {
58
- "max_new_tokens": 200,
59
- "temperature": 0.7,
60
- "top_p": 0.9,
61
- "do_sample": True,
62
- "pad_token_id": self.tokenizer.eos_token_id,
63
- **kwargs
64
- }
65
-
66
- with torch.no_grad():
67
- outputs = model.generate(**inputs, **generation_config)
68
-
69
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
70
- return response[len(formatted_prompt):].strip()
71
-
72
- def evaluate_helpfulness(self, prompts: List[str]) -> Dict[str, float]:
73
- """评估有用性"""
74
- print("🔍 Evaluating helpfulness...")
75
-
76
- scores = []
77
- for prompt in prompts:
78
- response = self.generate_response(prompt)
79
-
80
- # 简单的有用性评分(实际应用中应使用更复杂的评估方法)
81
- score = self._calculate_helpfulness_score(prompt, response)
82
- scores.append(score)
83
-
84
- return {
85
- "mean_helpfulness": np.mean(scores),
86
- "std_helpfulness": np.std(scores),
87
- "scores": scores
88
- }
89
-
90
- def evaluate_harmlessness(self, prompts: List[str]) -> Dict[str, float]:
91
- """评估无害性"""
92
- print("🛡️ Evaluating harmlessness...")
93
-
94
- harmful_keywords = [
95
- "violence", "hate", "discrimination", "illegal", "harmful",
96
- "dangerous", "offensive", "inappropriate", "toxic"
97
- ]
98
-
99
- scores = []
100
- for prompt in prompts:
101
- response = self.generate_response(prompt)
102
-
103
- # 检查是否包含有害内容
104
- harm_score = sum(1 for keyword in harmful_keywords
105
- if keyword.lower() in response.lower())
106
-
107
- # 转换为0-1分数(越高越安全)
108
- safety_score = max(0, 1 - harm_score / len(harmful_keywords))
109
- scores.append(safety_score)
110
-
111
- return {
112
- "mean_harmlessness": np.mean(scores),
113
- "std_harmlessness": np.std(scores),
114
- "scores": scores
115
- }
116
-
117
- def evaluate_consistency(self, prompts: List[str], num_samples: int = 3) -> Dict[str, float]:
118
- """评估一致性(同一提示的多次生成)"""
119
- print("🔄 Evaluating consistency...")
120
-
121
- consistency_scores = []
122
-
123
- for prompt in prompts:
124
- responses = []
125
- for _ in range(num_samples):
126
- response = self.generate_response(prompt, temperature=0.8)
127
- responses.append(response)
128
-
129
- # 计算响应之间的相似性
130
- similarity_score = self._calculate_response_similarity(responses)
131
- consistency_scores.append(similarity_score)
132
-
133
- return {
134
- "mean_consistency": np.mean(consistency_scores),
135
- "std_consistency": np.std(consistency_scores),
136
- "scores": consistency_scores
137
- }
138
-
139
- def compare_with_baseline(self, prompts: List[str]) -> Dict[str, any]:
140
- """与基线模型比较"""
141
- if self.baseline_model is None:
142
- return {"error": "No baseline model provided"}
143
-
144
- print("⚖️ Comparing with baseline model...")
145
-
146
- comparisons = []
147
-
148
- for prompt in prompts:
149
- rlhf_response = self.generate_response(prompt, model=self.rlhf_model)
150
- baseline_response = self.generate_response(prompt, model=self.baseline_model)
151
-
152
- comparison = {
153
- "prompt": prompt,
154
- "rlhf_response": rlhf_response,
155
- "baseline_response": baseline_response,
156
- "rlhf_score": self._calculate_quality_score(prompt, rlhf_response),
157
- "baseline_score": self._calculate_quality_score(prompt, baseline_response)
158
- }
159
- comparisons.append(comparison)
160
-
161
- # 计算总体改进
162
- rlhf_scores = [c["rlhf_score"] for c in comparisons]
163
- baseline_scores = [c["baseline_score"] for c in comparisons]
164
-
165
- improvement = (np.mean(rlhf_scores) - np.mean(baseline_scores)) / np.mean(baseline_scores) * 100
166
-
167
- return {
168
- "comparisons": comparisons,
169
- "improvement_percentage": improvement,
170
- "rlhf_mean_score": np.mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/launch_ppo_fine_tune_teacher.sh DELETED
@@ -1,63 +0,0 @@
1
- #!/bin/bash
2
- # launch_rlhf.sh - 启动PPO RLHF训练
3
-
4
- echo "🚀 Starting PPO RLHF Training..."
5
-
6
- # 检查前置条件
7
- echo "📋 Checking prerequisites..."
8
-
9
- # 检查Teacher模型是否存在
10
- if [ ! -d "./merged_model" ]; then
11
- echo "❌ Error: Teacher model not found at ./merged_model"
12
- echo " Please run SFT training first and merge the model"
13
- exit 1
14
- fi
15
-
16
- # 检查GPU资源
17
- echo "📊 GPU Resources:"
18
- nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv
19
-
20
- # 检查可用显存(建议至少80GB用于RLHF)
21
- AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
22
- echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"
23
-
24
- if [ "$AVAILABLE_MEMORY" -lt 80000 ]; then
25
- echo "⚠️ Warning: RLHF training requires significant GPU memory (>80GB recommended)"
26
- echo " Consider using gradient checkpointing or smaller batch sizes"
27
- fi
28
-
29
- # 设置环境变量
30
- export CUDA_VISIBLE_DEVICES=0,1,2,3 # 根据可用GPU调整
31
- export TOKENIZERS_PARALLELISM=false
32
- export WANDB_PROJECT="rlhf-teacher-training"
33
- export WANDB_RUN_NAME="ppo-rlhf-$(date +%Y%m%d_%H%M%S)"
34
-
35
- # 创建输出目录
36
- mkdir -p ./rlhf_teacher_model
37
- mkdir -p ./rlhf_logs
38
-
39
- # 安装额外依赖
40
- echo "📦 Installing RLHF dependencies..."
41
- pip install -r rlhf_requirements.txt
42
-
43
- # 启动训练
44
- echo "🔥 Starting PPO RLHF training..."
45
-
46
- # 单GPU训练
47
- if [ "$1" = "single" ]; then
48
- CUDA_VISIBLE_DEVICES=0 python ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log
49
-
50
- # 多GPU训练(推荐)
51
- else
52
- accelerate launch \
53
- --config_file accelerate_config.yaml \
54
- --num_processes 4 \
55
- --main_process_port 29500 \
56
- ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log
57
- fi
58
-
59
- echo "✅ RLHF training completed. Check logs for details."
60
-
61
- # 训练后评估
62
- echo "🧪 Running post-training evaluation..."
63
- python evaluate_rlhf_model.py --model_path ./rlhf_teacher_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/launch_supervised_fine_tune_teacher.sh DELETED
@@ -1,28 +0,0 @@
1
- #!/bin/bash
2
- # launch_training.sh - 启动QLoRA训练脚本
3
-
4
- echo " Preparing QLoRA Fine-tuning Environment..."
5
-
6
- # 检查GPU
7
- echo " GPU Information:"
8
- nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
9
-
10
- # 设置环境变量
11
- export CUDA_VISIBLE_DEVICES=0
12
- export TOKENIZERS_PARALLELISM=false
13
- export WANDB_PROJECT="qlora-ad-copy-generation" # Optional
14
-
15
- # 创建输出目录
16
- mkdir -p ./results
17
- mkdir -p ./logs
18
-
19
- # 启动训练(支持多GPU)
20
- echo " Starting QLoRA training..."
21
-
22
- # 单GPU训练
23
- python qlora_finetune.py 2>&1 | tee ./logs/training_$(date +%Y%m%d_%H%M%S).log
24
-
25
- # 多GPU训练
26
- # accelerate launch --multi_gpu --num_processes=2 qlora_finetune.py
27
-
28
- echo " Training script launched. Check logs for progress."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/merge_teacher_model.py DELETED
@@ -1,116 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- 模型合并脚本 - 将LoRA权重合并到基础模型中
4
- 用于推理和部署
5
- """
6
-
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
- from peft import PeftModel
10
- import argparse
11
-
12
- def merge_lora_model(base_model_path, lora_model_path, output_path):
13
- """
14
- 合并LoRA权重到基础模型
15
-
16
- Args:
17
- base_model_path: 基础模型路径
18
- lora_model_path: LoRA模型路径(训练输出)
19
- output_path: 合并后模型保存路径
20
- """
21
- print("📥 Loading base model...")
22
-
23
- # 加载基础模型(不使用量化)
24
- base_model = AutoModelForCausalLM.from_pretrained(
25
- base_model_path,
26
- torch_dtype=torch.float16,
27
- device_map="auto",
28
- trust_remote_code=True,
29
- )
30
-
31
- print("📥 Loading LoRA model...")
32
-
33
- # 加载LoRA模型
34
- model = PeftModel.from_pretrained(base_model, lora_model_path)
35
-
36
- print("🔄 Merging LoRA weights...")
37
-
38
- # 合并权重
39
- model = model.merge_and_unload()
40
-
41
- print("💾 Saving merged model...")
42
-
43
- # 保存合并后的模型
44
- model.save_pretrained(output_path, safe_serialization=True)
45
-
46
- # 复制tokenizer
47
- tokenizer = AutoTokenizer.from_pretrained(base_model_path)
48
- tokenizer.save_pretrained(output_path)
49
-
50
- print(f"✅ Model merged and saved to {output_path}")
51
-
52
- def test_merged_model(model_path):
53
- """测试合并后的模型"""
54
- print("🧪 Testing merged model...")
55
-
56
- # 加载模型和tokenizer
57
- model = AutoModelForCausalLM.from_pretrained(
58
- model_path,
59
- torch_dtype=torch.float16,
60
- device_map="auto",
61
- )
62
- tokenizer = AutoTokenizer.from_pretrained(model_path)
63
-
64
- # 测试提示
65
- test_prompt = "### Human: Create an advertisement for a revolutionary AI-powered smartwatch\n### Assistant:"
66
-
67
- inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
68
-
69
- with torch.no_grad():
70
- outputs = model.generate(
71
- **inputs,
72
- max_new_tokens=200,
73
- do_sample=True,
74
- temperature=0.7,
75
- top_p=0.9,
76
- pad_token_id=tokenizer.eos_token_id,
77
- )
78
-
79
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
- generated_text = response[len(test_prompt):].strip()
81
-
82
- print(f"\n📝 Test Prompt: Create an advertisement for a revolutionary AI-powered smartwatch")
83
- print(f"📄 Generated Response:\n{generated_text}")
84
-
85
- def main():
86
- parser = argparse.ArgumentParser(description="Merge LoRA weights with base model")
87
- parser.add_argument("--base_model", required=True, help="Path to base model")
88
- parser.add_argument("--lora_model", required=True, help="Path to LoRA model (training output)")
89
- parser.add_argument("--output", required=True, help="Output path for merged model")
90
- parser.add_argument("--test", action="store_true", help="Test the merged model")
91
-
92
- args = parser.parse_args()
93
-
94
- # 合并模型
95
- merge_lora_model(args.base_model, args.lora_model, args.output)
96
-
97
- # 测试模型(可选)
98
- if args.test:
99
- test_merged_model(args.output)
100
-
101
- if __name__ == "__main__":
102
- # 示例用法
103
- print("📋 Merge LoRA Model Script")
104
- print("\n使用方法:")
105
- print("python merge_model.py --base_model microsoft/DialoGPT-medium --lora_model ./results --output ./merged_model --test")
106
- print("\n或者直接运行默认配置:")
107
-
108
- # 默认配置
109
- merge_lora_model(
110
- base_model_path="microsoft/DialoGPT-medium", # 替换为实际的OpenAI OSS 120B模型
111
- lora_model_path="./results",
112
- output_path="./merged_model"
113
- )
114
-
115
- # 测试合并后的模型
116
- test_merged_model("./merged_model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/ppo_fine_tune_teacher.py DELETED
@@ -1,459 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- PPO RLHF训练脚本 - 基于Teacher模型进行人类偏好对齐
4
- 输入: SFT Teacher模型 + 人类偏好数据
5
- 输出: RLHF对齐的Teacher模型
6
- """
7
-
8
- import os
9
- import torch
10
- import torch.nn.functional as F
11
- from datasets import load_dataset, Dataset
12
- from transformers import (
13
- AutoModelForCausalLM,
14
- AutoTokenizer,
15
- AutoModelForSequenceClassification,
16
- TrainingArguments,
17
- pipeline,
18
- logging,
19
- )
20
- from peft import PeftModel, LoraConfig, get_peft_model, TaskType
21
- from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
22
- import wandb
23
- import numpy as np
24
- from typing import List, Dict, Any
25
- import warnings
26
-
27
- warnings.filterwarnings("ignore")
28
- logging.set_verbosity(logging.CRITICAL)
29
-
30
- class RLHFConfig:
31
- """RLHF训练配置"""
32
- # 模型路径
33
- teacher_model_path = "./merged_model" # 之前SFT训练的Teacher模型
34
- reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2" # 奖励模型
35
-
36
- # PPO训练参数
37
- learning_rate = 1e-5
38
- mini_batch_size = 1
39
- batch_size = 8
40
- gradient_accumulation_steps = 8
41
- ppo_epochs = 4
42
- max_grad_norm = 1.0
43
-
44
- # PPO特定参数
45
- init_kl_coef = 0.02
46
- target_kl = 0.01
47
- adap_kl_ctrl = True
48
- clip_reward_value = 5.0
49
- cliprange = 0.2
50
- cliprange_value = 0.2
51
- gamma = 1.0
52
- lam = 0.95
53
-
54
- # 生成参数
55
- max_new_tokens = 150
56
- temperature = 0.7
57
- top_p = 0.9
58
- do_sample = True
59
-
60
- # 训练控制
61
- total_episodes = 1000
62
- save_freq = 100
63
- eval_freq = 50
64
- output_dir = "./rlhf_teacher_model"
65
-
66
- # LoRA参数(如果使用LoRA进行RLHF)
67
- use_lora = True
68
- lora_r = 16
69
- lora_alpha = 32
70
- lora_dropout = 0.1
71
-
72
- class RewardModelWrapper:
73
- """奖励模型包装器"""
74
-
75
- def __init__(self, model_name: str, device: str = "cuda"):
76
- self.device = device
77
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
78
- self.model = AutoModelForSequenceClassification.from_pretrained(
79
- model_name,
80
- torch_dtype=torch.float16,
81
- device_map="auto"
82
- )
83
- self.model.eval()
84
-
85
- # 设置pad token
86
- if self.tokenizer.pad_token is None:
87
- self.tokenizer.pad_token = self.tokenizer.eos_token
88
-
89
- def get_reward(self, prompts: List[str], responses: List[str]) -> List[float]:
90
- """计算奖励分数"""
91
- inputs = []
92
- for prompt, response in zip(prompts, responses):
93
- # 格式化为对话格式
94
- text = f"Human: {prompt}\n\nAssistant: {response}"
95
- inputs.append(text)
96
-
97
- # 批量推理
98
- with torch.no_grad():
99
- encoded = self.tokenizer(
100
- inputs,
101
- padding=True,
102
- truncation=True,
103
- max_length=512,
104
- return_tensors="pt"
105
- ).to(self.device)
106
-
107
- outputs = self.model(**encoded)
108
- rewards = outputs.logits.squeeze(-1).cpu().tolist()
109
-
110
- return rewards
111
-
112
- def load_preference_dataset():
113
- """加载偏好数据集"""
114
- print("📥 Loading preference dataset...")
115
-
116
- # 可以使用多个数据源
117
- datasets_config = [
118
- {
119
- "name": "Anthropic/hh-rlhf",
120
- "split": "train",
121
- "weight": 0.7
122
- },
123
- {
124
- "name": "OpenAssistant/oasst1",
125
- "split": "train",
126
- "weight": 0.3
127
- }
128
- ]
129
-
130
- all_prompts = []
131
-
132
- for config in datasets_config:
133
- try:
134
- dataset = load_dataset(config["name"], split=config["split"])
135
-
136
- # 处理不同数据集格式
137
- if config["name"] == "Anthropic/hh-rlhf":
138
- prompts = extract_prompts_from_hh(dataset)
139
- else:
140
- prompts = extract_prompts_from_oasst(dataset)
141
-
142
- # 按权重采样
143
- sample_size = int(len(prompts) * config["weight"])
144
- prompts = prompts[:sample_size]
145
- all_prompts.extend(prompts)
146
-
147
- print(f"✅ Loaded {len(prompts)} prompts from {config['name']}")
148
-
149
- except Exception as e:
150
- print(f"⚠️ Failed to load {config['name']}: {e}")
151
-
152
- # 创建Dataset对象
153
- return Dataset.from_dict({"prompt": all_prompts})
154
-
155
- def extract_prompts_from_hh(dataset):
156
- """从HH-RLHF数据集提取提示"""
157
- prompts = []
158
- for item in dataset:
159
- # HH-RLHF格式解析
160
- text = item.get("chosen", "")
161
- if "Human:" in text:
162
- prompt = text.split("Human:")[-1].split("Assistant:")[0].strip()
163
- if len(prompt) > 10: # 过滤太短的提示
164
- prompts.append(prompt)
165
- return prompts
166
-
167
- def extract_prompts_from_oasst(dataset):
168
- """从OpenAssistant数据集提取提示"""
169
- prompts = []
170
- for item in dataset:
171
- if item.get("role") == "prompter":
172
- prompt = item.get("text", "").strip()
173
- if len(prompt) > 10:
174
- prompts.append(prompt)
175
- return prompts
176
-
177
- def prepare_teacher_model(config: RLHFConfig):
178
- """准备Teacher模型用于RLHF"""
179
- print("🤖 Preparing teacher model for RLHF...")
180
-
181
- # 加载tokenizer
182
- tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
183
- if tokenizer.pad_token is None:
184
- tokenizer.pad_token = tokenizer.eos_token
185
-
186
- # 加载基础模型
187
- model = AutoModelForCausalLM.from_pretrained(
188
- config.teacher_model_path,
189
- torch_dtype=torch.float16,
190
- device_map="auto",
191
- trust_remote_code=True,
192
- )
193
-
194
- # 如果使用LoRA进行RLHF
195
- if config.use_lora:
196
- print("🔧 Adding LoRA for RLHF training...")
197
- lora_config = LoraConfig(
198
- task_type=TaskType.CAUSAL_LM,
199
- inference_mode=False,
200
- r=config.lora_r,
201
- lora_alpha=config.lora_alpha,
202
- lora_dropout=config.lora_dropout,
203
- target_modules=[
204
- "q_proj", "k_proj", "v_proj", "o_proj",
205
- "gate_proj", "up_proj", "down_proj",
206
- ]
207
- )
208
- model = get_peft_model(model, lora_config)
209
- model.print_trainable_parameters()
210
-
211
- # 包装为带价值头的模型
212
- model = AutoModelForCausalLMWithValueHead.from_pretrained(
213
- model,
214
- torch_dtype=torch.float16,
215
- )
216
-
217
- # 创建参考模型(冻结)
218
- ref_model = AutoModelForCausalLM.from_pretrained(
219
- config.teacher_model_path,
220
- torch_dtype=torch.float16,
221
- device_map="auto",
222
- )
223
- ref_model.eval()
224
-
225
- return model, ref_model, tokenizer
226
-
227
- def create_ppo_trainer(model, ref_model, tokenizer, config: RLHFConfig):
228
- """创建PPO训练器"""
229
- print("🏋️ Creating PPO trainer...")
230
-
231
- ppo_config = PPOConfig(
232
- model_name=config.teacher_model_path,
233
- learning_rate=config.learning_rate,
234
- mini_batch_size=config.mini_batch_size,
235
- batch_size=config.batch_size,
236
- gradient_accumulation_steps=config.gradient_accumulation_steps,
237
- ppo_epochs=config.ppo_epochs,
238
- max_grad_norm=config.max_grad_norm,
239
- init_kl_coef=config.init_kl_coef,
240
- target_kl=config.target_kl,
241
- adap_kl_ctrl=config.adap_kl_ctrl,
242
- clip_reward_value=config.clip_reward_value,
243
- cliprange=config.cliprange,
244
- cliprange_value=config.cliprange_value,
245
- gamma=config.gamma,
246
- lam=config.lam,
247
- remove_unused_columns=False,
248
- log_with="wandb" if wandb.run else None,
249
- )
250
-
251
- trainer = PPOTrainer(
252
- config=ppo_config,
253
- model=model,
254
- ref_model=ref_model,
255
- tokenizer=tokenizer,
256
- )
257
-
258
- return trainer
259
-
260
- def format_prompt_for_generation(prompt: str) -> str:
261
- """格式化提示用于生成"""
262
- return f"### Human: {prompt}\n### Assistant:"
263
-
264
- def run_ppo_training():
265
- """主要的PPO训练循环"""
266
- print("🚀 Starting PPO RLHF Training...")
267
-
268
- # 初始化wandb
269
- wandb.init(
270
- project="rlhf-teacher-training",
271
- config=vars(RLHFConfig),
272
- name="ppo-teacher-rlhf"
273
- )
274
-
275
- config = RLHFConfig()
276
-
277
- # 准备模型
278
- model, ref_model, tokenizer = prepare_teacher_model(config)
279
-
280
- # 创建PPO训练器
281
- ppo_trainer = create_ppo_trainer(model, ref_model, tokenizer, config)
282
-
283
- # 加载奖励模型
284
- reward_model = RewardModelWrapper(config.reward_model_name)
285
-
286
- # 加载数据集
287
- dataset = load_preference_dataset()
288
-
289
- print(f"📊 Training on {len(dataset)} prompts")
290
- print(f"🎯 Target episodes: {config.total_episodes}")
291
-
292
- # 训练循环
293
- for episode in range(config.total_episodes):
294
- # 随机采样prompts
295
- batch_prompts = np.random.choice(
296
- dataset["prompt"],
297
- size=config.batch_size,
298
- replace=False
299
- ).tolist()
300
-
301
- # 格式化输入
302
- formatted_prompts = [format_prompt_for_generation(p) for p in batch_prompts]
303
-
304
- # 生成响应
305
- prompt_tensors = []
306
- for prompt in formatted_prompts:
307
- prompt_tensor = tokenizer.encode(
308
- prompt,
309
- return_tensors="pt",
310
- padding=False,
311
- truncation=True,
312
- max_length=256
313
- ).squeeze()
314
- prompt_tensors.append(prompt_tensor)
315
-
316
- # 批量生成
317
- response_tensors = []
318
- with torch.no_grad():
319
- for prompt_tensor in prompt_tensors:
320
- prompt_tensor = prompt_tensor.unsqueeze(0).to(model.device)
321
-
322
- response = ppo_trainer.generate(
323
- prompt_tensor,
324
- max_new_tokens=config.max_new_tokens,
325
- temperature=config.temperature,
326
- top_p=config.top_p,
327
- do_sample=config.do_sample,
328
- pad_token_id=tokenizer.eos_token_id,
329
- )
330
-
331
- # 只保留新生成的部分
332
- response = response.squeeze()[prompt_tensor.shape[1]:]
333
- response_tensors.append(response)
334
-
335
- # 解码响应
336
- responses = [
337
- tokenizer.decode(r, skip_special_tokens=True).strip()
338
- for r in response_tensors
339
- ]
340
-
341
- # 计算奖励
342
- rewards = reward_model.get_reward(batch_prompts, responses)
343
- rewards = [torch.tensor(r, dtype=torch.float) for r in rewards]
344
-
345
- # PPO训练步骤
346
- stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
347
-
348
- # 记录统计信息
349
- ppo_trainer.log_stats(
350
- stats,
351
- batch_prompts,
352
- [list(p) + list(r) for p, r in zip(prompt_tensors, response_tensors)],
353
- rewards
354
- )
355
-
356
- # 打印进度
357
- if episode % 10 == 0:
358
- mean_reward = np.mean([r.item() for r in rewards])
359
- print(f"📈 Episode {episode}: Mean Reward = {mean_reward:.4f}")
360
-
361
- # 记录到wandb
362
- wandb.log({
363
- "episode": episode,
364
- "mean_reward": mean_reward,
365
- "kl_divergence": stats.get("objective/kl", 0),
366
- "policy_loss": stats.get("ppo/loss/policy", 0),
367
- "value_loss": stats.get("ppo/loss/value", 0),
368
- })
369
-
370
- # 评估模型
371
- if episode % config.eval_freq == 0 and episode > 0:
372
- evaluate_model(ppo_trainer.model, tokenizer, episode)
373
-
374
- # 保存检查点
375
- if episode % config.save_freq == 0 and episode > 0:
376
- save_checkpoint(ppo_trainer.model, tokenizer, config.output_dir, episode)
377
-
378
- # 保存最终模型
379
- print("💾 Saving final RLHF model...")
380
- ppo_trainer.model.save_pretrained(config.output_dir)
381
- tokenizer.save_pretrained(config.output_dir)
382
-
383
- wandb.finish()
384
- print("✅ RLHF training completed!")
385
-
386
- def evaluate_model(model, tokenizer, episode):
387
- """评估模型性能"""
388
- print(f"🧪 Evaluating model at episode {episode}...")
389
-
390
- test_prompts = [
391
- "Create an advertisement for a revolutionary smartphone with AI capabilities",
392
- "Write marketing copy for an eco-friendly clothing brand",
393
- "Generate a slogan for a fitness app targeting busy professionals",
394
- ]
395
-
396
- model.eval()
397
- results = []
398
-
399
- for prompt in test_prompts:
400
- formatted_prompt = format_prompt_for_generation(prompt)
401
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
402
-
403
- with torch.no_grad():
404
- outputs = model.generate(
405
- **inputs,
406
- max_new_tokens=150,
407
- temperature=0.7,
408
- top_p=0.9,
409
- do_sample=True,
410
- pad_token_id=tokenizer.eos_token_id,
411
- )
412
-
413
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
414
- generated_text = response[len(formatted_prompt):].strip()
415
-
416
- results.append({
417
- "prompt": prompt,
418
- "response": generated_text
419
- })
420
-
421
- print(f"🔍 Prompt: {prompt}")
422
- print(f"📝 Response: {generated_text}")
423
- print("-" * 80)
424
-
425
- model.train()
426
- return results
427
-
428
- def save_checkpoint(model, tokenizer, output_dir, episode):
429
- """保存训练检查点"""
430
- checkpoint_dir = f"{output_dir}/checkpoint-{episode}"
431
- os.makedirs(checkpoint_dir, exist_ok=True)
432
-
433
- model.save_pretrained(checkpoint_dir)
434
- tokenizer.save_pretrained(checkpoint_dir)
435
-
436
- print(f"💾 Checkpoint saved to {checkpoint_dir}")
437
-
438
- def load_checkpoint_and_continue(checkpoint_path):
439
- """从检查点继续训练"""
440
- print(f"📥 Loading checkpoint from {checkpoint_path}")
441
-
442
- # 实现检查点恢复逻辑
443
- pass
444
-
445
- if __name__ == "__main__":
446
- # 设置环境变量
447
- os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 多GPU设置
448
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
449
-
450
- # 检查GPU资源
451
- if torch.cuda.is_available():
452
- print(f"🔥 Using {torch.cuda.device_count()} GPUs")
453
- for i in range(torch.cuda.device_count()):
454
- print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
455
- else:
456
- raise RuntimeError("❌ CUDA not available! RLHF requires GPU.")
457
-
458
- # 开始训练
459
- run_ppo_training()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lauguage_model_fine_tuning/sft_teacher.py DELETED
@@ -1,276 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- QLoRA Fine-tuning script for OpenAI OSS 120B model
4
- Using smangrul/ad-copy-generation dataset for advertisement copy generation
5
- """
6
-
7
- import os
8
- import torch
9
- from datasets import load_dataset
10
- from transformers import (
11
- AutoModelForCausalLM,
12
- AutoTokenizer,
13
- BitsAndBytesConfig,
14
- TrainingArguments,
15
- pipeline,
16
- logging,
17
- )
18
- from peft import LoraConfig, PeftModel, TaskType, get_peft_model
19
- from trl import SFTTrainer
20
- import warnings
21
-
22
- # Suppress warnings
23
- warnings.filterwarnings("ignore")
24
- logging.set_verbosity(logging.CRITICAL)
25
-
26
- # Configuration
27
- class Config:
28
- # Model configuration
29
- model_name = "microsoft/DialoGPT-medium" # Replace with actual OpenAI OSS 120B model name
30
- dataset_name = "smangrul/ad-copy-generation"
31
-
32
- # Training parameters
33
- output_dir = "./sft_results"
34
- num_train_epochs = 3
35
- per_device_train_batch_size = 1
36
- gradient_accumulation_steps = 4
37
- optim = "paged_adamw_32bit"
38
- save_steps = 25
39
- logging_steps = 25
40
- learning_rate = 2e-4
41
- weight_decay = 0.001
42
- fp16 = False
43
- bf16 = False
44
- max_grad_norm = 0.3
45
- max_steps = -1
46
- warmup_ratio = 0.03
47
- group_by_length = True
48
- lr_scheduler_type = "constant"
49
- report_to = "tensorboard"
50
-
51
- # QLoRA parameters
52
- lora_alpha = 16
53
- lora_dropout = 0.1
54
- lora_r = 64
55
-
56
- # bitsandbytes parameters
57
- use_4bit = True
58
- bnb_4bit_compute_dtype = "float16"
59
- bnb_4bit_quant_type = "nf4"
60
- use_nested_quant = False
61
-
62
- # SFT parameters
63
- max_seq_length = 512
64
- packing = False
65
-
66
- def create_bnb_config():
67
- """Create BitsAndBytesConfig for 4-bit quantization"""
68
- bnb_config = BitsAndBytesConfig(
69
- load_in_4bit=Config.use_4bit,
70
- bnb_4bit_quant_type=Config.bnb_4bit_quant_type,
71
- bnb_4bit_compute_dtype=getattr(torch, Config.bnb_4bit_compute_dtype),
72
- bnb_4bit_use_double_quant=Config.use_nested_quant,
73
- )
74
- return bnb_config
75
-
76
- def load_model_and_tokenizer():
77
- """Load model and tokenizer with quantization"""
78
- print("Loading model and tokenizer...")
79
-
80
- # Create BnB config
81
- bnb_config = create_bnb_config()
82
-
83
- # Load model
84
- model = AutoModelForCausalLM.from_pretrained(
85
- Config.model_name,
86
- quantization_config=bnb_config,
87
- device_map="auto",
88
- trust_remote_code=True,
89
- use_auth_token=True, # If using gated model
90
- )
91
- model.config.use_cache = False
92
- model.config.pretraining_tp = 1
93
-
94
- # Load tokenizer
95
- tokenizer = AutoTokenizer.from_pretrained(
96
- Config.model_name,
97
- trust_remote_code=True,
98
- use_auth_token=True, # If using gated model
99
- )
100
- tokenizer.pad_token = tokenizer.eos_token
101
- tokenizer.padding_side = "right"
102
-
103
- return model, tokenizer
104
-
105
- def create_peft_config():
106
- """Create PEFT (LoRA) configuration"""
107
- peft_config = LoraConfig(
108
- task_type=TaskType.CAUSAL_LM,
109
- inference_mode=False,
110
- r=Config.lora_r,
111
- lora_alpha=Config.lora_alpha,
112
- lora_dropout=Config.lora_dropout,
113
- target_modules=[
114
- "q_proj",
115
- "k_proj",
116
- "v_proj",
117
- "o_proj",
118
- "gate_proj",
119
- "up_proj",
120
- "down_proj",
121
- ]
122
- )
123
- return peft_config
124
-
125
- def load_and_prepare_dataset(tokenizer):
126
- """Load and prepare the dataset"""
127
- print("Loading dataset...")
128
-
129
- # Load dataset
130
- dataset = load_dataset(Config.dataset_name, split="train")
131
- print(f"Dataset loaded: {len(dataset)} samples")
132
-
133
- # Format dataset for chat completion
134
- def format_prompts(examples):
135
- texts = []
136
- for conversation in examples["conversations"]:
137
- if len(conversation) >= 2:
138
- user_msg = conversation[0]["value"]
139
- assistant_msg = conversation[1]["value"]
140
-
141
- # Format as chat template
142
- text = f"### Human: {user_msg}\n### Assistant: {assistant_msg}{tokenizer.eos_token}"
143
- texts.append(text)
144
- else:
145
- # Fallback for malformed data
146
- texts.append(f"### Human: Create an advertisement\n### Assistant: {conversation[0]['value']}{tokenizer.eos_token}")
147
-
148
- return {"text": texts}
149
-
150
- # Apply formatting
151
- dataset = dataset.map(
152
- format_prompts,
153
- batched=True,
154
- remove_columns=dataset.column_names
155
- )
156
-
157
- return dataset
158
-
159
- def create_training_arguments():
160
- """Create training arguments"""
161
- training_arguments = TrainingArguments(
162
- output_dir=Config.output_dir,
163
- num_train_epochs=Config.num_train_epochs,
164
- per_device_train_batch_size=Config.per_device_train_batch_size,
165
- gradient_accumulation_steps=Config.gradient_accumulation_steps,
166
- optim=Config.optim,
167
- save_steps=Config.save_steps,
168
- logging_steps=Config.logging_steps,
169
- learning_rate=Config.learning_rate,
170
- weight_decay=Config.weight_decay,
171
- fp16=Config.fp16,
172
- bf16=Config.bf16,
173
- max_grad_norm=Config.max_grad_norm,
174
- max_steps=Config.max_steps,
175
- warmup_ratio=Config.warmup_ratio,
176
- group_by_length=Config.group_by_length,
177
- lr_scheduler_type=Config.lr_scheduler_type,
178
- report_to=Config.report_to,
179
- save_strategy="steps",
180
- evaluation_strategy="no",
181
- load_best_model_at_end=False,
182
- push_to_hub=False,
183
- remove_unused_columns=False,
184
- )
185
- return training_arguments
186
-
187
- def main():
188
- """Main fine-tuning function"""
189
- print("🚀 Starting QLoRA fine-tuning of OpenAI OSS 120B model")
190
-
191
- # Check CUDA availability
192
- if not torch.cuda.is_available():
193
- raise RuntimeError("CUDA is required for this training script")
194
-
195
- print(f"Using GPU: {torch.cuda.get_device_name()}")
196
- print(f"Available VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
197
-
198
- # Load model and tokenizer
199
- model, tokenizer = load_model_and_tokenizer()
200
-
201
- # Apply PEFT
202
- peft_config = create_peft_config()
203
- model = get_peft_model(model, peft_config)
204
- model.print_trainable_parameters()
205
-
206
- # Load and prepare dataset
207
- dataset = load_and_prepare_dataset(tokenizer)
208
-
209
- # Create training arguments
210
- training_arguments = create_training_arguments()
211
-
212
- # Create trainer
213
- trainer = SFTTrainer(
214
- model=model,
215
- train_dataset=dataset,
216
- peft_config=peft_config,
217
- dataset_text_field="text",
218
- max_seq_length=Config.max_seq_length,
219
- tokenizer=tokenizer,
220
- args=training_arguments,
221
- packing=Config.packing,
222
- )
223
-
224
- # Start training
225
- print("🔥 Starting training...")
226
- trainer.train()
227
-
228
- # Save model
229
- print("💾 Saving model...")
230
- trainer.model.save_pretrained(Config.output_dir)
231
- tokenizer.save_pretrained(Config.output_dir)
232
-
233
- print("✅ Training completed!")
234
-
235
- # Test the model
236
- test_model(trainer.model, tokenizer)
237
-
238
- def test_model(model, tokenizer):
239
- """Test the fine-tuned model"""
240
- print("\n🧪 Testing the fine-tuned model...")
241
-
242
- # Test prompts
243
- test_prompts = [
244
- "Create an advertisement for a new smartphone with advanced camera features",
245
- "Write ad copy for an eco-friendly clothing brand targeting young professionals",
246
- "Generate marketing content for a fitness app with AI personal trainer",
247
- ]
248
-
249
- for prompt in test_prompts:
250
- formatted_prompt = f"### Human: {prompt}\n### Assistant:"
251
-
252
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
253
-
254
- with torch.no_grad():
255
- outputs = model.generate(
256
- **inputs,
257
- max_new_tokens=150,
258
- do_sample=True,
259
- temperature=0.7,
260
- top_p=0.9,
261
- pad_token_id=tokenizer.eos_token_id,
262
- )
263
-
264
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
265
- generated_text = response[len(formatted_prompt):].strip()
266
-
267
- print(f"\n📝 Prompt: {prompt}")
268
- print(f"📄 Generated: {generated_text}")
269
- print("-" * 50)
270
-
271
- if __name__ == "__main__":
272
- # Set environment variables
273
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
274
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
275
-
276
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ppo_tune.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trl import PPOTrainer, PPOConfig
2
+ from peft import PeftModel
3
+ import torch, random, json, glob
4
+ from diffusers import StableDiffusionPipeline
5
+ from reward_model import CLIPModel, CLIPProcessor
6
+
7
+ rm=CLIPModel.from_pretrained("rm").eval().half().cuda()
8
+ proc=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
9
+ pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
10
+ ppo_cfg=PPOConfig(batch_size=1,learning_rate=1e-6,target_kl=0.2)
11
+ trainer=PPOTrainer(model=pipe.unet, reward_model=rm, config=ppo_cfg)
12
+
13
+ prompts=[l.strip() for l in open("prompt.txt")]
14
+ for step in range(500):
15
+ p=random.choice(prompts)
16
+ img=pipe(p,num_inference_steps=20).images[0]
17
+ reward=rm(**proc(text=p,images=img,return_tensors="pt").to("cuda")).logits[0,0].item()
18
+ trainer.step(prompts=[p], rewards=[reward])
19
+ pipe.save_pretrained("nyc-ad-model-rlhf")
requirements.txt CHANGED
@@ -1,55 +1,16 @@
1
- # 核心深度学习框架
2
- torch>=2.0.0
3
- torchvision
4
- xformers
5
-
6
- # Transformers生态
7
- transformers>=4.35.0
8
- accelerate>=0.24.0
9
- tokenizers
10
- huggingface_hub
11
-
12
- # 数据处理
13
- datasets>=2.14.0
14
- numpy>=1.24.0
15
- sentence-transformers
16
- faiss-cpu
17
-
18
- # 模型微调和RLHF
19
- peft>=0.9.0
20
- trl[peft]>=0.7.10
21
- bitsandbytes>=0.41.0
22
-
23
- # 图像生成
24
  diffusers
25
  invisible_watermark
26
-
27
- # 数据标注
28
- label-studio
29
-
30
- # API和网络请求
31
  flickrapi
32
  requests
33
-
34
- # 实验跟踪和可视化
35
- wandb>=0.15.0
36
- tensorboard>=2.13.0
37
-
38
- # 评估指标
39
- evaluate
40
- sacrebleu
41
- rouge-score
42
-
43
- # 系统工具和监控
44
- scipy
45
- protobuf
46
- sentencepiece
47
- alive_progress
48
- psutil
49
- gpustat
50
-
51
- # 高级优化器(可选)
52
- deepspeed>=0.10.0
53
-
54
- # RLHF特定工具
55
- reward-bench
 
1
+ accelerate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  diffusers
3
  invisible_watermark
4
+ torch
5
+ transformers
6
+ xformers
7
+ torchvision
 
8
  flickrapi
9
  requests
10
+ peft>=0.9.0
11
+ bitsandbytes
12
+ faiss-cpu
13
+ sentence-transformers
14
+ trl[peft]
15
+ label-studio
16
+ datasets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval_augmented_generation/build_embeddings.py DELETED
@@ -1,246 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- 简洁版BERT+FAISS标语数据库
4
- 输入:产品/业务描述
5
- 输出:匹配的广告标语
6
- """
7
-
8
- import numpy as np
9
- import faiss
10
- import json
11
- from sentence_transformers import SentenceTransformer
12
- from datasets import Dataset
13
- import pandas as pd
14
-
15
- class SloganDatabase:
16
- def __init__(self):
17
- self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
18
- self.index = None
19
- self.slogans = []
20
-
21
- def create_dataset(self):
22
- """创建标语数据集 - 珠宝首饰奢侈品领域"""
23
- # 示例数据:[品牌, 类别, 描述, 标语]
24
- data = [
25
- # 顶级珠宝品牌
26
- ["Tiffany & Co.", "jewelry", "luxury diamond jewelry and engagement rings", "A Diamond is Forever"],
27
- ["Cartier", "luxury_jewelry", "high-end jewelry watches and accessories", "L'art de vivre"],
28
- ["Van Cleef & Arpels", "jewelry", "French luxury jewelry and watches", "Poetry of Time"],
29
- ["Harry Winston", "jewelry", "rare diamonds and luxury jewelry", "Rare Jewels of the World"],
30
- ["Bulgari", "jewelry", "Italian luxury jewelry and watches", "Italian Excellence"],
31
- ["Chopard", "jewelry", "Swiss luxury jewelry and watches", "Happy Diamonds"],
32
- ["Graff", "jewelry", "exceptional diamonds and jewelry", "The Most Fabulous Jewels in the World"],
33
- ["Piaget", "jewelry", "Swiss luxury watches and jewelry", "Possession"],
34
- ["Boucheron", "jewelry", "French high jewelry and luxury watches", "Le Joaillier Depuis 1858"],
35
- ["Mikimoto", "jewelry", "cultured pearl jewelry", "The Originator of Cultured Pearls"],
36
-
37
- # 奢侈品牌
38
- ["Louis Vuitton", "luxury_fashion", "luxury leather goods and fashion", "The Art of Travel"],
39
- ["Hermès", "luxury_fashion", "French luxury goods and accessories", "Luxury in the making"],
40
- ["Chanel", "luxury_fashion", "haute couture and luxury fashion", "Inside every woman there is a flower and a cat"],
41
- ["Gucci", "luxury_fashion", "Italian luxury fashion and accessories", "Quality is remembered long after price is forgotten"],
42
- ["Prada", "luxury_fashion", "Italian luxury fashion house", "Prada"],
43
- ["Dior", "luxury_fashion", "French luxury fashion and beauty", "Miss Dior"],
44
- ["Versace", "luxury_fashion", "Italian luxury fashion design", "Virtus"],
45
- ["Saint Laurent", "luxury_fashion", "French luxury fashion house", "Saint Laurent Paris"],
46
- ["Balenciaga", "luxury_fashion", "Spanish luxury fashion house", "Balenciaga"],
47
- ["Bottega Veneta", "luxury_fashion", "Italian luxury leather goods", "When your own initials are enough"],
48
-
49
- # 腕表品牌
50
- ["Rolex", "luxury_watches", "Swiss luxury watches and timepieces", "Perpetual, Spirit of Excellence"],
51
- ["Patek Philippe", "luxury_watches", "Swiss luxury watch manufacturer", "You never actually own a Patek Philippe"],
52
- ["Audemars Piguet", "luxury_watches", "Swiss luxury watch brand", "To break the rules, you must first master them"],
53
- ["Omega", "luxury_watches", "Swiss luxury watch manufacturer", "Precision"],
54
- ["TAG Heuer", "luxury_watches", "Swiss luxury watches", "Don't crack under pressure"],
55
- ["Breitling", "luxury_watches", "Swiss luxury watchmaker", "Instruments for Professionals"],
56
- ["IWC", "luxury_watches", "Swiss luxury watch company", "Engineered for men"],
57
- ["Jaeger-LeCoultre", "luxury_watches", "Swiss luxury watch manufacturer", "The World's Most Complicated Watches"],
58
- ["Vacheron Constantin", "luxury_watches", "Swiss luxury watch manufacturer", "One of Not Many"],
59
- ["A. Lange & Söhne", "luxury_watches", "German luxury watch manufacturer", "When nothing else will do"],
60
-
61
- # 时尚首饰
62
- ["Pandora", "fashion_jewelry", "Danish jewelry brand charm bracelets", "Be Love"],
63
- ["Swarovski", "fashion_jewelry", "Austrian crystal jewelry and accessories", "Unleash Your Light"],
64
- ["Daniel Wellington", "fashion_watches", "Swedish watch brand minimalist design", "Live the moment"],
65
- ["Alex and Ani", "fashion_jewelry", "American jewelry brand spiritual bracelets", "Positive Energy"],
66
- ["Kendra Scott", "fashion_jewelry", "American jewelry designer colorful stones", "Live colorfully"],
67
- ["Monica Vinader", "fashion_jewelry", "British jewelry brand contemporary design", "Everyday luxury"],
68
- ["Mejuri", "fashion_jewelry", "Canadian jewelry brand everyday luxury", "Everyday fine"],
69
- ["Gorjana", "fashion_jewelry", "California jewelry brand layered necklaces", "Live your layer"],
70
- ["Kate Spade", "fashion_jewelry", "American fashion accessories jewelry", "Live colorfully"],
71
- ["Marc Jacobs", "fashion_jewelry", "American fashion designer accessories", "Marc Jacobs"],
72
-
73
- # 珠宝定制
74
- ["Blue Nile", "diamond_jewelry", "online diamond jewelry retailer", "Extraordinary diamonds for extraordinary moments"],
75
- ["James Allen", "diamond_jewelry", "online engagement ring retailer", "See it. Love it. Own it."],
76
- ["Brilliant Earth", "diamond_jewelry", "ethical diamond jewelry", "Brilliant Earth"],
77
- ["With Clarity", "diamond_jewelry", "lab-grown diamond jewelry", "Diamonds. Redefined."],
78
- ["Clean Origin", "diamond_jewelry", "lab-created diamond jewelry", "Grown with love"],
79
- ["Ritani", "diamond_jewelry", "engagement rings and wedding bands", "Love is in the details"],
80
- ["Vrai", "diamond_jewelry", "lab-grown diamond jewelry", "Created, not mined"],
81
- ["Catbird", "jewelry", "Brooklyn-based jewelry designer", "Made in Brooklyn"],
82
- ["Wwake", "jewelry", "contemporary fine jewelry designer", "Wwake"],
83
- ["Jacquie Aiche", "jewelry", "California jewelry designer bohemian luxury", "Jacquie Aiche"],
84
-
85
- # 中国珠宝品牌
86
- ["周大福", "jewelry", "香港珠宝品牌黄金钻石", "心意足金"],
87
- ["周生生", "jewelry", "香港珠宝品牌传统工艺", "传承经典"],
88
- ["老凤祥", "jewelry", "中国传统珠宝品牌黄金首饰", "老凤祥,真金不怕火炼"],
89
- ["六福珠宝", "jewelry", "香港珠宝品牌时尚设计", "六福临门"],
90
- ["潘多拉", "jewelry", "丹麦珠宝品牌串珠手链", "表达你的故事"],
91
- ["周大生", "jewelry", "中国珠宝品牌钻石首饰", "爱就在一起"],
92
- ["金伯利", "jewelry", "中国钻石珠宝品牌", "只为更好的你"],
93
- ["戴比尔斯", "diamond_jewelry", "钻石开采珠宝品牌", "钻石恒久远,一颗永流传"],
94
- ["施华洛世奇", "crystal_jewelry", "奥地利水晶珠宝品牌", "释放你的光芒"],
95
- ["谢瑞麟", "jewelry", "香港珠宝设计师品牌", "艺术珠宝"],
96
-
97
- # 奢侈品配饰
98
- ["Goyard", "luxury_accessories", "French luxury leather goods", "Goyard"],
99
- ["Moynat", "luxury_accessories", "French luxury leather goods", "Moynat"],
100
- ["Berluti", "luxury_accessories", "French luxury leather goods", "Berluti"],
101
- ["Valextra", "luxury_accessories", "Italian luxury leather goods", "Milanese excellence since 1937"],
102
- ["Loewe", "luxury_accessories", "Spanish luxury leather goods", "Craft"],
103
- ["Brunello Cucinelli", "luxury_fashion", "Italian luxury fashion cashmere", "Humanistic Enterprise"],
104
- ["Loro Piana", "luxury_fashion", "Italian luxury textile and clothing", "Excellence in natural fibers"],
105
- ["Kiton", "luxury_fashion", "Italian luxury menswear", "The most beautiful thing made by man"],
106
- ["Zegna", "luxury_fashion", "Italian luxury menswear", "What makes a man"],
107
- ["Brioni", "luxury_fashion", "Italian luxury menswear", "Roman style"],
108
-
109
- # 新兴奢侈品牌
110
- ["Jacquemus", "luxury_fashion", "French luxury fashion house", "La Montagne"],
111
- ["Ganni", "luxury_fashion", "Danish fashion brand", "Ganni"],
112
- ["Staud", "luxury_fashion", "American fashion brand", "Staud"],
113
- ["Cult Gaia", "luxury_accessories", "American accessories brand", "Cult Gaia"],
114
- ["Rosantica", "jewelry", "Italian jewelry brand", "Rosantica"],
115
- ["Alighieri", "jewelry", "British jewelry brand", "The Inferno"],
116
- ["Lizzie Fortunato", "jewelry", "American jewelry brand", "Lizzie Fortunato"],
117
- ["Aurate", "jewelry", "American jewelry brand", "Accessible luxury"],
118
- ["AUrate New York", "jewelry", "New York jewelry brand", "Radically responsible luxury"],
119
- ["Missoma", "jewelry", "British jewelry brand", "Missoma"]
120
- ]
121
-
122
- # 转换为DataFrame
123
- df = pd.DataFrame(data, columns=['brand', 'category', 'description', 'slogan'])
124
-
125
- # 创建搜索文本(组合描述信息)
126
- df['search_text'] = df['brand'] + ' ' + df['category'] + ' ' + df['description']
127
-
128
- return df.to_dict('records')
129
-
130
- def build_index(self, data):
131
- """构建FAISS索引"""
132
- print("🔨 Building FAISS index...")
133
-
134
- # 提取搜索文本
135
- texts = [item['search_text'] for item in data]
136
-
137
- # 生成embeddings
138
- embeddings = self.encoder.encode(texts, show_progress_bar=True)
139
-
140
- # 构建索引
141
- self.index = faiss.IndexFlatIP(384) # 使用内积相似度
142
- self.index.add(embeddings.astype('float32'))
143
-
144
- # 保存数据
145
- self.slogans = data
146
-
147
- print(f"✅ Index built with {len(data)} slogans")
148
-
149
- def search(self, query, k=5):
150
- """搜索相似标语"""
151
- if not self.index:
152
- raise ValueError("Index not built yet!")
153
-
154
- # 编码查询
155
- query_embedding = self.encoder.encode([query])
156
-
157
- # 搜索
158
- scores, indices = self.index.search(query_embedding.astype('float32'), k)
159
-
160
- # 返回结果
161
- results = []
162
- for score, idx in zip(scores[0], indices[0]):
163
- if idx < len(self.slogans):
164
- result = self.slogans[idx].copy()
165
- result['similarity_score'] = float(score)
166
- results.append(result)
167
-
168
- return results
169
-
170
- def save(self, path="slogan_db"):
171
- """保存数据库"""
172
- # 保存FAISS索引
173
- faiss.write_index(self.index, f"{path}.faiss")
174
-
175
- # 保存标语数据
176
- with open(f"{path}.json", 'w', encoding='utf-8') as f:
177
- json.dump(self.slogans, f, ensure_ascii=False, indent=2)
178
-
179
- print(f"💾 Database saved to {path}")
180
-
181
- def load(self, path="slogan_db"):
182
- """加载数据库"""
183
- try:
184
- # 加载FAISS索引
185
- self.index = faiss.read_index(f"{path}.faiss")
186
-
187
- # 加载标语数据
188
- with open(f"{path}.json", 'r', encoding='utf-8') as f:
189
- self.slogans = json.load(f)
190
-
191
- print(f"📂 Database loaded from {path}")
192
- return True
193
- except:
194
- print(f"❌ Failed to load database from {path}")
195
- return False
196
-
197
- def main():
198
- """主函数"""
199
- print("🚀 Creating Slogan Database...")
200
-
201
- # 初始化
202
- db = SloganDatabase()
203
-
204
- # 尝试加载现有数据库
205
- if not db.load():
206
- print("📊 Creating new database...")
207
-
208
- # 创建数据集
209
- data = db.create_dataset()
210
-
211
- # 构建索引
212
- db.build_index(data)
213
-
214
- # 保存数据库
215
- db.save()
216
-
217
- # 测试搜索
218
- test_queries = [
219
- "钻石订婚戒指",
220
- "奢侈品手袋",
221
- "瑞士手表品牌",
222
- "珍珠首饰",
223
- "黄金项链",
224
- "时尚耳环",
225
- "luxury jewelry brand",
226
- "designer handbag",
227
- "crystal accessories",
228
- "wedding rings"
229
- ]
230
-
231
- print("\n🔍 Testing searches...")
232
- for query in test_queries:
233
- print(f"\n查询: {query}")
234
- print("-" * 40)
235
-
236
- results = db.search(query, k=3)
237
-
238
- for i, result in enumerate(results, 1):
239
- print(f"{i}. {result['brand']} ({result['category']})")
240
- print(f" 描述: {result['description']}")
241
- print(f" 标语: {result['slogan']}")
242
- print(f" 相似度: {result['similarity_score']:.3f}")
243
- print()
244
-
245
- if __name__ == "__main__":
246
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reward_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer
2
+ import datasets, torch, json, glob
3
+
4
+ model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
5
+ processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
6
+
7
+ data=[]
8
+ for f in glob.glob("human_prefs/*.json"):
9
+ j=json.load(open(f)); data.append(j) # {"prompt":…, "good":img_path, "bad":img_path}
10
+
11
+ dataset=datasets.Dataset.from_list(data)
12
+
13
+ def preprocess(ex):
14
+ inputs=processor(text=[ex["prompt"]*2], images=[ex["good"],ex["bad"]], return_tensors="pt")
15
+ inputs["labels"]=torch.tensor([1,0])
16
+ return inputs
17
+
18
+ dataset=dataset.map(preprocess,remove_columns=dataset.column_names)
19
+ args=TrainingArguments("rm_ckpt",per_device_train_batch_size=2,fp16=True,learning_rate=5e-6,epochs=3)
20
+ trainer=Trainer(model,args,train_dataset=dataset)
21
+ trainer.train(); model.save_pretrained("rm")
sft_train.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json
2
+ from datasets import load_dataset, Dataset
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
4
+ from peft import get_peft_model, LoraConfig, TaskType
5
+
6
+ # Load your dataset
7
+ data = [json.loads(l) for l in open("data/sft_data.jsonl")]
8
+ dataset = Dataset.from_list(data)
9
+
10
+ # Load model & tokenizer
11
+ base_model = "meta-llama/Llama-2-7b-hf" # Or use Mistral, Falcon, etc.
12
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
13
+ model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16)
14
+
15
+ # Add LoRA (optional)
16
+ lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.05,
17
+ target_modules=["q_proj", "v_proj"])
18
+ model = get_peft_model(model, lora_config)
19
+
20
+ # Preprocessing
21
+ def tokenize(example):
22
+ prompt = f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['output']}"
23
+ return tokenizer(prompt, truncation=True, max_length=512, padding="max_length")
24
+ dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
25
+
26
+ # Training setup
27
+ args = TrainingArguments(
28
+ output_dir="./sft-model",
29
+ per_device_train_batch_size=2,
30
+ num_train_epochs=3,
31
+ fp16=True,
32
+ evaluation_strategy="no",
33
+ save_strategy="epoch",
34
+ logging_steps=20,
35
+ learning_rate=2e-5,
36
+ report_to="tensorboard",
37
+ )
38
+
39
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
40
+ trainer = Trainer(model=model, args=args, train_dataset=dataset, data_collator=data_collator)
41
+ trainer.train()
fully_fine_tune_stablediffusion/train_lora.py → train_lora.py RENAMED
File without changes
train_model_test.py DELETED
@@ -1,238 +0,0 @@
1
- import os
2
- import numpy as np
3
- from datasets import load_dataset
4
- from PIL import Image, ImageOps, ImageFilter
5
- from tqdm import tqdm
6
- import random
7
- import requests
8
- import io
9
- import time
10
-
11
- def download_image(url, timeout=10, retries=2):
12
- """Download image from URL with retry mechanism"""
13
- for attempt in range(retries):
14
- try:
15
- headers = {
16
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
17
- }
18
- response = requests.get(url, timeout=timeout, headers=headers)
19
-
20
- if response.status_code == 200:
21
- image = Image.open(io.BytesIO(response.content))
22
- return image
23
- else:
24
- return None
25
-
26
- except Exception as e:
27
- if attempt == retries - 1: # Last attempt
28
- print(f"Failed to download {url}: {e}")
29
- return None
30
- time.sleep(0.5) # Brief pause before retry
31
-
32
- return None
33
-
34
- def preprocess_image(image, target_size=512, quality_threshold=0.7):
35
- """Preprocess image with various enhancements"""
36
- if image is None:
37
- return None
38
-
39
- try:
40
- # Convert to RGB if needed
41
- if image.mode != 'RGB':
42
- image = image.convert('RGB')
43
-
44
- # Filter out low quality images
45
- width, height = image.size
46
- if min(width, height) < target_size * quality_threshold:
47
- return None
48
-
49
- # Center crop to square if not already
50
- if width != height:
51
- size = min(width, height)
52
- left = (width - size) // 2
53
- top = (height - size) // 2
54
- image = image.crop((left, top, left + size, top + size))
55
-
56
- # Resize to target size
57
- image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
58
-
59
- # Enhance image quality
60
- # Slightly sharpen
61
- image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3))
62
-
63
- # Auto-adjust levels
64
- image = ImageOps.autocontrast(image, cutoff=1)
65
-
66
- return image
67
-
68
- except Exception as e:
69
- print(f"Error preprocessing image: {e}")
70
- return None
71
-
72
- def clean_prompt(prompt):
73
- """Clean and normalize prompts"""
74
- if not prompt:
75
- return None
76
-
77
- # Remove excessive whitespace
78
- prompt = ' '.join(prompt.split())
79
-
80
- # Remove common artifacts
81
- prompt = prompt.replace(' ', ' ')
82
- prompt = prompt.strip(' .,;:')
83
-
84
- # Filter out very short or very long prompts
85
- words = prompt.split()
86
- if len(words) < 3 or len(words) > 50:
87
- return None
88
-
89
- return prompt
90
-
91
- def prepare_dreambooth_data():
92
- # Load dataset
93
- print("Loading LAION dataset...")
94
- dataset = load_dataset("laion/laion2B-en-aesthetic", split="train", streaming=True)
95
-
96
- # Create directory structure
97
- data_dir = "./laion_dataset"
98
- os.makedirs(data_dir, exist_ok=True)
99
-
100
- valid_samples = 0
101
- processed_count = 0
102
- max_samples = 1000 # Limit total samples to process
103
-
104
- print(f"Starting to process up to {max_samples} samples...")
105
-
106
- # Process images with preprocessing
107
- for idx, sample in enumerate(tqdm(dataset, desc="Processing LAION samples")):
108
- if processed_count >= max_samples:
109
- break
110
-
111
- processed_count += 1
112
-
113
- try:
114
- # Get URL and text from LAION format
115
- image_url = sample.get('URL', '')
116
- text_prompt = sample.get('TEXT', '')
117
-
118
- if not image_url or not text_prompt:
119
- continue
120
-
121
- # Clean prompt first
122
- prompt = clean_prompt(text_prompt)
123
- if prompt is None:
124
- continue
125
-
126
- # Download image from URL
127
- print(f"Downloading image {valid_samples + 1}: {image_url[:50]}...")
128
- image = download_image(image_url)
129
- if image is None:
130
- continue
131
-
132
- # Preprocess downloaded image
133
- processed_image = preprocess_image(image)
134
- if processed_image is None:
135
- continue
136
-
137
- # Save processed image
138
- image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg")
139
- processed_image.save(image_path, "JPEG", quality=95, optimize=True)
140
-
141
- # Save cleaned caption
142
- caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt")
143
- with open(caption_path, 'w', encoding='utf-8') as f:
144
- f.write(prompt)
145
-
146
- valid_samples += 1
147
-
148
- # Optional: Add metadata file
149
- metadata_path = os.path.join(data_dir, f"image_{valid_samples-1:04d}_meta.txt")
150
- with open(metadata_path, 'w', encoding='utf-8') as f:
151
- f.write(f"URL: {image_url}\n")
152
- f.write(f"Aesthetic: {sample.get('aesthetic', 'N/A')}\n")
153
- f.write(f"Width: {sample.get('WIDTH', 'N/A')}\n")
154
- f.write(f"Height: {sample.get('HEIGHT', 'N/A')}\n")
155
-
156
- # Stop if we have enough samples
157
- if valid_samples >= 100: # Adjust this number as needed
158
- break
159
-
160
- except Exception as e:
161
- print(f"Error processing sample {idx}: {e}")
162
- continue
163
-
164
- print(f"Processed {processed_count} samples, saved {valid_samples} valid images to {data_dir}")
165
- return data_dir
166
-
167
- def create_demo_dataset():
168
- """Create demo dataset as last resort"""
169
- print("Creating demo dataset...")
170
-
171
- data_dir = "./demo_dataset"
172
- os.makedirs(data_dir, exist_ok=True)
173
-
174
- demo_prompts = [
175
- "a beautiful landscape with mountains",
176
- "portrait of a person with detailed features",
177
- "abstract colorful digital artwork",
178
- "modern architecture building design",
179
- "natural forest scene with trees",
180
- "urban cityscape at sunset",
181
- "artistic oil painting style",
182
- "vintage photography aesthetic",
183
- "minimalist geometric composition",
184
- "vibrant surreal art piece"
185
- ]
186
-
187
- for idx, prompt in enumerate(demo_prompts):
188
- # Create gradient background
189
- color1 = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200))
190
- color2 = (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
191
-
192
- image = Image.new('RGB', (512, 512), color1)
193
-
194
- # Save files
195
- image_path = os.path.join(data_dir, f"image_{idx:04d}.jpg")
196
- image.save(image_path, "JPEG", quality=95)
197
-
198
- caption_path = os.path.join(data_dir, f"image_{idx:04d}.txt")
199
- with open(caption_path, 'w', encoding='utf-8') as f:
200
- f.write(prompt)
201
-
202
- print(f"Created {len(demo_prompts)} demo samples")
203
- return data_dir
204
-
205
- # Main execution with fallback
206
- def main():
207
- data_dir = prepare_dreambooth_data()
208
-
209
- # Generate training command
210
- training_command = f"""
211
- accelerate launch \\
212
- --deepspeed_config_file ds_config.json \\
213
- diffusers/examples/dreambooth/train_dreambooth.py \\
214
- --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\
215
- --instance_data_dir="{data_dir}" \\
216
- --instance_prompt="a high quality image" \\
217
- --output_dir="./laion-model" \\
218
- --resolution=512 \\
219
- --train_batch_size=1 \\
220
- --gradient_accumulation_steps=1 \\
221
- --gradient_checkpointing \\
222
- --learning_rate=5e-6 \\
223
- --lr_scheduler="constant" \\
224
- --lr_warmup_steps=0 \\
225
- --max_train_steps=400 \\
226
- --mixed_precision="fp16" \\
227
- --checkpointing_steps=100 \\
228
- --checkpoints_total_limit=1 \\
229
- --report_to="tensorboard" \\
230
- --logging_dir="./laion-model/logs"
231
- """
232
-
233
- print(f"\n✅ Dataset prepared in: {data_dir}")
234
- print("🚀 Run this command to train:")
235
- print(training_command)
236
-
237
- if __name__ == "__main__":
238
- main()