Spaces:
Running
Running
test
#1
by
goodmodeler
- opened
- README.md +15 -30
- build_embeddings.py +11 -0
- data_loader/download.py +0 -209
- data_loader/download_dataset.py +0 -48
- deprecated/image_download.py → image_download.py +0 -0
- deprecated/image_gen.py → image_gen.py +0 -0
- lauguage_model_fine_tuning/accelerate_config.yaml +0 -23
- lauguage_model_fine_tuning/distillation/distill_llm.py +0 -485
- lauguage_model_fine_tuning/distillation/eval_compare_teacher_student.py +0 -168
- lauguage_model_fine_tuning/distillation/launch_distill.sh +0 -60
- lauguage_model_fine_tuning/eval_ppo_teacher.py +0 -170
- lauguage_model_fine_tuning/launch_ppo_fine_tune_teacher.sh +0 -63
- lauguage_model_fine_tuning/launch_supervised_fine_tune_teacher.sh +0 -28
- lauguage_model_fine_tuning/merge_teacher_model.py +0 -116
- lauguage_model_fine_tuning/ppo_fine_tune_teacher.py +0 -459
- lauguage_model_fine_tuning/sft_teacher.py +0 -276
- ppo_tune.py +19 -0
- requirements.txt +12 -51
- retrieval_augmented_generation/build_embeddings.py +0 -246
- reward_model.py +21 -0
- sft_train.py +41 -0
- fully_fine_tune_stablediffusion/train_lora.py → train_lora.py +0 -0
- train_model_test.py +0 -238
README.md
CHANGED
@@ -14,8 +14,6 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
|
|
14 |
|
15 |
commands:
|
16 |
|
17 |
-
download images: python download.py -i 1 -r 2 -o /home/user/app/image_tmp -z
|
18 |
-
|
19 |
pip install git+https://github.com/huggingface/diffusers
|
20 |
|
21 |
accelerate launch \
|
@@ -45,39 +43,26 @@ fine tune a trained model: --pretrained_model_name_or_path="./nyc-ad-model/check
|
|
45 |
|
46 |
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
47 |
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
# 1
|
51 |
accelerate launch --deepspeed_config_file=ds_config_zero3.json train_lora.py
|
52 |
-
|
53 |
|
54 |
-
# 2 SFT
|
55 |
-
|
56 |
|
57 |
-
# 3
|
58 |
-
|
59 |
|
60 |
-
# 4
|
61 |
-
|
62 |
-
用 Teacher 生成 Response,student模型用LoRA fine tuning
|
63 |
|
64 |
-
# 5
|
65 |
-
|
66 |
|
67 |
# 6 Inference with RAG
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
system flow:
|
72 |
-
input: business or product description text
|
73 |
-
1. 根据input用RAG取embedding
|
74 |
-
1. GPT‑OSS 生成 4 个广告文案 + 标题 + 口号(可选语气:专业/活泼/极简)
|
75 |
-
2. GPT‑OSS 基于选中文案生成 扩展视觉提示词(主体、配色、镜头、艺术风格)
|
76 |
-
3. stablediffusion model 生成 4 张草图(可选 ControlNet-Layout/Logo 插入)
|
77 |
-
4. 返回4张海报+后处理
|
78 |
-
output: an advertisement sentence and post image
|
79 |
-
|
80 |
-
|
81 |
-
design details:
|
82 |
-
LoRA fine tune teacher OSS 120B model using smangrul/ad-copy-generation (广告文案生成)
|
83 |
-
LoRA distill knowledge to OSS 20B model
|
|
|
14 |
|
15 |
commands:
|
16 |
|
|
|
|
|
17 |
pip install git+https://github.com/huggingface/diffusers
|
18 |
|
19 |
accelerate launch \
|
|
|
43 |
|
44 |
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
45 |
|
46 |
+
import torch
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
torch.cuda.reset_peak_memory_stats()
|
49 |
|
50 |
+
7/12
|
51 |
+
# 1 Fine‑tune image model LoRA+QLoRA
|
52 |
accelerate launch --deepspeed_config_file=ds_config_zero3.json train_lora.py
|
53 |
+
python train_lora.py
|
54 |
|
55 |
+
# 2 SFT 语言模型
|
56 |
+
python sft_train.py
|
57 |
|
58 |
+
# 3 Build RAG index
|
59 |
+
python build_embeddings.py
|
60 |
|
61 |
+
# 4 (可选) 收集偏好 → 训练 reward model
|
62 |
+
python reward_model.py
|
|
|
63 |
|
64 |
+
# 5 PPO RLHF 微调
|
65 |
+
python ppo_tune.py
|
66 |
|
67 |
# 6 Inference with RAG
|
68 |
+
python rag_infer.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build_embeddings.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import faiss, json, glob, os, numpy as np
|
3 |
+
|
4 |
+
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
5 |
+
texts=[]; vecs=[]
|
6 |
+
for f in glob.glob("nyc_ads_dataset/*.json"):
|
7 |
+
cap=json.load(open(f))["caption"]
|
8 |
+
texts.append(cap); vecs.append(model.encode(cap,normalize_embeddings=True))
|
9 |
+
vecs=np.vstack(vecs).astype("float32")
|
10 |
+
index=faiss.IndexFlatIP(vecs.shape[1]); index.add(vecs)
|
11 |
+
faiss.write_index(index,"prompt.index"); json.dump(texts,open("prompt.txt","w"))
|
data_loader/download.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
# Author: Marco Lustri 2022 - https://github.com/TheLustriVA
|
2 |
-
# MIT License
|
3 |
-
|
4 |
-
"""A script to make downloading the DiffusionDB dataset easier."""
|
5 |
-
from urllib.error import HTTPError
|
6 |
-
from urllib.request import urlretrieve
|
7 |
-
from alive_progress import alive_bar
|
8 |
-
from os.path import exists
|
9 |
-
|
10 |
-
import shutil
|
11 |
-
import os
|
12 |
-
import time
|
13 |
-
import argparse
|
14 |
-
|
15 |
-
index = None # initiate main arguments as None
|
16 |
-
range_max = None
|
17 |
-
output = None
|
18 |
-
unzip = None
|
19 |
-
large = None
|
20 |
-
|
21 |
-
parser = argparse.ArgumentParser(description="Download a file from a URL") #
|
22 |
-
|
23 |
-
# It's adding arguments to the parser.
|
24 |
-
parser.add_argument(
|
25 |
-
"-i",
|
26 |
-
"--index",
|
27 |
-
type=int,
|
28 |
-
default=1,
|
29 |
-
help="File to download or lower bound of range if -r is set",
|
30 |
-
)
|
31 |
-
parser.add_argument(
|
32 |
-
"-r",
|
33 |
-
"--range",
|
34 |
-
type=int,
|
35 |
-
default=None,
|
36 |
-
help="Upper bound of range if -i is provided",
|
37 |
-
)
|
38 |
-
parser.add_argument(
|
39 |
-
"-o", "--output", type=str, default="images", help="Output directory name"
|
40 |
-
)
|
41 |
-
parser.add_argument(
|
42 |
-
"-z",
|
43 |
-
"--unzip",
|
44 |
-
default=False,
|
45 |
-
help="Unzip the file after downloading",
|
46 |
-
# It's setting the argument to True if it's provided.
|
47 |
-
action="store_true",
|
48 |
-
)
|
49 |
-
parser.add_argument(
|
50 |
-
"-l",
|
51 |
-
"--large",
|
52 |
-
default=False,
|
53 |
-
help="Download from DiffusionDB Large (14 million images)",
|
54 |
-
action="store_true",
|
55 |
-
)
|
56 |
-
|
57 |
-
args = parser.parse_args() # parse the arguments
|
58 |
-
|
59 |
-
# It's checking if the user has provided any arguments, and if they have, it
|
60 |
-
# sets the variables to the arguments.
|
61 |
-
if args.index:
|
62 |
-
index = args.index
|
63 |
-
if args.range:
|
64 |
-
range_max = args.range
|
65 |
-
if args.output:
|
66 |
-
output = args.output
|
67 |
-
if args.unzip:
|
68 |
-
unzip = args.unzip
|
69 |
-
if args.large:
|
70 |
-
large = args.large
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
def download(index=1, range_index=0, output="", large=False):
|
75 |
-
"""
|
76 |
-
Download a file from a URL and save it to a local file
|
77 |
-
|
78 |
-
:param index: The index of the file to download, defaults to 1 (optional)
|
79 |
-
:param range_index: The number of files to download. If you want to download
|
80 |
-
all files, set this to the number of files you want to download,
|
81 |
-
defaults to 0 (optional)
|
82 |
-
:param output: The directory to download the files to :return: A list of
|
83 |
-
files to unzip
|
84 |
-
:param large: If downloading from DiffusionDB Large (14 million images)
|
85 |
-
instead of DiffusionDB 2M (2 million images)
|
86 |
-
"""
|
87 |
-
baseurl = "https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/"
|
88 |
-
files_to_unzip = []
|
89 |
-
|
90 |
-
if large:
|
91 |
-
if index <= 10000:
|
92 |
-
url = f"{baseurl}diffusiondb-large-part-1/part-{index:06}.zip"
|
93 |
-
else:
|
94 |
-
url = f"{baseurl}diffusiondb-large-part-2/part-{index:06}.zip"
|
95 |
-
else:
|
96 |
-
url = f"{baseurl}images/part-{index:06}.zip"
|
97 |
-
|
98 |
-
if output != "":
|
99 |
-
output = f"{output}/"
|
100 |
-
|
101 |
-
if not exists(output):
|
102 |
-
os.makedirs(output)
|
103 |
-
|
104 |
-
if range_index == 0:
|
105 |
-
print("Downloading file: ", url)
|
106 |
-
file_path = f"{output}part-{index:06}.zip"
|
107 |
-
try:
|
108 |
-
urlretrieve(url, file_path)
|
109 |
-
except HTTPError as e:
|
110 |
-
print(f"Encountered an HTTPError downloading file: {url} - {e}")
|
111 |
-
if unzip:
|
112 |
-
unzip(file_path)
|
113 |
-
else:
|
114 |
-
# It's downloading the files numbered from index to range_index.
|
115 |
-
with alive_bar(range_index - index, title="Downloading files") as bar:
|
116 |
-
for idx in range(index, range_index):
|
117 |
-
if large:
|
118 |
-
if idx <= 10000:
|
119 |
-
url = f"{baseurl}diffusiondb-large-part-1/part-{idx:06}.zip"
|
120 |
-
else:
|
121 |
-
url = f"{baseurl}diffusiondb-large-part-2/part-{idx:06}.zip"
|
122 |
-
else:
|
123 |
-
url = f"{baseurl}images/part-{idx:06}.zip"
|
124 |
-
|
125 |
-
loop_file_path = f"{output}part-{idx:06}.zip"
|
126 |
-
# It's trying to download the file, and if it encounters an
|
127 |
-
# HTTPError, it prints the error.
|
128 |
-
try:
|
129 |
-
urlretrieve(url, loop_file_path)
|
130 |
-
except HTTPError as e:
|
131 |
-
print(f"HTTPError downloading file: {url} - {e}")
|
132 |
-
files_to_unzip.append(loop_file_path)
|
133 |
-
# It's writing the url of the file to a manifest file.
|
134 |
-
with open("manifest.txt", "a") as f:
|
135 |
-
f.write(url + "\n")
|
136 |
-
time.sleep(0.1)
|
137 |
-
bar()
|
138 |
-
|
139 |
-
# It's checking if the user wants to unzip the files, and if they do, it
|
140 |
-
# returns a list of files to unzip. It would be a bad idea to put these
|
141 |
-
# together as the process is already lengthy.
|
142 |
-
if unzip and len(files_to_unzip) > 0:
|
143 |
-
return files_to_unzip
|
144 |
-
|
145 |
-
|
146 |
-
def unzip_file(file: str, extract_to: str = None):
|
147 |
-
"""
|
148 |
-
> This function takes a zip file and unpacks it to specified directory
|
149 |
-
|
150 |
-
:param file: str - path to zip file
|
151 |
-
:param extract_to: str - directory to extract to (default: same name as zip file)
|
152 |
-
:return: The extraction directory path
|
153 |
-
"""
|
154 |
-
if extract_to is None:
|
155 |
-
extract_to = file.replace('.zip', '')
|
156 |
-
|
157 |
-
shutil.unpack_archive(file, extract_to)
|
158 |
-
return f"File: {file} has been unzipped to {extract_to}"
|
159 |
-
|
160 |
-
|
161 |
-
def unzip_all(files: list):
|
162 |
-
"""
|
163 |
-
> Unzip all files in a list of files
|
164 |
-
|
165 |
-
:param files: list
|
166 |
-
:type files: list
|
167 |
-
"""
|
168 |
-
with alive_bar(len(files), title="Unzipping files") as bar:
|
169 |
-
for file in files:
|
170 |
-
unzip_file(file, '/home/user/app/images')
|
171 |
-
time.sleep(0.1)
|
172 |
-
bar()
|
173 |
-
|
174 |
-
|
175 |
-
def main(index=None, range_max=None, output=None, unzip=None, large=None):
|
176 |
-
"""
|
177 |
-
`main` is a function that takes in an index, a range_max, an output, and an
|
178 |
-
unzip, and if the user confirms that they have enough space, it downloads
|
179 |
-
the files from the index to the output, and if unzip is true, it unzips them
|
180 |
-
|
181 |
-
:param index: The index of the file you want to download
|
182 |
-
:param range_max: The number of files to download
|
183 |
-
:param output: The directory to download the files to
|
184 |
-
:param unzip: If you want to unzip the files after downloading them, set
|
185 |
-
this to True
|
186 |
-
:param large: If you want to download from DiffusionDB Large (14 million
|
187 |
-
images) instead of DiffusionDB 2M (2 million images)
|
188 |
-
:return: A list of files that have been downloaded
|
189 |
-
"""
|
190 |
-
if index and range_max:
|
191 |
-
if range_max - index >= 1999:
|
192 |
-
confirmation = input("Do you have at least 1.7Tb free: (y/n)")
|
193 |
-
if confirmation != "y":
|
194 |
-
return
|
195 |
-
files = download(index, range_max, output, large)
|
196 |
-
if unzip:
|
197 |
-
unzip_all(files)
|
198 |
-
elif index:
|
199 |
-
download(index, output=output, large=large)
|
200 |
-
else:
|
201 |
-
print("No index provided")
|
202 |
-
|
203 |
-
|
204 |
-
# This is a common pattern in Python. It allows you to run the main function of
|
205 |
-
# your script by running the script through the interpreter. It also allows you
|
206 |
-
# to import the script into the interpreter without automatically running the
|
207 |
-
# main function.
|
208 |
-
if __name__ == "__main__":
|
209 |
-
main(index, range_max, output, unzip, large)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_loader/download_dataset.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import pandas as pd
|
4 |
-
from datasets import load_dataset
|
5 |
-
from PIL import Image
|
6 |
-
import shutil
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
def load_and_process():
|
10 |
-
dataset = load_dataset("poloclub/diffusiondb", split="train[:1000]")
|
11 |
-
|
12 |
-
os.makedirs("processed/images", exist_ok=True)
|
13 |
-
processed_data = []
|
14 |
-
|
15 |
-
for idx, sample in enumerate(tqdm(dataset)):
|
16 |
-
image_id = f"{idx:06d}.png"
|
17 |
-
|
18 |
-
if sample.get('image'):
|
19 |
-
sample['image'].save(f"processed/images/{image_id}")
|
20 |
-
|
21 |
-
data_entry = {
|
22 |
-
"id": idx,
|
23 |
-
"image_file": image_id,
|
24 |
-
"prompt": sample.get('p', ''),
|
25 |
-
"seed": sample.get('se', 0),
|
26 |
-
"cfg_scale": sample.get('c', 0.0),
|
27 |
-
"steps": sample.get('st', 0),
|
28 |
-
"sampler": sample.get('sa', '')
|
29 |
-
}
|
30 |
-
processed_data.append(data_entry)
|
31 |
-
|
32 |
-
return processed_data
|
33 |
-
|
34 |
-
def save_data(data):
|
35 |
-
with open("processed/data.json", "w") as f:
|
36 |
-
json.dump(data, f)
|
37 |
-
|
38 |
-
df = pd.DataFrame(data)
|
39 |
-
df.to_csv("processed/data.csv", index=False)
|
40 |
-
df.to_parquet("processed/data.parquet", index=False)
|
41 |
-
|
42 |
-
def main():
|
43 |
-
data = load_and_process()
|
44 |
-
save_data(data)
|
45 |
-
print(f"Processed {len(data)} samples")
|
46 |
-
|
47 |
-
if __name__ == "__main__":
|
48 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deprecated/image_download.py → image_download.py
RENAMED
File without changes
|
deprecated/image_gen.py → image_gen.py
RENAMED
File without changes
|
lauguage_model_fine_tuning/accelerate_config.yaml
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
# accelerate_config.yaml - 多GPU训练配置
|
2 |
-
|
3 |
-
compute_environment: LOCAL_MACHINE
|
4 |
-
distributed_type: MULTI_GPU
|
5 |
-
downcast_bf16: 'no'
|
6 |
-
gpu_ids: all
|
7 |
-
machine_rank: 0
|
8 |
-
main_training_function: main
|
9 |
-
mixed_precision: fp16
|
10 |
-
num_machines: 1
|
11 |
-
num_processes: 4 # 根据GPU数量调整
|
12 |
-
rdzv_backend: static
|
13 |
-
same_network: true
|
14 |
-
tpu_env: []
|
15 |
-
tpu_use_cluster: false
|
16 |
-
tpu_use_sudo: false
|
17 |
-
use_cpu: false
|
18 |
-
|
19 |
-
# RLHF特定设置
|
20 |
-
gradient_accumulation_steps: 8
|
21 |
-
gradient_clipping: 1.0
|
22 |
-
learning_rate: 1e-5
|
23 |
-
dataloader_drop_last: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/distillation/distill_llm.py
DELETED
@@ -1,485 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
Teacher-Student知识蒸馏脚本
|
4 |
-
将经过SFT+PPO RLHF的Teacher模型蒸馏到更小的Student模型
|
5 |
-
"""
|
6 |
-
|
7 |
-
import os
|
8 |
-
import torch
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from torch.utils.data import DataLoader, Dataset
|
11 |
-
from transformers import (
|
12 |
-
AutoModelForCausalLM,
|
13 |
-
AutoTokenizer,
|
14 |
-
TrainingArguments,
|
15 |
-
Trainer,
|
16 |
-
DataCollatorForLanguageModeling,
|
17 |
-
logging,
|
18 |
-
)
|
19 |
-
from datasets import load_dataset, Dataset as HFDataset
|
20 |
-
from peft import LoraConfig, get_peft_model, TaskType
|
21 |
-
import numpy as np
|
22 |
-
import wandb
|
23 |
-
from typing import Dict, List, Any, Optional
|
24 |
-
import json
|
25 |
-
from tqdm import tqdm
|
26 |
-
import warnings
|
27 |
-
|
28 |
-
warnings.filterwarnings("ignore")
|
29 |
-
logging.set_verbosity(logging.CRITICAL)
|
30 |
-
|
31 |
-
class DistillationConfig:
|
32 |
-
"""蒸馏训练配置"""
|
33 |
-
# 模型路径
|
34 |
-
teacher_model_path = "./rlhf_teacher_model" # RLHF后的Teacher模型
|
35 |
-
student_model_name = "microsoft/DialoGPT-medium" # 替换为实际的OpenAI OSS 20B模型
|
36 |
-
|
37 |
-
# 蒸馏参数
|
38 |
-
temperature = 4.0 # 蒸馏温度
|
39 |
-
alpha = 0.7 # 蒸馏损失权重
|
40 |
-
beta = 0.3 # 学生损失权重
|
41 |
-
gamma = 0.1 # 特征匹配损失权重
|
42 |
-
|
43 |
-
# 训练参数
|
44 |
-
learning_rate = 1e-4
|
45 |
-
num_train_epochs = 3
|
46 |
-
per_device_train_batch_size = 2
|
47 |
-
per_device_eval_batch_size = 4
|
48 |
-
gradient_accumulation_steps = 8
|
49 |
-
warmup_ratio = 0.1
|
50 |
-
weight_decay = 0.01
|
51 |
-
logging_steps = 50
|
52 |
-
eval_steps = 500
|
53 |
-
save_steps = 1000
|
54 |
-
|
55 |
-
# LoRA配置(为Student模型添加LoRA以提高训练效率)
|
56 |
-
use_lora = True
|
57 |
-
lora_r = 32
|
58 |
-
lora_alpha = 64
|
59 |
-
lora_dropout = 0.1
|
60 |
-
|
61 |
-
# 数据配置
|
62 |
-
max_length = 512
|
63 |
-
num_distill_samples = 10000 # 用于蒸馏的样本数量
|
64 |
-
|
65 |
-
# 输出配置
|
66 |
-
output_dir = "./distilled_student_model"
|
67 |
-
run_name = "teacher-student-distillation"
|
68 |
-
|
69 |
-
class DistillationDataset(Dataset):
|
70 |
-
"""蒸馏数据集类"""
|
71 |
-
|
72 |
-
def __init__(self, teacher_outputs: List[Dict], tokenizer, max_length: int = 512):
|
73 |
-
self.data = teacher_outputs
|
74 |
-
self.tokenizer = tokenizer
|
75 |
-
self.max_length = max_length
|
76 |
-
|
77 |
-
def __len__(self):
|
78 |
-
return len(self.data)
|
79 |
-
|
80 |
-
def __getitem__(self, idx):
|
81 |
-
item = self.data[idx]
|
82 |
-
|
83 |
-
# 构建完整的输入-输出序列
|
84 |
-
full_text = f"### Human: {item['prompt']}\n### Assistant: {item['response']}"
|
85 |
-
|
86 |
-
# Tokenize
|
87 |
-
encoded = self.tokenizer(
|
88 |
-
full_text,
|
89 |
-
truncation=True,
|
90 |
-
padding="max_length",
|
91 |
-
max_length=self.max_length,
|
92 |
-
return_tensors="pt"
|
93 |
-
)
|
94 |
-
|
95 |
-
return {
|
96 |
-
"input_ids": encoded["input_ids"].squeeze(),
|
97 |
-
"attention_mask": encoded["attention_mask"].squeeze(),
|
98 |
-
"teacher_logits": torch.tensor(item["teacher_logits"], dtype=torch.float),
|
99 |
-
"labels": encoded["input_ids"].squeeze()
|
100 |
-
}
|
101 |
-
|
102 |
-
class KnowledgeDistillationTrainer(Trainer):
|
103 |
-
"""知识蒸馏训练器"""
|
104 |
-
|
105 |
-
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7, beta=0.3, gamma=0.1, **kwargs):
|
106 |
-
super().__init__(model=student_model, **kwargs)
|
107 |
-
self.teacher_model = teacher_model
|
108 |
-
self.teacher_model.eval() # 冻结Teacher模型
|
109 |
-
|
110 |
-
self.temperature = temperature
|
111 |
-
self.alpha = alpha # 蒸馏损失权重
|
112 |
-
self.beta = beta # 学生损失权重
|
113 |
-
self.gamma = gamma # 特征匹配损失权重
|
114 |
-
|
115 |
-
def compute_loss(self, model, inputs, return_outputs=False):
|
116 |
-
"""计算蒸馏损失"""
|
117 |
-
|
118 |
-
labels = inputs.get("labels")
|
119 |
-
teacher_logits = inputs.get("teacher_logits").to(model.device)
|
120 |
-
|
121 |
-
# Student模型前向传播
|
122 |
-
student_outputs = model(**{k: v for k, v in inputs.items() if k not in ["teacher_logits"]})
|
123 |
-
student_logits = student_outputs.logits
|
124 |
-
|
125 |
-
# 计算各种损失
|
126 |
-
losses = {}
|
127 |
-
|
128 |
-
# 1. 标准语言模型损失 (学生模型自己的损失)
|
129 |
-
if labels is not None:
|
130 |
-
shift_logits = student_logits[..., :-1, :].contiguous()
|
131 |
-
shift_labels = labels[..., 1:].contiguous()
|
132 |
-
loss_fct = torch.nn.CrossEntropyLoss()
|
133 |
-
student_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
134 |
-
losses["student_loss"] = student_loss
|
135 |
-
|
136 |
-
# 2. 蒸馏损失 (KL散度)
|
137 |
-
if teacher_logits is not None:
|
138 |
-
# 确保维度匹配
|
139 |
-
if teacher_logits.shape != student_logits.shape:
|
140 |
-
min_seq_len = min(teacher_logits.shape[1], student_logits.shape[1])
|
141 |
-
teacher_logits = teacher_logits[:, :min_seq_len, :]
|
142 |
-
student_logits_for_distill = student_logits[:, :min_seq_len, :]
|
143 |
-
else:
|
144 |
-
student_logits_for_distill = student_logits
|
145 |
-
|
146 |
-
# 计算软标签概率
|
147 |
-
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
|
148 |
-
student_log_probs = F.log_softmax(student_logits_for_distill / self.temperature, dim=-1)
|
149 |
-
|
150 |
-
# KL散度损失
|
151 |
-
distill_loss = F.kl_div(
|
152 |
-
student_log_probs,
|
153 |
-
teacher_probs,
|
154 |
-
reduction="batchmean"
|
155 |
-
) * (self.temperature ** 2)
|
156 |
-
|
157 |
-
losses["distill_loss"] = distill_loss
|
158 |
-
|
159 |
-
# 3. 组合总损失
|
160 |
-
total_loss = 0
|
161 |
-
if "student_loss" in losses:
|
162 |
-
total_loss += self.beta * losses["student_loss"]
|
163 |
-
if "distill_loss" in losses:
|
164 |
-
total_loss += self.alpha * losses["distill_loss"]
|
165 |
-
|
166 |
-
# 记录各项损失
|
167 |
-
self.log({
|
168 |
-
"train/total_loss": total_loss.item(),
|
169 |
-
"train/student_loss": losses.get("student_loss", 0).item() if "student_loss" in losses else 0,
|
170 |
-
"train/distill_loss": losses.get("distill_loss", 0).item() if "distill_loss" in losses else 0,
|
171 |
-
})
|
172 |
-
|
173 |
-
return (total_loss, student_outputs) if return_outputs else total_loss
|
174 |
-
|
175 |
-
def prepare_student_model(config: DistillationConfig):
|
176 |
-
"""准备Student模型"""
|
177 |
-
print("🎓 Preparing student model...")
|
178 |
-
|
179 |
-
# 加载Student基础模型
|
180 |
-
student_model = AutoModelForCausalLM.from_pretrained(
|
181 |
-
config.student_model_name,
|
182 |
-
torch_dtype=torch.float16,
|
183 |
-
device_map="auto",
|
184 |
-
trust_remote_code=True,
|
185 |
-
)
|
186 |
-
|
187 |
-
# 添加LoRA(可选,用于高效训练)
|
188 |
-
if config.use_lora:
|
189 |
-
print("🔧 Adding LoRA to student model...")
|
190 |
-
lora_config = LoraConfig(
|
191 |
-
task_type=TaskType.CAUSAL_LM,
|
192 |
-
inference_mode=False,
|
193 |
-
r=config.lora_r,
|
194 |
-
lora_alpha=config.lora_alpha,
|
195 |
-
lora_dropout=config.lora_dropout,
|
196 |
-
target_modules=[
|
197 |
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
198 |
-
"gate_proj", "up_proj", "down_proj",
|
199 |
-
]
|
200 |
-
)
|
201 |
-
student_model = get_peft_model(student_model, lora_config)
|
202 |
-
student_model.print_trainable_parameters()
|
203 |
-
|
204 |
-
return student_model
|
205 |
-
|
206 |
-
def load_teacher_model(config: DistillationConfig):
|
207 |
-
"""加载Teacher模型"""
|
208 |
-
print("👨🏫 Loading teacher model...")
|
209 |
-
|
210 |
-
teacher_model = AutoModelForCausalLM.from_pretrained(
|
211 |
-
config.teacher_model_path,
|
212 |
-
torch_dtype=torch.float16,
|
213 |
-
device_map="auto",
|
214 |
-
trust_remote_code=True,
|
215 |
-
)
|
216 |
-
teacher_model.eval()
|
217 |
-
|
218 |
-
return teacher_model
|
219 |
-
|
220 |
-
def generate_distillation_data(teacher_model, tokenizer, config: DistillationConfig):
|
221 |
-
"""生成蒸馏数据"""
|
222 |
-
print("📊 Generating distillation dataset...")
|
223 |
-
|
224 |
-
# 加载提示数据集
|
225 |
-
dataset_sources = [
|
226 |
-
"smangrul/ad-copy-generation",
|
227 |
-
# 可以添加更多数据源
|
228 |
-
]
|
229 |
-
|
230 |
-
all_prompts = []
|
231 |
-
for source in dataset_sources:
|
232 |
-
try:
|
233 |
-
ds = load_dataset(source, split="train")
|
234 |
-
# 提取提示词
|
235 |
-
for item in ds:
|
236 |
-
if "conversations" in item and len(item["conversations"]) > 0:
|
237 |
-
prompt = item["conversations"][0].get("value", "")
|
238 |
-
if len(prompt.strip()) > 10:
|
239 |
-
all_prompts.append(prompt.strip())
|
240 |
-
except Exception as e:
|
241 |
-
print(f"⚠️ Error loading {source}: {e}")
|
242 |
-
|
243 |
-
# 限制样本数量
|
244 |
-
if len(all_prompts) > config.num_distill_samples:
|
245 |
-
all_prompts = all_prompts[:config.num_distill_samples]
|
246 |
-
|
247 |
-
print(f"📝 Generating responses for {len(all_prompts)} prompts...")
|
248 |
-
|
249 |
-
distillation_data = []
|
250 |
-
teacher_model.eval()
|
251 |
-
|
252 |
-
with torch.no_grad():
|
253 |
-
for i, prompt in enumerate(tqdm(all_prompts, desc="Generating teacher responses")):
|
254 |
-
try:
|
255 |
-
# 格式化输入
|
256 |
-
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
257 |
-
inputs = tokenizer(
|
258 |
-
formatted_prompt,
|
259 |
-
return_tensors="pt",
|
260 |
-
truncation=True,
|
261 |
-
max_length=config.max_length // 2
|
262 |
-
).to(teacher_model.device)
|
263 |
-
|
264 |
-
# 生成响应
|
265 |
-
outputs = teacher_model.generate(
|
266 |
-
**inputs,
|
267 |
-
max_new_tokens=200,
|
268 |
-
temperature=0.7,
|
269 |
-
top_p=0.9,
|
270 |
-
do_sample=True,
|
271 |
-
pad_token_id=tokenizer.eos_token_id,
|
272 |
-
return_dict_in_generate=True,
|
273 |
-
output_scores=True
|
274 |
-
)
|
275 |
-
|
276 |
-
# 解码响应
|
277 |
-
generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:]
|
278 |
-
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
279 |
-
|
280 |
-
# 获取Teacher的logits
|
281 |
-
full_text = f"### Human: {prompt}\n### Assistant: {response}"
|
282 |
-
full_inputs = tokenizer(
|
283 |
-
full_text,
|
284 |
-
return_tensors="pt",
|
285 |
-
truncation=True,
|
286 |
-
max_length=config.max_length
|
287 |
-
).to(teacher_model.device)
|
288 |
-
|
289 |
-
teacher_outputs = teacher_model(**full_inputs)
|
290 |
-
teacher_logits = teacher_outputs.logits.cpu().numpy()
|
291 |
-
|
292 |
-
distillation_data.append({
|
293 |
-
"prompt": prompt,
|
294 |
-
"response": response,
|
295 |
-
"teacher_logits": teacher_logits.tolist()
|
296 |
-
})
|
297 |
-
|
298 |
-
# 定期保存中间结果
|
299 |
-
if (i + 1) % 100 == 0:
|
300 |
-
print(f"Generated {i + 1}/{len(all_prompts)} samples")
|
301 |
-
|
302 |
-
except Exception as e:
|
303 |
-
print(f"⚠️ Error generating for prompt {i}: {e}")
|
304 |
-
continue
|
305 |
-
|
306 |
-
print(f"✅ Generated {len(distillation_data)} teacher-student pairs")
|
307 |
-
|
308 |
-
# 保存蒸馏数据
|
309 |
-
with open("distillation_data.json", "w", encoding="utf-8") as f:
|
310 |
-
json.dump(distillation_data, f, ensure_ascii=False, indent=2)
|
311 |
-
|
312 |
-
return distillation_data
|
313 |
-
|
314 |
-
def create_data_collator(tokenizer):
|
315 |
-
"""创建数据整理器"""
|
316 |
-
return DataCollatorForLanguageModeling(
|
317 |
-
tokenizer=tokenizer,
|
318 |
-
mlm=False,
|
319 |
-
pad_to_multiple_of=8
|
320 |
-
)
|
321 |
-
|
322 |
-
def run_distillation():
|
323 |
-
"""主要的蒸馏训练流程"""
|
324 |
-
print("🚀 Starting Teacher-Student Distillation...")
|
325 |
-
|
326 |
-
config = DistillationConfig()
|
327 |
-
|
328 |
-
# 初始化wandb
|
329 |
-
wandb.init(
|
330 |
-
project="teacher-student-distillation",
|
331 |
-
config=vars(config),
|
332 |
-
name=config.run_name
|
333 |
-
)
|
334 |
-
|
335 |
-
# 加载tokenizer
|
336 |
-
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
|
337 |
-
if tokenizer.pad_token is None:
|
338 |
-
tokenizer.pad_token = tokenizer.eos_token
|
339 |
-
|
340 |
-
# 加载模型
|
341 |
-
teacher_model = load_teacher_model(config)
|
342 |
-
student_model = prepare_student_model(config)
|
343 |
-
|
344 |
-
# 生成蒸馏数据
|
345 |
-
if os.path.exists("distillation_data.json"):
|
346 |
-
print("📂 Loading existing distillation data...")
|
347 |
-
with open("distillation_data.json", "r", encoding="utf-8") as f:
|
348 |
-
distillation_data = json.load(f)
|
349 |
-
else:
|
350 |
-
distillation_data = generate_distillation_data(teacher_model, tokenizer, config)
|
351 |
-
|
352 |
-
# 创建数据集
|
353 |
-
train_size = int(0.9 * len(distillation_data))
|
354 |
-
train_data = distillation_data[:train_size]
|
355 |
-
eval_data = distillation_data[train_size:]
|
356 |
-
|
357 |
-
train_dataset = DistillationDataset(train_data, tokenizer, config.max_length)
|
358 |
-
eval_dataset = DistillationDataset(eval_data, tokenizer, config.max_length)
|
359 |
-
|
360 |
-
print(f"📊 Training samples: {len(train_dataset)}")
|
361 |
-
print(f"📊 Evaluation samples: {len(eval_dataset)}")
|
362 |
-
|
363 |
-
# 训练参数
|
364 |
-
training_args = TrainingArguments(
|
365 |
-
output_dir=config.output_dir,
|
366 |
-
overwrite_output_dir=True,
|
367 |
-
num_train_epochs=config.num_train_epochs,
|
368 |
-
per_device_train_batch_size=config.per_device_train_batch_size,
|
369 |
-
per_device_eval_batch_size=config.per_device_eval_batch_size,
|
370 |
-
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
371 |
-
learning_rate=config.learning_rate,
|
372 |
-
weight_decay=config.weight_decay,
|
373 |
-
warmup_ratio=config.warmup_ratio,
|
374 |
-
logging_steps=config.logging_steps,
|
375 |
-
eval_steps=config.eval_steps,
|
376 |
-
save_steps=config.save_steps,
|
377 |
-
evaluation_strategy="steps",
|
378 |
-
save_strategy="steps",
|
379 |
-
load_best_model_at_end=True,
|
380 |
-
metric_for_best_model="eval_loss",
|
381 |
-
greater_is_better=False,
|
382 |
-
report_to="wandb",
|
383 |
-
run_name=config.run_name,
|
384 |
-
fp16=True,
|
385 |
-
dataloader_pin_memory=False,
|
386 |
-
remove_unused_columns=False,
|
387 |
-
group_by_length=True,
|
388 |
-
)
|
389 |
-
|
390 |
-
# 创建数据整理器
|
391 |
-
data_collator = create_data_collator(tokenizer)
|
392 |
-
|
393 |
-
# 创建蒸馏训练器
|
394 |
-
trainer = KnowledgeDistillationTrainer(
|
395 |
-
teacher_model=teacher_model,
|
396 |
-
student_model=student_model,
|
397 |
-
args=training_args,
|
398 |
-
train_dataset=train_dataset,
|
399 |
-
eval_dataset=eval_dataset,
|
400 |
-
data_collator=data_collator,
|
401 |
-
tokenizer=tokenizer,
|
402 |
-
temperature=config.temperature,
|
403 |
-
alpha=config.alpha,
|
404 |
-
beta=config.beta,
|
405 |
-
gamma=config.gamma,
|
406 |
-
)
|
407 |
-
|
408 |
-
# 开始训练
|
409 |
-
print("🔥 Starting distillation training...")
|
410 |
-
trainer.train()
|
411 |
-
|
412 |
-
# 保存最终模型
|
413 |
-
print("💾 Saving distilled student model...")
|
414 |
-
trainer.save_model()
|
415 |
-
tokenizer.save_pretrained(config.output_dir)
|
416 |
-
|
417 |
-
# 评估模型
|
418 |
-
print("🧪 Evaluating distilled model...")
|
419 |
-
evaluate_distilled_model(trainer.model, tokenizer, config)
|
420 |
-
|
421 |
-
wandb.finish()
|
422 |
-
print("✅ Distillation training completed!")
|
423 |
-
|
424 |
-
def evaluate_distilled_model(model, tokenizer, config: DistillationConfig):
|
425 |
-
"""评估蒸馏后的模型"""
|
426 |
-
print("📊 Evaluating distilled student model...")
|
427 |
-
|
428 |
-
test_prompts = [
|
429 |
-
"Create an advertisement for a revolutionary AI-powered fitness tracker",
|
430 |
-
"Write marketing copy for an eco-friendly electric vehicle",
|
431 |
-
"Generate a slogan for a productivity app for remote workers",
|
432 |
-
"Create ad copy for a sustainable fashion brand targeting millennials",
|
433 |
-
"Write promotional content for a mental health app",
|
434 |
-
]
|
435 |
-
|
436 |
-
model.eval()
|
437 |
-
results = []
|
438 |
-
|
439 |
-
for prompt in test_prompts:
|
440 |
-
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
441 |
-
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
442 |
-
|
443 |
-
with torch.no_grad():
|
444 |
-
outputs = model.generate(
|
445 |
-
**inputs,
|
446 |
-
max_new_tokens=150,
|
447 |
-
temperature=0.7,
|
448 |
-
top_p=0.9,
|
449 |
-
do_sample=True,
|
450 |
-
pad_token_id=tokenizer.eos_token_id,
|
451 |
-
)
|
452 |
-
|
453 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
454 |
-
generated_text = response[len(formatted_prompt):].strip()
|
455 |
-
|
456 |
-
results.append({
|
457 |
-
"prompt": prompt,
|
458 |
-
"response": generated_text
|
459 |
-
})
|
460 |
-
|
461 |
-
print(f"\n🔍 Prompt: {prompt}")
|
462 |
-
print(f"📝 Student Response: {generated_text}")
|
463 |
-
print("-" * 80)
|
464 |
-
|
465 |
-
# 保存评估结果
|
466 |
-
with open(f"{config.output_dir}/evaluation_results.json", "w", encoding="utf-8") as f:
|
467 |
-
json.dump(results, f, ensure_ascii=False, indent=2)
|
468 |
-
|
469 |
-
return results
|
470 |
-
|
471 |
-
if __name__ == "__main__":
|
472 |
-
# 设置环境变量
|
473 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
474 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
475 |
-
|
476 |
-
# 检查GPU
|
477 |
-
if torch.cuda.is_available():
|
478 |
-
print(f"🔥 Using {torch.cuda.device_count()} GPUs")
|
479 |
-
for i in range(torch.cuda.device_count()):
|
480 |
-
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
481 |
-
else:
|
482 |
-
print("⚠️ Warning: No GPU available, using CPU (very slow)")
|
483 |
-
|
484 |
-
# 开始蒸馏训练
|
485 |
-
run_distillation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/distillation/eval_compare_teacher_student.py
DELETED
@@ -1,168 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
Teacher-Student模型性能比较脚本
|
4 |
-
比较RLHF Teacher模型和蒸馏后的Student模型的性能
|
5 |
-
"""
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import argparse
|
9 |
-
import json
|
10 |
-
import time
|
11 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
-
from typing import List, Dict, Any
|
13 |
-
import numpy as np
|
14 |
-
from datetime import datetime
|
15 |
-
|
16 |
-
class ModelComparator:
|
17 |
-
def __init__(self, teacher_path: str, student_path: str):
|
18 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
-
|
20 |
-
print("📥 Loading Teacher model...")
|
21 |
-
self.teacher_model = AutoModelForCausalLM.from_pretrained(
|
22 |
-
teacher_path,
|
23 |
-
torch_dtype=torch.float16,
|
24 |
-
device_map="auto"
|
25 |
-
)
|
26 |
-
self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_path)
|
27 |
-
|
28 |
-
print("📥 Loading Student model...")
|
29 |
-
self.student_model = AutoModelForCausalLM.from_pretrained(
|
30 |
-
student_path,
|
31 |
-
torch_dtype=torch.float16,
|
32 |
-
device_map="auto"
|
33 |
-
)
|
34 |
-
self.student_tokenizer = AutoTokenizer.from_pretrained(student_path)
|
35 |
-
|
36 |
-
# 设置pad tokens
|
37 |
-
for tokenizer in [self.teacher_tokenizer, self.student_tokenizer]:
|
38 |
-
if tokenizer.pad_token is None:
|
39 |
-
tokenizer.pad_token = tokenizer.eos_token
|
40 |
-
|
41 |
-
def generate_response(self, model, tokenizer, prompt: str, **kwargs) -> Dict[str, Any]:
|
42 |
-
"""生成响应并记录性能指标"""
|
43 |
-
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
44 |
-
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
45 |
-
|
46 |
-
generation_config = {
|
47 |
-
"max_new_tokens": 200,
|
48 |
-
"temperature": 0.7,
|
49 |
-
"top_p": 0.9,
|
50 |
-
"do_sample": True,
|
51 |
-
"pad_token_id": tokenizer.eos_token_id,
|
52 |
-
**kwargs
|
53 |
-
}
|
54 |
-
|
55 |
-
# 测量生成时间
|
56 |
-
start_time = time.time()
|
57 |
-
|
58 |
-
with torch.no_grad():
|
59 |
-
outputs = model.generate(**inputs, **generation_config)
|
60 |
-
|
61 |
-
generation_time = time.time() - start_time
|
62 |
-
|
63 |
-
# 解码响应
|
64 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
65 |
-
generated_text = response[len(formatted_prompt):].strip()
|
66 |
-
|
67 |
-
# 计算tokens数量
|
68 |
-
generated_tokens = len(tokenizer.encode(generated_text))
|
69 |
-
|
70 |
-
return {
|
71 |
-
"response": generated_text,
|
72 |
-
"generation_time": generation_time,
|
73 |
-
"tokens_generated": generated_tokens,
|
74 |
-
"tokens_per_second": generated_tokens / generation_time if generation_time > 0 else 0,
|
75 |
-
"prompt_tokens": inputs.input_ids.shape[1],
|
76 |
-
"total_tokens": outputs.shape[1]
|
77 |
-
}
|
78 |
-
|
79 |
-
def calculate_model_size(self, model) -> Dict[str, Any]:
|
80 |
-
"""计算模型大小和参数量"""
|
81 |
-
param_count = sum(p.numel() for p in model.parameters())
|
82 |
-
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
83 |
-
|
84 |
-
# 估算模型大小(bytes)
|
85 |
-
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
|
86 |
-
model_size_mb = model_size_bytes / (1024 * 1024)
|
87 |
-
model_size_gb = model_size_mb / 1024
|
88 |
-
|
89 |
-
return {
|
90 |
-
"total_parameters": param_count,
|
91 |
-
"trainable_parameters": trainable_params,
|
92 |
-
"model_size_mb": model_size_mb,
|
93 |
-
"model_size_gb": model_size_gb,
|
94 |
-
"compression_ratio": None # 将在比较时计算
|
95 |
-
}
|
96 |
-
|
97 |
-
def evaluate_quality_metrics(self, responses: List[str]) -> Dict[str, float]:
|
98 |
-
"""评估生成质量指标"""
|
99 |
-
metrics = {}
|
100 |
-
|
101 |
-
# 平均响应长度
|
102 |
-
avg_length = np.mean([len(resp.split()) for resp in responses])
|
103 |
-
metrics["avg_response_length"] = avg_length
|
104 |
-
|
105 |
-
# 响应长度标准差
|
106 |
-
length_std = np.std([len(resp.split()) for resp in responses])
|
107 |
-
metrics["response_length_std"] = length_std
|
108 |
-
|
109 |
-
# 词汇丰富度(使用type-token ratio的简化版本)
|
110 |
-
all_words = []
|
111 |
-
for resp in responses:
|
112 |
-
all_words.extend(resp.lower().split())
|
113 |
-
|
114 |
-
if all_words:
|
115 |
-
unique_words = len(set(all_words))
|
116 |
-
total_words = len(all_words)
|
117 |
-
metrics["vocabulary_richness"] = unique_words / total_words
|
118 |
-
else:
|
119 |
-
metrics["vocabulary_richness"] = 0.0
|
120 |
-
|
121 |
-
# 平均句子数量
|
122 |
-
avg_sentences = np.mean([resp.count('.') + resp.count('!') + resp.count('?') for resp in responses])
|
123 |
-
metrics["avg_sentences_per_response"] = avg_sentences
|
124 |
-
|
125 |
-
return metrics
|
126 |
-
|
127 |
-
def run_comprehensive_comparison(self) -> Dict[str, Any]:
|
128 |
-
"""运行全面的性能比较"""
|
129 |
-
print("🔍 Running comprehensive Teacher-Student comparison...")
|
130 |
-
|
131 |
-
# 测试提示词集合
|
132 |
-
test_prompts = [
|
133 |
-
# 广告文案生成
|
134 |
-
"Create an advertisement for a revolutionary smartphone with advanced AI features",
|
135 |
-
"Write marketing copy for an eco-friendly electric vehicle targeting urban professionals",
|
136 |
-
"Generate a catchy slogan for a fitness app that uses AI personal training",
|
137 |
-
"Create promotional content for a sustainable fashion brand targeting Gen Z",
|
138 |
-
"Write ad copy for a productivity software targeting remote workers",
|
139 |
-
|
140 |
-
# 不同复杂度的任务
|
141 |
-
"Explain the benefits of renewable energy in simple terms",
|
142 |
-
"Write a brief product description for wireless headphones with noise cancellation",
|
143 |
-
"Create a social media post promoting a new coffee shop opening",
|
144 |
-
"Generate marketing text for a luxury watch brand",
|
145 |
-
"Write an email subject line for a summer sale promotion",
|
146 |
-
|
147 |
-
# 创意任务
|
148 |
-
"Create a tagline for a travel app that focuses on sustainable tourism",
|
149 |
-
"Write a short product pitch for smart home security system",
|
150 |
-
"Generate advertising copy for a meal delivery service focusing on healthy options",
|
151 |
-
"Create marketing content for an online learning platform",
|
152 |
-
"Write promotional text for a mental wellness app"
|
153 |
-
]
|
154 |
-
|
155 |
-
# 初始化结果收集
|
156 |
-
results = {
|
157 |
-
"comparison_date": datetime.now().isoformat(),
|
158 |
-
"test_prompts_count": len(test_prompts),
|
159 |
-
"teacher_results": {},
|
160 |
-
"student_results": {},
|
161 |
-
"performance_comparison": {},
|
162 |
-
"detailed_responses": []
|
163 |
-
}
|
164 |
-
|
165 |
-
# 获取模型信息
|
166 |
-
print("📊 Analyzing model specifications...")
|
167 |
-
teacher_info = self.calculate_model_size(self.teacher_model)
|
168 |
-
student_info = self.calculate_model_size(self.student_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/distillation/launch_distill.sh
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
# launch_distillation.sh - 启动Teacher-Student蒸馏训练
|
3 |
-
|
4 |
-
echo "🎓 Starting Teacher-Student Distillation Training..."
|
5 |
-
|
6 |
-
# 检查前置条件
|
7 |
-
echo "📋 Checking prerequisites..."
|
8 |
-
|
9 |
-
# 检查Teacher模型
|
10 |
-
if [ ! -d "./rlhf_teacher_model" ]; then
|
11 |
-
echo "❌ Error: RLHF Teacher model not found at ./rlhf_teacher_model"
|
12 |
-
echo " Please complete SFT and RLHF training first"
|
13 |
-
exit 1
|
14 |
-
fi
|
15 |
-
|
16 |
-
# 检查GPU资源
|
17 |
-
echo "📊 GPU Resources:"
|
18 |
-
nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv
|
19 |
-
|
20 |
-
# 检查可用显存
|
21 |
-
AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
|
22 |
-
echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"
|
23 |
-
|
24 |
-
if [ "$AVAILABLE_MEMORY" -lt 40000 ]; then
|
25 |
-
echo "⚠️ Warning: Distillation training requires significant GPU memory (>40GB recommended)"
|
26 |
-
echo " Consider using gradient checkpointing or smaller batch sizes"
|
27 |
-
fi
|
28 |
-
|
29 |
-
# 设置环境变量
|
30 |
-
export CUDA_VISIBLE_DEVICES=0,1 # 根据可用GPU调整
|
31 |
-
export TOKENIZERS_PARALLELISM=false
|
32 |
-
export WANDB_PROJECT="teacher-student-distillation"
|
33 |
-
export WANDB_RUN_NAME="distillation-$(date +%Y%m%d_%H%M%S)"
|
34 |
-
|
35 |
-
# 创建输出目录
|
36 |
-
mkdir -p ./distilled_student_model
|
37 |
-
mkdir -p ./distillation_logs
|
38 |
-
|
39 |
-
# 检查是否有现有的蒸馏数据
|
40 |
-
if [ -f "./distillation_data.json" ]; then
|
41 |
-
echo "📂 Found existing distillation data, will reuse it"
|
42 |
-
else
|
43 |
-
echo "📊 Will generate new distillation data from teacher model"
|
44 |
-
fi
|
45 |
-
|
46 |
-
echo "🔥 Starting distillation training..."
|
47 |
-
|
48 |
-
# 启动训练
|
49 |
-
python teacher_student_distillation.py 2>&1 | tee ./distillation_logs/distillation_$(date +%Y%m%d_%H%M%S).log
|
50 |
-
|
51 |
-
echo "✅ Distillation training completed!"
|
52 |
-
|
53 |
-
# 训练后比较
|
54 |
-
echo "⚖️ Comparing Teacher vs Student performance..."
|
55 |
-
python compare_teacher_student.py \
|
56 |
-
--teacher_path ./rlhf_teacher_model \
|
57 |
-
--student_path ./distilled_student_model \
|
58 |
-
--output_file ./comparison_results.json
|
59 |
-
|
60 |
-
echo "📊 Results saved to comparison_results.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/eval_ppo_teacher.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
RLHF模型评估脚本
|
4 |
-
评估训练后模型的对齐效果和生成质量
|
5 |
-
"""
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import argparse
|
9 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
-
from datasets import Dataset
|
11 |
-
import numpy as np
|
12 |
-
from typing import List, Dict
|
13 |
-
import json
|
14 |
-
|
15 |
-
class RLHFEvaluator:
|
16 |
-
def __init__(self, model_path: str, baseline_path: str = None):
|
17 |
-
"""
|
18 |
-
初始化评估器
|
19 |
-
|
20 |
-
Args:
|
21 |
-
model_path: RLHF训练后的模型路径
|
22 |
-
baseline_path: 基线模型路径(SFT模型)
|
23 |
-
"""
|
24 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
-
|
26 |
-
# 加载RLHF模型
|
27 |
-
print(f"📥 Loading RLHF model from {model_path}...")
|
28 |
-
self.rlhf_model = AutoModelForCausalLM.from_pretrained(
|
29 |
-
model_path,
|
30 |
-
torch_dtype=torch.float16,
|
31 |
-
device_map="auto"
|
32 |
-
)
|
33 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
34 |
-
|
35 |
-
# 加载基线模型(可选)
|
36 |
-
self.baseline_model = None
|
37 |
-
if baseline_path:
|
38 |
-
print(f"📥 Loading baseline model from {baseline_path}...")
|
39 |
-
self.baseline_model = AutoModelForCausalLM.from_pretrained(
|
40 |
-
baseline_path,
|
41 |
-
torch_dtype=torch.float16,
|
42 |
-
device_map="auto"
|
43 |
-
)
|
44 |
-
|
45 |
-
# 设置pad token
|
46 |
-
if self.tokenizer.pad_token is None:
|
47 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
48 |
-
|
49 |
-
def generate_response(self, prompt: str, model=None, **kwargs) -> str:
|
50 |
-
"""生成响应"""
|
51 |
-
if model is None:
|
52 |
-
model = self.rlhf_model
|
53 |
-
|
54 |
-
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
55 |
-
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
|
56 |
-
|
57 |
-
generation_config = {
|
58 |
-
"max_new_tokens": 200,
|
59 |
-
"temperature": 0.7,
|
60 |
-
"top_p": 0.9,
|
61 |
-
"do_sample": True,
|
62 |
-
"pad_token_id": self.tokenizer.eos_token_id,
|
63 |
-
**kwargs
|
64 |
-
}
|
65 |
-
|
66 |
-
with torch.no_grad():
|
67 |
-
outputs = model.generate(**inputs, **generation_config)
|
68 |
-
|
69 |
-
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
70 |
-
return response[len(formatted_prompt):].strip()
|
71 |
-
|
72 |
-
def evaluate_helpfulness(self, prompts: List[str]) -> Dict[str, float]:
|
73 |
-
"""评估有用性"""
|
74 |
-
print("🔍 Evaluating helpfulness...")
|
75 |
-
|
76 |
-
scores = []
|
77 |
-
for prompt in prompts:
|
78 |
-
response = self.generate_response(prompt)
|
79 |
-
|
80 |
-
# 简单的有用性评分(实际应用中应使用更复杂的评估方法)
|
81 |
-
score = self._calculate_helpfulness_score(prompt, response)
|
82 |
-
scores.append(score)
|
83 |
-
|
84 |
-
return {
|
85 |
-
"mean_helpfulness": np.mean(scores),
|
86 |
-
"std_helpfulness": np.std(scores),
|
87 |
-
"scores": scores
|
88 |
-
}
|
89 |
-
|
90 |
-
def evaluate_harmlessness(self, prompts: List[str]) -> Dict[str, float]:
|
91 |
-
"""评估无害性"""
|
92 |
-
print("🛡️ Evaluating harmlessness...")
|
93 |
-
|
94 |
-
harmful_keywords = [
|
95 |
-
"violence", "hate", "discrimination", "illegal", "harmful",
|
96 |
-
"dangerous", "offensive", "inappropriate", "toxic"
|
97 |
-
]
|
98 |
-
|
99 |
-
scores = []
|
100 |
-
for prompt in prompts:
|
101 |
-
response = self.generate_response(prompt)
|
102 |
-
|
103 |
-
# 检查是否包含有害内容
|
104 |
-
harm_score = sum(1 for keyword in harmful_keywords
|
105 |
-
if keyword.lower() in response.lower())
|
106 |
-
|
107 |
-
# 转换为0-1分数(越高越安全)
|
108 |
-
safety_score = max(0, 1 - harm_score / len(harmful_keywords))
|
109 |
-
scores.append(safety_score)
|
110 |
-
|
111 |
-
return {
|
112 |
-
"mean_harmlessness": np.mean(scores),
|
113 |
-
"std_harmlessness": np.std(scores),
|
114 |
-
"scores": scores
|
115 |
-
}
|
116 |
-
|
117 |
-
def evaluate_consistency(self, prompts: List[str], num_samples: int = 3) -> Dict[str, float]:
|
118 |
-
"""评估一致性(同一提示的多次生成)"""
|
119 |
-
print("🔄 Evaluating consistency...")
|
120 |
-
|
121 |
-
consistency_scores = []
|
122 |
-
|
123 |
-
for prompt in prompts:
|
124 |
-
responses = []
|
125 |
-
for _ in range(num_samples):
|
126 |
-
response = self.generate_response(prompt, temperature=0.8)
|
127 |
-
responses.append(response)
|
128 |
-
|
129 |
-
# 计算响应之间的相似性
|
130 |
-
similarity_score = self._calculate_response_similarity(responses)
|
131 |
-
consistency_scores.append(similarity_score)
|
132 |
-
|
133 |
-
return {
|
134 |
-
"mean_consistency": np.mean(consistency_scores),
|
135 |
-
"std_consistency": np.std(consistency_scores),
|
136 |
-
"scores": consistency_scores
|
137 |
-
}
|
138 |
-
|
139 |
-
def compare_with_baseline(self, prompts: List[str]) -> Dict[str, any]:
|
140 |
-
"""与基线模型比较"""
|
141 |
-
if self.baseline_model is None:
|
142 |
-
return {"error": "No baseline model provided"}
|
143 |
-
|
144 |
-
print("⚖️ Comparing with baseline model...")
|
145 |
-
|
146 |
-
comparisons = []
|
147 |
-
|
148 |
-
for prompt in prompts:
|
149 |
-
rlhf_response = self.generate_response(prompt, model=self.rlhf_model)
|
150 |
-
baseline_response = self.generate_response(prompt, model=self.baseline_model)
|
151 |
-
|
152 |
-
comparison = {
|
153 |
-
"prompt": prompt,
|
154 |
-
"rlhf_response": rlhf_response,
|
155 |
-
"baseline_response": baseline_response,
|
156 |
-
"rlhf_score": self._calculate_quality_score(prompt, rlhf_response),
|
157 |
-
"baseline_score": self._calculate_quality_score(prompt, baseline_response)
|
158 |
-
}
|
159 |
-
comparisons.append(comparison)
|
160 |
-
|
161 |
-
# 计算总体改进
|
162 |
-
rlhf_scores = [c["rlhf_score"] for c in comparisons]
|
163 |
-
baseline_scores = [c["baseline_score"] for c in comparisons]
|
164 |
-
|
165 |
-
improvement = (np.mean(rlhf_scores) - np.mean(baseline_scores)) / np.mean(baseline_scores) * 100
|
166 |
-
|
167 |
-
return {
|
168 |
-
"comparisons": comparisons,
|
169 |
-
"improvement_percentage": improvement,
|
170 |
-
"rlhf_mean_score": np.mean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/launch_ppo_fine_tune_teacher.sh
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
# launch_rlhf.sh - 启动PPO RLHF训练
|
3 |
-
|
4 |
-
echo "🚀 Starting PPO RLHF Training..."
|
5 |
-
|
6 |
-
# 检查前置条件
|
7 |
-
echo "📋 Checking prerequisites..."
|
8 |
-
|
9 |
-
# 检查Teacher模型是否存在
|
10 |
-
if [ ! -d "./merged_model" ]; then
|
11 |
-
echo "❌ Error: Teacher model not found at ./merged_model"
|
12 |
-
echo " Please run SFT training first and merge the model"
|
13 |
-
exit 1
|
14 |
-
fi
|
15 |
-
|
16 |
-
# 检查GPU资源
|
17 |
-
echo "📊 GPU Resources:"
|
18 |
-
nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv
|
19 |
-
|
20 |
-
# 检查可用显存(建议至少80GB用于RLHF)
|
21 |
-
AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
|
22 |
-
echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"
|
23 |
-
|
24 |
-
if [ "$AVAILABLE_MEMORY" -lt 80000 ]; then
|
25 |
-
echo "⚠️ Warning: RLHF training requires significant GPU memory (>80GB recommended)"
|
26 |
-
echo " Consider using gradient checkpointing or smaller batch sizes"
|
27 |
-
fi
|
28 |
-
|
29 |
-
# 设置环境变量
|
30 |
-
export CUDA_VISIBLE_DEVICES=0,1,2,3 # 根据可用GPU调整
|
31 |
-
export TOKENIZERS_PARALLELISM=false
|
32 |
-
export WANDB_PROJECT="rlhf-teacher-training"
|
33 |
-
export WANDB_RUN_NAME="ppo-rlhf-$(date +%Y%m%d_%H%M%S)"
|
34 |
-
|
35 |
-
# 创建输出目录
|
36 |
-
mkdir -p ./rlhf_teacher_model
|
37 |
-
mkdir -p ./rlhf_logs
|
38 |
-
|
39 |
-
# 安装额外依赖
|
40 |
-
echo "📦 Installing RLHF dependencies..."
|
41 |
-
pip install -r rlhf_requirements.txt
|
42 |
-
|
43 |
-
# 启动训练
|
44 |
-
echo "🔥 Starting PPO RLHF training..."
|
45 |
-
|
46 |
-
# 单GPU训练
|
47 |
-
if [ "$1" = "single" ]; then
|
48 |
-
CUDA_VISIBLE_DEVICES=0 python ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log
|
49 |
-
|
50 |
-
# 多GPU训练(推荐)
|
51 |
-
else
|
52 |
-
accelerate launch \
|
53 |
-
--config_file accelerate_config.yaml \
|
54 |
-
--num_processes 4 \
|
55 |
-
--main_process_port 29500 \
|
56 |
-
ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log
|
57 |
-
fi
|
58 |
-
|
59 |
-
echo "✅ RLHF training completed. Check logs for details."
|
60 |
-
|
61 |
-
# 训练后评估
|
62 |
-
echo "🧪 Running post-training evaluation..."
|
63 |
-
python evaluate_rlhf_model.py --model_path ./rlhf_teacher_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/launch_supervised_fine_tune_teacher.sh
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
# launch_training.sh - 启动QLoRA训练脚本
|
3 |
-
|
4 |
-
echo " Preparing QLoRA Fine-tuning Environment..."
|
5 |
-
|
6 |
-
# 检查GPU
|
7 |
-
echo " GPU Information:"
|
8 |
-
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
|
9 |
-
|
10 |
-
# 设置环境变量
|
11 |
-
export CUDA_VISIBLE_DEVICES=0
|
12 |
-
export TOKENIZERS_PARALLELISM=false
|
13 |
-
export WANDB_PROJECT="qlora-ad-copy-generation" # Optional
|
14 |
-
|
15 |
-
# 创建输出目录
|
16 |
-
mkdir -p ./results
|
17 |
-
mkdir -p ./logs
|
18 |
-
|
19 |
-
# 启动训练(支持多GPU)
|
20 |
-
echo " Starting QLoRA training..."
|
21 |
-
|
22 |
-
# 单GPU训练
|
23 |
-
python qlora_finetune.py 2>&1 | tee ./logs/training_$(date +%Y%m%d_%H%M%S).log
|
24 |
-
|
25 |
-
# 多GPU训练
|
26 |
-
# accelerate launch --multi_gpu --num_processes=2 qlora_finetune.py
|
27 |
-
|
28 |
-
echo " Training script launched. Check logs for progress."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/merge_teacher_model.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
模型合并脚本 - 将LoRA权重合并到基础模型中
|
4 |
-
用于推理和部署
|
5 |
-
"""
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
-
from peft import PeftModel
|
10 |
-
import argparse
|
11 |
-
|
12 |
-
def merge_lora_model(base_model_path, lora_model_path, output_path):
|
13 |
-
"""
|
14 |
-
合并LoRA权重到基础模型
|
15 |
-
|
16 |
-
Args:
|
17 |
-
base_model_path: 基础模型路径
|
18 |
-
lora_model_path: LoRA模型路径(训练输出)
|
19 |
-
output_path: 合并后模型保存路径
|
20 |
-
"""
|
21 |
-
print("📥 Loading base model...")
|
22 |
-
|
23 |
-
# 加载基础模型(不使用量化)
|
24 |
-
base_model = AutoModelForCausalLM.from_pretrained(
|
25 |
-
base_model_path,
|
26 |
-
torch_dtype=torch.float16,
|
27 |
-
device_map="auto",
|
28 |
-
trust_remote_code=True,
|
29 |
-
)
|
30 |
-
|
31 |
-
print("📥 Loading LoRA model...")
|
32 |
-
|
33 |
-
# 加载LoRA模型
|
34 |
-
model = PeftModel.from_pretrained(base_model, lora_model_path)
|
35 |
-
|
36 |
-
print("🔄 Merging LoRA weights...")
|
37 |
-
|
38 |
-
# 合并权重
|
39 |
-
model = model.merge_and_unload()
|
40 |
-
|
41 |
-
print("💾 Saving merged model...")
|
42 |
-
|
43 |
-
# 保存合并后的模型
|
44 |
-
model.save_pretrained(output_path, safe_serialization=True)
|
45 |
-
|
46 |
-
# 复制tokenizer
|
47 |
-
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
48 |
-
tokenizer.save_pretrained(output_path)
|
49 |
-
|
50 |
-
print(f"✅ Model merged and saved to {output_path}")
|
51 |
-
|
52 |
-
def test_merged_model(model_path):
|
53 |
-
"""测试合并后的模型"""
|
54 |
-
print("🧪 Testing merged model...")
|
55 |
-
|
56 |
-
# 加载模型和tokenizer
|
57 |
-
model = AutoModelForCausalLM.from_pretrained(
|
58 |
-
model_path,
|
59 |
-
torch_dtype=torch.float16,
|
60 |
-
device_map="auto",
|
61 |
-
)
|
62 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
63 |
-
|
64 |
-
# 测试提示
|
65 |
-
test_prompt = "### Human: Create an advertisement for a revolutionary AI-powered smartwatch\n### Assistant:"
|
66 |
-
|
67 |
-
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
|
68 |
-
|
69 |
-
with torch.no_grad():
|
70 |
-
outputs = model.generate(
|
71 |
-
**inputs,
|
72 |
-
max_new_tokens=200,
|
73 |
-
do_sample=True,
|
74 |
-
temperature=0.7,
|
75 |
-
top_p=0.9,
|
76 |
-
pad_token_id=tokenizer.eos_token_id,
|
77 |
-
)
|
78 |
-
|
79 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
80 |
-
generated_text = response[len(test_prompt):].strip()
|
81 |
-
|
82 |
-
print(f"\n📝 Test Prompt: Create an advertisement for a revolutionary AI-powered smartwatch")
|
83 |
-
print(f"📄 Generated Response:\n{generated_text}")
|
84 |
-
|
85 |
-
def main():
|
86 |
-
parser = argparse.ArgumentParser(description="Merge LoRA weights with base model")
|
87 |
-
parser.add_argument("--base_model", required=True, help="Path to base model")
|
88 |
-
parser.add_argument("--lora_model", required=True, help="Path to LoRA model (training output)")
|
89 |
-
parser.add_argument("--output", required=True, help="Output path for merged model")
|
90 |
-
parser.add_argument("--test", action="store_true", help="Test the merged model")
|
91 |
-
|
92 |
-
args = parser.parse_args()
|
93 |
-
|
94 |
-
# 合并模型
|
95 |
-
merge_lora_model(args.base_model, args.lora_model, args.output)
|
96 |
-
|
97 |
-
# 测试模型(可选)
|
98 |
-
if args.test:
|
99 |
-
test_merged_model(args.output)
|
100 |
-
|
101 |
-
if __name__ == "__main__":
|
102 |
-
# 示例用法
|
103 |
-
print("📋 Merge LoRA Model Script")
|
104 |
-
print("\n使用方法:")
|
105 |
-
print("python merge_model.py --base_model microsoft/DialoGPT-medium --lora_model ./results --output ./merged_model --test")
|
106 |
-
print("\n或者直接运行默认配置:")
|
107 |
-
|
108 |
-
# 默认配置
|
109 |
-
merge_lora_model(
|
110 |
-
base_model_path="microsoft/DialoGPT-medium", # 替换为实际的OpenAI OSS 120B模型
|
111 |
-
lora_model_path="./results",
|
112 |
-
output_path="./merged_model"
|
113 |
-
)
|
114 |
-
|
115 |
-
# 测试合并后的模型
|
116 |
-
test_merged_model("./merged_model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/ppo_fine_tune_teacher.py
DELETED
@@ -1,459 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
PPO RLHF训练脚本 - 基于Teacher模型进行人类偏好对齐
|
4 |
-
输入: SFT Teacher模型 + 人类偏好数据
|
5 |
-
输出: RLHF对齐的Teacher模型
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
import torch
|
10 |
-
import torch.nn.functional as F
|
11 |
-
from datasets import load_dataset, Dataset
|
12 |
-
from transformers import (
|
13 |
-
AutoModelForCausalLM,
|
14 |
-
AutoTokenizer,
|
15 |
-
AutoModelForSequenceClassification,
|
16 |
-
TrainingArguments,
|
17 |
-
pipeline,
|
18 |
-
logging,
|
19 |
-
)
|
20 |
-
from peft import PeftModel, LoraConfig, get_peft_model, TaskType
|
21 |
-
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
22 |
-
import wandb
|
23 |
-
import numpy as np
|
24 |
-
from typing import List, Dict, Any
|
25 |
-
import warnings
|
26 |
-
|
27 |
-
warnings.filterwarnings("ignore")
|
28 |
-
logging.set_verbosity(logging.CRITICAL)
|
29 |
-
|
30 |
-
class RLHFConfig:
|
31 |
-
"""RLHF训练配置"""
|
32 |
-
# 模型路径
|
33 |
-
teacher_model_path = "./merged_model" # 之前SFT训练的Teacher模型
|
34 |
-
reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2" # 奖励模型
|
35 |
-
|
36 |
-
# PPO训练参数
|
37 |
-
learning_rate = 1e-5
|
38 |
-
mini_batch_size = 1
|
39 |
-
batch_size = 8
|
40 |
-
gradient_accumulation_steps = 8
|
41 |
-
ppo_epochs = 4
|
42 |
-
max_grad_norm = 1.0
|
43 |
-
|
44 |
-
# PPO特定参数
|
45 |
-
init_kl_coef = 0.02
|
46 |
-
target_kl = 0.01
|
47 |
-
adap_kl_ctrl = True
|
48 |
-
clip_reward_value = 5.0
|
49 |
-
cliprange = 0.2
|
50 |
-
cliprange_value = 0.2
|
51 |
-
gamma = 1.0
|
52 |
-
lam = 0.95
|
53 |
-
|
54 |
-
# 生成参数
|
55 |
-
max_new_tokens = 150
|
56 |
-
temperature = 0.7
|
57 |
-
top_p = 0.9
|
58 |
-
do_sample = True
|
59 |
-
|
60 |
-
# 训练控制
|
61 |
-
total_episodes = 1000
|
62 |
-
save_freq = 100
|
63 |
-
eval_freq = 50
|
64 |
-
output_dir = "./rlhf_teacher_model"
|
65 |
-
|
66 |
-
# LoRA参数(如果使用LoRA进行RLHF)
|
67 |
-
use_lora = True
|
68 |
-
lora_r = 16
|
69 |
-
lora_alpha = 32
|
70 |
-
lora_dropout = 0.1
|
71 |
-
|
72 |
-
class RewardModelWrapper:
|
73 |
-
"""奖励模型包装器"""
|
74 |
-
|
75 |
-
def __init__(self, model_name: str, device: str = "cuda"):
|
76 |
-
self.device = device
|
77 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
78 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(
|
79 |
-
model_name,
|
80 |
-
torch_dtype=torch.float16,
|
81 |
-
device_map="auto"
|
82 |
-
)
|
83 |
-
self.model.eval()
|
84 |
-
|
85 |
-
# 设置pad token
|
86 |
-
if self.tokenizer.pad_token is None:
|
87 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
88 |
-
|
89 |
-
def get_reward(self, prompts: List[str], responses: List[str]) -> List[float]:
|
90 |
-
"""计算奖励分数"""
|
91 |
-
inputs = []
|
92 |
-
for prompt, response in zip(prompts, responses):
|
93 |
-
# 格式化为对话格式
|
94 |
-
text = f"Human: {prompt}\n\nAssistant: {response}"
|
95 |
-
inputs.append(text)
|
96 |
-
|
97 |
-
# 批量推理
|
98 |
-
with torch.no_grad():
|
99 |
-
encoded = self.tokenizer(
|
100 |
-
inputs,
|
101 |
-
padding=True,
|
102 |
-
truncation=True,
|
103 |
-
max_length=512,
|
104 |
-
return_tensors="pt"
|
105 |
-
).to(self.device)
|
106 |
-
|
107 |
-
outputs = self.model(**encoded)
|
108 |
-
rewards = outputs.logits.squeeze(-1).cpu().tolist()
|
109 |
-
|
110 |
-
return rewards
|
111 |
-
|
112 |
-
def load_preference_dataset():
|
113 |
-
"""加载偏好数据集"""
|
114 |
-
print("📥 Loading preference dataset...")
|
115 |
-
|
116 |
-
# 可以使用多个数据源
|
117 |
-
datasets_config = [
|
118 |
-
{
|
119 |
-
"name": "Anthropic/hh-rlhf",
|
120 |
-
"split": "train",
|
121 |
-
"weight": 0.7
|
122 |
-
},
|
123 |
-
{
|
124 |
-
"name": "OpenAssistant/oasst1",
|
125 |
-
"split": "train",
|
126 |
-
"weight": 0.3
|
127 |
-
}
|
128 |
-
]
|
129 |
-
|
130 |
-
all_prompts = []
|
131 |
-
|
132 |
-
for config in datasets_config:
|
133 |
-
try:
|
134 |
-
dataset = load_dataset(config["name"], split=config["split"])
|
135 |
-
|
136 |
-
# 处理不同数据集格式
|
137 |
-
if config["name"] == "Anthropic/hh-rlhf":
|
138 |
-
prompts = extract_prompts_from_hh(dataset)
|
139 |
-
else:
|
140 |
-
prompts = extract_prompts_from_oasst(dataset)
|
141 |
-
|
142 |
-
# 按权重采样
|
143 |
-
sample_size = int(len(prompts) * config["weight"])
|
144 |
-
prompts = prompts[:sample_size]
|
145 |
-
all_prompts.extend(prompts)
|
146 |
-
|
147 |
-
print(f"✅ Loaded {len(prompts)} prompts from {config['name']}")
|
148 |
-
|
149 |
-
except Exception as e:
|
150 |
-
print(f"⚠️ Failed to load {config['name']}: {e}")
|
151 |
-
|
152 |
-
# 创建Dataset对象
|
153 |
-
return Dataset.from_dict({"prompt": all_prompts})
|
154 |
-
|
155 |
-
def extract_prompts_from_hh(dataset):
|
156 |
-
"""从HH-RLHF数据集提取提示"""
|
157 |
-
prompts = []
|
158 |
-
for item in dataset:
|
159 |
-
# HH-RLHF格式解析
|
160 |
-
text = item.get("chosen", "")
|
161 |
-
if "Human:" in text:
|
162 |
-
prompt = text.split("Human:")[-1].split("Assistant:")[0].strip()
|
163 |
-
if len(prompt) > 10: # 过滤太短的提示
|
164 |
-
prompts.append(prompt)
|
165 |
-
return prompts
|
166 |
-
|
167 |
-
def extract_prompts_from_oasst(dataset):
|
168 |
-
"""从OpenAssistant数据集提取提示"""
|
169 |
-
prompts = []
|
170 |
-
for item in dataset:
|
171 |
-
if item.get("role") == "prompter":
|
172 |
-
prompt = item.get("text", "").strip()
|
173 |
-
if len(prompt) > 10:
|
174 |
-
prompts.append(prompt)
|
175 |
-
return prompts
|
176 |
-
|
177 |
-
def prepare_teacher_model(config: RLHFConfig):
|
178 |
-
"""准备Teacher模型用于RLHF"""
|
179 |
-
print("🤖 Preparing teacher model for RLHF...")
|
180 |
-
|
181 |
-
# 加载tokenizer
|
182 |
-
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
|
183 |
-
if tokenizer.pad_token is None:
|
184 |
-
tokenizer.pad_token = tokenizer.eos_token
|
185 |
-
|
186 |
-
# 加载基础模型
|
187 |
-
model = AutoModelForCausalLM.from_pretrained(
|
188 |
-
config.teacher_model_path,
|
189 |
-
torch_dtype=torch.float16,
|
190 |
-
device_map="auto",
|
191 |
-
trust_remote_code=True,
|
192 |
-
)
|
193 |
-
|
194 |
-
# 如果使用LoRA进行RLHF
|
195 |
-
if config.use_lora:
|
196 |
-
print("🔧 Adding LoRA for RLHF training...")
|
197 |
-
lora_config = LoraConfig(
|
198 |
-
task_type=TaskType.CAUSAL_LM,
|
199 |
-
inference_mode=False,
|
200 |
-
r=config.lora_r,
|
201 |
-
lora_alpha=config.lora_alpha,
|
202 |
-
lora_dropout=config.lora_dropout,
|
203 |
-
target_modules=[
|
204 |
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
205 |
-
"gate_proj", "up_proj", "down_proj",
|
206 |
-
]
|
207 |
-
)
|
208 |
-
model = get_peft_model(model, lora_config)
|
209 |
-
model.print_trainable_parameters()
|
210 |
-
|
211 |
-
# 包装为带价值头的模型
|
212 |
-
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
213 |
-
model,
|
214 |
-
torch_dtype=torch.float16,
|
215 |
-
)
|
216 |
-
|
217 |
-
# 创建参考模型(冻结)
|
218 |
-
ref_model = AutoModelForCausalLM.from_pretrained(
|
219 |
-
config.teacher_model_path,
|
220 |
-
torch_dtype=torch.float16,
|
221 |
-
device_map="auto",
|
222 |
-
)
|
223 |
-
ref_model.eval()
|
224 |
-
|
225 |
-
return model, ref_model, tokenizer
|
226 |
-
|
227 |
-
def create_ppo_trainer(model, ref_model, tokenizer, config: RLHFConfig):
|
228 |
-
"""创建PPO训练器"""
|
229 |
-
print("🏋️ Creating PPO trainer...")
|
230 |
-
|
231 |
-
ppo_config = PPOConfig(
|
232 |
-
model_name=config.teacher_model_path,
|
233 |
-
learning_rate=config.learning_rate,
|
234 |
-
mini_batch_size=config.mini_batch_size,
|
235 |
-
batch_size=config.batch_size,
|
236 |
-
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
237 |
-
ppo_epochs=config.ppo_epochs,
|
238 |
-
max_grad_norm=config.max_grad_norm,
|
239 |
-
init_kl_coef=config.init_kl_coef,
|
240 |
-
target_kl=config.target_kl,
|
241 |
-
adap_kl_ctrl=config.adap_kl_ctrl,
|
242 |
-
clip_reward_value=config.clip_reward_value,
|
243 |
-
cliprange=config.cliprange,
|
244 |
-
cliprange_value=config.cliprange_value,
|
245 |
-
gamma=config.gamma,
|
246 |
-
lam=config.lam,
|
247 |
-
remove_unused_columns=False,
|
248 |
-
log_with="wandb" if wandb.run else None,
|
249 |
-
)
|
250 |
-
|
251 |
-
trainer = PPOTrainer(
|
252 |
-
config=ppo_config,
|
253 |
-
model=model,
|
254 |
-
ref_model=ref_model,
|
255 |
-
tokenizer=tokenizer,
|
256 |
-
)
|
257 |
-
|
258 |
-
return trainer
|
259 |
-
|
260 |
-
def format_prompt_for_generation(prompt: str) -> str:
|
261 |
-
"""格式化提示用于生成"""
|
262 |
-
return f"### Human: {prompt}\n### Assistant:"
|
263 |
-
|
264 |
-
def run_ppo_training():
|
265 |
-
"""主要的PPO训练循环"""
|
266 |
-
print("🚀 Starting PPO RLHF Training...")
|
267 |
-
|
268 |
-
# 初始化wandb
|
269 |
-
wandb.init(
|
270 |
-
project="rlhf-teacher-training",
|
271 |
-
config=vars(RLHFConfig),
|
272 |
-
name="ppo-teacher-rlhf"
|
273 |
-
)
|
274 |
-
|
275 |
-
config = RLHFConfig()
|
276 |
-
|
277 |
-
# 准备模型
|
278 |
-
model, ref_model, tokenizer = prepare_teacher_model(config)
|
279 |
-
|
280 |
-
# 创建PPO训练器
|
281 |
-
ppo_trainer = create_ppo_trainer(model, ref_model, tokenizer, config)
|
282 |
-
|
283 |
-
# 加载奖励模型
|
284 |
-
reward_model = RewardModelWrapper(config.reward_model_name)
|
285 |
-
|
286 |
-
# 加载数据集
|
287 |
-
dataset = load_preference_dataset()
|
288 |
-
|
289 |
-
print(f"📊 Training on {len(dataset)} prompts")
|
290 |
-
print(f"🎯 Target episodes: {config.total_episodes}")
|
291 |
-
|
292 |
-
# 训练循环
|
293 |
-
for episode in range(config.total_episodes):
|
294 |
-
# 随机采样prompts
|
295 |
-
batch_prompts = np.random.choice(
|
296 |
-
dataset["prompt"],
|
297 |
-
size=config.batch_size,
|
298 |
-
replace=False
|
299 |
-
).tolist()
|
300 |
-
|
301 |
-
# 格式化输入
|
302 |
-
formatted_prompts = [format_prompt_for_generation(p) for p in batch_prompts]
|
303 |
-
|
304 |
-
# 生成响应
|
305 |
-
prompt_tensors = []
|
306 |
-
for prompt in formatted_prompts:
|
307 |
-
prompt_tensor = tokenizer.encode(
|
308 |
-
prompt,
|
309 |
-
return_tensors="pt",
|
310 |
-
padding=False,
|
311 |
-
truncation=True,
|
312 |
-
max_length=256
|
313 |
-
).squeeze()
|
314 |
-
prompt_tensors.append(prompt_tensor)
|
315 |
-
|
316 |
-
# 批量生成
|
317 |
-
response_tensors = []
|
318 |
-
with torch.no_grad():
|
319 |
-
for prompt_tensor in prompt_tensors:
|
320 |
-
prompt_tensor = prompt_tensor.unsqueeze(0).to(model.device)
|
321 |
-
|
322 |
-
response = ppo_trainer.generate(
|
323 |
-
prompt_tensor,
|
324 |
-
max_new_tokens=config.max_new_tokens,
|
325 |
-
temperature=config.temperature,
|
326 |
-
top_p=config.top_p,
|
327 |
-
do_sample=config.do_sample,
|
328 |
-
pad_token_id=tokenizer.eos_token_id,
|
329 |
-
)
|
330 |
-
|
331 |
-
# 只保留新生成的部分
|
332 |
-
response = response.squeeze()[prompt_tensor.shape[1]:]
|
333 |
-
response_tensors.append(response)
|
334 |
-
|
335 |
-
# 解码响应
|
336 |
-
responses = [
|
337 |
-
tokenizer.decode(r, skip_special_tokens=True).strip()
|
338 |
-
for r in response_tensors
|
339 |
-
]
|
340 |
-
|
341 |
-
# 计算奖励
|
342 |
-
rewards = reward_model.get_reward(batch_prompts, responses)
|
343 |
-
rewards = [torch.tensor(r, dtype=torch.float) for r in rewards]
|
344 |
-
|
345 |
-
# PPO训练步骤
|
346 |
-
stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
|
347 |
-
|
348 |
-
# 记录统计信息
|
349 |
-
ppo_trainer.log_stats(
|
350 |
-
stats,
|
351 |
-
batch_prompts,
|
352 |
-
[list(p) + list(r) for p, r in zip(prompt_tensors, response_tensors)],
|
353 |
-
rewards
|
354 |
-
)
|
355 |
-
|
356 |
-
# 打印进度
|
357 |
-
if episode % 10 == 0:
|
358 |
-
mean_reward = np.mean([r.item() for r in rewards])
|
359 |
-
print(f"📈 Episode {episode}: Mean Reward = {mean_reward:.4f}")
|
360 |
-
|
361 |
-
# 记录到wandb
|
362 |
-
wandb.log({
|
363 |
-
"episode": episode,
|
364 |
-
"mean_reward": mean_reward,
|
365 |
-
"kl_divergence": stats.get("objective/kl", 0),
|
366 |
-
"policy_loss": stats.get("ppo/loss/policy", 0),
|
367 |
-
"value_loss": stats.get("ppo/loss/value", 0),
|
368 |
-
})
|
369 |
-
|
370 |
-
# 评估模型
|
371 |
-
if episode % config.eval_freq == 0 and episode > 0:
|
372 |
-
evaluate_model(ppo_trainer.model, tokenizer, episode)
|
373 |
-
|
374 |
-
# 保存检查点
|
375 |
-
if episode % config.save_freq == 0 and episode > 0:
|
376 |
-
save_checkpoint(ppo_trainer.model, tokenizer, config.output_dir, episode)
|
377 |
-
|
378 |
-
# 保存最终模型
|
379 |
-
print("💾 Saving final RLHF model...")
|
380 |
-
ppo_trainer.model.save_pretrained(config.output_dir)
|
381 |
-
tokenizer.save_pretrained(config.output_dir)
|
382 |
-
|
383 |
-
wandb.finish()
|
384 |
-
print("✅ RLHF training completed!")
|
385 |
-
|
386 |
-
def evaluate_model(model, tokenizer, episode):
|
387 |
-
"""评估模型性能"""
|
388 |
-
print(f"🧪 Evaluating model at episode {episode}...")
|
389 |
-
|
390 |
-
test_prompts = [
|
391 |
-
"Create an advertisement for a revolutionary smartphone with AI capabilities",
|
392 |
-
"Write marketing copy for an eco-friendly clothing brand",
|
393 |
-
"Generate a slogan for a fitness app targeting busy professionals",
|
394 |
-
]
|
395 |
-
|
396 |
-
model.eval()
|
397 |
-
results = []
|
398 |
-
|
399 |
-
for prompt in test_prompts:
|
400 |
-
formatted_prompt = format_prompt_for_generation(prompt)
|
401 |
-
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
402 |
-
|
403 |
-
with torch.no_grad():
|
404 |
-
outputs = model.generate(
|
405 |
-
**inputs,
|
406 |
-
max_new_tokens=150,
|
407 |
-
temperature=0.7,
|
408 |
-
top_p=0.9,
|
409 |
-
do_sample=True,
|
410 |
-
pad_token_id=tokenizer.eos_token_id,
|
411 |
-
)
|
412 |
-
|
413 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
414 |
-
generated_text = response[len(formatted_prompt):].strip()
|
415 |
-
|
416 |
-
results.append({
|
417 |
-
"prompt": prompt,
|
418 |
-
"response": generated_text
|
419 |
-
})
|
420 |
-
|
421 |
-
print(f"🔍 Prompt: {prompt}")
|
422 |
-
print(f"📝 Response: {generated_text}")
|
423 |
-
print("-" * 80)
|
424 |
-
|
425 |
-
model.train()
|
426 |
-
return results
|
427 |
-
|
428 |
-
def save_checkpoint(model, tokenizer, output_dir, episode):
|
429 |
-
"""保存训练检查点"""
|
430 |
-
checkpoint_dir = f"{output_dir}/checkpoint-{episode}"
|
431 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
432 |
-
|
433 |
-
model.save_pretrained(checkpoint_dir)
|
434 |
-
tokenizer.save_pretrained(checkpoint_dir)
|
435 |
-
|
436 |
-
print(f"💾 Checkpoint saved to {checkpoint_dir}")
|
437 |
-
|
438 |
-
def load_checkpoint_and_continue(checkpoint_path):
|
439 |
-
"""从检查点继续训练"""
|
440 |
-
print(f"📥 Loading checkpoint from {checkpoint_path}")
|
441 |
-
|
442 |
-
# 实现检查点恢复逻辑
|
443 |
-
pass
|
444 |
-
|
445 |
-
if __name__ == "__main__":
|
446 |
-
# 设置环境变量
|
447 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 多GPU设置
|
448 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
449 |
-
|
450 |
-
# 检查GPU资源
|
451 |
-
if torch.cuda.is_available():
|
452 |
-
print(f"🔥 Using {torch.cuda.device_count()} GPUs")
|
453 |
-
for i in range(torch.cuda.device_count()):
|
454 |
-
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
455 |
-
else:
|
456 |
-
raise RuntimeError("❌ CUDA not available! RLHF requires GPU.")
|
457 |
-
|
458 |
-
# 开始训练
|
459 |
-
run_ppo_training()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lauguage_model_fine_tuning/sft_teacher.py
DELETED
@@ -1,276 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
QLoRA Fine-tuning script for OpenAI OSS 120B model
|
4 |
-
Using smangrul/ad-copy-generation dataset for advertisement copy generation
|
5 |
-
"""
|
6 |
-
|
7 |
-
import os
|
8 |
-
import torch
|
9 |
-
from datasets import load_dataset
|
10 |
-
from transformers import (
|
11 |
-
AutoModelForCausalLM,
|
12 |
-
AutoTokenizer,
|
13 |
-
BitsAndBytesConfig,
|
14 |
-
TrainingArguments,
|
15 |
-
pipeline,
|
16 |
-
logging,
|
17 |
-
)
|
18 |
-
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
19 |
-
from trl import SFTTrainer
|
20 |
-
import warnings
|
21 |
-
|
22 |
-
# Suppress warnings
|
23 |
-
warnings.filterwarnings("ignore")
|
24 |
-
logging.set_verbosity(logging.CRITICAL)
|
25 |
-
|
26 |
-
# Configuration
|
27 |
-
class Config:
|
28 |
-
# Model configuration
|
29 |
-
model_name = "microsoft/DialoGPT-medium" # Replace with actual OpenAI OSS 120B model name
|
30 |
-
dataset_name = "smangrul/ad-copy-generation"
|
31 |
-
|
32 |
-
# Training parameters
|
33 |
-
output_dir = "./sft_results"
|
34 |
-
num_train_epochs = 3
|
35 |
-
per_device_train_batch_size = 1
|
36 |
-
gradient_accumulation_steps = 4
|
37 |
-
optim = "paged_adamw_32bit"
|
38 |
-
save_steps = 25
|
39 |
-
logging_steps = 25
|
40 |
-
learning_rate = 2e-4
|
41 |
-
weight_decay = 0.001
|
42 |
-
fp16 = False
|
43 |
-
bf16 = False
|
44 |
-
max_grad_norm = 0.3
|
45 |
-
max_steps = -1
|
46 |
-
warmup_ratio = 0.03
|
47 |
-
group_by_length = True
|
48 |
-
lr_scheduler_type = "constant"
|
49 |
-
report_to = "tensorboard"
|
50 |
-
|
51 |
-
# QLoRA parameters
|
52 |
-
lora_alpha = 16
|
53 |
-
lora_dropout = 0.1
|
54 |
-
lora_r = 64
|
55 |
-
|
56 |
-
# bitsandbytes parameters
|
57 |
-
use_4bit = True
|
58 |
-
bnb_4bit_compute_dtype = "float16"
|
59 |
-
bnb_4bit_quant_type = "nf4"
|
60 |
-
use_nested_quant = False
|
61 |
-
|
62 |
-
# SFT parameters
|
63 |
-
max_seq_length = 512
|
64 |
-
packing = False
|
65 |
-
|
66 |
-
def create_bnb_config():
|
67 |
-
"""Create BitsAndBytesConfig for 4-bit quantization"""
|
68 |
-
bnb_config = BitsAndBytesConfig(
|
69 |
-
load_in_4bit=Config.use_4bit,
|
70 |
-
bnb_4bit_quant_type=Config.bnb_4bit_quant_type,
|
71 |
-
bnb_4bit_compute_dtype=getattr(torch, Config.bnb_4bit_compute_dtype),
|
72 |
-
bnb_4bit_use_double_quant=Config.use_nested_quant,
|
73 |
-
)
|
74 |
-
return bnb_config
|
75 |
-
|
76 |
-
def load_model_and_tokenizer():
|
77 |
-
"""Load model and tokenizer with quantization"""
|
78 |
-
print("Loading model and tokenizer...")
|
79 |
-
|
80 |
-
# Create BnB config
|
81 |
-
bnb_config = create_bnb_config()
|
82 |
-
|
83 |
-
# Load model
|
84 |
-
model = AutoModelForCausalLM.from_pretrained(
|
85 |
-
Config.model_name,
|
86 |
-
quantization_config=bnb_config,
|
87 |
-
device_map="auto",
|
88 |
-
trust_remote_code=True,
|
89 |
-
use_auth_token=True, # If using gated model
|
90 |
-
)
|
91 |
-
model.config.use_cache = False
|
92 |
-
model.config.pretraining_tp = 1
|
93 |
-
|
94 |
-
# Load tokenizer
|
95 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
96 |
-
Config.model_name,
|
97 |
-
trust_remote_code=True,
|
98 |
-
use_auth_token=True, # If using gated model
|
99 |
-
)
|
100 |
-
tokenizer.pad_token = tokenizer.eos_token
|
101 |
-
tokenizer.padding_side = "right"
|
102 |
-
|
103 |
-
return model, tokenizer
|
104 |
-
|
105 |
-
def create_peft_config():
|
106 |
-
"""Create PEFT (LoRA) configuration"""
|
107 |
-
peft_config = LoraConfig(
|
108 |
-
task_type=TaskType.CAUSAL_LM,
|
109 |
-
inference_mode=False,
|
110 |
-
r=Config.lora_r,
|
111 |
-
lora_alpha=Config.lora_alpha,
|
112 |
-
lora_dropout=Config.lora_dropout,
|
113 |
-
target_modules=[
|
114 |
-
"q_proj",
|
115 |
-
"k_proj",
|
116 |
-
"v_proj",
|
117 |
-
"o_proj",
|
118 |
-
"gate_proj",
|
119 |
-
"up_proj",
|
120 |
-
"down_proj",
|
121 |
-
]
|
122 |
-
)
|
123 |
-
return peft_config
|
124 |
-
|
125 |
-
def load_and_prepare_dataset(tokenizer):
|
126 |
-
"""Load and prepare the dataset"""
|
127 |
-
print("Loading dataset...")
|
128 |
-
|
129 |
-
# Load dataset
|
130 |
-
dataset = load_dataset(Config.dataset_name, split="train")
|
131 |
-
print(f"Dataset loaded: {len(dataset)} samples")
|
132 |
-
|
133 |
-
# Format dataset for chat completion
|
134 |
-
def format_prompts(examples):
|
135 |
-
texts = []
|
136 |
-
for conversation in examples["conversations"]:
|
137 |
-
if len(conversation) >= 2:
|
138 |
-
user_msg = conversation[0]["value"]
|
139 |
-
assistant_msg = conversation[1]["value"]
|
140 |
-
|
141 |
-
# Format as chat template
|
142 |
-
text = f"### Human: {user_msg}\n### Assistant: {assistant_msg}{tokenizer.eos_token}"
|
143 |
-
texts.append(text)
|
144 |
-
else:
|
145 |
-
# Fallback for malformed data
|
146 |
-
texts.append(f"### Human: Create an advertisement\n### Assistant: {conversation[0]['value']}{tokenizer.eos_token}")
|
147 |
-
|
148 |
-
return {"text": texts}
|
149 |
-
|
150 |
-
# Apply formatting
|
151 |
-
dataset = dataset.map(
|
152 |
-
format_prompts,
|
153 |
-
batched=True,
|
154 |
-
remove_columns=dataset.column_names
|
155 |
-
)
|
156 |
-
|
157 |
-
return dataset
|
158 |
-
|
159 |
-
def create_training_arguments():
|
160 |
-
"""Create training arguments"""
|
161 |
-
training_arguments = TrainingArguments(
|
162 |
-
output_dir=Config.output_dir,
|
163 |
-
num_train_epochs=Config.num_train_epochs,
|
164 |
-
per_device_train_batch_size=Config.per_device_train_batch_size,
|
165 |
-
gradient_accumulation_steps=Config.gradient_accumulation_steps,
|
166 |
-
optim=Config.optim,
|
167 |
-
save_steps=Config.save_steps,
|
168 |
-
logging_steps=Config.logging_steps,
|
169 |
-
learning_rate=Config.learning_rate,
|
170 |
-
weight_decay=Config.weight_decay,
|
171 |
-
fp16=Config.fp16,
|
172 |
-
bf16=Config.bf16,
|
173 |
-
max_grad_norm=Config.max_grad_norm,
|
174 |
-
max_steps=Config.max_steps,
|
175 |
-
warmup_ratio=Config.warmup_ratio,
|
176 |
-
group_by_length=Config.group_by_length,
|
177 |
-
lr_scheduler_type=Config.lr_scheduler_type,
|
178 |
-
report_to=Config.report_to,
|
179 |
-
save_strategy="steps",
|
180 |
-
evaluation_strategy="no",
|
181 |
-
load_best_model_at_end=False,
|
182 |
-
push_to_hub=False,
|
183 |
-
remove_unused_columns=False,
|
184 |
-
)
|
185 |
-
return training_arguments
|
186 |
-
|
187 |
-
def main():
|
188 |
-
"""Main fine-tuning function"""
|
189 |
-
print("🚀 Starting QLoRA fine-tuning of OpenAI OSS 120B model")
|
190 |
-
|
191 |
-
# Check CUDA availability
|
192 |
-
if not torch.cuda.is_available():
|
193 |
-
raise RuntimeError("CUDA is required for this training script")
|
194 |
-
|
195 |
-
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
196 |
-
print(f"Available VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
197 |
-
|
198 |
-
# Load model and tokenizer
|
199 |
-
model, tokenizer = load_model_and_tokenizer()
|
200 |
-
|
201 |
-
# Apply PEFT
|
202 |
-
peft_config = create_peft_config()
|
203 |
-
model = get_peft_model(model, peft_config)
|
204 |
-
model.print_trainable_parameters()
|
205 |
-
|
206 |
-
# Load and prepare dataset
|
207 |
-
dataset = load_and_prepare_dataset(tokenizer)
|
208 |
-
|
209 |
-
# Create training arguments
|
210 |
-
training_arguments = create_training_arguments()
|
211 |
-
|
212 |
-
# Create trainer
|
213 |
-
trainer = SFTTrainer(
|
214 |
-
model=model,
|
215 |
-
train_dataset=dataset,
|
216 |
-
peft_config=peft_config,
|
217 |
-
dataset_text_field="text",
|
218 |
-
max_seq_length=Config.max_seq_length,
|
219 |
-
tokenizer=tokenizer,
|
220 |
-
args=training_arguments,
|
221 |
-
packing=Config.packing,
|
222 |
-
)
|
223 |
-
|
224 |
-
# Start training
|
225 |
-
print("🔥 Starting training...")
|
226 |
-
trainer.train()
|
227 |
-
|
228 |
-
# Save model
|
229 |
-
print("💾 Saving model...")
|
230 |
-
trainer.model.save_pretrained(Config.output_dir)
|
231 |
-
tokenizer.save_pretrained(Config.output_dir)
|
232 |
-
|
233 |
-
print("✅ Training completed!")
|
234 |
-
|
235 |
-
# Test the model
|
236 |
-
test_model(trainer.model, tokenizer)
|
237 |
-
|
238 |
-
def test_model(model, tokenizer):
|
239 |
-
"""Test the fine-tuned model"""
|
240 |
-
print("\n🧪 Testing the fine-tuned model...")
|
241 |
-
|
242 |
-
# Test prompts
|
243 |
-
test_prompts = [
|
244 |
-
"Create an advertisement for a new smartphone with advanced camera features",
|
245 |
-
"Write ad copy for an eco-friendly clothing brand targeting young professionals",
|
246 |
-
"Generate marketing content for a fitness app with AI personal trainer",
|
247 |
-
]
|
248 |
-
|
249 |
-
for prompt in test_prompts:
|
250 |
-
formatted_prompt = f"### Human: {prompt}\n### Assistant:"
|
251 |
-
|
252 |
-
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
253 |
-
|
254 |
-
with torch.no_grad():
|
255 |
-
outputs = model.generate(
|
256 |
-
**inputs,
|
257 |
-
max_new_tokens=150,
|
258 |
-
do_sample=True,
|
259 |
-
temperature=0.7,
|
260 |
-
top_p=0.9,
|
261 |
-
pad_token_id=tokenizer.eos_token_id,
|
262 |
-
)
|
263 |
-
|
264 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
265 |
-
generated_text = response[len(formatted_prompt):].strip()
|
266 |
-
|
267 |
-
print(f"\n📝 Prompt: {prompt}")
|
268 |
-
print(f"📄 Generated: {generated_text}")
|
269 |
-
print("-" * 50)
|
270 |
-
|
271 |
-
if __name__ == "__main__":
|
272 |
-
# Set environment variables
|
273 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
274 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
275 |
-
|
276 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ppo_tune.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trl import PPOTrainer, PPOConfig
|
2 |
+
from peft import PeftModel
|
3 |
+
import torch, random, json, glob
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
from reward_model import CLIPModel, CLIPProcessor
|
6 |
+
|
7 |
+
rm=CLIPModel.from_pretrained("rm").eval().half().cuda()
|
8 |
+
proc=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
9 |
+
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
|
10 |
+
ppo_cfg=PPOConfig(batch_size=1,learning_rate=1e-6,target_kl=0.2)
|
11 |
+
trainer=PPOTrainer(model=pipe.unet, reward_model=rm, config=ppo_cfg)
|
12 |
+
|
13 |
+
prompts=[l.strip() for l in open("prompt.txt")]
|
14 |
+
for step in range(500):
|
15 |
+
p=random.choice(prompts)
|
16 |
+
img=pipe(p,num_inference_steps=20).images[0]
|
17 |
+
reward=rm(**proc(text=p,images=img,return_tensors="pt").to("cuda")).logits[0,0].item()
|
18 |
+
trainer.step(prompts=[p], rewards=[reward])
|
19 |
+
pipe.save_pretrained("nyc-ad-model-rlhf")
|
requirements.txt
CHANGED
@@ -1,55 +1,16 @@
|
|
1 |
-
|
2 |
-
torch>=2.0.0
|
3 |
-
torchvision
|
4 |
-
xformers
|
5 |
-
|
6 |
-
# Transformers生态
|
7 |
-
transformers>=4.35.0
|
8 |
-
accelerate>=0.24.0
|
9 |
-
tokenizers
|
10 |
-
huggingface_hub
|
11 |
-
|
12 |
-
# 数据处理
|
13 |
-
datasets>=2.14.0
|
14 |
-
numpy>=1.24.0
|
15 |
-
sentence-transformers
|
16 |
-
faiss-cpu
|
17 |
-
|
18 |
-
# 模型微调和RLHF
|
19 |
-
peft>=0.9.0
|
20 |
-
trl[peft]>=0.7.10
|
21 |
-
bitsandbytes>=0.41.0
|
22 |
-
|
23 |
-
# 图像生成
|
24 |
diffusers
|
25 |
invisible_watermark
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
# API和网络请求
|
31 |
flickrapi
|
32 |
requests
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
sacrebleu
|
41 |
-
rouge-score
|
42 |
-
|
43 |
-
# 系统工具和监控
|
44 |
-
scipy
|
45 |
-
protobuf
|
46 |
-
sentencepiece
|
47 |
-
alive_progress
|
48 |
-
psutil
|
49 |
-
gpustat
|
50 |
-
|
51 |
-
# 高级优化器(可选)
|
52 |
-
deepspeed>=0.10.0
|
53 |
-
|
54 |
-
# RLHF特定工具
|
55 |
-
reward-bench
|
|
|
1 |
+
accelerate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
diffusers
|
3 |
invisible_watermark
|
4 |
+
torch
|
5 |
+
transformers
|
6 |
+
xformers
|
7 |
+
torchvision
|
|
|
8 |
flickrapi
|
9 |
requests
|
10 |
+
peft>=0.9.0
|
11 |
+
bitsandbytes
|
12 |
+
faiss-cpu
|
13 |
+
sentence-transformers
|
14 |
+
trl[peft]
|
15 |
+
label-studio
|
16 |
+
datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval_augmented_generation/build_embeddings.py
DELETED
@@ -1,246 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
简洁版BERT+FAISS标语数据库
|
4 |
-
输入:产品/业务描述
|
5 |
-
输出:匹配的广告标语
|
6 |
-
"""
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import faiss
|
10 |
-
import json
|
11 |
-
from sentence_transformers import SentenceTransformer
|
12 |
-
from datasets import Dataset
|
13 |
-
import pandas as pd
|
14 |
-
|
15 |
-
class SloganDatabase:
|
16 |
-
def __init__(self):
|
17 |
-
self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
|
18 |
-
self.index = None
|
19 |
-
self.slogans = []
|
20 |
-
|
21 |
-
def create_dataset(self):
|
22 |
-
"""创建标语数据集 - 珠宝首饰奢侈品领域"""
|
23 |
-
# 示例数据:[品牌, 类别, 描述, 标语]
|
24 |
-
data = [
|
25 |
-
# 顶级珠宝品牌
|
26 |
-
["Tiffany & Co.", "jewelry", "luxury diamond jewelry and engagement rings", "A Diamond is Forever"],
|
27 |
-
["Cartier", "luxury_jewelry", "high-end jewelry watches and accessories", "L'art de vivre"],
|
28 |
-
["Van Cleef & Arpels", "jewelry", "French luxury jewelry and watches", "Poetry of Time"],
|
29 |
-
["Harry Winston", "jewelry", "rare diamonds and luxury jewelry", "Rare Jewels of the World"],
|
30 |
-
["Bulgari", "jewelry", "Italian luxury jewelry and watches", "Italian Excellence"],
|
31 |
-
["Chopard", "jewelry", "Swiss luxury jewelry and watches", "Happy Diamonds"],
|
32 |
-
["Graff", "jewelry", "exceptional diamonds and jewelry", "The Most Fabulous Jewels in the World"],
|
33 |
-
["Piaget", "jewelry", "Swiss luxury watches and jewelry", "Possession"],
|
34 |
-
["Boucheron", "jewelry", "French high jewelry and luxury watches", "Le Joaillier Depuis 1858"],
|
35 |
-
["Mikimoto", "jewelry", "cultured pearl jewelry", "The Originator of Cultured Pearls"],
|
36 |
-
|
37 |
-
# 奢侈品牌
|
38 |
-
["Louis Vuitton", "luxury_fashion", "luxury leather goods and fashion", "The Art of Travel"],
|
39 |
-
["Hermès", "luxury_fashion", "French luxury goods and accessories", "Luxury in the making"],
|
40 |
-
["Chanel", "luxury_fashion", "haute couture and luxury fashion", "Inside every woman there is a flower and a cat"],
|
41 |
-
["Gucci", "luxury_fashion", "Italian luxury fashion and accessories", "Quality is remembered long after price is forgotten"],
|
42 |
-
["Prada", "luxury_fashion", "Italian luxury fashion house", "Prada"],
|
43 |
-
["Dior", "luxury_fashion", "French luxury fashion and beauty", "Miss Dior"],
|
44 |
-
["Versace", "luxury_fashion", "Italian luxury fashion design", "Virtus"],
|
45 |
-
["Saint Laurent", "luxury_fashion", "French luxury fashion house", "Saint Laurent Paris"],
|
46 |
-
["Balenciaga", "luxury_fashion", "Spanish luxury fashion house", "Balenciaga"],
|
47 |
-
["Bottega Veneta", "luxury_fashion", "Italian luxury leather goods", "When your own initials are enough"],
|
48 |
-
|
49 |
-
# 腕表品牌
|
50 |
-
["Rolex", "luxury_watches", "Swiss luxury watches and timepieces", "Perpetual, Spirit of Excellence"],
|
51 |
-
["Patek Philippe", "luxury_watches", "Swiss luxury watch manufacturer", "You never actually own a Patek Philippe"],
|
52 |
-
["Audemars Piguet", "luxury_watches", "Swiss luxury watch brand", "To break the rules, you must first master them"],
|
53 |
-
["Omega", "luxury_watches", "Swiss luxury watch manufacturer", "Precision"],
|
54 |
-
["TAG Heuer", "luxury_watches", "Swiss luxury watches", "Don't crack under pressure"],
|
55 |
-
["Breitling", "luxury_watches", "Swiss luxury watchmaker", "Instruments for Professionals"],
|
56 |
-
["IWC", "luxury_watches", "Swiss luxury watch company", "Engineered for men"],
|
57 |
-
["Jaeger-LeCoultre", "luxury_watches", "Swiss luxury watch manufacturer", "The World's Most Complicated Watches"],
|
58 |
-
["Vacheron Constantin", "luxury_watches", "Swiss luxury watch manufacturer", "One of Not Many"],
|
59 |
-
["A. Lange & Söhne", "luxury_watches", "German luxury watch manufacturer", "When nothing else will do"],
|
60 |
-
|
61 |
-
# 时尚首饰
|
62 |
-
["Pandora", "fashion_jewelry", "Danish jewelry brand charm bracelets", "Be Love"],
|
63 |
-
["Swarovski", "fashion_jewelry", "Austrian crystal jewelry and accessories", "Unleash Your Light"],
|
64 |
-
["Daniel Wellington", "fashion_watches", "Swedish watch brand minimalist design", "Live the moment"],
|
65 |
-
["Alex and Ani", "fashion_jewelry", "American jewelry brand spiritual bracelets", "Positive Energy"],
|
66 |
-
["Kendra Scott", "fashion_jewelry", "American jewelry designer colorful stones", "Live colorfully"],
|
67 |
-
["Monica Vinader", "fashion_jewelry", "British jewelry brand contemporary design", "Everyday luxury"],
|
68 |
-
["Mejuri", "fashion_jewelry", "Canadian jewelry brand everyday luxury", "Everyday fine"],
|
69 |
-
["Gorjana", "fashion_jewelry", "California jewelry brand layered necklaces", "Live your layer"],
|
70 |
-
["Kate Spade", "fashion_jewelry", "American fashion accessories jewelry", "Live colorfully"],
|
71 |
-
["Marc Jacobs", "fashion_jewelry", "American fashion designer accessories", "Marc Jacobs"],
|
72 |
-
|
73 |
-
# 珠宝定制
|
74 |
-
["Blue Nile", "diamond_jewelry", "online diamond jewelry retailer", "Extraordinary diamonds for extraordinary moments"],
|
75 |
-
["James Allen", "diamond_jewelry", "online engagement ring retailer", "See it. Love it. Own it."],
|
76 |
-
["Brilliant Earth", "diamond_jewelry", "ethical diamond jewelry", "Brilliant Earth"],
|
77 |
-
["With Clarity", "diamond_jewelry", "lab-grown diamond jewelry", "Diamonds. Redefined."],
|
78 |
-
["Clean Origin", "diamond_jewelry", "lab-created diamond jewelry", "Grown with love"],
|
79 |
-
["Ritani", "diamond_jewelry", "engagement rings and wedding bands", "Love is in the details"],
|
80 |
-
["Vrai", "diamond_jewelry", "lab-grown diamond jewelry", "Created, not mined"],
|
81 |
-
["Catbird", "jewelry", "Brooklyn-based jewelry designer", "Made in Brooklyn"],
|
82 |
-
["Wwake", "jewelry", "contemporary fine jewelry designer", "Wwake"],
|
83 |
-
["Jacquie Aiche", "jewelry", "California jewelry designer bohemian luxury", "Jacquie Aiche"],
|
84 |
-
|
85 |
-
# 中国珠宝品牌
|
86 |
-
["周大福", "jewelry", "香港珠宝品牌黄金钻石", "心意足金"],
|
87 |
-
["周生生", "jewelry", "香港珠宝品牌传统工艺", "传承经典"],
|
88 |
-
["老凤祥", "jewelry", "中国传统珠宝品牌黄金首饰", "老凤祥,真金不怕火炼"],
|
89 |
-
["六福珠宝", "jewelry", "香港珠宝品牌时尚设计", "六福临门"],
|
90 |
-
["潘多拉", "jewelry", "丹麦珠宝品牌串珠手链", "表达你的故事"],
|
91 |
-
["周大生", "jewelry", "中国珠宝品牌钻石首饰", "爱就在一起"],
|
92 |
-
["金伯利", "jewelry", "中国钻石珠宝品牌", "只为更好的你"],
|
93 |
-
["戴比尔斯", "diamond_jewelry", "钻石开采珠宝品牌", "钻石恒久远,一颗永流传"],
|
94 |
-
["施华洛世奇", "crystal_jewelry", "奥地利水晶珠宝品牌", "释放你的光芒"],
|
95 |
-
["谢瑞麟", "jewelry", "香港珠宝设计师品牌", "艺术珠宝"],
|
96 |
-
|
97 |
-
# 奢侈品配饰
|
98 |
-
["Goyard", "luxury_accessories", "French luxury leather goods", "Goyard"],
|
99 |
-
["Moynat", "luxury_accessories", "French luxury leather goods", "Moynat"],
|
100 |
-
["Berluti", "luxury_accessories", "French luxury leather goods", "Berluti"],
|
101 |
-
["Valextra", "luxury_accessories", "Italian luxury leather goods", "Milanese excellence since 1937"],
|
102 |
-
["Loewe", "luxury_accessories", "Spanish luxury leather goods", "Craft"],
|
103 |
-
["Brunello Cucinelli", "luxury_fashion", "Italian luxury fashion cashmere", "Humanistic Enterprise"],
|
104 |
-
["Loro Piana", "luxury_fashion", "Italian luxury textile and clothing", "Excellence in natural fibers"],
|
105 |
-
["Kiton", "luxury_fashion", "Italian luxury menswear", "The most beautiful thing made by man"],
|
106 |
-
["Zegna", "luxury_fashion", "Italian luxury menswear", "What makes a man"],
|
107 |
-
["Brioni", "luxury_fashion", "Italian luxury menswear", "Roman style"],
|
108 |
-
|
109 |
-
# 新兴奢侈品牌
|
110 |
-
["Jacquemus", "luxury_fashion", "French luxury fashion house", "La Montagne"],
|
111 |
-
["Ganni", "luxury_fashion", "Danish fashion brand", "Ganni"],
|
112 |
-
["Staud", "luxury_fashion", "American fashion brand", "Staud"],
|
113 |
-
["Cult Gaia", "luxury_accessories", "American accessories brand", "Cult Gaia"],
|
114 |
-
["Rosantica", "jewelry", "Italian jewelry brand", "Rosantica"],
|
115 |
-
["Alighieri", "jewelry", "British jewelry brand", "The Inferno"],
|
116 |
-
["Lizzie Fortunato", "jewelry", "American jewelry brand", "Lizzie Fortunato"],
|
117 |
-
["Aurate", "jewelry", "American jewelry brand", "Accessible luxury"],
|
118 |
-
["AUrate New York", "jewelry", "New York jewelry brand", "Radically responsible luxury"],
|
119 |
-
["Missoma", "jewelry", "British jewelry brand", "Missoma"]
|
120 |
-
]
|
121 |
-
|
122 |
-
# 转换为DataFrame
|
123 |
-
df = pd.DataFrame(data, columns=['brand', 'category', 'description', 'slogan'])
|
124 |
-
|
125 |
-
# 创建搜索文本(组合描述信息)
|
126 |
-
df['search_text'] = df['brand'] + ' ' + df['category'] + ' ' + df['description']
|
127 |
-
|
128 |
-
return df.to_dict('records')
|
129 |
-
|
130 |
-
def build_index(self, data):
|
131 |
-
"""构建FAISS索引"""
|
132 |
-
print("🔨 Building FAISS index...")
|
133 |
-
|
134 |
-
# 提取搜索文本
|
135 |
-
texts = [item['search_text'] for item in data]
|
136 |
-
|
137 |
-
# 生成embeddings
|
138 |
-
embeddings = self.encoder.encode(texts, show_progress_bar=True)
|
139 |
-
|
140 |
-
# 构建索引
|
141 |
-
self.index = faiss.IndexFlatIP(384) # 使用内积相似度
|
142 |
-
self.index.add(embeddings.astype('float32'))
|
143 |
-
|
144 |
-
# 保存数据
|
145 |
-
self.slogans = data
|
146 |
-
|
147 |
-
print(f"✅ Index built with {len(data)} slogans")
|
148 |
-
|
149 |
-
def search(self, query, k=5):
|
150 |
-
"""搜索相似标语"""
|
151 |
-
if not self.index:
|
152 |
-
raise ValueError("Index not built yet!")
|
153 |
-
|
154 |
-
# 编码查询
|
155 |
-
query_embedding = self.encoder.encode([query])
|
156 |
-
|
157 |
-
# 搜索
|
158 |
-
scores, indices = self.index.search(query_embedding.astype('float32'), k)
|
159 |
-
|
160 |
-
# 返回结果
|
161 |
-
results = []
|
162 |
-
for score, idx in zip(scores[0], indices[0]):
|
163 |
-
if idx < len(self.slogans):
|
164 |
-
result = self.slogans[idx].copy()
|
165 |
-
result['similarity_score'] = float(score)
|
166 |
-
results.append(result)
|
167 |
-
|
168 |
-
return results
|
169 |
-
|
170 |
-
def save(self, path="slogan_db"):
|
171 |
-
"""保存数据库"""
|
172 |
-
# 保存FAISS索引
|
173 |
-
faiss.write_index(self.index, f"{path}.faiss")
|
174 |
-
|
175 |
-
# 保存标语数据
|
176 |
-
with open(f"{path}.json", 'w', encoding='utf-8') as f:
|
177 |
-
json.dump(self.slogans, f, ensure_ascii=False, indent=2)
|
178 |
-
|
179 |
-
print(f"💾 Database saved to {path}")
|
180 |
-
|
181 |
-
def load(self, path="slogan_db"):
|
182 |
-
"""加载数据库"""
|
183 |
-
try:
|
184 |
-
# 加载FAISS索引
|
185 |
-
self.index = faiss.read_index(f"{path}.faiss")
|
186 |
-
|
187 |
-
# 加载标语数据
|
188 |
-
with open(f"{path}.json", 'r', encoding='utf-8') as f:
|
189 |
-
self.slogans = json.load(f)
|
190 |
-
|
191 |
-
print(f"📂 Database loaded from {path}")
|
192 |
-
return True
|
193 |
-
except:
|
194 |
-
print(f"❌ Failed to load database from {path}")
|
195 |
-
return False
|
196 |
-
|
197 |
-
def main():
|
198 |
-
"""主函数"""
|
199 |
-
print("🚀 Creating Slogan Database...")
|
200 |
-
|
201 |
-
# 初始化
|
202 |
-
db = SloganDatabase()
|
203 |
-
|
204 |
-
# 尝试加载现有数据库
|
205 |
-
if not db.load():
|
206 |
-
print("📊 Creating new database...")
|
207 |
-
|
208 |
-
# 创建数据集
|
209 |
-
data = db.create_dataset()
|
210 |
-
|
211 |
-
# 构建索引
|
212 |
-
db.build_index(data)
|
213 |
-
|
214 |
-
# 保存数据库
|
215 |
-
db.save()
|
216 |
-
|
217 |
-
# 测试搜索
|
218 |
-
test_queries = [
|
219 |
-
"钻石订婚戒指",
|
220 |
-
"奢侈品手袋",
|
221 |
-
"瑞士手表品牌",
|
222 |
-
"珍珠首饰",
|
223 |
-
"黄金项链",
|
224 |
-
"时尚耳环",
|
225 |
-
"luxury jewelry brand",
|
226 |
-
"designer handbag",
|
227 |
-
"crystal accessories",
|
228 |
-
"wedding rings"
|
229 |
-
]
|
230 |
-
|
231 |
-
print("\n🔍 Testing searches...")
|
232 |
-
for query in test_queries:
|
233 |
-
print(f"\n查询: {query}")
|
234 |
-
print("-" * 40)
|
235 |
-
|
236 |
-
results = db.search(query, k=3)
|
237 |
-
|
238 |
-
for i, result in enumerate(results, 1):
|
239 |
-
print(f"{i}. {result['brand']} ({result['category']})")
|
240 |
-
print(f" 描述: {result['description']}")
|
241 |
-
print(f" 标语: {result['slogan']}")
|
242 |
-
print(f" 相似度: {result['similarity_score']:.3f}")
|
243 |
-
print()
|
244 |
-
|
245 |
-
if __name__ == "__main__":
|
246 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reward_model.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer
|
2 |
+
import datasets, torch, json, glob
|
3 |
+
|
4 |
+
model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
5 |
+
processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
6 |
+
|
7 |
+
data=[]
|
8 |
+
for f in glob.glob("human_prefs/*.json"):
|
9 |
+
j=json.load(open(f)); data.append(j) # {"prompt":…, "good":img_path, "bad":img_path}
|
10 |
+
|
11 |
+
dataset=datasets.Dataset.from_list(data)
|
12 |
+
|
13 |
+
def preprocess(ex):
|
14 |
+
inputs=processor(text=[ex["prompt"]*2], images=[ex["good"],ex["bad"]], return_tensors="pt")
|
15 |
+
inputs["labels"]=torch.tensor([1,0])
|
16 |
+
return inputs
|
17 |
+
|
18 |
+
dataset=dataset.map(preprocess,remove_columns=dataset.column_names)
|
19 |
+
args=TrainingArguments("rm_ckpt",per_device_train_batch_size=2,fp16=True,learning_rate=5e-6,epochs=3)
|
20 |
+
trainer=Trainer(model,args,train_dataset=dataset)
|
21 |
+
trainer.train(); model.save_pretrained("rm")
|
sft_train.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, json
|
2 |
+
from datasets import load_dataset, Dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
4 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
5 |
+
|
6 |
+
# Load your dataset
|
7 |
+
data = [json.loads(l) for l in open("data/sft_data.jsonl")]
|
8 |
+
dataset = Dataset.from_list(data)
|
9 |
+
|
10 |
+
# Load model & tokenizer
|
11 |
+
base_model = "meta-llama/Llama-2-7b-hf" # Or use Mistral, Falcon, etc.
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
|
13 |
+
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16)
|
14 |
+
|
15 |
+
# Add LoRA (optional)
|
16 |
+
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.05,
|
17 |
+
target_modules=["q_proj", "v_proj"])
|
18 |
+
model = get_peft_model(model, lora_config)
|
19 |
+
|
20 |
+
# Preprocessing
|
21 |
+
def tokenize(example):
|
22 |
+
prompt = f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['output']}"
|
23 |
+
return tokenizer(prompt, truncation=True, max_length=512, padding="max_length")
|
24 |
+
dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
|
25 |
+
|
26 |
+
# Training setup
|
27 |
+
args = TrainingArguments(
|
28 |
+
output_dir="./sft-model",
|
29 |
+
per_device_train_batch_size=2,
|
30 |
+
num_train_epochs=3,
|
31 |
+
fp16=True,
|
32 |
+
evaluation_strategy="no",
|
33 |
+
save_strategy="epoch",
|
34 |
+
logging_steps=20,
|
35 |
+
learning_rate=2e-5,
|
36 |
+
report_to="tensorboard",
|
37 |
+
)
|
38 |
+
|
39 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
40 |
+
trainer = Trainer(model=model, args=args, train_dataset=dataset, data_collator=data_collator)
|
41 |
+
trainer.train()
|
fully_fine_tune_stablediffusion/train_lora.py → train_lora.py
RENAMED
File without changes
|
train_model_test.py
DELETED
@@ -1,238 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import numpy as np
|
3 |
-
from datasets import load_dataset
|
4 |
-
from PIL import Image, ImageOps, ImageFilter
|
5 |
-
from tqdm import tqdm
|
6 |
-
import random
|
7 |
-
import requests
|
8 |
-
import io
|
9 |
-
import time
|
10 |
-
|
11 |
-
def download_image(url, timeout=10, retries=2):
|
12 |
-
"""Download image from URL with retry mechanism"""
|
13 |
-
for attempt in range(retries):
|
14 |
-
try:
|
15 |
-
headers = {
|
16 |
-
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
17 |
-
}
|
18 |
-
response = requests.get(url, timeout=timeout, headers=headers)
|
19 |
-
|
20 |
-
if response.status_code == 200:
|
21 |
-
image = Image.open(io.BytesIO(response.content))
|
22 |
-
return image
|
23 |
-
else:
|
24 |
-
return None
|
25 |
-
|
26 |
-
except Exception as e:
|
27 |
-
if attempt == retries - 1: # Last attempt
|
28 |
-
print(f"Failed to download {url}: {e}")
|
29 |
-
return None
|
30 |
-
time.sleep(0.5) # Brief pause before retry
|
31 |
-
|
32 |
-
return None
|
33 |
-
|
34 |
-
def preprocess_image(image, target_size=512, quality_threshold=0.7):
|
35 |
-
"""Preprocess image with various enhancements"""
|
36 |
-
if image is None:
|
37 |
-
return None
|
38 |
-
|
39 |
-
try:
|
40 |
-
# Convert to RGB if needed
|
41 |
-
if image.mode != 'RGB':
|
42 |
-
image = image.convert('RGB')
|
43 |
-
|
44 |
-
# Filter out low quality images
|
45 |
-
width, height = image.size
|
46 |
-
if min(width, height) < target_size * quality_threshold:
|
47 |
-
return None
|
48 |
-
|
49 |
-
# Center crop to square if not already
|
50 |
-
if width != height:
|
51 |
-
size = min(width, height)
|
52 |
-
left = (width - size) // 2
|
53 |
-
top = (height - size) // 2
|
54 |
-
image = image.crop((left, top, left + size, top + size))
|
55 |
-
|
56 |
-
# Resize to target size
|
57 |
-
image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
58 |
-
|
59 |
-
# Enhance image quality
|
60 |
-
# Slightly sharpen
|
61 |
-
image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3))
|
62 |
-
|
63 |
-
# Auto-adjust levels
|
64 |
-
image = ImageOps.autocontrast(image, cutoff=1)
|
65 |
-
|
66 |
-
return image
|
67 |
-
|
68 |
-
except Exception as e:
|
69 |
-
print(f"Error preprocessing image: {e}")
|
70 |
-
return None
|
71 |
-
|
72 |
-
def clean_prompt(prompt):
|
73 |
-
"""Clean and normalize prompts"""
|
74 |
-
if not prompt:
|
75 |
-
return None
|
76 |
-
|
77 |
-
# Remove excessive whitespace
|
78 |
-
prompt = ' '.join(prompt.split())
|
79 |
-
|
80 |
-
# Remove common artifacts
|
81 |
-
prompt = prompt.replace(' ', ' ')
|
82 |
-
prompt = prompt.strip(' .,;:')
|
83 |
-
|
84 |
-
# Filter out very short or very long prompts
|
85 |
-
words = prompt.split()
|
86 |
-
if len(words) < 3 or len(words) > 50:
|
87 |
-
return None
|
88 |
-
|
89 |
-
return prompt
|
90 |
-
|
91 |
-
def prepare_dreambooth_data():
|
92 |
-
# Load dataset
|
93 |
-
print("Loading LAION dataset...")
|
94 |
-
dataset = load_dataset("laion/laion2B-en-aesthetic", split="train", streaming=True)
|
95 |
-
|
96 |
-
# Create directory structure
|
97 |
-
data_dir = "./laion_dataset"
|
98 |
-
os.makedirs(data_dir, exist_ok=True)
|
99 |
-
|
100 |
-
valid_samples = 0
|
101 |
-
processed_count = 0
|
102 |
-
max_samples = 1000 # Limit total samples to process
|
103 |
-
|
104 |
-
print(f"Starting to process up to {max_samples} samples...")
|
105 |
-
|
106 |
-
# Process images with preprocessing
|
107 |
-
for idx, sample in enumerate(tqdm(dataset, desc="Processing LAION samples")):
|
108 |
-
if processed_count >= max_samples:
|
109 |
-
break
|
110 |
-
|
111 |
-
processed_count += 1
|
112 |
-
|
113 |
-
try:
|
114 |
-
# Get URL and text from LAION format
|
115 |
-
image_url = sample.get('URL', '')
|
116 |
-
text_prompt = sample.get('TEXT', '')
|
117 |
-
|
118 |
-
if not image_url or not text_prompt:
|
119 |
-
continue
|
120 |
-
|
121 |
-
# Clean prompt first
|
122 |
-
prompt = clean_prompt(text_prompt)
|
123 |
-
if prompt is None:
|
124 |
-
continue
|
125 |
-
|
126 |
-
# Download image from URL
|
127 |
-
print(f"Downloading image {valid_samples + 1}: {image_url[:50]}...")
|
128 |
-
image = download_image(image_url)
|
129 |
-
if image is None:
|
130 |
-
continue
|
131 |
-
|
132 |
-
# Preprocess downloaded image
|
133 |
-
processed_image = preprocess_image(image)
|
134 |
-
if processed_image is None:
|
135 |
-
continue
|
136 |
-
|
137 |
-
# Save processed image
|
138 |
-
image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg")
|
139 |
-
processed_image.save(image_path, "JPEG", quality=95, optimize=True)
|
140 |
-
|
141 |
-
# Save cleaned caption
|
142 |
-
caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt")
|
143 |
-
with open(caption_path, 'w', encoding='utf-8') as f:
|
144 |
-
f.write(prompt)
|
145 |
-
|
146 |
-
valid_samples += 1
|
147 |
-
|
148 |
-
# Optional: Add metadata file
|
149 |
-
metadata_path = os.path.join(data_dir, f"image_{valid_samples-1:04d}_meta.txt")
|
150 |
-
with open(metadata_path, 'w', encoding='utf-8') as f:
|
151 |
-
f.write(f"URL: {image_url}\n")
|
152 |
-
f.write(f"Aesthetic: {sample.get('aesthetic', 'N/A')}\n")
|
153 |
-
f.write(f"Width: {sample.get('WIDTH', 'N/A')}\n")
|
154 |
-
f.write(f"Height: {sample.get('HEIGHT', 'N/A')}\n")
|
155 |
-
|
156 |
-
# Stop if we have enough samples
|
157 |
-
if valid_samples >= 100: # Adjust this number as needed
|
158 |
-
break
|
159 |
-
|
160 |
-
except Exception as e:
|
161 |
-
print(f"Error processing sample {idx}: {e}")
|
162 |
-
continue
|
163 |
-
|
164 |
-
print(f"Processed {processed_count} samples, saved {valid_samples} valid images to {data_dir}")
|
165 |
-
return data_dir
|
166 |
-
|
167 |
-
def create_demo_dataset():
|
168 |
-
"""Create demo dataset as last resort"""
|
169 |
-
print("Creating demo dataset...")
|
170 |
-
|
171 |
-
data_dir = "./demo_dataset"
|
172 |
-
os.makedirs(data_dir, exist_ok=True)
|
173 |
-
|
174 |
-
demo_prompts = [
|
175 |
-
"a beautiful landscape with mountains",
|
176 |
-
"portrait of a person with detailed features",
|
177 |
-
"abstract colorful digital artwork",
|
178 |
-
"modern architecture building design",
|
179 |
-
"natural forest scene with trees",
|
180 |
-
"urban cityscape at sunset",
|
181 |
-
"artistic oil painting style",
|
182 |
-
"vintage photography aesthetic",
|
183 |
-
"minimalist geometric composition",
|
184 |
-
"vibrant surreal art piece"
|
185 |
-
]
|
186 |
-
|
187 |
-
for idx, prompt in enumerate(demo_prompts):
|
188 |
-
# Create gradient background
|
189 |
-
color1 = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200))
|
190 |
-
color2 = (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
|
191 |
-
|
192 |
-
image = Image.new('RGB', (512, 512), color1)
|
193 |
-
|
194 |
-
# Save files
|
195 |
-
image_path = os.path.join(data_dir, f"image_{idx:04d}.jpg")
|
196 |
-
image.save(image_path, "JPEG", quality=95)
|
197 |
-
|
198 |
-
caption_path = os.path.join(data_dir, f"image_{idx:04d}.txt")
|
199 |
-
with open(caption_path, 'w', encoding='utf-8') as f:
|
200 |
-
f.write(prompt)
|
201 |
-
|
202 |
-
print(f"Created {len(demo_prompts)} demo samples")
|
203 |
-
return data_dir
|
204 |
-
|
205 |
-
# Main execution with fallback
|
206 |
-
def main():
|
207 |
-
data_dir = prepare_dreambooth_data()
|
208 |
-
|
209 |
-
# Generate training command
|
210 |
-
training_command = f"""
|
211 |
-
accelerate launch \\
|
212 |
-
--deepspeed_config_file ds_config.json \\
|
213 |
-
diffusers/examples/dreambooth/train_dreambooth.py \\
|
214 |
-
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\
|
215 |
-
--instance_data_dir="{data_dir}" \\
|
216 |
-
--instance_prompt="a high quality image" \\
|
217 |
-
--output_dir="./laion-model" \\
|
218 |
-
--resolution=512 \\
|
219 |
-
--train_batch_size=1 \\
|
220 |
-
--gradient_accumulation_steps=1 \\
|
221 |
-
--gradient_checkpointing \\
|
222 |
-
--learning_rate=5e-6 \\
|
223 |
-
--lr_scheduler="constant" \\
|
224 |
-
--lr_warmup_steps=0 \\
|
225 |
-
--max_train_steps=400 \\
|
226 |
-
--mixed_precision="fp16" \\
|
227 |
-
--checkpointing_steps=100 \\
|
228 |
-
--checkpoints_total_limit=1 \\
|
229 |
-
--report_to="tensorboard" \\
|
230 |
-
--logging_dir="./laion-model/logs"
|
231 |
-
"""
|
232 |
-
|
233 |
-
print(f"\n✅ Dataset prepared in: {data_dir}")
|
234 |
-
print("🚀 Run this command to train:")
|
235 |
-
print(training_command)
|
236 |
-
|
237 |
-
if __name__ == "__main__":
|
238 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|