Commit
·
a662214
1
Parent(s):
b7af310
Upload 27 files
Browse files- dreambooth-for-diffusion/.gitignore +17 -0
- dreambooth-for-diffusion/README.md +217 -0
- dreambooth-for-diffusion/back_train.sh +2 -0
- dreambooth-for-diffusion/ckpt_models/model.yaml +69 -0
- dreambooth-for-diffusion/ckpt_models/put_your_ckpt_models_here.txt +0 -0
- dreambooth-for-diffusion/datasets/put_datasets_here.txt +0 -0
- dreambooth-for-diffusion/other/something others.txt +0 -0
- dreambooth-for-diffusion/test_model.py +28 -0
- dreambooth-for-diffusion/test_prompts_object.txt +2 -0
- dreambooth-for-diffusion/test_prompts_style.txt +3 -0
- dreambooth-for-diffusion/tools/ckpt2diffusers.py +835 -0
- dreambooth-for-diffusion/tools/ckpt2diffusers_old.py +619 -0
- dreambooth-for-diffusion/tools/ckpt_merge.py +56 -0
- dreambooth-for-diffusion/tools/ckpt_prune.py +14 -0
- dreambooth-for-diffusion/tools/deepdanbooru-models/put_deepdanbooru_model_here.txt +0 -0
- dreambooth-for-diffusion/tools/diagnose_tensorboard.py +570 -0
- dreambooth-for-diffusion/tools/diffusers2ckpt.py +234 -0
- dreambooth-for-diffusion/tools/handle_images.py +82 -0
- dreambooth-for-diffusion/tools/label_images.py +152 -0
- dreambooth-for-diffusion/tools/test_cuda.py +2 -0
- dreambooth-for-diffusion/tools/train_dreambooth.py +784 -0
- dreambooth-for-diffusion/tools/train_textual_inversion.py +572 -0
- dreambooth-for-diffusion/tools/upload_cos.py +19 -0
- dreambooth-for-diffusion/train_object.sh +79 -0
- dreambooth-for-diffusion/train_style.sh +62 -0
- dreambooth-for-diffusion/train_textual_inversion.sh +29 -0
- dreambooth-for-diffusion/运行.ipynb +452 -0
dreambooth-for-diffusion/.gitignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.ipynb_checkpoints
|
3 |
+
*/.ipynb_checkpoints
|
4 |
+
*.ckpt
|
5 |
+
*.pt
|
6 |
+
*.whl
|
7 |
+
*.log
|
8 |
+
*.png
|
9 |
+
*.jpg
|
10 |
+
nohup.out
|
11 |
+
/datasets
|
12 |
+
/model
|
13 |
+
/new-*
|
14 |
+
/log
|
15 |
+
/output*
|
16 |
+
/tools/deepdanbooru-models/*
|
17 |
+
/tools/diffusers-models/*
|
dreambooth-for-diffusion/README.md
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dreambooth Stable Diffusion 集成化环境训练
|
2 |
+
如果你是在autodl上的机器可以直接使用封装好的镜像创建实例,开箱即用
|
3 |
+
如果是本地或者其他服务器上也可以使用,需要手动安装一些pip包
|
4 |
+
|
5 |
+
## 如何运行
|
6 |
+
直接在autodl使用镜像运行:https://www.codewithgpu.com/i/CrazyBoyM/dreambooth-for-diffusion/dreambooth-for-diffusion
|
7 |
+
|
8 |
+
如果你不熟悉notebook代码的训练方式,也可以直接使用封装好的webui在线镜像(含稳定Dreambooth、dreamArtist训练插件,已fix):
|
9 |
+
https://www.codewithgpu.com/i/CrazyBoyM/sd_dreambooth_extension_webui/dreambooth-dreamartist-for-webui
|
10 |
+
|
11 |
+
## 注意
|
12 |
+
本项目仅供用于学习、测试人工智能技术使用
|
13 |
+
请勿用于训练生成不良或侵权图片内容
|
14 |
+
|
15 |
+
## 关于项目
|
16 |
+
在autodl封装的镜像名称为:dreambooth-for-diffusion
|
17 |
+
可在创建实例时直接选择公开的算法镜像使用。
|
18 |
+
在autodl内蒙A区A5000的机器上封装,如遇到问题且无法自行解决的朋友请使用同一环境。
|
19 |
+
白菜写教程时做了尽可能多的测试,但仍然无法确保每一个环节都完全覆盖
|
20 |
+
如有小错误可尝试手动解决,或者访问git项目地址查看最新的README
|
21 |
+
项目地址:https://github.com/CrazyBoyM/dreambooth-for-diffusion
|
22 |
+
|
23 |
+
如果遇到问题可到b站主页找该教程对应训练演示的视频:https://space.bilibili.com/291593914
|
24 |
+
(因为现在写时视频还没做)
|
25 |
+
|
26 |
+
## 强烈建议
|
27 |
+
1.用vscode的ssh功能远程连接到本服务器,训练体验更好,autodl自带的notebook也不错,有文件上传、下载功能。
|
28 |
+
2.(重要)先把/root/目录下dreambooth-for-diffusion文件夹整个移动到/root/autodl-tmp/路径下(数据盘),避免系统盘空间满
|
29 |
+
|
30 |
+
## 进入工作文件夹
|
31 |
+
```
|
32 |
+
cd /root/autodl-tmp/dreambooth-for-diffusion
|
33 |
+
```
|
34 |
+
|
35 |
+
## 转换ckpt检查点文件为diffusers官方权重
|
36 |
+
已经内置了两个基础模型,可以根据自己数据集的特性选择。
|
37 |
+
- sd_1-5.ckpt是偏真实风格
|
38 |
+
- nd_lastest.ckpt是偏二次元风格
|
39 |
+
开始转换二次元模型:
|
40 |
+
```
|
41 |
+
# 该步需要运行大约一分钟
|
42 |
+
!python tools/ckpt2diffusers.py \
|
43 |
+
--checkpoint_path=./ckpt_models/nd_lastest.ckpt \
|
44 |
+
--dump_path=./model \
|
45 |
+
--vae_path=./ckpt_models/animevae.pt \
|
46 |
+
--original_config_file=./ckpt_models/model.yaml \
|
47 |
+
--scheduler_type="ddim"
|
48 |
+
```
|
49 |
+
转换写实风格模型:
|
50 |
+
```
|
51 |
+
# 该步需要运行大约一分钟
|
52 |
+
!python tools/ckpt2diffusers.py \
|
53 |
+
--checkpoint_path=./ckpt_models/sd_1-5.ckpt \
|
54 |
+
--dump_path=./model \
|
55 |
+
--original_config_file=./ckpt_models/model.yaml \
|
56 |
+
--scheduler_type="ddim"
|
57 |
+
```
|
58 |
+
这里后面跟的两个文件分别是你的ckpt文件和转换后的输出路径。
|
59 |
+
|
60 |
+
## 转换diffusers官方权重为ckpt检查点文件
|
61 |
+
```
|
62 |
+
python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel_half.ckpt --half
|
63 |
+
```
|
64 |
+
如需保存为float16版精度,添加--half参数,权重大小会减半。
|
65 |
+
|
66 |
+
## 准备数据集
|
67 |
+
请按照训练任务准备好对应的数据集。
|
68 |
+
### 图像裁剪为512*512
|
69 |
+
我在tools/handle_images.py中提供了一份批量处理的代码用于参考
|
70 |
+
自动center crop图像,并缩放尺寸
|
71 |
+
```
|
72 |
+
python tools/handle_images.py ./datasets/test ./datasets/test2 --width=512 --height=512
|
73 |
+
```
|
74 |
+
test为未处理的原始图像文件夹,test2为输出处理图像的路径
|
75 |
+
如需处理透明背景png图为黑色/白色底jpg,可以添加--png参数。
|
76 |
+
|
77 |
+
### 图像自动标注
|
78 |
+
使用deepdanbooru生成tags label.
|
79 |
+
```
|
80 |
+
!python tools/label_images.py --path=./datasets/test2
|
81 |
+
```
|
82 |
+
第二个参数--path为你需要标注的图像文件夹路径
|
83 |
+
|
84 |
+
注:如提示deepdanbooru找不到,可自行参考以下仓库进行编译
|
85 |
+
https://github.com/KichangKim/DeepDanbooru
|
86 |
+
|
87 |
+
我在other文件夹下也提供了一份编译好的版本:
|
88 |
+
```
|
89 |
+
pip install other/deepdanbooru-1.0.0-py3-none-any.whl
|
90 |
+
```
|
91 |
+
|
92 |
+
## 训练以及常用命令总结
|
93 |
+
### 配置训练环境(可选)
|
94 |
+
如果你不是在封装好的镜像上直接使用,则需要做以下配置:
|
95 |
+
```
|
96 |
+
pip install accelerate
|
97 |
+
```
|
98 |
+
运行以下命令,并选择本地运行、NO、NO
|
99 |
+
```
|
100 |
+
accelerate config
|
101 |
+
```
|
102 |
+
|
103 |
+
### 开始训练
|
104 |
+
请打开train.sh文件,参考其中的具体参数说明。
|
105 |
+
如果需要训练特定人、事物:
|
106 |
+
(推荐准备3~5张风格统一、特定对象的图片)
|
107 |
+
|
108 |
+
```
|
109 |
+
sh train_object.sh
|
110 |
+
```
|
111 |
+
|
112 |
+
如果要Finetune训练自己的大模型:
|
113 |
+
(推荐准备3000+张图片,包含尽可能的多样性,数据决定训练出的模型质量)
|
114 |
+
```
|
115 |
+
sh train_style.sh
|
116 |
+
```
|
117 |
+
A5000的训练速度大概8分钟/1000步
|
118 |
+
|
119 |
+
### 测试训练效果
|
120 |
+
打开train/test_model.py文件修改其中的model_path和prompt,然后执行:
|
121 |
+
```
|
122 |
+
python test_model.py
|
123 |
+
```
|
124 |
+
|
125 |
+
### 其他常用命令
|
126 |
+
如需后台任务训练:
|
127 |
+
```
|
128 |
+
nohup sh train_style.sh &
|
129 |
+
```
|
130 |
+
推荐晚上这样挂后台跑着,不需要担心连接中断导致的训练停止。
|
131 |
+
白菜个人推荐的省钱训练小妙招:
|
132 |
+
```
|
133 |
+
nohup sh back_train.sh &
|
134 |
+
```
|
135 |
+
(训��完直接自动关机)
|
136 |
+
|
137 |
+
训练日志会输出到nohup.out文件中,可以vscode直接打开或下载查看。
|
138 |
+
查看日志后十行:
|
139 |
+
```
|
140 |
+
tail -n 10 nohup.out
|
141 |
+
```
|
142 |
+
|
143 |
+
查看当前磁盘占有率:
|
144 |
+
(记得清理不要的文件,不然经常容易磁盘几十个g空间满导致模型保存失败!!)
|
145 |
+
```
|
146 |
+
df -h
|
147 |
+
```
|
148 |
+
|
149 |
+
## 如果你是在其他服务器上执行,没有使用集成环境
|
150 |
+
提示缺少一些包可以自行安装:
|
151 |
+
```
|
152 |
+
pip install diffusers
|
153 |
+
pip install ftfy
|
154 |
+
pip install tensorflow-gpu
|
155 |
+
pip install pytorch_lightning
|
156 |
+
pip install OmegaConf
|
157 |
+
... 以及其他的一些
|
158 |
+
```
|
159 |
+
|
160 |
+
## 学术加速(可选)
|
161 |
+
如果你需要拉取git上一些内容,发现速度很慢,以下内容或许对你有帮助。
|
162 |
+
请根据机器所在区域执行以下命令:
|
163 |
+
```
|
164 |
+
北京A区的实例¶
|
165 |
+
export http_proxy=http://100.72.64.19:12798 && export https_proxy=http://100.72.64.19:12798
|
166 |
+
|
167 |
+
内蒙A区的实例¶
|
168 |
+
export http_proxy=http://192.168.1.174:12798 && export https_proxy=http://192.168.1.174:12798
|
169 |
+
|
170 |
+
泉州A区的实例¶
|
171 |
+
export http_proxy=http://10.55.146.88:12798 && export https_proxy=http://10.55.146.88:12798
|
172 |
+
```
|
173 |
+
|
174 |
+
## xformers(可选)
|
175 |
+
由于A5000实测训练和推理的速度已经很快了,就没有安装。
|
176 |
+
如果你使用的是其他显卡或者实在有需要,可以参考下面的地址进行编译使用:
|
177 |
+
https://github.com/facebookresearch/xformers
|
178 |
+
(我猜到你可能想要尝试,已经在train/other目录下放了一个提前编译好的版本啦)
|
179 |
+
注:需要升级pytorch版本到1.12.x以上才能安装使用(好懒)(更新:我已经升级好并帮你装好啦~!)
|
180 |
+
|
181 |
+
## 升级pytorch版本到1.12.x
|
182 |
+
```
|
183 |
+
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
184 |
+
```
|
185 |
+
|
186 |
+
# 关于autodl的使用心得
|
187 |
+
|
188 |
+
## 服务器的数据迁移
|
189 |
+
经常关机后再开机发现机器资源被占用了,这时候你只能另外开一台机器了
|
190 |
+
但是对于已经关机的机器在菜单上有个功能是“跨实例拷贝数据”,
|
191 |
+
可以很方便地同步/root/autodl-tmp文件夹下的内容到其他已开机的机器(所以推荐工作文件都放这)
|
192 |
+
(注意,只适用于同一区域的机器之间)
|
193 |
+
数据迁移教程:https://www.autodl.com/docs/migrate_instance/
|
194 |
+
|
195 |
+
## 传输文件的方式
|
196 |
+
### 方式一 使用VScode
|
197 |
+
直接从vscode拖动上传、下载文件,速度慢,也最简单。
|
198 |
+
|
199 |
+
### 方式二 使用autodl的用户网盘
|
200 |
+
在autodl的网盘界面初始化一个同区域的网盘,然后重启一下服务器实例
|
201 |
+
会发现多了一个文件夹/root/autodl-nas/, 你可以在网页界面进行权重和数据的上传
|
202 |
+
训练完,把生成的权重文件移动到该路径下,就可以去网页上进行下载了。
|
203 |
+
(对应网页:https://www.autodl.com/console/netdisk)
|
204 |
+
注意:初始化的网盘一定要和服务器处于同一区域.
|
205 |
+
|
206 |
+
### 方式三 使用对象存储
|
207 |
+
有条件的朋友也可以尝试使用cos或oss进行文件中转,速度更快。
|
208 |
+
在train/tools文件夹中我也放置了一份上传到cos的代码供参考,请有经验的朋友自行使用。
|
209 |
+
|
210 |
+
autodl官网也有一些推荐的方式可以参考,https://www.autodl.com/docs/scp/
|
211 |
+
|
212 |
+
# 其他内容
|
213 |
+
感谢diffusers、deepdanbooru等开源项目
|
214 |
+
风格训练代码来自nbardy的PR进行修改
|
215 |
+
打tags标签的部分代码来自crosstyan、Nyanko Lepsoni、AUTOMATC1111
|
216 |
+
如果感兴趣欢迎加QQ群探讨交流,455521885
|
217 |
+
封装整理by - 白菜
|
dreambooth-for-diffusion/back_train.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# 省钱训练:训练正常完成后关机
|
2 |
+
sh train_style.sh && shutdown
|
dreambooth-for-diffusion/ckpt_models/model.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
|
19 |
+
scheduler_config: # 10000 warmup steps
|
20 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
21 |
+
params:
|
22 |
+
warm_up_steps: [ 10000 ]
|
23 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
24 |
+
f_start: [ 1.e-6 ]
|
25 |
+
f_max: [ 1. ]
|
26 |
+
f_min: [ 1. ]
|
27 |
+
|
28 |
+
unet_config:
|
29 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
30 |
+
params:
|
31 |
+
image_size: 32 # unused
|
32 |
+
in_channels: 4
|
33 |
+
out_channels: 4
|
34 |
+
model_channels: 320
|
35 |
+
attention_resolutions: [ 4, 2, 1 ]
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
38 |
+
num_heads: 8
|
39 |
+
use_spatial_transformer: True
|
40 |
+
transformer_depth: 1
|
41 |
+
context_dim: 768
|
42 |
+
use_checkpoint: True
|
43 |
+
legacy: False
|
44 |
+
|
45 |
+
first_stage_config:
|
46 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
47 |
+
params:
|
48 |
+
embed_dim: 4
|
49 |
+
monitor: val/rec_loss
|
50 |
+
ddconfig:
|
51 |
+
double_z: true
|
52 |
+
z_channels: 4
|
53 |
+
resolution: 512
|
54 |
+
in_channels: 3
|
55 |
+
out_ch: 3
|
56 |
+
ch: 128
|
57 |
+
ch_mult:
|
58 |
+
- 1
|
59 |
+
- 2
|
60 |
+
- 4
|
61 |
+
- 4
|
62 |
+
num_res_blocks: 2
|
63 |
+
attn_resolutions: []
|
64 |
+
dropout: 0.0
|
65 |
+
lossconfig:
|
66 |
+
target: torch.nn.Identity
|
67 |
+
|
68 |
+
cond_stage_config:
|
69 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
dreambooth-for-diffusion/ckpt_models/put_your_ckpt_models_here.txt
ADDED
File without changes
|
dreambooth-for-diffusion/datasets/put_datasets_here.txt
ADDED
File without changes
|
dreambooth-for-diffusion/other/something others.txt
ADDED
File without changes
|
dreambooth-for-diffusion/test_model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import StableDiffusionPipeline
|
2 |
+
import torch
|
3 |
+
from diffusers import DDIMScheduler
|
4 |
+
|
5 |
+
model_path = "./new_model"
|
6 |
+
prompt = "a cute girl, blue eyes, brown hair"
|
7 |
+
torch.manual_seed(123123123)
|
8 |
+
|
9 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
10 |
+
model_path,
|
11 |
+
torch_dtype=torch.float16,
|
12 |
+
scheduler=DDIMScheduler(
|
13 |
+
beta_start=0.00085,
|
14 |
+
beta_end=0.012,
|
15 |
+
beta_schedule="scaled_linear",
|
16 |
+
clip_sample=False,
|
17 |
+
set_alpha_to_one=True,
|
18 |
+
),
|
19 |
+
safety_checker=None
|
20 |
+
)
|
21 |
+
|
22 |
+
# def dummy(images, **kwargs):
|
23 |
+
# return images, False
|
24 |
+
# pipe.safety_checker = dummy
|
25 |
+
pipe = pipe.to("cuda")
|
26 |
+
images = pipe(prompt, width=512, height=512, num_inference_steps=30, num_images_per_prompt=3).images
|
27 |
+
for i, image in enumerate(images):
|
28 |
+
image.save(f"test-{i}.png")
|
dreambooth-for-diffusion/test_prompts_object.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
a photo of <xxx> dog
|
2 |
+
a photo of dog
|
dreambooth-for-diffusion/test_prompts_style.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
a cute girl, blue eyes, brown hair
|
2 |
+
a cute girl, blue eyes, blue hair
|
3 |
+
a cute boy, green eyes, brown hair
|
dreambooth-for-diffusion/tools/ckpt2diffusers.py
ADDED
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the LDM checkpoints. """
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
try:
|
24 |
+
from omegaconf import OmegaConf
|
25 |
+
except ImportError:
|
26 |
+
raise ImportError(
|
27 |
+
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
28 |
+
)
|
29 |
+
|
30 |
+
from diffusers import (
|
31 |
+
AutoencoderKL,
|
32 |
+
DDIMScheduler,
|
33 |
+
LDMTextToImagePipeline,
|
34 |
+
LMSDiscreteScheduler,
|
35 |
+
PNDMScheduler,
|
36 |
+
StableDiffusionPipeline,
|
37 |
+
UNet2DConditionModel,
|
38 |
+
)
|
39 |
+
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
40 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
41 |
+
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
42 |
+
|
43 |
+
script_path = os.path.realpath(__file__)
|
44 |
+
default_model_path = os.path.join(os.path.dirname(script_path), "diffusers-models")
|
45 |
+
|
46 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
47 |
+
"""
|
48 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
49 |
+
"""
|
50 |
+
if n_shave_prefix_segments >= 0:
|
51 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
52 |
+
else:
|
53 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
54 |
+
|
55 |
+
|
56 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
57 |
+
"""
|
58 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
59 |
+
"""
|
60 |
+
mapping = []
|
61 |
+
for old_item in old_list:
|
62 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
63 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
64 |
+
|
65 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
66 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
67 |
+
|
68 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
69 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
70 |
+
|
71 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
72 |
+
|
73 |
+
mapping.append({"old": old_item, "new": new_item})
|
74 |
+
|
75 |
+
return mapping
|
76 |
+
|
77 |
+
|
78 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
79 |
+
"""
|
80 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
81 |
+
"""
|
82 |
+
mapping = []
|
83 |
+
for old_item in old_list:
|
84 |
+
new_item = old_item
|
85 |
+
|
86 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
87 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
88 |
+
|
89 |
+
mapping.append({"old": old_item, "new": new_item})
|
90 |
+
|
91 |
+
return mapping
|
92 |
+
|
93 |
+
|
94 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
95 |
+
"""
|
96 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
97 |
+
"""
|
98 |
+
mapping = []
|
99 |
+
for old_item in old_list:
|
100 |
+
new_item = old_item
|
101 |
+
|
102 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
103 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
104 |
+
|
105 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
106 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
107 |
+
|
108 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
109 |
+
|
110 |
+
mapping.append({"old": old_item, "new": new_item})
|
111 |
+
|
112 |
+
return mapping
|
113 |
+
|
114 |
+
|
115 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
116 |
+
"""
|
117 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
118 |
+
"""
|
119 |
+
mapping = []
|
120 |
+
for old_item in old_list:
|
121 |
+
new_item = old_item
|
122 |
+
|
123 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
124 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
125 |
+
|
126 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
127 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
128 |
+
|
129 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
130 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
131 |
+
|
132 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
133 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
134 |
+
|
135 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
136 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
137 |
+
|
138 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
139 |
+
|
140 |
+
mapping.append({"old": old_item, "new": new_item})
|
141 |
+
|
142 |
+
return mapping
|
143 |
+
|
144 |
+
|
145 |
+
def assign_to_checkpoint(
|
146 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
147 |
+
):
|
148 |
+
"""
|
149 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
150 |
+
to them. It splits attention layers, and takes into account additional replacements
|
151 |
+
that may arise.
|
152 |
+
|
153 |
+
Assigns the weights to the new checkpoint.
|
154 |
+
"""
|
155 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
156 |
+
|
157 |
+
# Splits the attention layers into three variables.
|
158 |
+
if attention_paths_to_split is not None:
|
159 |
+
for path, path_map in attention_paths_to_split.items():
|
160 |
+
old_tensor = old_checkpoint[path]
|
161 |
+
channels = old_tensor.shape[0] // 3
|
162 |
+
|
163 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
164 |
+
|
165 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
166 |
+
|
167 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
168 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
169 |
+
|
170 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
171 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
172 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
173 |
+
|
174 |
+
for path in paths:
|
175 |
+
new_path = path["new"]
|
176 |
+
|
177 |
+
# These have already been assigned
|
178 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
179 |
+
continue
|
180 |
+
|
181 |
+
# Global renaming happens here
|
182 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
183 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
184 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
185 |
+
|
186 |
+
if additional_replacements is not None:
|
187 |
+
for replacement in additional_replacements:
|
188 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
189 |
+
|
190 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
191 |
+
if "proj_attn.weight" in new_path:
|
192 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
193 |
+
else:
|
194 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
195 |
+
|
196 |
+
|
197 |
+
def conv_attn_to_linear(checkpoint):
|
198 |
+
keys = list(checkpoint.keys())
|
199 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
200 |
+
for key in keys:
|
201 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
202 |
+
if checkpoint[key].ndim > 2:
|
203 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
204 |
+
elif "proj_attn.weight" in key:
|
205 |
+
if checkpoint[key].ndim > 2:
|
206 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
207 |
+
|
208 |
+
|
209 |
+
def create_unet_diffusers_config(original_config):
|
210 |
+
"""
|
211 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
212 |
+
"""
|
213 |
+
unet_params = original_config.model.params.unet_config.params
|
214 |
+
|
215 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
216 |
+
|
217 |
+
down_block_types = []
|
218 |
+
resolution = 1
|
219 |
+
for i in range(len(block_out_channels)):
|
220 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
221 |
+
down_block_types.append(block_type)
|
222 |
+
if i != len(block_out_channels) - 1:
|
223 |
+
resolution *= 2
|
224 |
+
|
225 |
+
up_block_types = []
|
226 |
+
for i in range(len(block_out_channels)):
|
227 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
228 |
+
up_block_types.append(block_type)
|
229 |
+
resolution //= 2
|
230 |
+
|
231 |
+
config = dict(
|
232 |
+
sample_size=unet_params.image_size,
|
233 |
+
in_channels=unet_params.in_channels,
|
234 |
+
out_channels=unet_params.out_channels,
|
235 |
+
down_block_types=tuple(down_block_types),
|
236 |
+
up_block_types=tuple(up_block_types),
|
237 |
+
block_out_channels=tuple(block_out_channels),
|
238 |
+
layers_per_block=unet_params.num_res_blocks,
|
239 |
+
cross_attention_dim=unet_params.context_dim,
|
240 |
+
attention_head_dim=unet_params.num_heads,
|
241 |
+
)
|
242 |
+
|
243 |
+
return config
|
244 |
+
|
245 |
+
|
246 |
+
def create_vae_diffusers_config(original_config):
|
247 |
+
"""
|
248 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
249 |
+
"""
|
250 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
251 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
252 |
+
|
253 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
254 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
255 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
256 |
+
|
257 |
+
config = dict(
|
258 |
+
sample_size=vae_params.resolution,
|
259 |
+
in_channels=vae_params.in_channels,
|
260 |
+
out_channels=vae_params.out_ch,
|
261 |
+
down_block_types=tuple(down_block_types),
|
262 |
+
up_block_types=tuple(up_block_types),
|
263 |
+
block_out_channels=tuple(block_out_channels),
|
264 |
+
latent_channels=vae_params.z_channels,
|
265 |
+
layers_per_block=vae_params.num_res_blocks,
|
266 |
+
)
|
267 |
+
return config
|
268 |
+
|
269 |
+
|
270 |
+
def create_diffusers_schedular(original_config):
|
271 |
+
schedular = DDIMScheduler(
|
272 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
273 |
+
beta_start=original_config.model.params.linear_start,
|
274 |
+
beta_end=original_config.model.params.linear_end,
|
275 |
+
beta_schedule="scaled_linear",
|
276 |
+
)
|
277 |
+
return schedular
|
278 |
+
|
279 |
+
|
280 |
+
def create_ldm_bert_config(original_config):
|
281 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
282 |
+
config = LDMBertConfig(
|
283 |
+
d_model=bert_params.n_embed,
|
284 |
+
encoder_layers=bert_params.n_layer,
|
285 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
286 |
+
)
|
287 |
+
return config
|
288 |
+
|
289 |
+
|
290 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
291 |
+
"""
|
292 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
293 |
+
"""
|
294 |
+
|
295 |
+
# extract state_dict for UNet
|
296 |
+
unet_state_dict = {}
|
297 |
+
keys = list(checkpoint.keys())
|
298 |
+
|
299 |
+
unet_key = "model.diffusion_model."
|
300 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
301 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
302 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
303 |
+
if extract_ema:
|
304 |
+
print(
|
305 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
306 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
307 |
+
)
|
308 |
+
for key in keys:
|
309 |
+
if key.startswith("model.diffusion_model"):
|
310 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
311 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
312 |
+
else:
|
313 |
+
print(
|
314 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
315 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
316 |
+
)
|
317 |
+
|
318 |
+
keys = list(checkpoint.keys())
|
319 |
+
for key in keys:
|
320 |
+
if key.startswith(unet_key):
|
321 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
322 |
+
|
323 |
+
new_checkpoint = {"time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"],
|
324 |
+
"time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"],
|
325 |
+
"time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"],
|
326 |
+
"time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"],
|
327 |
+
"conv_in.weight": unet_state_dict["input_blocks.0.0.weight"],
|
328 |
+
"conv_in.bias": unet_state_dict["input_blocks.0.0.bias"],
|
329 |
+
"conv_norm_out.weight": unet_state_dict["out.0.weight"],
|
330 |
+
"conv_norm_out.bias": unet_state_dict["out.0.bias"],
|
331 |
+
"conv_out.weight": unet_state_dict["out.2.weight"],
|
332 |
+
"conv_out.bias": unet_state_dict["out.2.bias"]}
|
333 |
+
|
334 |
+
# Retrieves the keys for the input blocks only
|
335 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
336 |
+
input_blocks = {
|
337 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
338 |
+
for layer_id in range(num_input_blocks)
|
339 |
+
}
|
340 |
+
|
341 |
+
# Retrieves the keys for the middle blocks only
|
342 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
343 |
+
middle_blocks = {
|
344 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
345 |
+
for layer_id in range(num_middle_blocks)
|
346 |
+
}
|
347 |
+
|
348 |
+
# Retrieves the keys for the output blocks only
|
349 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
350 |
+
output_blocks = {
|
351 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
352 |
+
for layer_id in range(num_output_blocks)
|
353 |
+
}
|
354 |
+
|
355 |
+
for i in range(1, num_input_blocks):
|
356 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
357 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
358 |
+
|
359 |
+
resnets = [
|
360 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
361 |
+
]
|
362 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
363 |
+
|
364 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
365 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
366 |
+
f"input_blocks.{i}.0.op.weight"
|
367 |
+
)
|
368 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
369 |
+
f"input_blocks.{i}.0.op.bias"
|
370 |
+
)
|
371 |
+
|
372 |
+
paths = renew_resnet_paths(resnets)
|
373 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
374 |
+
assign_to_checkpoint(
|
375 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
376 |
+
)
|
377 |
+
|
378 |
+
if len(attentions):
|
379 |
+
paths = renew_attention_paths(attentions)
|
380 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
381 |
+
assign_to_checkpoint(
|
382 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
383 |
+
)
|
384 |
+
|
385 |
+
resnet_0 = middle_blocks[0]
|
386 |
+
attentions = middle_blocks[1]
|
387 |
+
resnet_1 = middle_blocks[2]
|
388 |
+
|
389 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
390 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
391 |
+
|
392 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
393 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
394 |
+
|
395 |
+
attentions_paths = renew_attention_paths(attentions)
|
396 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
397 |
+
assign_to_checkpoint(
|
398 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
399 |
+
)
|
400 |
+
|
401 |
+
for i in range(num_output_blocks):
|
402 |
+
block_id = i // (config["layers_per_block"] + 1)
|
403 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
404 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
405 |
+
output_block_list = {}
|
406 |
+
|
407 |
+
for layer in output_block_layers:
|
408 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
409 |
+
if layer_id in output_block_list:
|
410 |
+
output_block_list[layer_id].append(layer_name)
|
411 |
+
else:
|
412 |
+
output_block_list[layer_id] = [layer_name]
|
413 |
+
|
414 |
+
if len(output_block_list) > 1:
|
415 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
416 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
417 |
+
|
418 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
419 |
+
paths = renew_resnet_paths(resnets)
|
420 |
+
|
421 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
422 |
+
assign_to_checkpoint(
|
423 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
424 |
+
)
|
425 |
+
|
426 |
+
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
427 |
+
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
428 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
429 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
430 |
+
]
|
431 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
432 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
433 |
+
]
|
434 |
+
|
435 |
+
# Clear attentions as they have been attributed above.
|
436 |
+
if len(attentions) == 2:
|
437 |
+
attentions = []
|
438 |
+
|
439 |
+
if len(attentions):
|
440 |
+
paths = renew_attention_paths(attentions)
|
441 |
+
meta_path = {
|
442 |
+
"old": f"output_blocks.{i}.1",
|
443 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
444 |
+
}
|
445 |
+
assign_to_checkpoint(
|
446 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
450 |
+
for path in resnet_0_paths:
|
451 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
452 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
453 |
+
|
454 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
455 |
+
|
456 |
+
return new_checkpoint
|
457 |
+
|
458 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
459 |
+
# extract state dict for VAE
|
460 |
+
vae_state_dict = {}
|
461 |
+
vae_key = "first_stage_model."
|
462 |
+
keys = list(checkpoint.keys())
|
463 |
+
for key in keys:
|
464 |
+
if key.startswith(vae_key):
|
465 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
466 |
+
|
467 |
+
new_checkpoint = {}
|
468 |
+
|
469 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
470 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
471 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
472 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
473 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
474 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
475 |
+
|
476 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
477 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
478 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
479 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
480 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
481 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
482 |
+
|
483 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
484 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
485 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
486 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
487 |
+
|
488 |
+
|
489 |
+
# Retrieves the keys for the encoder down blocks only
|
490 |
+
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'encoder.down' in layer})
|
491 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
492 |
+
|
493 |
+
# Retrieves the keys for the decoder up blocks only
|
494 |
+
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'decoder.up' in layer})
|
495 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
496 |
+
|
497 |
+
|
498 |
+
for i in range(num_down_blocks):
|
499 |
+
resnets = [key for key in down_blocks[i] if f'down.{i}' in key and f"down.{i}.downsample" not in key]
|
500 |
+
|
501 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
502 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
|
503 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
|
504 |
+
|
505 |
+
paths = renew_vae_resnet_paths(resnets)
|
506 |
+
meta_path = {'old': f'down.{i}.block', 'new': f'down_blocks.{i}.resnets'}
|
507 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
508 |
+
|
509 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
510 |
+
num_mid_res_blocks = 2
|
511 |
+
for i in range(1, num_mid_res_blocks + 1):
|
512 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
513 |
+
|
514 |
+
paths = renew_vae_resnet_paths(resnets)
|
515 |
+
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
516 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
517 |
+
|
518 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
519 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
520 |
+
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
521 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
522 |
+
conv_attn_to_linear(new_checkpoint)
|
523 |
+
|
524 |
+
for i in range(num_up_blocks):
|
525 |
+
block_id = num_up_blocks - 1 - i
|
526 |
+
resnets = [key for key in up_blocks[block_id] if f'up.{block_id}' in key and f"up.{block_id}.upsample" not in key]
|
527 |
+
|
528 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
529 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
|
530 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
|
531 |
+
|
532 |
+
paths = renew_vae_resnet_paths(resnets)
|
533 |
+
meta_path = {'old': f'up.{block_id}.block', 'new': f'up_blocks.{i}.resnets'}
|
534 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
535 |
+
|
536 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
537 |
+
num_mid_res_blocks = 2
|
538 |
+
for i in range(1, num_mid_res_blocks + 1):
|
539 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
540 |
+
|
541 |
+
paths = renew_vae_resnet_paths(resnets)
|
542 |
+
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
543 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
544 |
+
|
545 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
546 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
547 |
+
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
548 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
549 |
+
conv_attn_to_linear(new_checkpoint)
|
550 |
+
return new_checkpoint
|
551 |
+
|
552 |
+
|
553 |
+
def convert_ldm_vae(vae_path, config):
|
554 |
+
vae_state_dict = torch.load(vae_path)['state_dict']
|
555 |
+
|
556 |
+
new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
|
557 |
+
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
|
558 |
+
"encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"],
|
559 |
+
"encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"],
|
560 |
+
"encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"],
|
561 |
+
"encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"],
|
562 |
+
"decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"],
|
563 |
+
"decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"],
|
564 |
+
"decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"],
|
565 |
+
"decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"],
|
566 |
+
"decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"],
|
567 |
+
"decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"],
|
568 |
+
"quant_conv.weight": vae_state_dict["quant_conv.weight"],
|
569 |
+
"quant_conv.bias": vae_state_dict["quant_conv.bias"],
|
570 |
+
"post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"],
|
571 |
+
"post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"]}
|
572 |
+
|
573 |
+
|
574 |
+
# Retrieves the keys for the encoder down blocks only
|
575 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
576 |
+
down_blocks = {
|
577 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
578 |
+
}
|
579 |
+
|
580 |
+
# Retrieves the keys for the decoder up blocks only
|
581 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
582 |
+
up_blocks = {
|
583 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
584 |
+
}
|
585 |
+
|
586 |
+
for i in range(num_down_blocks):
|
587 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
588 |
+
|
589 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
590 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
591 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
592 |
+
)
|
593 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
594 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
595 |
+
)
|
596 |
+
|
597 |
+
paths = renew_vae_resnet_paths(resnets)
|
598 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
599 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
600 |
+
|
601 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
602 |
+
num_mid_res_blocks = 2
|
603 |
+
for i in range(1, num_mid_res_blocks + 1):
|
604 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
605 |
+
|
606 |
+
paths = renew_vae_resnet_paths(resnets)
|
607 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
608 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
609 |
+
|
610 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
611 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
612 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
613 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
614 |
+
conv_attn_to_linear(new_checkpoint)
|
615 |
+
|
616 |
+
for i in range(num_up_blocks):
|
617 |
+
block_id = num_up_blocks - 1 - i
|
618 |
+
resnets = [
|
619 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
620 |
+
]
|
621 |
+
|
622 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
623 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
624 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
625 |
+
]
|
626 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
627 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
628 |
+
]
|
629 |
+
|
630 |
+
paths = renew_vae_resnet_paths(resnets)
|
631 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
632 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
633 |
+
|
634 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
635 |
+
num_mid_res_blocks = 2
|
636 |
+
for i in range(1, num_mid_res_blocks + 1):
|
637 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
638 |
+
|
639 |
+
paths = renew_vae_resnet_paths(resnets)
|
640 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
641 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
642 |
+
|
643 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
644 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
645 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
646 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
647 |
+
conv_attn_to_linear(new_checkpoint)
|
648 |
+
return new_checkpoint
|
649 |
+
|
650 |
+
|
651 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
652 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
653 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
654 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
655 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
656 |
+
|
657 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
658 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
659 |
+
|
660 |
+
def _copy_linear(hf_linear, pt_linear):
|
661 |
+
hf_linear.weight = pt_linear.weight
|
662 |
+
hf_linear.bias = pt_linear.bias
|
663 |
+
|
664 |
+
def _copy_layer(hf_layer, pt_layer):
|
665 |
+
# copy layer norms
|
666 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
667 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
668 |
+
|
669 |
+
# copy attn
|
670 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
671 |
+
|
672 |
+
# copy MLP
|
673 |
+
pt_mlp = pt_layer[1][1]
|
674 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
675 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
676 |
+
|
677 |
+
def _copy_layers(hf_layers, pt_layers):
|
678 |
+
for i, hf_layer in enumerate(hf_layers):
|
679 |
+
if i != 0:
|
680 |
+
i += i
|
681 |
+
pt_layer = pt_layers[i : i + 2]
|
682 |
+
_copy_layer(hf_layer, pt_layer)
|
683 |
+
|
684 |
+
hf_model = LDMBertModel(config).eval()
|
685 |
+
|
686 |
+
# copy embeds
|
687 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
688 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
689 |
+
|
690 |
+
# copy layer norm
|
691 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
692 |
+
|
693 |
+
# copy hidden layers
|
694 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
695 |
+
|
696 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
697 |
+
|
698 |
+
return hf_model
|
699 |
+
|
700 |
+
|
701 |
+
def convert_ldm_clip_checkpoint(checkpoint):
|
702 |
+
if os.path.exists(default_model_path):
|
703 |
+
text_model = CLIPTextModel.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
|
704 |
+
else:
|
705 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
706 |
+
|
707 |
+
keys = list(checkpoint.keys())
|
708 |
+
|
709 |
+
text_model_dict = {}
|
710 |
+
|
711 |
+
for key in keys:
|
712 |
+
if key.startswith("cond_stage_model.transformer"):
|
713 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
714 |
+
|
715 |
+
text_model.load_state_dict(text_model_dict, strict=False)
|
716 |
+
|
717 |
+
return text_model
|
718 |
+
|
719 |
+
|
720 |
+
if __name__ == "__main__":
|
721 |
+
parser = argparse.ArgumentParser()
|
722 |
+
|
723 |
+
parser.add_argument(
|
724 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
725 |
+
)
|
726 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
727 |
+
parser.add_argument(
|
728 |
+
"--vae_path", default=None, type=str, help="Path to the vae to convert."
|
729 |
+
)
|
730 |
+
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
731 |
+
parser.add_argument(
|
732 |
+
"--original_config_file",
|
733 |
+
default=None,
|
734 |
+
type=str,
|
735 |
+
help="The YAML config file corresponding to the original architecture.",
|
736 |
+
)
|
737 |
+
parser.add_argument(
|
738 |
+
"--scheduler_type",
|
739 |
+
default="pndm",
|
740 |
+
type=str,
|
741 |
+
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
|
742 |
+
)
|
743 |
+
parser.add_argument(
|
744 |
+
"--extract_ema",
|
745 |
+
action="store_true",
|
746 |
+
default=False,
|
747 |
+
help=(
|
748 |
+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
749 |
+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
750 |
+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
751 |
+
),
|
752 |
+
)
|
753 |
+
|
754 |
+
args = parser.parse_args()
|
755 |
+
|
756 |
+
if args.original_config_file is None:
|
757 |
+
os.system(
|
758 |
+
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
759 |
+
)
|
760 |
+
args.original_config_file = "./v1-inference.yaml"
|
761 |
+
|
762 |
+
original_config = OmegaConf.load(args.original_config_file)
|
763 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cuda")
|
764 |
+
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
765 |
+
|
766 |
+
num_train_timesteps = original_config.model.params.timesteps
|
767 |
+
beta_start = original_config.model.params.linear_start
|
768 |
+
beta_end = original_config.model.params.linear_end
|
769 |
+
if args.scheduler_type == "pndm":
|
770 |
+
scheduler = PNDMScheduler(
|
771 |
+
beta_end=beta_end,
|
772 |
+
beta_schedule="scaled_linear",
|
773 |
+
beta_start=beta_start,
|
774 |
+
num_train_timesteps=num_train_timesteps,
|
775 |
+
skip_prk_steps=True,
|
776 |
+
)
|
777 |
+
elif args.scheduler_type == "lms":
|
778 |
+
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
779 |
+
elif args.scheduler_type == "ddim":
|
780 |
+
scheduler = DDIMScheduler(
|
781 |
+
beta_start=beta_start,
|
782 |
+
beta_end=beta_end,
|
783 |
+
beta_schedule="scaled_linear",
|
784 |
+
clip_sample=False,
|
785 |
+
set_alpha_to_one=False,
|
786 |
+
)
|
787 |
+
else:
|
788 |
+
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
789 |
+
|
790 |
+
# Convert the UNet2DConditionModel model.
|
791 |
+
unet_config = create_unet_diffusers_config(original_config)
|
792 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
793 |
+
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
|
794 |
+
)
|
795 |
+
|
796 |
+
unet = UNet2DConditionModel(**unet_config)
|
797 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
798 |
+
|
799 |
+
# Convert the VAE model.
|
800 |
+
if args.vae_path:
|
801 |
+
vae_config = create_vae_diffusers_config(original_config)
|
802 |
+
converted_vae_checkpoint = convert_ldm_vae(args.vae_path, vae_config)
|
803 |
+
else:
|
804 |
+
vae_config = create_vae_diffusers_config(original_config)
|
805 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
806 |
+
|
807 |
+
vae = AutoencoderKL(**vae_config)
|
808 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
809 |
+
|
810 |
+
# Convert the text model.
|
811 |
+
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
812 |
+
if text_model_type == "FrozenCLIPEmbedder":
|
813 |
+
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
814 |
+
if os.path.exists(default_model_path):
|
815 |
+
tokenizer = CLIPTokenizer.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
|
816 |
+
else:
|
817 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
818 |
+
#safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
819 |
+
#feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
820 |
+
pipe = StableDiffusionPipeline(
|
821 |
+
vae=vae,
|
822 |
+
text_encoder=text_model,
|
823 |
+
tokenizer=tokenizer,
|
824 |
+
unet=unet,
|
825 |
+
scheduler=scheduler,
|
826 |
+
safety_checker=None,
|
827 |
+
feature_extractor=None,
|
828 |
+
)
|
829 |
+
else:
|
830 |
+
text_config = create_ldm_bert_config(original_config)
|
831 |
+
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
832 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
833 |
+
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
834 |
+
|
835 |
+
pipe.save_pretrained(args.dump_path)
|
dreambooth-for-diffusion/tools/ckpt2diffusers_old.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the LDM checkpoints. """
|
16 |
+
|
17 |
+
import argparse, os
|
18 |
+
import torch
|
19 |
+
|
20 |
+
try:
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
except ImportError:
|
23 |
+
raise ImportError("OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`.")
|
24 |
+
|
25 |
+
from transformers import BertTokenizerFast, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel
|
26 |
+
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler
|
27 |
+
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel, LDMBertConfig
|
28 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
29 |
+
|
30 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
31 |
+
"""
|
32 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
33 |
+
"""
|
34 |
+
if n_shave_prefix_segments >= 0:
|
35 |
+
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
36 |
+
else:
|
37 |
+
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
38 |
+
|
39 |
+
|
40 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
41 |
+
"""
|
42 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
43 |
+
"""
|
44 |
+
mapping = []
|
45 |
+
for old_item in old_list:
|
46 |
+
new_item = old_item.replace('in_layers.0', 'norm1')
|
47 |
+
new_item = new_item.replace('in_layers.2', 'conv1')
|
48 |
+
|
49 |
+
new_item = new_item.replace('out_layers.0', 'norm2')
|
50 |
+
new_item = new_item.replace('out_layers.3', 'conv2')
|
51 |
+
|
52 |
+
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
|
53 |
+
new_item = new_item.replace('skip_connection', 'conv_shortcut')
|
54 |
+
|
55 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
56 |
+
|
57 |
+
mapping.append({'old': old_item, 'new': new_item})
|
58 |
+
|
59 |
+
return mapping
|
60 |
+
|
61 |
+
|
62 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
63 |
+
"""
|
64 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
65 |
+
"""
|
66 |
+
mapping = []
|
67 |
+
for old_item in old_list:
|
68 |
+
new_item = old_item
|
69 |
+
|
70 |
+
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
|
71 |
+
|
72 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
73 |
+
|
74 |
+
mapping.append({'old': old_item, 'new': new_item})
|
75 |
+
|
76 |
+
return mapping
|
77 |
+
|
78 |
+
|
79 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
80 |
+
"""
|
81 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
82 |
+
"""
|
83 |
+
mapping = []
|
84 |
+
for old_item in old_list:
|
85 |
+
new_item = old_item
|
86 |
+
|
87 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
88 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
89 |
+
|
90 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
91 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
92 |
+
|
93 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
94 |
+
|
95 |
+
mapping.append({'old': old_item, 'new': new_item})
|
96 |
+
|
97 |
+
return mapping
|
98 |
+
|
99 |
+
|
100 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
101 |
+
"""
|
102 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
103 |
+
"""
|
104 |
+
mapping = []
|
105 |
+
for old_item in old_list:
|
106 |
+
new_item = old_item
|
107 |
+
|
108 |
+
new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
109 |
+
new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
110 |
+
|
111 |
+
new_item = new_item.replace('q.weight', 'query.weight')
|
112 |
+
new_item = new_item.replace('q.bias', 'query.bias')
|
113 |
+
|
114 |
+
new_item = new_item.replace('k.weight', 'key.weight')
|
115 |
+
new_item = new_item.replace('k.bias', 'key.bias')
|
116 |
+
|
117 |
+
new_item = new_item.replace('v.weight', 'value.weight')
|
118 |
+
new_item = new_item.replace('v.bias', 'value.bias')
|
119 |
+
|
120 |
+
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
121 |
+
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
122 |
+
|
123 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
124 |
+
|
125 |
+
mapping.append({'old': old_item, 'new': new_item})
|
126 |
+
|
127 |
+
return mapping
|
128 |
+
|
129 |
+
|
130 |
+
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
131 |
+
"""
|
132 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
133 |
+
to them. It splits attention layers, and takes into account additional replacements
|
134 |
+
that may arise.
|
135 |
+
|
136 |
+
Assigns the weights to the new checkpoint.
|
137 |
+
"""
|
138 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
139 |
+
|
140 |
+
# Splits the attention layers into three variables.
|
141 |
+
if attention_paths_to_split is not None:
|
142 |
+
for path, path_map in attention_paths_to_split.items():
|
143 |
+
old_tensor = old_checkpoint[path]
|
144 |
+
channels = old_tensor.shape[0] // 3
|
145 |
+
|
146 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
147 |
+
|
148 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
149 |
+
|
150 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
151 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
152 |
+
|
153 |
+
checkpoint[path_map['query']] = query.reshape(target_shape)
|
154 |
+
checkpoint[path_map['key']] = key.reshape(target_shape)
|
155 |
+
checkpoint[path_map['value']] = value.reshape(target_shape)
|
156 |
+
|
157 |
+
for path in paths:
|
158 |
+
new_path = path['new']
|
159 |
+
|
160 |
+
# These have already been assigned
|
161 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
162 |
+
continue
|
163 |
+
|
164 |
+
# Global renaming happens here
|
165 |
+
new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0')
|
166 |
+
new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0')
|
167 |
+
new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1')
|
168 |
+
|
169 |
+
if additional_replacements is not None:
|
170 |
+
for replacement in additional_replacements:
|
171 |
+
new_path = new_path.replace(replacement['old'], replacement['new'])
|
172 |
+
|
173 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
174 |
+
if "proj_attn.weight" in new_path:
|
175 |
+
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
176 |
+
else:
|
177 |
+
checkpoint[new_path] = old_checkpoint[path['old']]
|
178 |
+
|
179 |
+
|
180 |
+
def conv_attn_to_linear(checkpoint):
|
181 |
+
keys = list(checkpoint.keys())
|
182 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
183 |
+
for key in keys:
|
184 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
185 |
+
if checkpoint[key].ndim > 2:
|
186 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
187 |
+
elif "proj_attn.weight" in key:
|
188 |
+
if checkpoint[key].ndim > 2:
|
189 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
190 |
+
|
191 |
+
|
192 |
+
def create_unet_diffusers_config(original_config):
|
193 |
+
"""
|
194 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
195 |
+
"""
|
196 |
+
unet_params = original_config.model.params.unet_config.params
|
197 |
+
|
198 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
199 |
+
|
200 |
+
down_block_types = []
|
201 |
+
resolution = 1
|
202 |
+
for i in range(len(block_out_channels)):
|
203 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
204 |
+
down_block_types.append(block_type)
|
205 |
+
if i != len(block_out_channels) - 1:
|
206 |
+
resolution *= 2
|
207 |
+
|
208 |
+
up_block_types = []
|
209 |
+
for i in range(len(block_out_channels)):
|
210 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
211 |
+
up_block_types.append(block_type)
|
212 |
+
resolution //= 2
|
213 |
+
|
214 |
+
config = dict(
|
215 |
+
sample_size=unet_params.image_size,
|
216 |
+
in_channels=unet_params.in_channels,
|
217 |
+
out_channels=unet_params.out_channels,
|
218 |
+
down_block_types=tuple(down_block_types),
|
219 |
+
up_block_types=tuple(up_block_types),
|
220 |
+
block_out_channels=tuple(block_out_channels),
|
221 |
+
layers_per_block=unet_params.num_res_blocks,
|
222 |
+
cross_attention_dim=unet_params.context_dim,
|
223 |
+
attention_head_dim=unet_params.num_heads,
|
224 |
+
)
|
225 |
+
|
226 |
+
return config
|
227 |
+
|
228 |
+
|
229 |
+
def create_vae_diffusers_config(original_config):
|
230 |
+
"""
|
231 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
232 |
+
"""
|
233 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
234 |
+
latent_channles = original_config.model.params.first_stage_config.params.embed_dim
|
235 |
+
|
236 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
237 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
238 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
239 |
+
|
240 |
+
config = dict(
|
241 |
+
sample_size=vae_params.resolution,
|
242 |
+
in_channels=vae_params.in_channels,
|
243 |
+
out_channels=vae_params.out_ch,
|
244 |
+
down_block_types=tuple(down_block_types),
|
245 |
+
up_block_types=tuple(up_block_types),
|
246 |
+
block_out_channels=tuple(block_out_channels),
|
247 |
+
latent_channels=vae_params.z_channels,
|
248 |
+
layers_per_block=vae_params.num_res_blocks,
|
249 |
+
)
|
250 |
+
return config
|
251 |
+
|
252 |
+
|
253 |
+
def create_diffusers_schedular(original_config):
|
254 |
+
schedular = DDIMScheduler(
|
255 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
256 |
+
beta_start=original_config.model.params.linear_start,
|
257 |
+
beta_end=original_config.model.params.linear_end,
|
258 |
+
beta_schedule="scaled_linear",
|
259 |
+
)
|
260 |
+
return schedular
|
261 |
+
|
262 |
+
|
263 |
+
def create_ldm_bert_config(original_config):
|
264 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
265 |
+
config = LDMBertConfig(
|
266 |
+
d_model=bert_params.n_embed,
|
267 |
+
encoder_layers=bert_params.n_layer,
|
268 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
269 |
+
)
|
270 |
+
return config
|
271 |
+
|
272 |
+
|
273 |
+
def convert_ldm_unet_checkpoint(checkpoint, config):
|
274 |
+
"""
|
275 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
276 |
+
"""
|
277 |
+
|
278 |
+
# extract state_dict for UNet
|
279 |
+
unet_state_dict = {}
|
280 |
+
unet_key = "model.diffusion_model."
|
281 |
+
keys = list(checkpoint.keys())
|
282 |
+
for key in keys:
|
283 |
+
if key.startswith(unet_key):
|
284 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
285 |
+
|
286 |
+
new_checkpoint = {}
|
287 |
+
|
288 |
+
new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict['time_embed.0.weight']
|
289 |
+
new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict['time_embed.0.bias']
|
290 |
+
new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict['time_embed.2.weight']
|
291 |
+
new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict['time_embed.2.bias']
|
292 |
+
|
293 |
+
new_checkpoint['conv_in.weight'] = unet_state_dict['input_blocks.0.0.weight']
|
294 |
+
new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias']
|
295 |
+
|
296 |
+
new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight']
|
297 |
+
new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias']
|
298 |
+
new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight']
|
299 |
+
new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias']
|
300 |
+
|
301 |
+
# Retrieves the keys for the input blocks only
|
302 |
+
num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'input_blocks' in layer})
|
303 |
+
input_blocks = {layer_id: [key for key in unet_state_dict if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)}
|
304 |
+
|
305 |
+
# Retrieves the keys for the middle blocks only
|
306 |
+
num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'middle_block' in layer})
|
307 |
+
middle_blocks = {layer_id: [key for key in unet_state_dict if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)}
|
308 |
+
|
309 |
+
# Retrieves the keys for the output blocks only
|
310 |
+
num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'output_blocks' in layer})
|
311 |
+
output_blocks = {layer_id: [key for key in unet_state_dict if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)}
|
312 |
+
|
313 |
+
for i in range(1, num_input_blocks):
|
314 |
+
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
315 |
+
layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1)
|
316 |
+
|
317 |
+
resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key and f'input_blocks.{i}.0.op' not in key]
|
318 |
+
attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key]
|
319 |
+
|
320 |
+
if f'input_blocks.{i}.0.op.weight' in unet_state_dict:
|
321 |
+
new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.weight'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight')
|
322 |
+
new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.bias'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias')
|
323 |
+
|
324 |
+
paths = renew_resnet_paths(resnets)
|
325 |
+
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
326 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
327 |
+
|
328 |
+
if len(attentions):
|
329 |
+
paths = renew_attention_paths(attentions)
|
330 |
+
meta_path = {'old': f'input_blocks.{i}.1', 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}'}
|
331 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
332 |
+
|
333 |
+
|
334 |
+
resnet_0 = middle_blocks[0]
|
335 |
+
attentions = middle_blocks[1]
|
336 |
+
resnet_1 = middle_blocks[2]
|
337 |
+
|
338 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
339 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
340 |
+
|
341 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
342 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
343 |
+
|
344 |
+
attentions_paths = renew_attention_paths(attentions)
|
345 |
+
meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'}
|
346 |
+
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
347 |
+
|
348 |
+
for i in range(num_output_blocks):
|
349 |
+
block_id = i // (config['layers_per_block'] + 1)
|
350 |
+
layer_in_block_id = i % (config['layers_per_block'] + 1)
|
351 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
352 |
+
output_block_list = {}
|
353 |
+
|
354 |
+
for layer in output_block_layers:
|
355 |
+
layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1)
|
356 |
+
if layer_id in output_block_list:
|
357 |
+
output_block_list[layer_id].append(layer_name)
|
358 |
+
else:
|
359 |
+
output_block_list[layer_id] = [layer_name]
|
360 |
+
|
361 |
+
if len(output_block_list) > 1:
|
362 |
+
resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key]
|
363 |
+
attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key]
|
364 |
+
|
365 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
366 |
+
paths = renew_resnet_paths(resnets)
|
367 |
+
|
368 |
+
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
369 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
370 |
+
|
371 |
+
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
372 |
+
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
373 |
+
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight']
|
374 |
+
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias']
|
375 |
+
|
376 |
+
# Clear attentions as they have been attributed above.
|
377 |
+
if len(attentions) == 2:
|
378 |
+
attentions = []
|
379 |
+
|
380 |
+
if len(attentions):
|
381 |
+
paths = renew_attention_paths(attentions)
|
382 |
+
meta_path = {
|
383 |
+
'old': f'output_blocks.{i}.1',
|
384 |
+
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
|
385 |
+
}
|
386 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
387 |
+
else:
|
388 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
389 |
+
for path in resnet_0_paths:
|
390 |
+
old_path = '.'.join(['output_blocks', str(i), path['old']])
|
391 |
+
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
|
392 |
+
|
393 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
394 |
+
|
395 |
+
return new_checkpoint
|
396 |
+
|
397 |
+
|
398 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
399 |
+
# extract state dict for VAE
|
400 |
+
vae_state_dict = {}
|
401 |
+
vae_key = "first_stage_model."
|
402 |
+
keys = list(checkpoint.keys())
|
403 |
+
for key in keys:
|
404 |
+
if key.startswith(vae_key):
|
405 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
406 |
+
|
407 |
+
new_checkpoint = {}
|
408 |
+
|
409 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
410 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
411 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
412 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
413 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
414 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
415 |
+
|
416 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
417 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
418 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
419 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
420 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
421 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
422 |
+
|
423 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
424 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
425 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
426 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
427 |
+
|
428 |
+
|
429 |
+
# Retrieves the keys for the encoder down blocks only
|
430 |
+
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'encoder.down' in layer})
|
431 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
432 |
+
|
433 |
+
# Retrieves the keys for the decoder up blocks only
|
434 |
+
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'decoder.up' in layer})
|
435 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
436 |
+
|
437 |
+
|
438 |
+
for i in range(num_down_blocks):
|
439 |
+
resnets = [key for key in down_blocks[i] if f'down.{i}' in key and f"down.{i}.downsample" not in key]
|
440 |
+
|
441 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
442 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
|
443 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
|
444 |
+
|
445 |
+
paths = renew_vae_resnet_paths(resnets)
|
446 |
+
meta_path = {'old': f'down.{i}.block', 'new': f'down_blocks.{i}.resnets'}
|
447 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
448 |
+
|
449 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
450 |
+
num_mid_res_blocks = 2
|
451 |
+
for i in range(1, num_mid_res_blocks + 1):
|
452 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
453 |
+
|
454 |
+
paths = renew_vae_resnet_paths(resnets)
|
455 |
+
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
456 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
457 |
+
|
458 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
459 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
460 |
+
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
461 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
462 |
+
conv_attn_to_linear(new_checkpoint)
|
463 |
+
|
464 |
+
for i in range(num_up_blocks):
|
465 |
+
block_id = num_up_blocks - 1 - i
|
466 |
+
resnets = [key for key in up_blocks[block_id] if f'up.{block_id}' in key and f"up.{block_id}.upsample" not in key]
|
467 |
+
|
468 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
|
470 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
|
471 |
+
|
472 |
+
paths = renew_vae_resnet_paths(resnets)
|
473 |
+
meta_path = {'old': f'up.{block_id}.block', 'new': f'up_blocks.{i}.resnets'}
|
474 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
475 |
+
|
476 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
477 |
+
num_mid_res_blocks = 2
|
478 |
+
for i in range(1, num_mid_res_blocks + 1):
|
479 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
480 |
+
|
481 |
+
paths = renew_vae_resnet_paths(resnets)
|
482 |
+
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
483 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
484 |
+
|
485 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
486 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
487 |
+
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
488 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
489 |
+
conv_attn_to_linear(new_checkpoint)
|
490 |
+
return new_checkpoint
|
491 |
+
|
492 |
+
|
493 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
494 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
495 |
+
|
496 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
497 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
498 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
499 |
+
|
500 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
501 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
502 |
+
|
503 |
+
|
504 |
+
def _copy_linear(hf_linear, pt_linear):
|
505 |
+
hf_linear.weight = pt_linear.weight
|
506 |
+
hf_linear.bias = pt_linear.bias
|
507 |
+
|
508 |
+
|
509 |
+
def _copy_layer(hf_layer, pt_layer):
|
510 |
+
# copy layer norms
|
511 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
512 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
513 |
+
|
514 |
+
# copy attn
|
515 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
516 |
+
|
517 |
+
# copy MLP
|
518 |
+
pt_mlp = pt_layer[1][1]
|
519 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
520 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
521 |
+
|
522 |
+
|
523 |
+
def _copy_layers(hf_layers, pt_layers):
|
524 |
+
for i, hf_layer in enumerate(hf_layers):
|
525 |
+
if i != 0: i += i
|
526 |
+
pt_layer = pt_layers[i:i+2]
|
527 |
+
_copy_layer(hf_layer, pt_layer)
|
528 |
+
|
529 |
+
hf_model = LDMBertModel(config).eval()
|
530 |
+
|
531 |
+
# copy embeds
|
532 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
533 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
534 |
+
|
535 |
+
# copy layer norm
|
536 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
537 |
+
|
538 |
+
# copy hidden layers
|
539 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
540 |
+
|
541 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
542 |
+
|
543 |
+
return hf_model
|
544 |
+
|
545 |
+
|
546 |
+
|
547 |
+
if __name__ == "__main__":
|
548 |
+
parser = argparse.ArgumentParser()
|
549 |
+
|
550 |
+
parser.add_argument(
|
551 |
+
"checkpoint_path", default='./model.ckpt', type=str, help="Path to the checkpoint to convert."
|
552 |
+
)
|
553 |
+
|
554 |
+
|
555 |
+
parser.add_argument(
|
556 |
+
"dump_path", default='./model', type=str, help="Path to the output model."
|
557 |
+
)
|
558 |
+
|
559 |
+
parser.add_argument(
|
560 |
+
"--original_config_file",
|
561 |
+
default='./ckpt_models/model.yaml',
|
562 |
+
type=str,
|
563 |
+
required=False,
|
564 |
+
help="The YAML config file corresponding to the original architecture.",
|
565 |
+
)
|
566 |
+
|
567 |
+
args = parser.parse_args()
|
568 |
+
|
569 |
+
original_config = OmegaConf.load(args.original_config_file)
|
570 |
+
|
571 |
+
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
|
572 |
+
|
573 |
+
# Convert the UNet2DConditionModel model.
|
574 |
+
unet_config = create_unet_diffusers_config(original_config)
|
575 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
|
576 |
+
|
577 |
+
unet = UNet2DConditionModel(**unet_config)
|
578 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
579 |
+
|
580 |
+
# Convert the VAE model.
|
581 |
+
vae_config = create_vae_diffusers_config(original_config)
|
582 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
583 |
+
|
584 |
+
vae = AutoencoderKL(**vae_config)
|
585 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
586 |
+
|
587 |
+
|
588 |
+
|
589 |
+
# Convert the text model.
|
590 |
+
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
591 |
+
|
592 |
+
script_path = os.path.realpath(__file__)
|
593 |
+
default_model_path = os.path.join(os.path.dirname(script_path), "diffusers-models")
|
594 |
+
|
595 |
+
try:
|
596 |
+
text_model = CLIPTextModel.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
|
597 |
+
tokenizer = CLIPTokenizer.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
|
598 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(os.path.join(default_model_path, "safety-checker"))
|
599 |
+
|
600 |
+
except Exception as e:
|
601 |
+
print(e)
|
602 |
+
print("Could not load the default text model. Auto downloading...")
|
603 |
+
if text_model_type == "FrozenCLIPEmbedder":
|
604 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
605 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
606 |
+
else:
|
607 |
+
# TODO: update the convert function to use the state_dict without the model instance.
|
608 |
+
text_config = create_ldm_bert_config(original_config)
|
609 |
+
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
610 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
611 |
+
|
612 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker')
|
613 |
+
|
614 |
+
scheduler = create_diffusers_schedular(original_config)
|
615 |
+
|
616 |
+
scheduler = create_diffusers_schedular(original_config)
|
617 |
+
feature_extractor = CLIPFeatureExtractor()
|
618 |
+
pipe = StableDiffusionPipeline(vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor)
|
619 |
+
pipe.save_pretrained(args.dump_path)
|
dreambooth-for-diffusion/tools/ckpt_merge.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser(description="Merge two models")
|
7 |
+
parser.add_argument("model_0", type=str, help="Path to model 0")
|
8 |
+
parser.add_argument("model_1", type=str, help="Path to model 1")
|
9 |
+
parser.add_argument("--alpha", type=float, help="Alpha value, optional, defaults to 0.5", default=0.5, required=False)
|
10 |
+
parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False)
|
11 |
+
parser.add_argument("--device", type=str, help="Device to use, defaults to cpu", default="cpu", required=False)
|
12 |
+
parser.add_argument("--without_vae", action="store_true", help="Do not merge VAE", required=False)
|
13 |
+
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
device = args.device
|
17 |
+
model_0 = torch.load(args.model_0, map_location=device)
|
18 |
+
model_1 = torch.load(args.model_1, map_location=device)
|
19 |
+
theta_0 = model_0["state_dict"]
|
20 |
+
theta_1 = model_1["state_dict"]
|
21 |
+
alpha = args.alpha
|
22 |
+
|
23 |
+
output_file = f'{args.output}-{str(alpha)[2:] + "0"}.ckpt'
|
24 |
+
|
25 |
+
# check if output file already exists, ask to overwrite
|
26 |
+
if os.path.isfile(output_file):
|
27 |
+
print("Output file already exists. Overwrite? (y/n)")
|
28 |
+
while True:
|
29 |
+
overwrite = input()
|
30 |
+
if overwrite == "y":
|
31 |
+
break
|
32 |
+
elif overwrite == "n":
|
33 |
+
print("Exiting...")
|
34 |
+
exit()
|
35 |
+
else:
|
36 |
+
print("Please enter y or n")
|
37 |
+
|
38 |
+
|
39 |
+
for key in tqdm(theta_0.keys(), desc="Stage 1/2"):
|
40 |
+
# skip VAE model parameters to get better results(tested for anime models)
|
41 |
+
# for anime model,with merging VAE model, the result will be worse (dark and blurry)
|
42 |
+
if args.without_vae and "first_stage_model" in key:
|
43 |
+
continue
|
44 |
+
|
45 |
+
if "model" in key and key in theta_1:
|
46 |
+
theta_0[key] = (1 - alpha) * theta_0[key] + alpha * theta_1[key]
|
47 |
+
|
48 |
+
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
|
49 |
+
if "model" in key and key not in theta_0:
|
50 |
+
theta_0[key] = theta_1[key]
|
51 |
+
|
52 |
+
print("Saving...")
|
53 |
+
|
54 |
+
torch.save({"state_dict": theta_0}, output_file)
|
55 |
+
|
56 |
+
print("Done!")
|
dreambooth-for-diffusion/tools/ckpt_prune.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sd = torch.load(model_path, map_location="cpu")
|
2 |
+
if "state_dict" not in sd:
|
3 |
+
pruned_sd = {
|
4 |
+
"state_dict": dict(),
|
5 |
+
}
|
6 |
+
else:
|
7 |
+
pruned_sd = dict()
|
8 |
+
for k in sd.keys():
|
9 |
+
if k != "optimizer_states":
|
10 |
+
if "state_dict" not in sd:
|
11 |
+
pruned_sd["state_dict"][k] = sd[k]
|
12 |
+
else:
|
13 |
+
pruned_sd[k] = sd[k]
|
14 |
+
torch.save(pruned_sd, "model-pruned.ckpt")
|
dreambooth-for-diffusion/tools/deepdanbooru-models/put_deepdanbooru_model_here.txt
ADDED
File without changes
|
dreambooth-for-diffusion/tools/diagnose_tensorboard.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Self-diagnosis script for TensorBoard.
|
16 |
+
|
17 |
+
Instructions: Save this script to your local machine, then execute it in
|
18 |
+
the same environment (virtualenv, Conda, etc.) from which you normally
|
19 |
+
run TensorBoard. Read the output and follow the directions.
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
# This script may only depend on the Python standard library. It is not
|
24 |
+
# built with Bazel and should not assume any third-party dependencies.
|
25 |
+
import dataclasses
|
26 |
+
import errno
|
27 |
+
import functools
|
28 |
+
import hashlib
|
29 |
+
import inspect
|
30 |
+
import logging
|
31 |
+
import os
|
32 |
+
import pipes
|
33 |
+
import shlex
|
34 |
+
import socket
|
35 |
+
import subprocess
|
36 |
+
import sys
|
37 |
+
import tempfile
|
38 |
+
import textwrap
|
39 |
+
import traceback
|
40 |
+
|
41 |
+
|
42 |
+
# A *check* is a function (of no arguments) that performs a diagnostic,
|
43 |
+
# writes log messages, and optionally yields suggestions. Each check
|
44 |
+
# runs in isolation; exceptions will be caught and reported.
|
45 |
+
CHECKS = []
|
46 |
+
|
47 |
+
|
48 |
+
@dataclasses.dataclass(frozen=True)
|
49 |
+
class Suggestion:
|
50 |
+
"""A suggestion to the end user.
|
51 |
+
|
52 |
+
Attributes:
|
53 |
+
headline: A short description, like "Turn it off and on again". Should be
|
54 |
+
imperative with no trailing punctuation. May contain inline Markdown.
|
55 |
+
description: A full enumeration of the steps that the user should take to
|
56 |
+
accept the suggestion. Within this string, prose should be formatted
|
57 |
+
with `reflow`. May contain Markdown.
|
58 |
+
"""
|
59 |
+
|
60 |
+
headline: str
|
61 |
+
description: str
|
62 |
+
|
63 |
+
|
64 |
+
def check(fn):
|
65 |
+
"""Decorator to register a function as a check.
|
66 |
+
|
67 |
+
Checks are run in the order in which they are registered.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
fn: A function that takes no arguments and either returns `None` or
|
71 |
+
returns a generator of `Suggestion`s. (The ability to return
|
72 |
+
`None` is to work around the awkwardness of defining empty
|
73 |
+
generator functions in Python.)
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
A wrapped version of `fn` that returns a generator of `Suggestion`s.
|
77 |
+
"""
|
78 |
+
|
79 |
+
@functools.wraps(fn)
|
80 |
+
def wrapper():
|
81 |
+
result = fn()
|
82 |
+
return iter(()) if result is None else result
|
83 |
+
|
84 |
+
CHECKS.append(wrapper)
|
85 |
+
return wrapper
|
86 |
+
|
87 |
+
|
88 |
+
def reflow(paragraph):
|
89 |
+
return textwrap.fill(textwrap.dedent(paragraph).strip())
|
90 |
+
|
91 |
+
|
92 |
+
def pip(args):
|
93 |
+
"""Invoke command-line Pip with the specified args.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
A bytestring containing the output of Pip.
|
97 |
+
"""
|
98 |
+
# Suppress the Python 2.7 deprecation warning.
|
99 |
+
PYTHONWARNINGS_KEY = "PYTHONWARNINGS"
|
100 |
+
old_pythonwarnings = os.environ.get(PYTHONWARNINGS_KEY)
|
101 |
+
new_pythonwarnings = "%s%s" % (
|
102 |
+
"ignore:DEPRECATION",
|
103 |
+
",%s" % old_pythonwarnings if old_pythonwarnings else "",
|
104 |
+
)
|
105 |
+
command = [sys.executable, "-m", "pip", "--disable-pip-version-check"]
|
106 |
+
command.extend(args)
|
107 |
+
try:
|
108 |
+
os.environ[PYTHONWARNINGS_KEY] = new_pythonwarnings
|
109 |
+
return subprocess.check_output(command)
|
110 |
+
finally:
|
111 |
+
if old_pythonwarnings is None:
|
112 |
+
del os.environ[PYTHONWARNINGS_KEY]
|
113 |
+
else:
|
114 |
+
os.environ[PYTHONWARNINGS_KEY] = old_pythonwarnings
|
115 |
+
|
116 |
+
|
117 |
+
def which(name):
|
118 |
+
"""Return the path to a binary, or `None` if it's not on the path.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
A bytestring.
|
122 |
+
"""
|
123 |
+
binary = "where" if os.name == "nt" else "which"
|
124 |
+
try:
|
125 |
+
return subprocess.check_output([binary, name])
|
126 |
+
except subprocess.CalledProcessError:
|
127 |
+
return None
|
128 |
+
|
129 |
+
|
130 |
+
def sgetattr(attr, default):
|
131 |
+
"""Get an attribute off the `socket` module, or use a default."""
|
132 |
+
sentinel = object()
|
133 |
+
result = getattr(socket, attr, sentinel)
|
134 |
+
if result is sentinel:
|
135 |
+
print("socket.%s does not exist" % attr)
|
136 |
+
return default
|
137 |
+
else:
|
138 |
+
print("socket.%s = %r" % (attr, result))
|
139 |
+
return result
|
140 |
+
|
141 |
+
|
142 |
+
@check
|
143 |
+
def autoidentify():
|
144 |
+
"""Print the Git hash of this version of `diagnose_tensorboard.py`.
|
145 |
+
|
146 |
+
Given this hash, use `git cat-file blob HASH` to recover the
|
147 |
+
relevant version of the script.
|
148 |
+
"""
|
149 |
+
module = sys.modules[__name__]
|
150 |
+
try:
|
151 |
+
source = inspect.getsource(module).encode("utf-8")
|
152 |
+
except TypeError:
|
153 |
+
logging.info("diagnose_tensorboard.py source unavailable")
|
154 |
+
else:
|
155 |
+
# Git inserts a length-prefix before hashing; cf. `git-hash-object`.
|
156 |
+
blob = b"blob %d\0%s" % (len(source), source)
|
157 |
+
hash = hashlib.sha1(blob).hexdigest()
|
158 |
+
logging.info("diagnose_tensorboard.py version %s", hash)
|
159 |
+
|
160 |
+
|
161 |
+
@check
|
162 |
+
def general():
|
163 |
+
logging.info("sys.version_info: %s", sys.version_info)
|
164 |
+
logging.info("os.name: %s", os.name)
|
165 |
+
na = type("N/A", (object,), {"__repr__": lambda self: "N/A"})
|
166 |
+
logging.info(
|
167 |
+
"os.uname(): %r",
|
168 |
+
getattr(os, "uname", na)(),
|
169 |
+
)
|
170 |
+
logging.info(
|
171 |
+
"sys.getwindowsversion(): %r",
|
172 |
+
getattr(sys, "getwindowsversion", na)(),
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
@check
|
177 |
+
def package_management():
|
178 |
+
conda_meta = os.path.join(sys.prefix, "conda-meta")
|
179 |
+
logging.info("has conda-meta: %s", os.path.exists(conda_meta))
|
180 |
+
logging.info("$VIRTUAL_ENV: %r", os.environ.get("VIRTUAL_ENV"))
|
181 |
+
|
182 |
+
|
183 |
+
@check
|
184 |
+
def installed_packages():
|
185 |
+
freeze = pip(["freeze", "--all"]).decode("utf-8").splitlines()
|
186 |
+
packages = {line.split("==")[0]: line for line in freeze}
|
187 |
+
packages_set = frozenset(packages)
|
188 |
+
|
189 |
+
# For each of the following families, expect exactly one package to be
|
190 |
+
# installed.
|
191 |
+
expect_unique = [
|
192 |
+
frozenset(
|
193 |
+
[
|
194 |
+
"tensorboard",
|
195 |
+
"tb-nightly",
|
196 |
+
"tensorflow-tensorboard",
|
197 |
+
]
|
198 |
+
),
|
199 |
+
frozenset(
|
200 |
+
[
|
201 |
+
"tensorflow",
|
202 |
+
"tensorflow-gpu",
|
203 |
+
"tf-nightly",
|
204 |
+
"tf-nightly-2.0-preview",
|
205 |
+
"tf-nightly-gpu",
|
206 |
+
"tf-nightly-gpu-2.0-preview",
|
207 |
+
]
|
208 |
+
),
|
209 |
+
frozenset(
|
210 |
+
[
|
211 |
+
"tensorflow-estimator",
|
212 |
+
"tensorflow-estimator-2.0-preview",
|
213 |
+
"tf-estimator-nightly",
|
214 |
+
]
|
215 |
+
),
|
216 |
+
]
|
217 |
+
salient_extras = frozenset(["tensorboard-data-server"])
|
218 |
+
|
219 |
+
found_conflict = False
|
220 |
+
for family in expect_unique:
|
221 |
+
actual = family & packages_set
|
222 |
+
for package in actual:
|
223 |
+
logging.info("installed: %s", packages[package])
|
224 |
+
if len(actual) == 0:
|
225 |
+
logging.warning("no installation among: %s", sorted(family))
|
226 |
+
elif len(actual) > 1:
|
227 |
+
logging.warning("conflicting installations: %s", sorted(actual))
|
228 |
+
found_conflict = True
|
229 |
+
for package in sorted(salient_extras & packages_set):
|
230 |
+
logging.info("installed: %s", packages[package])
|
231 |
+
|
232 |
+
if found_conflict:
|
233 |
+
preamble = reflow(
|
234 |
+
"""
|
235 |
+
Conflicting package installations found. Depending on the order
|
236 |
+
of installations and uninstallations, behavior may be undefined.
|
237 |
+
Please uninstall ALL versions of TensorFlow and TensorBoard,
|
238 |
+
then reinstall ONLY the desired version of TensorFlow, which
|
239 |
+
will transitively pull in the proper version of TensorBoard. (If
|
240 |
+
you use TensorBoard without TensorFlow, just reinstall the
|
241 |
+
appropriate version of TensorBoard directly.)
|
242 |
+
"""
|
243 |
+
)
|
244 |
+
packages_to_uninstall = sorted(
|
245 |
+
frozenset().union(*expect_unique) & packages_set
|
246 |
+
)
|
247 |
+
commands = [
|
248 |
+
"pip uninstall %s" % " ".join(packages_to_uninstall),
|
249 |
+
"pip install tensorflow # or `tensorflow-gpu`, or `tf-nightly`, ...",
|
250 |
+
]
|
251 |
+
message = "%s\n\nNamely:\n\n%s" % (
|
252 |
+
preamble,
|
253 |
+
"\n".join("\t%s" % c for c in commands),
|
254 |
+
)
|
255 |
+
yield Suggestion("Fix conflicting installations", message)
|
256 |
+
|
257 |
+
wit_version = packages.get("tensorboard-plugin-wit")
|
258 |
+
if wit_version == "tensorboard-plugin-wit==1.6.0.post2":
|
259 |
+
# This is only incompatible with TensorBoard prior to 2.2.0, but
|
260 |
+
# we just issue a blanket warning so that we don't have to pull
|
261 |
+
# in a `pkg_resources` dep to parse the version.
|
262 |
+
preamble = reflow(
|
263 |
+
"""
|
264 |
+
Versions of the What-If Tool (`tensorboard-plugin-wit`)
|
265 |
+
prior to 1.6.0.post3 are incompatible with some versions of
|
266 |
+
TensorBoard. Please upgrade this package to the latest
|
267 |
+
version to resolve any startup errors:
|
268 |
+
"""
|
269 |
+
)
|
270 |
+
command = "pip install -U tensorboard-plugin-wit"
|
271 |
+
message = "%s\n\n\t%s" % (preamble, command)
|
272 |
+
yield Suggestion("Upgrade `tensorboard-plugin-wit`", message)
|
273 |
+
|
274 |
+
|
275 |
+
@check
|
276 |
+
def tensorboard_python_version():
|
277 |
+
from tensorboard import version
|
278 |
+
|
279 |
+
logging.info("tensorboard.version.VERSION: %r", version.VERSION)
|
280 |
+
|
281 |
+
|
282 |
+
@check
|
283 |
+
def tensorflow_python_version():
|
284 |
+
import tensorflow as tf
|
285 |
+
|
286 |
+
logging.info("tensorflow.__version__: %r", tf.__version__)
|
287 |
+
logging.info("tensorflow.__git_version__: %r", tf.__git_version__)
|
288 |
+
|
289 |
+
|
290 |
+
@check
|
291 |
+
def tensorboard_data_server_version():
|
292 |
+
try:
|
293 |
+
import tensorboard_data_server
|
294 |
+
except ImportError:
|
295 |
+
logging.info("no data server installed")
|
296 |
+
return
|
297 |
+
|
298 |
+
path = tensorboard_data_server.server_binary()
|
299 |
+
logging.info("data server binary: %r", path)
|
300 |
+
if path is None:
|
301 |
+
return
|
302 |
+
|
303 |
+
try:
|
304 |
+
subprocess_output = subprocess.run(
|
305 |
+
[path, "--version"],
|
306 |
+
capture_output=True,
|
307 |
+
check=True,
|
308 |
+
)
|
309 |
+
except subprocess.CalledProcessError as e:
|
310 |
+
logging.info("failed to check binary version: %s", e)
|
311 |
+
else:
|
312 |
+
logging.info(
|
313 |
+
"data server binary version: %s", subprocess_output.stdout.strip()
|
314 |
+
)
|
315 |
+
|
316 |
+
|
317 |
+
@check
|
318 |
+
def tensorboard_binary_path():
|
319 |
+
logging.info("which tensorboard: %r", which("tensorboard"))
|
320 |
+
|
321 |
+
|
322 |
+
@check
|
323 |
+
def addrinfos():
|
324 |
+
sgetattr("has_ipv6", None)
|
325 |
+
family = sgetattr("AF_UNSPEC", 0)
|
326 |
+
socktype = sgetattr("SOCK_STREAM", 0)
|
327 |
+
protocol = 0
|
328 |
+
flags_loopback = sgetattr("AI_ADDRCONFIG", 0)
|
329 |
+
flags_wildcard = sgetattr("AI_PASSIVE", 0)
|
330 |
+
|
331 |
+
hints_loopback = (family, socktype, protocol, flags_loopback)
|
332 |
+
infos_loopback = socket.getaddrinfo(None, 0, *hints_loopback)
|
333 |
+
print("Loopback flags: %r" % (flags_loopback,))
|
334 |
+
print("Loopback infos: %r" % (infos_loopback,))
|
335 |
+
|
336 |
+
hints_wildcard = (family, socktype, protocol, flags_wildcard)
|
337 |
+
infos_wildcard = socket.getaddrinfo(None, 0, *hints_wildcard)
|
338 |
+
print("Wildcard flags: %r" % (flags_wildcard,))
|
339 |
+
print("Wildcard infos: %r" % (infos_wildcard,))
|
340 |
+
|
341 |
+
|
342 |
+
@check
|
343 |
+
def readable_fqdn():
|
344 |
+
# May raise `UnicodeDecodeError` for non-ASCII hostnames:
|
345 |
+
# https://github.com/tensorflow/tensorboard/issues/682
|
346 |
+
try:
|
347 |
+
logging.info("socket.getfqdn(): %r", socket.getfqdn())
|
348 |
+
except UnicodeDecodeError as e:
|
349 |
+
try:
|
350 |
+
binary_hostname = subprocess.check_output(["hostname"]).strip()
|
351 |
+
except subprocess.CalledProcessError:
|
352 |
+
binary_hostname = b"<unavailable>"
|
353 |
+
is_non_ascii = not all(
|
354 |
+
0x20
|
355 |
+
<= (ord(c) if not isinstance(c, int) else c)
|
356 |
+
<= 0x7E # Python 2
|
357 |
+
for c in binary_hostname
|
358 |
+
)
|
359 |
+
if is_non_ascii:
|
360 |
+
message = reflow(
|
361 |
+
"""
|
362 |
+
Your computer's hostname, %r, contains bytes outside of the
|
363 |
+
printable ASCII range. Some versions of Python have trouble
|
364 |
+
working with such names (https://bugs.python.org/issue26227).
|
365 |
+
Consider changing to a hostname that only contains printable
|
366 |
+
ASCII bytes.
|
367 |
+
"""
|
368 |
+
% (binary_hostname,)
|
369 |
+
)
|
370 |
+
yield Suggestion("Use an ASCII hostname", message)
|
371 |
+
else:
|
372 |
+
message = reflow(
|
373 |
+
"""
|
374 |
+
Python can't read your computer's hostname, %r. This can occur
|
375 |
+
if the hostname contains non-ASCII bytes
|
376 |
+
(https://bugs.python.org/issue26227). Consider changing your
|
377 |
+
hostname, rebooting your machine, and rerunning this diagnosis
|
378 |
+
script to see if the problem is resolved.
|
379 |
+
"""
|
380 |
+
% (binary_hostname,)
|
381 |
+
)
|
382 |
+
yield Suggestion("Use a simpler hostname", message)
|
383 |
+
raise e
|
384 |
+
|
385 |
+
|
386 |
+
@check
|
387 |
+
def stat_tensorboardinfo():
|
388 |
+
# We don't use `manager._get_info_dir`, because (a) that requires
|
389 |
+
# TensorBoard, and (b) that creates the directory if it doesn't exist.
|
390 |
+
path = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
|
391 |
+
logging.info("directory: %s", path)
|
392 |
+
try:
|
393 |
+
stat_result = os.stat(path)
|
394 |
+
except OSError as e:
|
395 |
+
if e.errno == errno.ENOENT:
|
396 |
+
# No problem; this is just fine.
|
397 |
+
logging.info(".tensorboard-info directory does not exist")
|
398 |
+
return
|
399 |
+
else:
|
400 |
+
raise
|
401 |
+
logging.info("os.stat(...): %r", stat_result)
|
402 |
+
logging.info("mode: 0o%o", stat_result.st_mode)
|
403 |
+
if stat_result.st_mode & 0o777 != 0o777:
|
404 |
+
preamble = reflow(
|
405 |
+
"""
|
406 |
+
The ".tensorboard-info" directory was created by an old version
|
407 |
+
of TensorBoard, and its permissions are not set correctly; see
|
408 |
+
issue #2010. Change that directory to be world-accessible (may
|
409 |
+
require superuser privilege):
|
410 |
+
"""
|
411 |
+
)
|
412 |
+
# This error should only appear on Unices, so it's okay to use
|
413 |
+
# Unix-specific utilities and shell syntax.
|
414 |
+
quote = getattr(shlex, "quote", None) or pipes.quote # Python <3.3
|
415 |
+
command = "chmod 777 %s" % quote(path)
|
416 |
+
message = "%s\n\n\t%s" % (preamble, command)
|
417 |
+
yield Suggestion('Fix permissions on "%s"' % path, message)
|
418 |
+
|
419 |
+
|
420 |
+
@check
|
421 |
+
def source_trees_without_genfiles():
|
422 |
+
roots = list(sys.path)
|
423 |
+
if "" not in roots:
|
424 |
+
# Catch problems that would occur in a Python interactive shell
|
425 |
+
# (where `""` is prepended to `sys.path`) but not when
|
426 |
+
# `diagnose_tensorboard.py` is run as a standalone script.
|
427 |
+
roots.insert(0, "")
|
428 |
+
|
429 |
+
def has_tensorboard(root):
|
430 |
+
return os.path.isfile(os.path.join(root, "tensorboard", "__init__.py"))
|
431 |
+
|
432 |
+
def has_genfiles(root):
|
433 |
+
sample_genfile = os.path.join("compat", "proto", "summary_pb2.py")
|
434 |
+
return os.path.isfile(os.path.join(root, "tensorboard", sample_genfile))
|
435 |
+
|
436 |
+
def is_bad(root):
|
437 |
+
return has_tensorboard(root) and not has_genfiles(root)
|
438 |
+
|
439 |
+
tensorboard_roots = [root for root in roots if has_tensorboard(root)]
|
440 |
+
bad_roots = [root for root in roots if is_bad(root)]
|
441 |
+
|
442 |
+
logging.info(
|
443 |
+
"tensorboard_roots (%d): %r; bad_roots (%d): %r",
|
444 |
+
len(tensorboard_roots),
|
445 |
+
tensorboard_roots,
|
446 |
+
len(bad_roots),
|
447 |
+
bad_roots,
|
448 |
+
)
|
449 |
+
|
450 |
+
if bad_roots:
|
451 |
+
if bad_roots == [""]:
|
452 |
+
message = reflow(
|
453 |
+
"""
|
454 |
+
Your current directory contains a `tensorboard` Python package
|
455 |
+
that does not include generated files. This can happen if your
|
456 |
+
current directory includes the TensorBoard source tree (e.g.,
|
457 |
+
you are in the TensorBoard Git repository). Consider changing
|
458 |
+
to a different directory.
|
459 |
+
"""
|
460 |
+
)
|
461 |
+
else:
|
462 |
+
preamble = reflow(
|
463 |
+
"""
|
464 |
+
Your Python path contains a `tensorboard` package that does
|
465 |
+
not include generated files. This can happen if your current
|
466 |
+
directory includes the TensorBoard source tree (e.g., you are
|
467 |
+
in the TensorBoard Git repository). The following directories
|
468 |
+
from your Python path may be problematic:
|
469 |
+
"""
|
470 |
+
)
|
471 |
+
roots = []
|
472 |
+
realpaths_seen = set()
|
473 |
+
for root in bad_roots:
|
474 |
+
label = repr(root) if root else "current directory"
|
475 |
+
realpath = os.path.realpath(root)
|
476 |
+
if realpath in realpaths_seen:
|
477 |
+
# virtualenvs on Ubuntu install to both `lib` and `local/lib`;
|
478 |
+
# explicitly call out such duplicates to avoid confusion.
|
479 |
+
label += " (duplicate underlying directory)"
|
480 |
+
realpaths_seen.add(realpath)
|
481 |
+
roots.append(label)
|
482 |
+
message = "%s\n\n%s" % (
|
483 |
+
preamble,
|
484 |
+
"\n".join(" - %s" % s for s in roots),
|
485 |
+
)
|
486 |
+
yield Suggestion(
|
487 |
+
"Avoid `tensorboard` packages without genfiles", message
|
488 |
+
)
|
489 |
+
|
490 |
+
|
491 |
+
# Prefer to include this check last, as its output is long.
|
492 |
+
@check
|
493 |
+
def full_pip_freeze():
|
494 |
+
logging.info(
|
495 |
+
"pip freeze --all:\n%s", pip(["freeze", "--all"]).decode("utf-8")
|
496 |
+
)
|
497 |
+
|
498 |
+
|
499 |
+
def set_up_logging():
|
500 |
+
# Manually install handlers to prevent TensorFlow from stomping the
|
501 |
+
# default configuration if it's imported:
|
502 |
+
# https://github.com/tensorflow/tensorflow/issues/28147
|
503 |
+
logger = logging.getLogger()
|
504 |
+
logger.setLevel(logging.INFO)
|
505 |
+
handler = logging.StreamHandler(sys.stdout)
|
506 |
+
handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
|
507 |
+
logger.addHandler(handler)
|
508 |
+
|
509 |
+
|
510 |
+
def main():
|
511 |
+
set_up_logging()
|
512 |
+
|
513 |
+
print("### Diagnostics")
|
514 |
+
print()
|
515 |
+
|
516 |
+
print("<details>")
|
517 |
+
print("<summary>Diagnostics output</summary>")
|
518 |
+
print()
|
519 |
+
|
520 |
+
markdown_code_fence = "``````" # seems likely to be sufficient
|
521 |
+
print(markdown_code_fence)
|
522 |
+
suggestions = []
|
523 |
+
for (i, check) in enumerate(CHECKS):
|
524 |
+
if i > 0:
|
525 |
+
print()
|
526 |
+
print("--- check: %s" % check.__name__)
|
527 |
+
try:
|
528 |
+
suggestions.extend(check())
|
529 |
+
except Exception:
|
530 |
+
traceback.print_exc(file=sys.stdout)
|
531 |
+
pass
|
532 |
+
print(markdown_code_fence)
|
533 |
+
print()
|
534 |
+
print("</details>")
|
535 |
+
|
536 |
+
for suggestion in suggestions:
|
537 |
+
print()
|
538 |
+
print("### Suggestion: %s" % suggestion.headline)
|
539 |
+
print()
|
540 |
+
print(suggestion.description)
|
541 |
+
|
542 |
+
print()
|
543 |
+
print("### Next steps")
|
544 |
+
print()
|
545 |
+
if suggestions:
|
546 |
+
print(
|
547 |
+
reflow(
|
548 |
+
"""
|
549 |
+
Please try each suggestion enumerated above to determine whether
|
550 |
+
it solves your problem. If none of these suggestions works,
|
551 |
+
please copy ALL of the above output, including the lines
|
552 |
+
containing only backticks, into your GitHub issue or comment. Be
|
553 |
+
sure to redact any sensitive information.
|
554 |
+
"""
|
555 |
+
)
|
556 |
+
)
|
557 |
+
else:
|
558 |
+
print(
|
559 |
+
reflow(
|
560 |
+
"""
|
561 |
+
No action items identified. Please copy ALL of the above output,
|
562 |
+
including the lines containing only backticks, into your GitHub
|
563 |
+
issue or comment. Be sure to redact any sensitive information.
|
564 |
+
"""
|
565 |
+
)
|
566 |
+
)
|
567 |
+
|
568 |
+
|
569 |
+
if __name__ == "__main__":
|
570 |
+
main()
|
dreambooth-for-diffusion/tools/diffusers2ckpt.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
2 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
3 |
+
# Does not convert optimizer state or any other thing.
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
# =================#
|
12 |
+
# UNet Conversion #
|
13 |
+
# =================#
|
14 |
+
|
15 |
+
unet_conversion_map = [
|
16 |
+
# (stable-diffusion, HF Diffusers)
|
17 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
18 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
19 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
20 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
21 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
22 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
23 |
+
("out.0.weight", "conv_norm_out.weight"),
|
24 |
+
("out.0.bias", "conv_norm_out.bias"),
|
25 |
+
("out.2.weight", "conv_out.weight"),
|
26 |
+
("out.2.bias", "conv_out.bias"),
|
27 |
+
]
|
28 |
+
|
29 |
+
unet_conversion_map_resnet = [
|
30 |
+
# (stable-diffusion, HF Diffusers)
|
31 |
+
("in_layers.0", "norm1"),
|
32 |
+
("in_layers.2", "conv1"),
|
33 |
+
("out_layers.0", "norm2"),
|
34 |
+
("out_layers.3", "conv2"),
|
35 |
+
("emb_layers.1", "time_emb_proj"),
|
36 |
+
("skip_connection", "conv_shortcut"),
|
37 |
+
]
|
38 |
+
|
39 |
+
unet_conversion_map_layer = []
|
40 |
+
# hardcoded number of downblocks and resnets/attentions...
|
41 |
+
# would need smarter logic for other networks.
|
42 |
+
for i in range(4):
|
43 |
+
# loop over downblocks/upblocks
|
44 |
+
|
45 |
+
for j in range(2):
|
46 |
+
# loop over resnets/attentions for downblocks
|
47 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
48 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
49 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
50 |
+
|
51 |
+
if i < 3:
|
52 |
+
# no attention layers in down_blocks.3
|
53 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
54 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
55 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
56 |
+
|
57 |
+
for j in range(3):
|
58 |
+
# loop over resnets/attentions for upblocks
|
59 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
60 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
61 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
62 |
+
|
63 |
+
if i > 0:
|
64 |
+
# no attention layers in up_blocks.0
|
65 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
66 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
67 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
68 |
+
|
69 |
+
if i < 3:
|
70 |
+
# no downsample in down_blocks.3
|
71 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
72 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
73 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
74 |
+
|
75 |
+
# no upsample in up_blocks.3
|
76 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
77 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
78 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
79 |
+
|
80 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
81 |
+
sd_mid_atn_prefix = "middle_block.1."
|
82 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
83 |
+
|
84 |
+
for j in range(2):
|
85 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
86 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
87 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
88 |
+
|
89 |
+
|
90 |
+
def convert_unet_state_dict(unet_state_dict):
|
91 |
+
# buyer beware: this is a *brittle* function,
|
92 |
+
# and correct output requires that all of these pieces interact in
|
93 |
+
# the exact order in which I have arranged them.
|
94 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
95 |
+
for sd_name, hf_name in unet_conversion_map:
|
96 |
+
mapping[hf_name] = sd_name
|
97 |
+
for k, v in mapping.items():
|
98 |
+
if "resnets" in k:
|
99 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
100 |
+
v = v.replace(hf_part, sd_part)
|
101 |
+
mapping[k] = v
|
102 |
+
for k, v in mapping.items():
|
103 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
104 |
+
v = v.replace(hf_part, sd_part)
|
105 |
+
mapping[k] = v
|
106 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
107 |
+
return new_state_dict
|
108 |
+
|
109 |
+
|
110 |
+
# ================#
|
111 |
+
# VAE Conversion #
|
112 |
+
# ================#
|
113 |
+
|
114 |
+
vae_conversion_map = [
|
115 |
+
# (stable-diffusion, HF Diffusers)
|
116 |
+
("nin_shortcut", "conv_shortcut"),
|
117 |
+
("norm_out", "conv_norm_out"),
|
118 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
119 |
+
]
|
120 |
+
|
121 |
+
for i in range(4):
|
122 |
+
# down_blocks have two resnets
|
123 |
+
for j in range(2):
|
124 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
125 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
126 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
127 |
+
|
128 |
+
if i < 3:
|
129 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
130 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
131 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
132 |
+
|
133 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
134 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
135 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
136 |
+
|
137 |
+
# up_blocks have three resnets
|
138 |
+
# also, up blocks in hf are numbered in reverse from sd
|
139 |
+
for j in range(3):
|
140 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
141 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
142 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
143 |
+
|
144 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
145 |
+
for i in range(2):
|
146 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
147 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
148 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
149 |
+
|
150 |
+
|
151 |
+
vae_conversion_map_attn = [
|
152 |
+
# (stable-diffusion, HF Diffusers)
|
153 |
+
("norm.", "group_norm."),
|
154 |
+
("q.", "query."),
|
155 |
+
("k.", "key."),
|
156 |
+
("v.", "value."),
|
157 |
+
("proj_out.", "proj_attn."),
|
158 |
+
]
|
159 |
+
|
160 |
+
|
161 |
+
def reshape_weight_for_sd(w):
|
162 |
+
# convert HF linear weights to SD conv2d weights
|
163 |
+
return w.reshape(*w.shape, 1, 1)
|
164 |
+
|
165 |
+
|
166 |
+
def convert_vae_state_dict(vae_state_dict):
|
167 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
168 |
+
for k, v in mapping.items():
|
169 |
+
for sd_part, hf_part in vae_conversion_map:
|
170 |
+
v = v.replace(hf_part, sd_part)
|
171 |
+
mapping[k] = v
|
172 |
+
for k, v in mapping.items():
|
173 |
+
if "attentions" in k:
|
174 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
175 |
+
v = v.replace(hf_part, sd_part)
|
176 |
+
mapping[k] = v
|
177 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
178 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
179 |
+
for k, v in new_state_dict.items():
|
180 |
+
for weight_name in weights_to_convert:
|
181 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
182 |
+
print(f"Reshaping {k} for SD format")
|
183 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
184 |
+
return new_state_dict
|
185 |
+
|
186 |
+
|
187 |
+
# =========================#
|
188 |
+
# Text Encoder Conversion #
|
189 |
+
# =========================#
|
190 |
+
# pretty much a no-op
|
191 |
+
|
192 |
+
|
193 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
194 |
+
return text_enc_dict
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
parser = argparse.ArgumentParser()
|
199 |
+
|
200 |
+
parser.add_argument("model_path", default=None, type=str, help="Path to the model to convert.")
|
201 |
+
parser.add_argument("checkpoint_path", default=None, type=str, help="Path to the output model.")
|
202 |
+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
203 |
+
|
204 |
+
args = parser.parse_args()
|
205 |
+
|
206 |
+
assert args.model_path is not None, "Must provide a model path!"
|
207 |
+
|
208 |
+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
|
209 |
+
|
210 |
+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
|
211 |
+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
|
212 |
+
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
|
213 |
+
|
214 |
+
# Convert the UNet model
|
215 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
216 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
217 |
+
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
218 |
+
|
219 |
+
# Convert the VAE model
|
220 |
+
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
221 |
+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
222 |
+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
223 |
+
|
224 |
+
# Convert the text encoder model
|
225 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
226 |
+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
227 |
+
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
228 |
+
|
229 |
+
# Put together new checkpoint
|
230 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
231 |
+
if args.half:
|
232 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
233 |
+
state_dict = {"state_dict": state_dict}
|
234 |
+
torch.save(state_dict, args.checkpoint_path)
|
dreambooth-for-diffusion/tools/handle_images.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, cv2, argparse
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# 修改透明背景为白色
|
5 |
+
def transparence2white(img):
|
6 |
+
sp=img.shape
|
7 |
+
width=sp[0]
|
8 |
+
height=sp[1]
|
9 |
+
for yh in range(height):
|
10 |
+
for xw in range(width):
|
11 |
+
color_d=img[xw,yh]
|
12 |
+
if(color_d[3]==0):
|
13 |
+
img[xw,yh]=[255,255,255,255]
|
14 |
+
return img
|
15 |
+
|
16 |
+
# 修改透明背景为黑色
|
17 |
+
def transparence2black(img):
|
18 |
+
sp = img.shape
|
19 |
+
width = sp[0]
|
20 |
+
height = sp[1]
|
21 |
+
for yh in range(height):
|
22 |
+
for xw in range(width):
|
23 |
+
color_d = img[xw, yh]
|
24 |
+
if (color_d[3] == 0):
|
25 |
+
img[xw, yh] = [0, 0, 0, 255]
|
26 |
+
return img
|
27 |
+
|
28 |
+
# 中心裁剪
|
29 |
+
def center_crop(img, crop_size):
|
30 |
+
h, w = img.shape[:2]
|
31 |
+
th, tw = crop_size
|
32 |
+
i = int(round((h - th) / 2.))
|
33 |
+
j = int(round((w - tw) / 2.))
|
34 |
+
return img[i:i + th, j:j + tw]
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
|
39 |
+
parser.add_argument("origin_image_path", default=None, type=str, help="Path to the images to convert.")
|
40 |
+
parser.add_argument("output_image_path", default=None, type=str, help="Path to the output images.")
|
41 |
+
parser.add_argument("--width", default=512, type=int, help="Width of the output images.")
|
42 |
+
parser.add_argument("--height", default=512, type=int, help="Height of the output images.")
|
43 |
+
parser.add_argument("--png", action="store_true", help="convert the transparent background to white/black.")
|
44 |
+
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
path = args.origin_image_path
|
49 |
+
save_path = args.output_image_path
|
50 |
+
if not os.path.exists(save_path):
|
51 |
+
os.makedirs(save_path)
|
52 |
+
else:
|
53 |
+
print('The folder already exists, please check the path.')
|
54 |
+
|
55 |
+
# 只读取png、jpg、jpeg、bmp、webp格式
|
56 |
+
allow_suffix = ['png', 'jpg', 'jpeg', 'bmp', 'webp']
|
57 |
+
image_list = os.listdir(path)
|
58 |
+
image_list = [os.path.join(path, image) for image in image_list if image.split('.')[-1] in allow_suffix]
|
59 |
+
|
60 |
+
for file, i in zip(image_list, range(1, len(image_list)+1)):
|
61 |
+
print('Processing image: %s' % file)
|
62 |
+
try:
|
63 |
+
img = cv2.imread(file, -1)
|
64 |
+
|
65 |
+
# 对图像进行center crop, 保证图像的长宽比为1:1, crop_size为图像的较短边
|
66 |
+
crop_size = min(img.shape[:2])
|
67 |
+
img = center_crop(img, (crop_size, crop_size))
|
68 |
+
|
69 |
+
# 缩放图像到512*512
|
70 |
+
width = args.width
|
71 |
+
height = args.height
|
72 |
+
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
73 |
+
|
74 |
+
# 如果是透明图,将透明背景转换为白色或者黑色
|
75 |
+
if args.png:
|
76 |
+
img = transparence2black(img)
|
77 |
+
|
78 |
+
cv2.imwrite(os.path.join(save_path, str(i).zfill(4) + ".jpg"), img)
|
79 |
+
except Exception as e:
|
80 |
+
print(e)
|
81 |
+
os.remove(path+file) # 删除无效图片
|
82 |
+
print("删除无效图片: " + path+file)
|
dreambooth-for-diffusion/tools/label_images.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from AUTOMATC1111
|
2 |
+
# maybe modified by Nyanko Lepsoni
|
3 |
+
# modified by crosstyan
|
4 |
+
import os.path
|
5 |
+
import re
|
6 |
+
import tempfile
|
7 |
+
import argparse
|
8 |
+
import glob
|
9 |
+
import zipfile
|
10 |
+
import deepdanbooru as dd
|
11 |
+
import tensorflow as tf
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from basicsr.utils.download_util import load_file_from_url
|
15 |
+
from PIL import Image
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
re_special = re.compile(r"([\\()])")
|
19 |
+
|
20 |
+
def get_deepbooru_tags_model(model_path: str):
|
21 |
+
if not os.path.exists(os.path.join(model_path, "project.json")):
|
22 |
+
is_abs = os.path.isabs(model_path)
|
23 |
+
if not is_abs:
|
24 |
+
model_path = os.path.abspath(model_path)
|
25 |
+
|
26 |
+
load_file_from_url(
|
27 |
+
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
28 |
+
model_path,
|
29 |
+
)
|
30 |
+
with zipfile.ZipFile(
|
31 |
+
os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r"
|
32 |
+
) as zip_ref:
|
33 |
+
zip_ref.extractall(model_path)
|
34 |
+
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
35 |
+
|
36 |
+
tags = dd.project.load_tags_from_project(model_path)
|
37 |
+
model = dd.project.load_model_from_project(model_path, compile_model=False)
|
38 |
+
return model, tags
|
39 |
+
|
40 |
+
|
41 |
+
def get_deepbooru_tags_from_model(
|
42 |
+
model,
|
43 |
+
tags,
|
44 |
+
pil_image,
|
45 |
+
threshold,
|
46 |
+
alpha_sort=False,
|
47 |
+
use_spaces=True,
|
48 |
+
use_escape=True,
|
49 |
+
include_ranks=False,
|
50 |
+
):
|
51 |
+
width = model.input_shape[2]
|
52 |
+
height = model.input_shape[1]
|
53 |
+
image = np.array(pil_image)
|
54 |
+
image = tf.image.resize(
|
55 |
+
image,
|
56 |
+
size=(height, width),
|
57 |
+
method=tf.image.ResizeMethod.AREA,
|
58 |
+
preserve_aspect_ratio=True,
|
59 |
+
)
|
60 |
+
image = image.numpy() # EagerTensor to np.array
|
61 |
+
image = dd.image.transform_and_pad_image(image, width, height)
|
62 |
+
image = image / 255.0
|
63 |
+
image_shape = image.shape
|
64 |
+
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
65 |
+
|
66 |
+
y = model.predict(image)[0]
|
67 |
+
|
68 |
+
result_dict = {}
|
69 |
+
|
70 |
+
for i, tag in enumerate(tags):
|
71 |
+
result_dict[tag] = y[i]
|
72 |
+
|
73 |
+
unsorted_tags_in_theshold = []
|
74 |
+
result_tags_print = []
|
75 |
+
for tag in tags:
|
76 |
+
if result_dict[tag] >= threshold:
|
77 |
+
if tag.startswith("rating:"):
|
78 |
+
continue
|
79 |
+
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
80 |
+
result_tags_print.append(f"{result_dict[tag]} {tag}")
|
81 |
+
|
82 |
+
# sort tags
|
83 |
+
result_tags_out = []
|
84 |
+
sort_ndx = 0
|
85 |
+
if alpha_sort:
|
86 |
+
sort_ndx = 1
|
87 |
+
|
88 |
+
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
89 |
+
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
90 |
+
for weight, tag in unsorted_tags_in_theshold:
|
91 |
+
tag_outformat = tag
|
92 |
+
if use_spaces:
|
93 |
+
tag_outformat = tag_outformat.replace("_", " ")
|
94 |
+
if use_escape:
|
95 |
+
tag_outformat = re.sub(re_special, r"\\\1", tag_outformat)
|
96 |
+
if include_ranks:
|
97 |
+
tag_outformat = f"({tag_outformat}:{weight:.3f})"
|
98 |
+
|
99 |
+
result_tags_out.append(tag_outformat)
|
100 |
+
|
101 |
+
# print("\n".join(sorted(result_tags_print, reverse=True)))
|
102 |
+
|
103 |
+
return ", ".join(result_tags_out)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
parser.add_argument("--path", type=str, default=".")
|
109 |
+
parser.add_argument("--threshold", type=int, default=0.75)
|
110 |
+
parser.add_argument("--alpha_sort", type=bool, default=False)
|
111 |
+
parser.add_argument("--use_spaces", type=bool, default=True)
|
112 |
+
parser.add_argument("--use_escape", type=bool, default=True)
|
113 |
+
parser.add_argument("--model_path", type=str, default="")
|
114 |
+
parser.add_argument("--include_ranks", type=bool, default=False)
|
115 |
+
|
116 |
+
args = parser.parse_args()
|
117 |
+
|
118 |
+
global model_path
|
119 |
+
model_path:str
|
120 |
+
if args.model_path == "":
|
121 |
+
script_path = os.path.realpath(__file__)
|
122 |
+
default_model_path = os.path.join(os.path.dirname(script_path), "deepdanbooru-models")
|
123 |
+
# print("No model path specified, using default model path: {}".format(default_model_path))
|
124 |
+
model_path = default_model_path
|
125 |
+
else:
|
126 |
+
model_path = args.model_path
|
127 |
+
|
128 |
+
types = ('*.jpg', '*.png', '*.jpeg', '*.gif', '*.webp', '*.bmp')
|
129 |
+
files_grabbed = []
|
130 |
+
for files in types:
|
131 |
+
files_grabbed.extend(glob.glob(os.path.join(args.path, files)))
|
132 |
+
# print(glob.glob(args.path + files))
|
133 |
+
|
134 |
+
model, tags = get_deepbooru_tags_model(model_path)
|
135 |
+
for image_path in tqdm(files_grabbed, desc="Processing"):
|
136 |
+
image = Image.open(image_path).convert("RGB")
|
137 |
+
prompt = get_deepbooru_tags_from_model(
|
138 |
+
model,
|
139 |
+
tags,
|
140 |
+
image,
|
141 |
+
args.threshold,
|
142 |
+
alpha_sort=args.alpha_sort,
|
143 |
+
use_spaces=args.use_spaces,
|
144 |
+
use_escape=args.use_escape,
|
145 |
+
include_ranks=args.include_ranks,
|
146 |
+
)
|
147 |
+
image_name = os.path.splitext(os.path.basename(image_path))[0]
|
148 |
+
txt_filename = os.path.join(args.path, f"{image_name}.txt")
|
149 |
+
# print(f"writing {txt_filename}: {prompt}")
|
150 |
+
with open(txt_filename, 'w') as f:
|
151 |
+
f.write(prompt)
|
152 |
+
|
dreambooth-for-diffusion/tools/test_cuda.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
print(torch.cuda.is_available())
|
dreambooth-for-diffusion/tools/train_dreambooth.py
ADDED
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import hashlib
|
3 |
+
import itertools
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
|
14 |
+
from accelerate import Accelerator
|
15 |
+
from accelerate.logging import get_logger
|
16 |
+
from accelerate.utils import set_seed
|
17 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
18 |
+
from diffusers.optimization import get_scheduler
|
19 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
20 |
+
from PIL import Image
|
21 |
+
from torchvision import transforms
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
24 |
+
|
25 |
+
|
26 |
+
logger = get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
def parse_args(input_args=None):
|
30 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
31 |
+
parser.add_argument(
|
32 |
+
"--pretrained_model_name_or_path",
|
33 |
+
type=str,
|
34 |
+
default=None,
|
35 |
+
required=True,
|
36 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--revision",
|
40 |
+
type=str,
|
41 |
+
default=None,
|
42 |
+
required=False,
|
43 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--tokenizer_name",
|
47 |
+
type=str,
|
48 |
+
default=None,
|
49 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--instance_data_dir",
|
53 |
+
type=str,
|
54 |
+
default=None,
|
55 |
+
required=True,
|
56 |
+
help="A folder containing the training data of instance images.",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--class_data_dir",
|
60 |
+
type=str,
|
61 |
+
default=None,
|
62 |
+
required=False,
|
63 |
+
help="A folder containing the training data of class images.",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--instance_prompt",
|
67 |
+
type=str,
|
68 |
+
default=None,
|
69 |
+
help="The prompt with identifier specifying the instance",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--class_prompt",
|
73 |
+
type=str,
|
74 |
+
default=None,
|
75 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--with_prior_preservation",
|
79 |
+
default=False,
|
80 |
+
action="store_true",
|
81 |
+
help="Flag to add prior preservation loss.",
|
82 |
+
)
|
83 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
|
84 |
+
parser.add_argument(
|
85 |
+
"--num_class_images",
|
86 |
+
type=int,
|
87 |
+
default=100,
|
88 |
+
help=(
|
89 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
90 |
+
" sampled with class_prompt."
|
91 |
+
),
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--output_dir",
|
95 |
+
type=str,
|
96 |
+
default="text-inversion-model",
|
97 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
98 |
+
)
|
99 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
100 |
+
parser.add_argument(
|
101 |
+
"--resolution",
|
102 |
+
type=int,
|
103 |
+
default=512,
|
104 |
+
help=(
|
105 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
106 |
+
" resolution"
|
107 |
+
),
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--use_filename_as_label", action="store_true", help="Uses the filename as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance"
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--use_txt_as_label", action="store_true", help="Uses the filename.txt file's content as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance"
|
117 |
+
)
|
118 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
|
119 |
+
parser.add_argument(
|
120 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
124 |
+
)
|
125 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
126 |
+
parser.add_argument(
|
127 |
+
"--max_train_steps",
|
128 |
+
type=int,
|
129 |
+
default=None,
|
130 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--gradient_accumulation_steps",
|
134 |
+
type=int,
|
135 |
+
default=1,
|
136 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--gradient_checkpointing",
|
140 |
+
action="store_true",
|
141 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--learning_rate",
|
145 |
+
type=float,
|
146 |
+
default=5e-6,
|
147 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--scale_lr",
|
151 |
+
action="store_true",
|
152 |
+
default=False,
|
153 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--lr_scheduler",
|
157 |
+
type=str,
|
158 |
+
default="constant",
|
159 |
+
help=(
|
160 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
161 |
+
' "constant", "constant_with_warmup"]'
|
162 |
+
),
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
169 |
+
)
|
170 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
171 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
172 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
173 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
174 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
175 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
176 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
177 |
+
parser.add_argument(
|
178 |
+
"--hub_model_id",
|
179 |
+
type=str,
|
180 |
+
default=None,
|
181 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--logging_dir",
|
185 |
+
type=str,
|
186 |
+
default="logs",
|
187 |
+
help=(
|
188 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
189 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
190 |
+
),
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--log_with",
|
194 |
+
type=str,
|
195 |
+
default="tensorboard",
|
196 |
+
choices=["tensorboard", "wandb"]
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--mixed_precision",
|
200 |
+
type=str,
|
201 |
+
default="no",
|
202 |
+
choices=["no", "fp16", "bf16"],
|
203 |
+
help=(
|
204 |
+
"Whether to use mixed precision. Choose"
|
205 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
206 |
+
"and an Nvidia Ampere GPU."
|
207 |
+
),
|
208 |
+
)
|
209 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
210 |
+
parser.add_argument("--save_model_every_n_steps", type=int)
|
211 |
+
parser.add_argument("--auto_test_model", action="store_true", help="Whether or not to automatically test the model after saving it")
|
212 |
+
parser.add_argument("--test_prompt", type=str, default="A photo of a cat", help="The prompt to use for testing the model.")
|
213 |
+
parser.add_argument("--test_prompts_file", type=str, default=None, help="The file containing the prompts to use for testing the model.example: test_prompts.txt, each line is a prompt")
|
214 |
+
parser.add_argument("--test_negative_prompt", type=str, default="", help="The negative prompt to use for testing the model.")
|
215 |
+
parser.add_argument("--test_seed", type=int, default=42, help="The seed to use for testing the model.")
|
216 |
+
parser.add_argument("--test_num_per_prompt", type=int, default=1, help="The number of images to generate per prompt.")
|
217 |
+
|
218 |
+
if input_args is not None:
|
219 |
+
args = parser.parse_args(input_args)
|
220 |
+
else:
|
221 |
+
args = parser.parse_args()
|
222 |
+
|
223 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
224 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
225 |
+
args.local_rank = env_local_rank
|
226 |
+
|
227 |
+
if args.instance_data_dir is None:
|
228 |
+
raise ValueError("You must specify a train data directory.")
|
229 |
+
|
230 |
+
if args.with_prior_preservation:
|
231 |
+
if args.class_data_dir is None:
|
232 |
+
raise ValueError("You must specify a data directory for class images.")
|
233 |
+
if args.class_prompt is None:
|
234 |
+
raise ValueError("You must specify prompt for class images.")
|
235 |
+
|
236 |
+
return args
|
237 |
+
|
238 |
+
# turns a path into a filename without the extension
|
239 |
+
def get_filename(path):
|
240 |
+
return path.stem
|
241 |
+
|
242 |
+
def get_label_from_txt(path):
|
243 |
+
txt_path = path.with_suffix(".txt") # get the path to the .txt file
|
244 |
+
if txt_path.exists():
|
245 |
+
with open(txt_path, "r") as f:
|
246 |
+
return f.read()
|
247 |
+
else:
|
248 |
+
return ""
|
249 |
+
|
250 |
+
class DreamBoothDataset(Dataset):
|
251 |
+
"""
|
252 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
253 |
+
It pre-processes the images and the tokenizes prompts.
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
instance_data_root,
|
259 |
+
instance_prompt,
|
260 |
+
tokenizer,
|
261 |
+
class_data_root=None,
|
262 |
+
class_prompt=None,
|
263 |
+
size=512,
|
264 |
+
center_crop=False,
|
265 |
+
use_filename_as_label=False,
|
266 |
+
use_txt_as_label=False,
|
267 |
+
):
|
268 |
+
self.size = size
|
269 |
+
self.center_crop = center_crop
|
270 |
+
self.tokenizer = tokenizer
|
271 |
+
|
272 |
+
self.instance_data_root = Path(instance_data_root)
|
273 |
+
if not self.instance_data_root.exists():
|
274 |
+
raise ValueError("Instance images root doesn't exists.")
|
275 |
+
|
276 |
+
self.instance_images_path = list(self.instance_data_root.glob("*.jpg")) + list(self.instance_data_root.glob("*.png"))
|
277 |
+
self.num_instance_images = len(self.instance_images_path)
|
278 |
+
self.instance_prompt = instance_prompt
|
279 |
+
self.use_filename_as_label = use_filename_as_label
|
280 |
+
self.use_txt_as_label = use_txt_as_label
|
281 |
+
self._length = self.num_instance_images
|
282 |
+
|
283 |
+
if class_data_root is not None:
|
284 |
+
self.class_data_root = Path(class_data_root)
|
285 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
286 |
+
self.class_images_path = list(self.class_data_root.glob("*.jpg")) + list(self.class_data_root.glob("*.png"))
|
287 |
+
self.num_class_images = len(self.class_images_path)
|
288 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
289 |
+
self.class_prompt = class_prompt
|
290 |
+
else:
|
291 |
+
self.class_data_root = None
|
292 |
+
|
293 |
+
self.image_transforms = transforms.Compose(
|
294 |
+
[
|
295 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
296 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
297 |
+
transforms.ToTensor(),
|
298 |
+
transforms.Normalize([0.5], [0.5]),
|
299 |
+
]
|
300 |
+
)
|
301 |
+
|
302 |
+
def __len__(self):
|
303 |
+
return self._length
|
304 |
+
|
305 |
+
def __getitem__(self, index):
|
306 |
+
example = {}
|
307 |
+
path = self.instance_images_path[index % self.num_instance_images]
|
308 |
+
prompt = get_filename(path) if self.use_filename_as_label else self.instance_prompt
|
309 |
+
prompt = get_label_from_txt(path) if self.use_txt_as_label else prompt
|
310 |
+
|
311 |
+
print("prompt", prompt)
|
312 |
+
|
313 |
+
instance_image = Image.open(path)
|
314 |
+
if not instance_image.mode == "RGB":
|
315 |
+
instance_image = instance_image.convert("RGB")
|
316 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
317 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
318 |
+
prompt,
|
319 |
+
padding="do_not_pad",
|
320 |
+
truncation=True,
|
321 |
+
max_length=self.tokenizer.model_max_length,
|
322 |
+
).input_ids
|
323 |
+
|
324 |
+
if self.class_data_root:
|
325 |
+
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
326 |
+
if not class_image.mode == "RGB":
|
327 |
+
class_image = class_image.convert("RGB")
|
328 |
+
example["class_images"] = self.image_transforms(class_image)
|
329 |
+
example["class_prompt_ids"] = self.tokenizer(
|
330 |
+
self.class_prompt,
|
331 |
+
padding="do_not_pad",
|
332 |
+
truncation=True,
|
333 |
+
max_length=self.tokenizer.model_max_length,
|
334 |
+
).input_ids
|
335 |
+
|
336 |
+
return example
|
337 |
+
|
338 |
+
|
339 |
+
class PromptDataset(Dataset):
|
340 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
341 |
+
|
342 |
+
def __init__(self, prompt, num_samples):
|
343 |
+
self.prompt = prompt
|
344 |
+
self.num_samples = num_samples
|
345 |
+
|
346 |
+
def __len__(self):
|
347 |
+
return self.num_samples
|
348 |
+
|
349 |
+
def __getitem__(self, index):
|
350 |
+
example = {}
|
351 |
+
example["prompt"] = self.prompt
|
352 |
+
example["index"] = index
|
353 |
+
return example
|
354 |
+
|
355 |
+
|
356 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
357 |
+
if token is None:
|
358 |
+
token = HfFolder.get_token()
|
359 |
+
if organization is None:
|
360 |
+
username = whoami(token)["name"]
|
361 |
+
return f"{username}/{model_id}"
|
362 |
+
else:
|
363 |
+
return f"{organization}/{model_id}"
|
364 |
+
|
365 |
+
def test_model(folder, args):
|
366 |
+
if args.test_prompts_file is not None:
|
367 |
+
with open(args.test_prompts_file, "r") as f:
|
368 |
+
prompts = f.read().splitlines()
|
369 |
+
else:
|
370 |
+
prompts = [args.test_prompt]
|
371 |
+
|
372 |
+
test_path = os.path.join(folder, "test")
|
373 |
+
if not os.path.exists(test_path):
|
374 |
+
os.makedirs(test_path)
|
375 |
+
|
376 |
+
print("Testing the model...")
|
377 |
+
from diffusers import DDIMScheduler
|
378 |
+
|
379 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
380 |
+
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
|
381 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
382 |
+
folder,
|
383 |
+
torch_dtype=torch_dtype,
|
384 |
+
safety_checker=None,
|
385 |
+
load_in_8bit=True,
|
386 |
+
scheduler = DDIMScheduler(
|
387 |
+
beta_start=0.00085,
|
388 |
+
beta_end=0.012,
|
389 |
+
beta_schedule="scaled_linear",
|
390 |
+
clip_sample=False,
|
391 |
+
set_alpha_to_one=False,
|
392 |
+
),
|
393 |
+
)
|
394 |
+
pipeline.set_progress_bar_config(disable=True)
|
395 |
+
pipeline.enable_attention_slicing()
|
396 |
+
pipeline = pipeline.to(device)
|
397 |
+
|
398 |
+
torch.manual_seed(args.test_seed)
|
399 |
+
with torch.autocast('cuda'):
|
400 |
+
for prompt in prompts:
|
401 |
+
print(f"Generating test images for prompt: {prompt}")
|
402 |
+
test_images = pipeline(
|
403 |
+
prompt=prompt,
|
404 |
+
width=512,
|
405 |
+
height=512,
|
406 |
+
negative_prompt=args.test_negative_prompt,
|
407 |
+
num_inference_steps=30,
|
408 |
+
num_images_per_prompt=args.test_num_per_prompt,
|
409 |
+
).images
|
410 |
+
|
411 |
+
for index, image in enumerate(test_images):
|
412 |
+
image.save(f"{test_path}/{prompt}_{index}.png")
|
413 |
+
|
414 |
+
del pipeline
|
415 |
+
if torch.cuda.is_available():
|
416 |
+
torch.cuda.empty_cache()
|
417 |
+
|
418 |
+
print(f"Test completed.The examples are saved in {test_path}")
|
419 |
+
|
420 |
+
|
421 |
+
def save_model(accelerator, unet, text_encoder, args, step=None):
|
422 |
+
unet = accelerator.unwrap_model(unet)
|
423 |
+
text_encoder = accelerator.unwrap_model(text_encoder)
|
424 |
+
|
425 |
+
if step == None:
|
426 |
+
folder = args.output_dir
|
427 |
+
else:
|
428 |
+
folder = args.output_dir + "-Step-" + str(step)
|
429 |
+
|
430 |
+
print("Saving Model Checkpoint...")
|
431 |
+
print("Directory: " + folder)
|
432 |
+
|
433 |
+
# Create the pipeline using using the trained modules and save it.
|
434 |
+
if accelerator.is_main_process:
|
435 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
436 |
+
args.pretrained_model_name_or_path,
|
437 |
+
unet=unet,
|
438 |
+
text_encoder=text_encoder,
|
439 |
+
revision=args.revision,
|
440 |
+
)
|
441 |
+
pipeline.save_pretrained(folder)
|
442 |
+
del pipeline
|
443 |
+
if torch.cuda.is_available():
|
444 |
+
torch.cuda.empty_cache()
|
445 |
+
|
446 |
+
if args.auto_test_model:
|
447 |
+
print("Testing Model...")
|
448 |
+
test_model(folder, args)
|
449 |
+
|
450 |
+
if args.push_to_hub:
|
451 |
+
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
452 |
+
|
453 |
+
|
454 |
+
def main(args):
|
455 |
+
logging_dir = Path(args.logging_dir)
|
456 |
+
|
457 |
+
accelerator = Accelerator(
|
458 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
459 |
+
mixed_precision=args.mixed_precision,
|
460 |
+
log_with=args.log_with,
|
461 |
+
logging_dir=logging_dir,
|
462 |
+
)
|
463 |
+
|
464 |
+
|
465 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
466 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
467 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
468 |
+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
469 |
+
raise ValueError(
|
470 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
471 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
472 |
+
)
|
473 |
+
|
474 |
+
if args.seed is not None:
|
475 |
+
set_seed(args.seed)
|
476 |
+
|
477 |
+
if args.with_prior_preservation:
|
478 |
+
class_images_dir = Path(args.class_data_dir)
|
479 |
+
if not class_images_dir.exists():
|
480 |
+
class_images_dir.mkdir(parents=True)
|
481 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
482 |
+
|
483 |
+
if cur_class_images < args.num_class_images:
|
484 |
+
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
485 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
486 |
+
args.pretrained_model_name_or_path,
|
487 |
+
torch_dtype=torch_dtype,
|
488 |
+
safety_checker=None,
|
489 |
+
revision=args.revision,
|
490 |
+
)
|
491 |
+
pipeline.set_progress_bar_config(disable=True)
|
492 |
+
|
493 |
+
num_new_images = args.num_class_images - cur_class_images
|
494 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
495 |
+
|
496 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
497 |
+
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
498 |
+
|
499 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
500 |
+
pipeline.to(accelerator.device)
|
501 |
+
|
502 |
+
for example in tqdm(
|
503 |
+
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
504 |
+
):
|
505 |
+
images = pipeline(example["prompt"]).images
|
506 |
+
|
507 |
+
for i, image in enumerate(images):
|
508 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
509 |
+
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
510 |
+
image.save(image_filename)
|
511 |
+
|
512 |
+
del pipeline
|
513 |
+
if torch.cuda.is_available():
|
514 |
+
torch.cuda.empty_cache()
|
515 |
+
|
516 |
+
# Handle the repository creation
|
517 |
+
if accelerator.is_main_process:
|
518 |
+
if args.push_to_hub:
|
519 |
+
if args.hub_model_id is None:
|
520 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
521 |
+
else:
|
522 |
+
repo_name = args.hub_model_id
|
523 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
524 |
+
|
525 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
526 |
+
if "step_*" not in gitignore:
|
527 |
+
gitignore.write("step_*\n")
|
528 |
+
if "epoch_*" not in gitignore:
|
529 |
+
gitignore.write("epoch_*\n")
|
530 |
+
elif args.output_dir is not None:
|
531 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
532 |
+
|
533 |
+
# Load the tokenizer
|
534 |
+
if args.tokenizer_name:
|
535 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
536 |
+
args.tokenizer_name,
|
537 |
+
revision=args.revision,
|
538 |
+
)
|
539 |
+
elif args.pretrained_model_name_or_path:
|
540 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
541 |
+
args.pretrained_model_name_or_path,
|
542 |
+
subfolder="tokenizer",
|
543 |
+
revision=args.revision,
|
544 |
+
)
|
545 |
+
|
546 |
+
# Load models and create wrapper for stable diffusion
|
547 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
548 |
+
args.pretrained_model_name_or_path,
|
549 |
+
subfolder="text_encoder",
|
550 |
+
revision=args.revision,
|
551 |
+
)
|
552 |
+
vae = AutoencoderKL.from_pretrained(
|
553 |
+
args.pretrained_model_name_or_path,
|
554 |
+
subfolder="vae",
|
555 |
+
revision=args.revision,
|
556 |
+
)
|
557 |
+
unet = UNet2DConditionModel.from_pretrained(
|
558 |
+
args.pretrained_model_name_or_path,
|
559 |
+
subfolder="unet",
|
560 |
+
revision=args.revision,
|
561 |
+
)
|
562 |
+
|
563 |
+
vae.requires_grad_(False)
|
564 |
+
if not args.train_text_encoder:
|
565 |
+
text_encoder.requires_grad_(False)
|
566 |
+
|
567 |
+
if args.gradient_checkpointing:
|
568 |
+
unet.enable_gradient_checkpointing()
|
569 |
+
if args.train_text_encoder:
|
570 |
+
text_encoder.gradient_checkpointing_enable()
|
571 |
+
|
572 |
+
if args.scale_lr:
|
573 |
+
args.learning_rate = (
|
574 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
575 |
+
)
|
576 |
+
|
577 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
578 |
+
if args.use_8bit_adam:
|
579 |
+
try:
|
580 |
+
import bitsandbytes as bnb
|
581 |
+
except ImportError:
|
582 |
+
raise ImportError(
|
583 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
584 |
+
)
|
585 |
+
|
586 |
+
optimizer_class = bnb.optim.AdamW8bit
|
587 |
+
else:
|
588 |
+
optimizer_class = torch.optim.AdamW
|
589 |
+
|
590 |
+
params_to_optimize = (
|
591 |
+
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
592 |
+
)
|
593 |
+
optimizer = optimizer_class(
|
594 |
+
params_to_optimize,
|
595 |
+
lr=args.learning_rate,
|
596 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
597 |
+
weight_decay=args.adam_weight_decay,
|
598 |
+
eps=args.adam_epsilon,
|
599 |
+
)
|
600 |
+
|
601 |
+
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
602 |
+
|
603 |
+
train_dataset = DreamBoothDataset(
|
604 |
+
instance_data_root=args.instance_data_dir,
|
605 |
+
instance_prompt=args.instance_prompt,
|
606 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
607 |
+
class_prompt=args.class_prompt,
|
608 |
+
tokenizer=tokenizer,
|
609 |
+
size=args.resolution,
|
610 |
+
center_crop=args.center_crop,
|
611 |
+
use_filename_as_label=args.use_filename_as_label,
|
612 |
+
use_txt_as_label=args.use_txt_as_label,
|
613 |
+
)
|
614 |
+
|
615 |
+
def collate_fn(examples):
|
616 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
617 |
+
pixel_values = [example["instance_images"] for example in examples]
|
618 |
+
|
619 |
+
# Concat class and instance examples for prior preservation.
|
620 |
+
# We do this to avoid doing two forward passes.
|
621 |
+
if args.with_prior_preservation:
|
622 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
623 |
+
pixel_values += [example["class_images"] for example in examples]
|
624 |
+
|
625 |
+
pixel_values = torch.stack(pixel_values)
|
626 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
627 |
+
|
628 |
+
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
629 |
+
|
630 |
+
batch = {
|
631 |
+
"input_ids": input_ids,
|
632 |
+
"pixel_values": pixel_values,
|
633 |
+
}
|
634 |
+
return batch
|
635 |
+
|
636 |
+
train_dataloader = torch.utils.data.DataLoader(
|
637 |
+
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
638 |
+
)
|
639 |
+
|
640 |
+
# Scheduler and math around the number of training steps.
|
641 |
+
overrode_max_train_steps = False
|
642 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
643 |
+
if args.max_train_steps is None:
|
644 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
645 |
+
overrode_max_train_steps = True
|
646 |
+
|
647 |
+
lr_scheduler = get_scheduler(
|
648 |
+
args.lr_scheduler,
|
649 |
+
optimizer=optimizer,
|
650 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
651 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
652 |
+
)
|
653 |
+
|
654 |
+
if args.train_text_encoder:
|
655 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
656 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
657 |
+
)
|
658 |
+
else:
|
659 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
660 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
661 |
+
)
|
662 |
+
|
663 |
+
weight_dtype = torch.float32
|
664 |
+
if args.mixed_precision == "fp16":
|
665 |
+
weight_dtype = torch.float16
|
666 |
+
elif args.mixed_precision == "bf16":
|
667 |
+
weight_dtype = torch.bfloat16
|
668 |
+
|
669 |
+
# Move text_encode and vae to gpu.
|
670 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
671 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
672 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
673 |
+
if not args.train_text_encoder:
|
674 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
675 |
+
|
676 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
677 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
678 |
+
if overrode_max_train_steps:
|
679 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
680 |
+
# Afterwards we recalculate our number of training epochs
|
681 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
682 |
+
|
683 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
684 |
+
# The trackers initializes automatically on the main process.
|
685 |
+
if accelerator.is_main_process:
|
686 |
+
accelerator.init_trackers("dreambooth", config=vars(args))
|
687 |
+
|
688 |
+
# Train!
|
689 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
690 |
+
|
691 |
+
logger.info("***** Running training *****")
|
692 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
693 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
694 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
695 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
696 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
697 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
698 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
699 |
+
# Only show the progress bar once on each machine.
|
700 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
701 |
+
progress_bar.set_description("Steps")
|
702 |
+
global_step = 0
|
703 |
+
|
704 |
+
for epoch in range(args.num_train_epochs):
|
705 |
+
unet.train()
|
706 |
+
if args.train_text_encoder:
|
707 |
+
text_encoder.train()
|
708 |
+
for step, batch in enumerate(train_dataloader):
|
709 |
+
with accelerator.accumulate(unet):
|
710 |
+
# Convert images to latent space
|
711 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
712 |
+
latents = latents * 0.18215
|
713 |
+
|
714 |
+
# Sample noise that we'll add to the latents
|
715 |
+
noise = torch.randn_like(latents)
|
716 |
+
bsz = latents.shape[0]
|
717 |
+
# Sample a random timestep for each image
|
718 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
719 |
+
timesteps = timesteps.long()
|
720 |
+
|
721 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
722 |
+
# (this is the forward diffusion process)
|
723 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
724 |
+
|
725 |
+
# Get the text embedding for conditioning
|
726 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
727 |
+
|
728 |
+
# Predict the noise residual
|
729 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
730 |
+
|
731 |
+
if args.with_prior_preservation:
|
732 |
+
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
|
733 |
+
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
|
734 |
+
noise, noise_prior = torch.chunk(noise, 2, dim=0)
|
735 |
+
|
736 |
+
# Compute instance loss
|
737 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
|
738 |
+
|
739 |
+
# Compute prior loss
|
740 |
+
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
741 |
+
|
742 |
+
# Add the prior loss to the instance loss.
|
743 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
744 |
+
else:
|
745 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
746 |
+
|
747 |
+
accelerator.backward(loss)
|
748 |
+
if accelerator.sync_gradients:
|
749 |
+
params_to_clip = (
|
750 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
751 |
+
if args.train_text_encoder
|
752 |
+
else unet.parameters()
|
753 |
+
)
|
754 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
755 |
+
optimizer.step()
|
756 |
+
lr_scheduler.step()
|
757 |
+
optimizer.zero_grad()
|
758 |
+
|
759 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
760 |
+
if accelerator.sync_gradients:
|
761 |
+
progress_bar.update(1)
|
762 |
+
global_step += 1
|
763 |
+
|
764 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
765 |
+
progress_bar.set_postfix(**logs)
|
766 |
+
accelerator.log(logs, step=global_step)
|
767 |
+
|
768 |
+
if global_step >= args.max_train_steps:
|
769 |
+
break
|
770 |
+
|
771 |
+
|
772 |
+
if args.save_model_every_n_steps != None and (global_step % args.save_model_every_n_steps) == 0:
|
773 |
+
save_model(accelerator, unet, text_encoder, args, global_step)
|
774 |
+
|
775 |
+
accelerator.wait_for_everyone()
|
776 |
+
|
777 |
+
save_model(accelerator, unet, text_encoder, args, step=None)
|
778 |
+
|
779 |
+
accelerator.end_training()
|
780 |
+
|
781 |
+
|
782 |
+
if __name__ == "__main__":
|
783 |
+
args = parse_args()
|
784 |
+
main(args)
|
dreambooth-for-diffusion/tools/train_textual_inversion.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
# import torch
|
11 |
+
import oneflow as torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.utils.checkpoint
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
import PIL
|
17 |
+
from accelerate import Accelerator
|
18 |
+
from accelerate.logging import get_logger
|
19 |
+
from accelerate.utils import set_seed
|
20 |
+
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
21 |
+
from diffusers.optimization import get_scheduler
|
22 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
23 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
24 |
+
from PIL import Image
|
25 |
+
from torchvision import transforms
|
26 |
+
from tqdm.auto import tqdm
|
27 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
28 |
+
|
29 |
+
|
30 |
+
logger = get_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
|
34 |
+
logger.info("Saving embeddings")
|
35 |
+
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
36 |
+
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
37 |
+
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
|
38 |
+
|
39 |
+
|
40 |
+
def parse_args():
|
41 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
42 |
+
parser.add_argument(
|
43 |
+
"--save_steps",
|
44 |
+
type=int,
|
45 |
+
default=500,
|
46 |
+
help="Save learned_embeds.bin every X updates steps.",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--pretrained_model_name_or_path",
|
50 |
+
type=str,
|
51 |
+
default=None,
|
52 |
+
required=True,
|
53 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--tokenizer_name",
|
57 |
+
type=str,
|
58 |
+
default=None,
|
59 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--placeholder_token",
|
66 |
+
type=str,
|
67 |
+
default=None,
|
68 |
+
required=True,
|
69 |
+
help="A token to use as a placeholder for the concept.",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
|
73 |
+
)
|
74 |
+
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
75 |
+
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
76 |
+
parser.add_argument(
|
77 |
+
"--output_dir",
|
78 |
+
type=str,
|
79 |
+
default="text-inversion-model",
|
80 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
81 |
+
)
|
82 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
83 |
+
parser.add_argument(
|
84 |
+
"--resolution",
|
85 |
+
type=int,
|
86 |
+
default=512,
|
87 |
+
help=(
|
88 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
89 |
+
" resolution"
|
90 |
+
),
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
97 |
+
)
|
98 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
99 |
+
parser.add_argument(
|
100 |
+
"--max_train_steps",
|
101 |
+
type=int,
|
102 |
+
default=5000,
|
103 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--gradient_accumulation_steps",
|
107 |
+
type=int,
|
108 |
+
default=1,
|
109 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--learning_rate",
|
113 |
+
type=float,
|
114 |
+
default=1e-4,
|
115 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--scale_lr",
|
119 |
+
action="store_true",
|
120 |
+
default=True,
|
121 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--lr_scheduler",
|
125 |
+
type=str,
|
126 |
+
default="constant",
|
127 |
+
help=(
|
128 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
129 |
+
' "constant", "constant_with_warmup"]'
|
130 |
+
),
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
134 |
+
)
|
135 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
136 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
137 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
138 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
139 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
140 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
141 |
+
parser.add_argument(
|
142 |
+
"--hub_model_id",
|
143 |
+
type=str,
|
144 |
+
default=None,
|
145 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--logging_dir",
|
149 |
+
type=str,
|
150 |
+
default="logs",
|
151 |
+
help=(
|
152 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
153 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
154 |
+
),
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--mixed_precision",
|
158 |
+
type=str,
|
159 |
+
default="no",
|
160 |
+
choices=["no", "fp16", "bf16"],
|
161 |
+
help=(
|
162 |
+
"Whether to use mixed precision. Choose"
|
163 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
164 |
+
"and an Nvidia Ampere GPU."
|
165 |
+
),
|
166 |
+
)
|
167 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
168 |
+
|
169 |
+
args = parser.parse_args()
|
170 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
171 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
172 |
+
args.local_rank = env_local_rank
|
173 |
+
|
174 |
+
if args.train_data_dir is None:
|
175 |
+
raise ValueError("You must specify a train data directory.")
|
176 |
+
|
177 |
+
return args
|
178 |
+
|
179 |
+
|
180 |
+
imagenet_templates_small = [
|
181 |
+
"a photo of a {}",
|
182 |
+
"a rendering of a {}",
|
183 |
+
"a cropped photo of the {}",
|
184 |
+
"the photo of a {}",
|
185 |
+
"a photo of a clean {}",
|
186 |
+
"a photo of a dirty {}",
|
187 |
+
"a dark photo of the {}",
|
188 |
+
"a photo of my {}",
|
189 |
+
"a photo of the cool {}",
|
190 |
+
"a close-up photo of a {}",
|
191 |
+
"a bright photo of the {}",
|
192 |
+
"a cropped photo of a {}",
|
193 |
+
"a photo of the {}",
|
194 |
+
"a good photo of the {}",
|
195 |
+
"a photo of one {}",
|
196 |
+
"a close-up photo of the {}",
|
197 |
+
"a rendition of the {}",
|
198 |
+
"a photo of the clean {}",
|
199 |
+
"a rendition of a {}",
|
200 |
+
"a photo of a nice {}",
|
201 |
+
"a good photo of a {}",
|
202 |
+
"a photo of the nice {}",
|
203 |
+
"a photo of the small {}",
|
204 |
+
"a photo of the weird {}",
|
205 |
+
"a photo of the large {}",
|
206 |
+
"a photo of a cool {}",
|
207 |
+
"a photo of a small {}",
|
208 |
+
]
|
209 |
+
|
210 |
+
imagenet_style_templates_small = [
|
211 |
+
"a painting in the style of {}",
|
212 |
+
"a rendering in the style of {}",
|
213 |
+
"a cropped painting in the style of {}",
|
214 |
+
"the painting in the style of {}",
|
215 |
+
"a clean painting in the style of {}",
|
216 |
+
"a dirty painting in the style of {}",
|
217 |
+
"a dark painting in the style of {}",
|
218 |
+
"a picture in the style of {}",
|
219 |
+
"a cool painting in the style of {}",
|
220 |
+
"a close-up painting in the style of {}",
|
221 |
+
"a bright painting in the style of {}",
|
222 |
+
"a cropped painting in the style of {}",
|
223 |
+
"a good painting in the style of {}",
|
224 |
+
"a close-up painting in the style of {}",
|
225 |
+
"a rendition in the style of {}",
|
226 |
+
"a nice painting in the style of {}",
|
227 |
+
"a small painting in the style of {}",
|
228 |
+
"a weird painting in the style of {}",
|
229 |
+
"a large painting in the style of {}",
|
230 |
+
]
|
231 |
+
|
232 |
+
|
233 |
+
class TextualInversionDataset(Dataset):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
data_root,
|
237 |
+
tokenizer,
|
238 |
+
learnable_property="object", # [object, style]
|
239 |
+
size=512,
|
240 |
+
repeats=100,
|
241 |
+
interpolation="bicubic",
|
242 |
+
flip_p=0.5,
|
243 |
+
set="train",
|
244 |
+
placeholder_token="*",
|
245 |
+
center_crop=False,
|
246 |
+
):
|
247 |
+
self.data_root = data_root
|
248 |
+
self.tokenizer = tokenizer
|
249 |
+
self.learnable_property = learnable_property
|
250 |
+
self.size = size
|
251 |
+
self.placeholder_token = placeholder_token
|
252 |
+
self.center_crop = center_crop
|
253 |
+
self.flip_p = flip_p
|
254 |
+
|
255 |
+
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
256 |
+
|
257 |
+
self.num_images = len(self.image_paths)
|
258 |
+
self._length = self.num_images
|
259 |
+
|
260 |
+
if set == "train":
|
261 |
+
self._length = self.num_images * repeats
|
262 |
+
|
263 |
+
self.interpolation = {
|
264 |
+
"linear": PIL.Image.LINEAR,
|
265 |
+
"bilinear": PIL.Image.BILINEAR,
|
266 |
+
"bicubic": PIL.Image.BICUBIC,
|
267 |
+
"lanczos": PIL.Image.LANCZOS,
|
268 |
+
}[interpolation]
|
269 |
+
|
270 |
+
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
271 |
+
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
272 |
+
|
273 |
+
def __len__(self):
|
274 |
+
return self._length
|
275 |
+
|
276 |
+
def __getitem__(self, i):
|
277 |
+
example = {}
|
278 |
+
image = Image.open(self.image_paths[i % self.num_images])
|
279 |
+
|
280 |
+
if not image.mode == "RGB":
|
281 |
+
image = image.convert("RGB")
|
282 |
+
|
283 |
+
placeholder_string = self.placeholder_token
|
284 |
+
text = random.choice(self.templates).format(placeholder_string)
|
285 |
+
|
286 |
+
example["input_ids"] = self.tokenizer(
|
287 |
+
text,
|
288 |
+
padding="max_length",
|
289 |
+
truncation=True,
|
290 |
+
max_length=self.tokenizer.model_max_length,
|
291 |
+
return_tensors="pt",
|
292 |
+
).input_ids[0]
|
293 |
+
|
294 |
+
# default to score-sde preprocessing
|
295 |
+
img = np.array(image).astype(np.uint8)
|
296 |
+
|
297 |
+
if self.center_crop:
|
298 |
+
crop = min(img.shape[0], img.shape[1])
|
299 |
+
h, w, = (
|
300 |
+
img.shape[0],
|
301 |
+
img.shape[1],
|
302 |
+
)
|
303 |
+
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
304 |
+
|
305 |
+
image = Image.fromarray(img)
|
306 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
307 |
+
|
308 |
+
image = self.flip_transform(image)
|
309 |
+
image = np.array(image).astype(np.uint8)
|
310 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
311 |
+
|
312 |
+
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
313 |
+
return example
|
314 |
+
|
315 |
+
|
316 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
317 |
+
if token is None:
|
318 |
+
token = HfFolder.get_token()
|
319 |
+
if organization is None:
|
320 |
+
username = whoami(token)["name"]
|
321 |
+
return f"{username}/{model_id}"
|
322 |
+
else:
|
323 |
+
return f"{organization}/{model_id}"
|
324 |
+
|
325 |
+
|
326 |
+
def freeze_params(params):
|
327 |
+
for param in params:
|
328 |
+
param.requires_grad = False
|
329 |
+
|
330 |
+
|
331 |
+
def main():
|
332 |
+
args = parse_args()
|
333 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
334 |
+
|
335 |
+
accelerator = Accelerator(
|
336 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
337 |
+
mixed_precision=args.mixed_precision,
|
338 |
+
log_with="tensorboard",
|
339 |
+
logging_dir=logging_dir,
|
340 |
+
)
|
341 |
+
|
342 |
+
# If passed along, set the training seed now.
|
343 |
+
if args.seed is not None:
|
344 |
+
set_seed(args.seed)
|
345 |
+
|
346 |
+
# Handle the repository creation
|
347 |
+
if accelerator.is_main_process:
|
348 |
+
if args.push_to_hub:
|
349 |
+
if args.hub_model_id is None:
|
350 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
351 |
+
else:
|
352 |
+
repo_name = args.hub_model_id
|
353 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
354 |
+
|
355 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
356 |
+
if "step_*" not in gitignore:
|
357 |
+
gitignore.write("step_*\n")
|
358 |
+
if "epoch_*" not in gitignore:
|
359 |
+
gitignore.write("epoch_*\n")
|
360 |
+
elif args.output_dir is not None:
|
361 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
362 |
+
|
363 |
+
# Load the tokenizer and add the placeholder token as a additional special token
|
364 |
+
if args.tokenizer_name:
|
365 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
366 |
+
elif args.pretrained_model_name_or_path:
|
367 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
368 |
+
|
369 |
+
# Add the placeholder token in tokenizer
|
370 |
+
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
371 |
+
if num_added_tokens == 0:
|
372 |
+
raise ValueError(
|
373 |
+
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
374 |
+
" `placeholder_token` that is not already in the tokenizer."
|
375 |
+
)
|
376 |
+
|
377 |
+
# Convert the initializer_token, placeholder_token to ids
|
378 |
+
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
379 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
380 |
+
if len(token_ids) > 1:
|
381 |
+
raise ValueError("The initializer token must be a single token.")
|
382 |
+
|
383 |
+
initializer_token_id = token_ids[0]
|
384 |
+
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
385 |
+
|
386 |
+
# Load models and create wrapper for stable diffusion
|
387 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
388 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
389 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
390 |
+
|
391 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
392 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
393 |
+
|
394 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
395 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
396 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
397 |
+
|
398 |
+
# Freeze vae and unet
|
399 |
+
freeze_params(vae.parameters())
|
400 |
+
freeze_params(unet.parameters())
|
401 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
402 |
+
params_to_freeze = itertools.chain(
|
403 |
+
text_encoder.text_model.encoder.parameters(),
|
404 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
405 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
406 |
+
)
|
407 |
+
freeze_params(params_to_freeze)
|
408 |
+
|
409 |
+
if args.scale_lr:
|
410 |
+
args.learning_rate = (
|
411 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
412 |
+
)
|
413 |
+
|
414 |
+
# Initialize the optimizer
|
415 |
+
optimizer = torch.optim.AdamW(
|
416 |
+
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
417 |
+
lr=args.learning_rate,
|
418 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
419 |
+
weight_decay=args.adam_weight_decay,
|
420 |
+
eps=args.adam_epsilon,
|
421 |
+
)
|
422 |
+
|
423 |
+
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
424 |
+
|
425 |
+
train_dataset = TextualInversionDataset(
|
426 |
+
data_root=args.train_data_dir,
|
427 |
+
tokenizer=tokenizer,
|
428 |
+
size=args.resolution,
|
429 |
+
placeholder_token=args.placeholder_token,
|
430 |
+
repeats=args.repeats,
|
431 |
+
learnable_property=args.learnable_property,
|
432 |
+
center_crop=args.center_crop,
|
433 |
+
set="train",
|
434 |
+
)
|
435 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
436 |
+
|
437 |
+
# Scheduler and math around the number of training steps.
|
438 |
+
overrode_max_train_steps = False
|
439 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
440 |
+
if args.max_train_steps is None:
|
441 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
442 |
+
overrode_max_train_steps = True
|
443 |
+
|
444 |
+
lr_scheduler = get_scheduler(
|
445 |
+
args.lr_scheduler,
|
446 |
+
optimizer=optimizer,
|
447 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
448 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
449 |
+
)
|
450 |
+
|
451 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
452 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler
|
453 |
+
)
|
454 |
+
|
455 |
+
# Move vae and unet to device
|
456 |
+
vae.to(accelerator.device)
|
457 |
+
unet.to(accelerator.device)
|
458 |
+
|
459 |
+
# Keep vae and unet in eval model as we don't train these
|
460 |
+
vae.eval()
|
461 |
+
unet.eval()
|
462 |
+
|
463 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
464 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
465 |
+
if overrode_max_train_steps:
|
466 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
467 |
+
# Afterwards we recalculate our number of training epochs
|
468 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
469 |
+
|
470 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
471 |
+
# The trackers initializes automatically on the main process.
|
472 |
+
if accelerator.is_main_process:
|
473 |
+
accelerator.init_trackers("textual_inversion", config=vars(args))
|
474 |
+
|
475 |
+
# Train!
|
476 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
477 |
+
|
478 |
+
logger.info("***** Running training *****")
|
479 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
480 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
481 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
482 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
483 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
484 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
485 |
+
# Only show the progress bar once on each machine.
|
486 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
487 |
+
progress_bar.set_description("Steps")
|
488 |
+
global_step = 0
|
489 |
+
|
490 |
+
for epoch in range(args.num_train_epochs):
|
491 |
+
text_encoder.train()
|
492 |
+
for step, batch in enumerate(train_dataloader):
|
493 |
+
with accelerator.accumulate(text_encoder):
|
494 |
+
# Convert images to latent space
|
495 |
+
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
496 |
+
latents = latents * 0.18215
|
497 |
+
|
498 |
+
# Sample noise that we'll add to the latents
|
499 |
+
noise = torch.randn(latents.shape).to(latents.device)
|
500 |
+
bsz = latents.shape[0]
|
501 |
+
# Sample a random timestep for each image
|
502 |
+
timesteps = torch.randint(
|
503 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
504 |
+
).long()
|
505 |
+
|
506 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
507 |
+
# (this is the forward diffusion process)
|
508 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
509 |
+
|
510 |
+
# Get the text embedding for conditioning
|
511 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
512 |
+
|
513 |
+
# Predict the noise residual
|
514 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
515 |
+
|
516 |
+
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
517 |
+
accelerator.backward(loss)
|
518 |
+
|
519 |
+
# Zero out the gradients for all token embeddings except the newly added
|
520 |
+
# embeddings for the concept, as we only want to optimize the concept embeddings
|
521 |
+
# if accelerator.num_processes > 1:
|
522 |
+
# grads = text_encoder.module.get_input_embeddings().weight.grad
|
523 |
+
# else:
|
524 |
+
# grads = text_encoder.get_input_embeddings().weight.grad
|
525 |
+
grads = text_encoder.module.get_input_embeddings().weight.grad
|
526 |
+
# Get the index for tokens that we want to zero the grads for
|
527 |
+
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
528 |
+
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
|
529 |
+
|
530 |
+
optimizer.step()
|
531 |
+
lr_scheduler.step()
|
532 |
+
optimizer.zero_grad()
|
533 |
+
|
534 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
535 |
+
if accelerator.sync_gradients:
|
536 |
+
progress_bar.update(1)
|
537 |
+
global_step += 1
|
538 |
+
if global_step % args.save_steps == 0:
|
539 |
+
save_progress(text_encoder, placeholder_token_id, accelerator, args)
|
540 |
+
|
541 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
542 |
+
progress_bar.set_postfix(**logs)
|
543 |
+
accelerator.log(logs, step=global_step)
|
544 |
+
|
545 |
+
if global_step >= args.max_train_steps:
|
546 |
+
break
|
547 |
+
|
548 |
+
accelerator.wait_for_everyone()
|
549 |
+
|
550 |
+
# Create the pipeline using using the trained modules and save it.
|
551 |
+
if accelerator.is_main_process:
|
552 |
+
pipeline = StableDiffusionPipeline(
|
553 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
554 |
+
vae=vae,
|
555 |
+
unet=unet,
|
556 |
+
tokenizer=tokenizer,
|
557 |
+
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
|
558 |
+
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
559 |
+
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
560 |
+
)
|
561 |
+
pipeline.save_pretrained(args.output_dir)
|
562 |
+
# Also save the newly trained embeddings
|
563 |
+
save_progress(text_encoder, placeholder_token_id, accelerator, args)
|
564 |
+
|
565 |
+
if args.push_to_hub:
|
566 |
+
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
567 |
+
|
568 |
+
accelerator.end_training()
|
569 |
+
|
570 |
+
|
571 |
+
if __name__ == "__main__":
|
572 |
+
main()
|
dreambooth-for-diffusion/tools/upload_cos.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: UTF-8 -*-
|
2 |
+
# by ruochen
|
3 |
+
# 需要先执行 pip install -U cos-python-sdk-v5
|
4 |
+
from qcloud_cos import CosConfig
|
5 |
+
from qcloud_cos import CosS3Client
|
6 |
+
|
7 |
+
secret_id = 'abc123' # 替换为用户的 secretId
|
8 |
+
secret_key = 'abc123' # 替换为用户的 secretKey
|
9 |
+
region = 'ap-guangzhou' # 替换为用户的 Region
|
10 |
+
|
11 |
+
config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
|
12 |
+
client = CosS3Client(config)
|
13 |
+
|
14 |
+
response = client.upload_file(
|
15 |
+
Bucket='xxx', # 替换为存储桶名称
|
16 |
+
LocalFilePath='../ckpt_models/newModel.ckpt', # 本地文件的路径
|
17 |
+
Key='newModel.ckpt', # 上传之后的文件名
|
18 |
+
)
|
19 |
+
print(response['ETag'])
|
dreambooth-for-diffusion/train_object.sh
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 用于训练特定物体/人物的方法(只需单一标签)
|
2 |
+
export MODEL_NAME="./model"
|
3 |
+
export INSTANCE_DIR="./datasets/test2"
|
4 |
+
export OUTPUT_DIR="./new_model"
|
5 |
+
export CLASS_DIR="./datasets/class" # 用于存放模型生成的先验知识的图片文件夹,请勿改动
|
6 |
+
export LOG_DIR="/root/tf-logs"
|
7 |
+
export TEST_PROMPTS_FILE="./test_prompts_object.txt"
|
8 |
+
|
9 |
+
rm -rf $CLASS_DIR/* # 如果你要训练与上次不同的特定物体/人物,需要先清空该文件夹。其他时候可以注释掉这一行(前面加#)
|
10 |
+
rm -rf $LOG_DIR/*
|
11 |
+
|
12 |
+
accelerate launch tools/train_dreambooth.py \
|
13 |
+
--train_text_encoder \
|
14 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
15 |
+
--mixed_precision="fp16" \
|
16 |
+
--instance_data_dir=$INSTANCE_DIR \
|
17 |
+
--instance_prompt="a photo of <xxx> dog" \
|
18 |
+
--with_prior_preservation --prior_loss_weight=1.0 \
|
19 |
+
--class_prompt="a photo of dog" \
|
20 |
+
--class_data_dir=$CLASS_DIR \
|
21 |
+
--num_class_images=200 \
|
22 |
+
--output_dir=$OUTPUT_DIR \
|
23 |
+
--logging_dir=$LOG_DIR \
|
24 |
+
--center_crop \
|
25 |
+
--resolution=512 \
|
26 |
+
--train_batch_size=1 \
|
27 |
+
--gradient_accumulation_steps=1 --gradient_checkpointing \
|
28 |
+
--use_8bit_adam \
|
29 |
+
--learning_rate=2e-6 \
|
30 |
+
--lr_scheduler="constant" \
|
31 |
+
--lr_warmup_steps=0 \
|
32 |
+
--auto_test_model \
|
33 |
+
--test_prompts_file=$TEST_PROMPTS_FILE \
|
34 |
+
--test_seed=123 \
|
35 |
+
--test_num_per_prompt=3 \
|
36 |
+
--max_train_steps=1000 \
|
37 |
+
--save_model_every_n_steps=500
|
38 |
+
|
39 |
+
# 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大
|
40 |
+
# 不然磁盘很容易中间就满了
|
41 |
+
|
42 |
+
# 以下是核心参数介绍:
|
43 |
+
# 主要的几个
|
44 |
+
# --train_text_encoder 训练文本编码器
|
45 |
+
# --mixed_precision="fp16" 混合精度训练
|
46 |
+
# - center_crop
|
47 |
+
# 是否裁剪图片,一般如果你的数据集不是正方形的话,需要裁剪
|
48 |
+
# - resolution
|
49 |
+
# 图片的分辨率,一般是512,使用该参数会自动缩放输入图像
|
50 |
+
# 可以配合center_crop使用,达到裁剪成正方形并缩放到512*512的效果
|
51 |
+
# - instance_prompt
|
52 |
+
# 如果你希望训练的是特定的人物,使用该参数
|
53 |
+
# 如 --instance_prompt="a photo of <xxx> girl"
|
54 |
+
# - class_prompt
|
55 |
+
# 如果你希望训练的是某个特定的类别,使用该参数可能提升一定的训练效果
|
56 |
+
# - use_txt_as_label
|
57 |
+
# 是否读取与图片同名的txt文件作为label
|
58 |
+
# 如果你要训练的是整个大模型的图像风格,那么可以使用该参数
|
59 |
+
# 该选项会忽略instance_prompt参数传入的内容
|
60 |
+
# - learning_rate
|
61 |
+
# 学习率,一般是2e-6,是训练中需要调整的关键参数
|
62 |
+
# 太大会导致模型不收敛,太小的话,训练速度会变慢
|
63 |
+
# - lr_scheduler, 可选项有constant, linear, cosine, cosine_with_restarts, cosine_with_hard_restarts
|
64 |
+
# 学习率调整策略,一般是constant,即不调整,如果你的数据集很大,可以尝试其他的,但是可能会导致模型不收敛,需要调整学习率
|
65 |
+
# - lr_warmup_steps,如果你使用的是constant,那么这个参数可以忽略,
|
66 |
+
# 如果使用其他的,那么这个参数可以设置为0,即不使用warmup
|
67 |
+
# 也可以设置为其他的值,比如1000,即在前1000个step中,学习率从0慢慢增加到learning_rate的值
|
68 |
+
# 一般不需要设置, 除非你的数据集很大,训练收敛很慢
|
69 |
+
# - max_train_steps
|
70 |
+
# 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值
|
71 |
+
# - save_model_every_n_steps
|
72 |
+
# 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训
|
73 |
+
|
74 |
+
# --with_prior_preservation,--prior_loss_weight=1.0,分别是使用先验知识保留和先验损失权重
|
75 |
+
# 如果你的数据样本比较少,那么可以使用这两个参数,可以提升训练效果,还可以防止过拟合(即生成的图片与训练的图片相似度过高)
|
76 |
+
|
77 |
+
# --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt
|
78 |
+
# 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数
|
79 |
+
# 测试的样本图片会保存在模型输出目录下的test文件夹中
|
dreambooth-for-diffusion/train_style.sh
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 主要用于训练风格、作画能力(需要每张图片都有对应的标签描述)
|
2 |
+
export MODEL_NAME="./model"
|
3 |
+
export INSTANCE_DIR="./datasets/test2"
|
4 |
+
export OUTPUT_DIR="./new_model"
|
5 |
+
export LOG_DIR="/root/tf-logs"
|
6 |
+
export TEST_PROMPTS_FILE="./test_prompts_style.txt"
|
7 |
+
|
8 |
+
rm -rf $LOG_DIR/*
|
9 |
+
|
10 |
+
accelerate launch tools/train_dreambooth.py \
|
11 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
12 |
+
--mixed_precision="fp16" \
|
13 |
+
--instance_data_dir=$INSTANCE_DIR \
|
14 |
+
--use_txt_as_label \
|
15 |
+
--output_dir=$OUTPUT_DIR \
|
16 |
+
--logging_dir=$LOG_DIR \
|
17 |
+
--center_crop \
|
18 |
+
--resolution=768 \
|
19 |
+
--train_batch_size=1 \
|
20 |
+
--use_8bit_adam \
|
21 |
+
--gradient_accumulation_steps=1 --gradient_checkpointing \
|
22 |
+
--learning_rate=2e-6 \
|
23 |
+
--lr_scheduler="constant" \
|
24 |
+
--lr_warmup_steps=0 \
|
25 |
+
--max_train_steps=1000 \
|
26 |
+
--save_model_every_n_steps=500 \
|
27 |
+
--auto_test_model \
|
28 |
+
--test_prompts_file=$TEST_PROMPTS_FILE \
|
29 |
+
--test_seed=123 \
|
30 |
+
--test_num_per_prompt=3
|
31 |
+
|
32 |
+
# 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大,不然磁盘容易中间就满了
|
33 |
+
|
34 |
+
# 以下是核心参数介绍:
|
35 |
+
# 主要的几个
|
36 |
+
# --train_text_encoder 训练文本编码器
|
37 |
+
# --mixed_precision="fp16" 混合精度训练
|
38 |
+
# - center_crop
|
39 |
+
# 是否裁剪图片,一般如果你的数据集不是正方形的话,需要裁剪
|
40 |
+
# - resolution
|
41 |
+
# 图片的分辨率,一般是512,使用该参数会自动缩放输入图像
|
42 |
+
# 可以配合center_crop使用,达到裁剪成正方形并缩放到512*512的效果
|
43 |
+
# - instance_prompt
|
44 |
+
# 如果你希望训练的是特定的人物,使用该参数
|
45 |
+
# 如 --instance_prompt="a photo of <xxx> girl"
|
46 |
+
# - use_txt_as_label
|
47 |
+
# 是否读取与图片同名的txt文件作为label
|
48 |
+
# 如果你要训练的是整个大模型的图像风格,那么可以使用该参数
|
49 |
+
# 该选项会忽略instance_prompt参数传入的内容
|
50 |
+
# - learning_rate
|
51 |
+
# 学习率,一般是2e-6,是训练中需要调整的关键参数
|
52 |
+
# 太大会导致模型不收敛,太小的话,训练速度会变慢
|
53 |
+
# - max_train_steps
|
54 |
+
# 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值
|
55 |
+
# - save_model_every_n_steps
|
56 |
+
# 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训
|
57 |
+
|
58 |
+
# --train_text_encoder # 除了图像生成器,也训练文本编码器
|
59 |
+
|
60 |
+
# --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt
|
61 |
+
# 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数
|
62 |
+
# 测试的样本图片会保存在模型输出目录下的test文件夹中
|
dreambooth-for-diffusion/train_textual_inversion.sh
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 这是另一种finetune模型的方法,名为textual inversion,效果一般,仅内置一份供参考。
|
2 |
+
# 提示:该方法训练出的概念编码只能在diffusers使用。暂时不支持在diffusers之外的推理框架使用。(如webui)
|
3 |
+
#!/sbin/bash
|
4 |
+
export LOG_DIR="/root/tf-logs"
|
5 |
+
|
6 |
+
accelerate launch ./tools/train_textual_inversion.py \
|
7 |
+
--pretrained_model_name_or_path="./model/" \
|
8 |
+
--train_data_dir="./datasets/test" \
|
9 |
+
--learnable_property="style" \
|
10 |
+
--placeholder_token="<xxx-girl>" --initializer_token="girl" \
|
11 |
+
--resolution=512 \
|
12 |
+
--train_batch_size=1 \
|
13 |
+
--gradient_accumulation_steps=4 \
|
14 |
+
--learning_rate=5.0e-04 --scale_lr \
|
15 |
+
--lr_scheduler="constant" \
|
16 |
+
--lr_warmup_steps=0 \
|
17 |
+
--save_steps=200 \
|
18 |
+
--max_train_steps=3000 \
|
19 |
+
--mixed_precision="fp16" \
|
20 |
+
--logging_dir=$LOG_DIR \
|
21 |
+
--output_dir="output_model"
|
22 |
+
|
23 |
+
# --learnable_property为style时训练特定风格,为object时训练特定物体/人物。
|
24 |
+
# --placeholder_token为训练时的占位符,--initializer_token为训练时的初始化词。
|
25 |
+
# --resolution为训练时的分辨率,--train_batch_size为训练时的batch size,--gradient_accumulation_steps为梯度累积步数。
|
26 |
+
# --learning_rate为训练时的学习率,--scale_lr为是否对学习率进行缩放,--lr_scheduler为学习率调度器,--lr_warmup_steps为学习率预热步数。
|
27 |
+
# --save_steps为保存模型的步数,--max_train_steps为最大训练步数,--mixed_precision为混合精度训练模式。
|
28 |
+
# --logging_dir为日志保存路径,--output_dir为模型保存路径。
|
29 |
+
# --pretrained_model_name_or_path为预训练模型路径,--train_data_dir为训练数据路径,必须为文件夹,文件夹内为处理后的图片。
|
dreambooth-for-diffusion/运行.ipynb
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "a0b34c19-4215-46f9-9def-65e73629665c",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Dreambooth Stable Diffusion 集成化环境训练\n",
|
9 |
+
"如果你是在autodl上的机器可以直接使用封装好的镜像创建实例,开箱即用 \n",
|
10 |
+
"如果是本地或者其他服务器上也可以使用,需要手动安装一些pip包\n",
|
11 |
+
"\n",
|
12 |
+
"## 注意\n",
|
13 |
+
"本项目仅供用于学习、测试人工智能技术使用 \n",
|
14 |
+
"请勿用于训练生成不良或侵权图片内容\n",
|
15 |
+
"\n",
|
16 |
+
"## 关于项目\n",
|
17 |
+
"在autodl封装的镜像名称为:dreambooth-for-diffusion \n",
|
18 |
+
"可在创建实例时直接选择公开的算法镜像使用。 \n",
|
19 |
+
"在autodl内蒙A区A5000的机器上封装,如遇到问题且无法自行解决的朋友请使用同一环境。 \n",
|
20 |
+
"白菜写教程时做了尽可能多的测试,但仍然无法确保每一个环节都完全覆盖 \n",
|
21 |
+
"如有小错误可尝试手动解决,或者访问git项目地址查看最新的README \n",
|
22 |
+
"项目地址:https://github.com/CrazyBoyM/dreambooth-for-diffusion\n",
|
23 |
+
"\n",
|
24 |
+
"## #强烈建议\n",
|
25 |
+
"1.用vscode的ssh功能远程连接到本服务器,训练体验更好,autodl自带的notebook也不错,有文件上传、下载功能。 \n",
|
26 |
+
"(vscode连接autodl教程:https://www.autodl.com/docs/vscode/ ) \n",
|
27 |
+
"### 2.(重要)把train文件夹整个移动到/root/autodl-tmp/路径下进行训练(数据盘),避免系统盘空间满\n",
|
28 |
+
"有的机器数据盘也很小,需要自行关注开合适的机器或进行扩容\n",
|
29 |
+
"\n",
|
30 |
+
"如果遇到问题可到b站主页找该教程对应训练演示的视频:https://space.bilibili.com/291593914\n",
|
31 |
+
"(因为现在写时视频还没做 \n",
|
32 |
+
"\n",
|
33 |
+
"## 服务器的数据迁移\n",
|
34 |
+
"经常关机后再开机发现机器资源被占用了,这时候你只能另外开一台机器了 \n",
|
35 |
+
"但是对于已经关机的机器在菜单上有个功能是“跨实例拷贝数据”, \n",
|
36 |
+
"可以很方便地同步/root/autodl-tmp文件夹下的内容到其他已开机的机器(所以推荐工作文件都放这) \n",
|
37 |
+
"(注意,只适用于同一区域的机器之间)\n",
|
38 |
+
"数据迁移教程:https://www.autodl.com/docs/migrate_instance/"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "markdown",
|
43 |
+
"id": "f091e609-bacc-469a-b6cf-bffe331a8944",
|
44 |
+
"metadata": {},
|
45 |
+
"source": [
|
46 |
+
"### 本文件为notebook在线运行版\n",
|
47 |
+
"具体详细的教程和参数说明请在根目录下教程.md 文件中查看。 \n",
|
48 |
+
"在notebook中执行linux命令,需要前面加个!(感叹号) \n",
|
49 |
+
"代码块前如果有个[*],表示正在运行该步骤,并不是卡住了\n"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "markdown",
|
54 |
+
"id": "3555d8bd-fb3f-4303-8915-ec6fefcc780c",
|
55 |
+
"metadata": {},
|
56 |
+
"source": [
|
57 |
+
"# 笔者前言\n",
|
58 |
+
"\n",
|
59 |
+
"linux压缩一个文件夹为单个文件包的命令:\n",
|
60 |
+
"```\n",
|
61 |
+
"!zip xx.zip -r ./xxx\n",
|
62 |
+
"```\n",
|
63 |
+
"解压一个包到文件夹:\n",
|
64 |
+
"```\n",
|
65 |
+
"!unzip xx.zip -d xxx\n",
|
66 |
+
"```\n",
|
67 |
+
"或许你在上传、下载数据集时会用到。\n",
|
68 |
+
"\n",
|
69 |
+
"其他linux基础命令:https://www.autodl.com/docs/linux/\n",
|
70 |
+
"\n",
|
71 |
+
"关于文件上传下载的提速可查看官网文档推荐的几种方式:https://www.autodl.com/docs/scp/"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "markdown",
|
76 |
+
"id": "34cf6ed1-f2b1-4abd-baf6-565ac00567ab",
|
77 |
+
"metadata": {},
|
78 |
+
"source": [
|
79 |
+
"### 首先,进入工作文件夹(记得先把dreambooth-for-diffusion文件夹移动到autodl-tmp目录下)"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"id": "a1249a32-ce15-4b1b-8068-8149ad40588b",
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"%cd /root/autodl-tmp/dreambooth-for-diffusion"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "markdown",
|
94 |
+
"id": "ccba0e31-f01d-43e5-b474-7d88e0b09bd8",
|
95 |
+
"metadata": {},
|
96 |
+
"source": [
|
97 |
+
"# 准备数据集\n",
|
98 |
+
"该部分请参考教程.md文件中的详细内容自行上传并处理你的数据集 \n",
|
99 |
+
"dreambooth-for-diffusion/datasets/test中为16张仅供于学习测试的样本数据,便于你了解以下代码的用处 \n"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "markdown",
|
104 |
+
"id": "470113f6-795a-41f8-a6b3-09f854a4cbc3",
|
105 |
+
"metadata": {},
|
106 |
+
"source": [
|
107 |
+
"## 一键裁剪\n",
|
108 |
+
"### 图像批量center crop并处理大小、格式和背景\n",
|
109 |
+
"./datasets/test是原始图片数据文件夹,请上传你的图片数据并进行更换 \n",
|
110 |
+
"width和height请设置为8的整倍数,并记得修改训练脚本中的参数 \n",
|
111 |
+
"(在显存低于20G的设备上请修改使用小于768的分辨率数据去训练,比如512) \n",
|
112 |
+
"如果是对透明底的png图处理成纯色底可以加--png参数,具体可以看对应的代码文件"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"id": "10d2bb3d-9002-4d3b-a4be-f5f74a008b9c",
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"!python tools/handle_images.py ./datasets/test ./datasets/test2 --width=768 --height=768"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "markdown",
|
127 |
+
"id": "34efda73-9cb4-4a54-8aac-489ded452a50",
|
128 |
+
"metadata": {},
|
129 |
+
"source": [
|
130 |
+
"## 一键打标签\n",
|
131 |
+
"### 图像批量自动标注\n",
|
132 |
+
"使用deepdanbooru生成tags标注文件。(仅针对纯二次元类图片效果较好,其他风格请手动标注) \n",
|
133 |
+
"./datasets/test2中是需要打标注的图片数据,请按需更换为自己的路径 "
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"id": "8863a53a-4650-4f27-863e-2a70e8b89e11",
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"# 该步根据需要标注文件数量不同,需要运行一段时间(测试6000张图片需要10分钟)\n",
|
144 |
+
"!python tools/label_images.py --path=./datasets/test2 "
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "markdown",
|
149 |
+
"id": "def72b19-9851-400f-8672-48023b3e95fb",
|
150 |
+
"metadata": {},
|
151 |
+
"source": [
|
152 |
+
"## 转换ckpt检查点文件为diffusers官方权重\n",
|
153 |
+
"输出的文件在dreambooth-for-diffusion/model下 \n",
|
154 |
+
"./ckpt_models/sd_1-5.ckpt需要更换为你自己的权重文件路径 "
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "markdown",
|
159 |
+
"id": "0582e3c4-e899-4a3b-a468-d49e7775efc6",
|
160 |
+
"metadata": {},
|
161 |
+
"source": [
|
162 |
+
"如需转换写实风格模型:"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": null,
|
168 |
+
"id": "05aaf7fd-315f-45b4-9b22-70a46a18424f",
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"# 该步需要运行大约一分钟 \n",
|
173 |
+
"!python tools/ckpt2diffusers.py \\\n",
|
174 |
+
" --checkpoint_path=./ckpt_models/sd_1-5.ckpt \\\n",
|
175 |
+
" --dump_path=./model \\\n",
|
176 |
+
" --original_config_file=./ckpt_models/model.yaml \\\n",
|
177 |
+
" --scheduler_type=\"ddim\""
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "markdown",
|
182 |
+
"id": "48c7893a-22db-4ea2-95dc-93fdbd6b5c4b",
|
183 |
+
"metadata": {},
|
184 |
+
"source": [
|
185 |
+
"如需转换二次元风格模型:"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": null,
|
191 |
+
"id": "f7afb70d-7af4-4bd1-804e-40927f1257e2",
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [],
|
194 |
+
"source": [
|
195 |
+
"# 该步需要运行大约一分钟 \n",
|
196 |
+
"!python tools/ckpt2diffusers.py \\\n",
|
197 |
+
" --checkpoint_path=./ckpt_models/nd_lastest.ckpt \\\n",
|
198 |
+
" --dump_path=./model \\\n",
|
199 |
+
" --vae_path=./ckpt_models/animevae.pt \\\n",
|
200 |
+
" --original_config_file=./ckpt_models/model.yaml \\\n",
|
201 |
+
" --scheduler_type=\"ddim\""
|
202 |
+
]
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "markdown",
|
206 |
+
"id": "a1edb9be-1de3-488e-baa3-8f3ab6b8f269",
|
207 |
+
"metadata": {},
|
208 |
+
"source": [
|
209 |
+
"对于需要转换某个特殊模型(7g)并遇到报错的同学,ckpt_models里的nd_lastest.ckpt就是你需要的文件。 \n",
|
210 |
+
"如果希望手动转换,我在./tools下放了一份ckpt_prune.py可以参考。"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"cell_type": "markdown",
|
215 |
+
"id": "3a3470d3-1691-438c-b8d7-df2cbf885614",
|
216 |
+
"metadata": {},
|
217 |
+
"source": [
|
218 |
+
"# 训练Unet和text encoder\n",
|
219 |
+
"以下训练脚本会自动帮你启动tensorboard日志监控进程,入口可参考: https://www.autodl.com/docs/tensorboard/ \n",
|
220 |
+
"使用tensorboard面板可以帮助分析loss在不同step的总体下降情况 \n",
|
221 |
+
"如果你嫌输出太长,可以在以下命令每一行后加一句 &> log.txt, 会把输出都扔到这个文件中 \n",
|
222 |
+
"```\n",
|
223 |
+
"!sh train_style.sh &> log.txt\n",
|
224 |
+
"```\n",
|
225 |
+
"本代码包环境已在A5000、3090测试通过,如果你在某些机器上运行遇到问题可以尝试卸载编译的xformers\n",
|
226 |
+
"```\n",
|
227 |
+
"!pip uninstall xformers\n",
|
228 |
+
"```"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "markdown",
|
233 |
+
"id": "98645b45-4cf1-49f8-b2bb-42a5a8771164",
|
234 |
+
"metadata": {},
|
235 |
+
"source": [
|
236 |
+
"### 如果需要训练特定人、事物: \n",
|
237 |
+
"(推荐准备3~5张风格统一、特定对象的图片) \n",
|
238 |
+
"请打开train_object.sh具体修改里面的参数"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"execution_count": null,
|
244 |
+
"id": "8b6833e3-8d3f-438a-b45d-0711e9724496",
|
245 |
+
"metadata": {},
|
246 |
+
"outputs": [],
|
247 |
+
"source": [
|
248 |
+
"# 大约十分钟后才会在tensorboard有日志(因为前十分钟在生成同类别伪图)\n",
|
249 |
+
"!sh train_object.sh "
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "markdown",
|
254 |
+
"id": "594a0352-8bb5-45de-bb19-0028b671569b",
|
255 |
+
"metadata": {},
|
256 |
+
"source": [
|
257 |
+
"### 如果要训练画风: \n",
|
258 |
+
"(推荐准备3000+张图片,包含尽可能的多样性,数据决定训练出的模型质量) \n",
|
259 |
+
"请打开train_object具体修改里面的参数 \n",
|
260 |
+
"实测速度1000步大概8分钟 "
|
261 |
+
]
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"cell_type": "code",
|
265 |
+
"execution_count": null,
|
266 |
+
"id": "442cff33-d264-4096-97e2-0c578229c814",
|
267 |
+
"metadata": {},
|
268 |
+
"outputs": [],
|
269 |
+
"source": [
|
270 |
+
"# 正常训练立刻就可以在tensorboard看到日志\n",
|
271 |
+
"!sh train_style.sh "
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "markdown",
|
276 |
+
"id": "3aa1d170-e2d1-4f72-8b0c-b6bfd5f0c318",
|
277 |
+
"metadata": {},
|
278 |
+
"source": [
|
279 |
+
"后台训练法请参考教程.md中的内容"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"cell_type": "markdown",
|
284 |
+
"id": "9aece8a8-c9ec-41eb-b6ad-c6c88b6203e1",
|
285 |
+
"metadata": {},
|
286 |
+
"source": [
|
287 |
+
"省钱训练法(训练成功后自动关机,适合步数很大且夜晚训练的场景)"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": null,
|
293 |
+
"id": "52fff58d-1a88-4a59-a961-b13b52812425",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"!sh back_train.sh"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "markdown",
|
302 |
+
"id": "17557280-3a5a-4bde-95c3-f20e1ccffa4d",
|
303 |
+
"metadata": {},
|
304 |
+
"source": [
|
305 |
+
"## 拓展:训练Textual inversion"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"execution_count": null,
|
311 |
+
"id": "36a543ee-56f8-405a-baaa-b784d96c7d40",
|
312 |
+
"metadata": {},
|
313 |
+
"outputs": [],
|
314 |
+
"source": [
|
315 |
+
"!sh train_textual_inversion.sh"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "markdown",
|
320 |
+
"id": "f467b2e9-9170-4f19-aea9-7ce0b4e5444e",
|
321 |
+
"metadata": {},
|
322 |
+
"source": [
|
323 |
+
"### 测试训练效果\n",
|
324 |
+
"打开dreambooth-for-diffusion/test_model.py文件修改其中的model_path和prompt,然后执行以下测试 \n",
|
325 |
+
"会生成一张图片 在左侧test-1、2、3.png"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "code",
|
330 |
+
"execution_count": null,
|
331 |
+
"id": "b462f33b-48e2-4092-b3de-463025e4ff9e",
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [],
|
334 |
+
"source": [
|
335 |
+
"# 大约5~10s \n",
|
336 |
+
"!python test_model.py"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "markdown",
|
341 |
+
"id": "47abb5fd-2f84-4344-a9cf-539b52515971",
|
342 |
+
"metadata": {},
|
343 |
+
"source": [
|
344 |
+
"### 转换diffusers官方权重为ckpt检查点文件\n",
|
345 |
+
"输出的文件在dreambooth-for-diffusion/ckpt_models/中,名为newModel.ckpt"
|
346 |
+
]
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"cell_type": "markdown",
|
350 |
+
"id": "5bfe9643-ef1d-42a3-a427-c4904f3a8631",
|
351 |
+
"metadata": {},
|
352 |
+
"source": [
|
353 |
+
"原始保存:"
|
354 |
+
]
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"cell_type": "code",
|
358 |
+
"execution_count": null,
|
359 |
+
"id": "2ad27225-10ed-4b3c-9978-bd909404949c",
|
360 |
+
"metadata": {},
|
361 |
+
"outputs": [],
|
362 |
+
"source": [
|
363 |
+
"!python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel.ckpt "
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"cell_type": "markdown",
|
368 |
+
"id": "b08a5e37-97d3-4c1e-9ba7-e331af23437f",
|
369 |
+
"metadata": {},
|
370 |
+
"source": [
|
371 |
+
"以下代码添加--half 保存float16半精度,权重文件大小会减半(约2g),效果基本一致"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "code",
|
376 |
+
"execution_count": null,
|
377 |
+
"id": "cba99145-6aab-41b6-a5b7-6e0c4fd96641",
|
378 |
+
"metadata": {},
|
379 |
+
"outputs": [],
|
380 |
+
"source": [
|
381 |
+
"!python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel_half.ckpt --half"
|
382 |
+
]
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"cell_type": "markdown",
|
386 |
+
"id": "d1f98d06-27f3-45b6-85df-c57cda5d6166",
|
387 |
+
"metadata": {},
|
388 |
+
"source": [
|
389 |
+
"下载ckpt文件,去玩吧~"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "markdown",
|
394 |
+
"id": "b13f0627-1d0a-4ae2-ab9c-90a605ee4a0e",
|
395 |
+
"metadata": {},
|
396 |
+
"source": [
|
397 |
+
"有问题可以进XDiffusion QQ Group:455521885 "
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "markdown",
|
402 |
+
"id": "b939a03f-23c9-410d-89be-02e154eeb6b4",
|
403 |
+
"metadata": {},
|
404 |
+
"source": [
|
405 |
+
"### 记得定期清理不需要的中间权重和文件,不然容易导致空间满\n",
|
406 |
+
"大部分问题已在教程.md中详细记录,也包含其他非autodl机器手动部署该训练一体化封装代码包的步骤"
|
407 |
+
]
|
408 |
+
},
|
409 |
+
{
|
410 |
+
"cell_type": "code",
|
411 |
+
"execution_count": null,
|
412 |
+
"id": "3236d62e-fa3d-4826-874e-431f208cfb6d",
|
413 |
+
"metadata": {},
|
414 |
+
"outputs": [],
|
415 |
+
"source": [
|
416 |
+
"# 清理文件的示例\n",
|
417 |
+
"!rm -rf ./model* # 删除当前目录model文件/文件夹\n",
|
418 |
+
"!rm -rf ./new_* # 删除当前目录所有new_开头的模型文件夹\n",
|
419 |
+
"# !rm -rf ./datasets/test2 #删除datasets中的test2数据集 "
|
420 |
+
]
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"cell_type": "code",
|
424 |
+
"execution_count": null,
|
425 |
+
"id": "224924ae-2d6d-47d0-aa36-0989a6572bd2",
|
426 |
+
"metadata": {},
|
427 |
+
"outputs": [],
|
428 |
+
"source": []
|
429 |
+
}
|
430 |
+
],
|
431 |
+
"metadata": {
|
432 |
+
"kernelspec": {
|
433 |
+
"display_name": "Python 3 (ipykernel)",
|
434 |
+
"language": "python",
|
435 |
+
"name": "python3"
|
436 |
+
},
|
437 |
+
"language_info": {
|
438 |
+
"codemirror_mode": {
|
439 |
+
"name": "ipython",
|
440 |
+
"version": 3
|
441 |
+
},
|
442 |
+
"file_extension": ".py",
|
443 |
+
"mimetype": "text/x-python",
|
444 |
+
"name": "python",
|
445 |
+
"nbconvert_exporter": "python",
|
446 |
+
"pygments_lexer": "ipython3",
|
447 |
+
"version": "3.8.10"
|
448 |
+
}
|
449 |
+
},
|
450 |
+
"nbformat": 4,
|
451 |
+
"nbformat_minor": 5
|
452 |
+
}
|