andzhang01 commited on
Commit
a662214
·
1 Parent(s): b7af310

Upload 27 files

Browse files
Files changed (27) hide show
  1. dreambooth-for-diffusion/.gitignore +17 -0
  2. dreambooth-for-diffusion/README.md +217 -0
  3. dreambooth-for-diffusion/back_train.sh +2 -0
  4. dreambooth-for-diffusion/ckpt_models/model.yaml +69 -0
  5. dreambooth-for-diffusion/ckpt_models/put_your_ckpt_models_here.txt +0 -0
  6. dreambooth-for-diffusion/datasets/put_datasets_here.txt +0 -0
  7. dreambooth-for-diffusion/other/something others.txt +0 -0
  8. dreambooth-for-diffusion/test_model.py +28 -0
  9. dreambooth-for-diffusion/test_prompts_object.txt +2 -0
  10. dreambooth-for-diffusion/test_prompts_style.txt +3 -0
  11. dreambooth-for-diffusion/tools/ckpt2diffusers.py +835 -0
  12. dreambooth-for-diffusion/tools/ckpt2diffusers_old.py +619 -0
  13. dreambooth-for-diffusion/tools/ckpt_merge.py +56 -0
  14. dreambooth-for-diffusion/tools/ckpt_prune.py +14 -0
  15. dreambooth-for-diffusion/tools/deepdanbooru-models/put_deepdanbooru_model_here.txt +0 -0
  16. dreambooth-for-diffusion/tools/diagnose_tensorboard.py +570 -0
  17. dreambooth-for-diffusion/tools/diffusers2ckpt.py +234 -0
  18. dreambooth-for-diffusion/tools/handle_images.py +82 -0
  19. dreambooth-for-diffusion/tools/label_images.py +152 -0
  20. dreambooth-for-diffusion/tools/test_cuda.py +2 -0
  21. dreambooth-for-diffusion/tools/train_dreambooth.py +784 -0
  22. dreambooth-for-diffusion/tools/train_textual_inversion.py +572 -0
  23. dreambooth-for-diffusion/tools/upload_cos.py +19 -0
  24. dreambooth-for-diffusion/train_object.sh +79 -0
  25. dreambooth-for-diffusion/train_style.sh +62 -0
  26. dreambooth-for-diffusion/train_textual_inversion.sh +29 -0
  27. dreambooth-for-diffusion/运行.ipynb +452 -0
dreambooth-for-diffusion/.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .ipynb_checkpoints
3
+ */.ipynb_checkpoints
4
+ *.ckpt
5
+ *.pt
6
+ *.whl
7
+ *.log
8
+ *.png
9
+ *.jpg
10
+ nohup.out
11
+ /datasets
12
+ /model
13
+ /new-*
14
+ /log
15
+ /output*
16
+ /tools/deepdanbooru-models/*
17
+ /tools/diffusers-models/*
dreambooth-for-diffusion/README.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dreambooth Stable Diffusion 集成化环境训练
2
+ 如果你是在autodl上的机器可以直接使用封装好的镜像创建实例,开箱即用
3
+ 如果是本地或者其他服务器上也可以使用,需要手动安装一些pip包
4
+
5
+ ## 如何运行
6
+ 直接在autodl使用镜像运行:https://www.codewithgpu.com/i/CrazyBoyM/dreambooth-for-diffusion/dreambooth-for-diffusion
7
+
8
+ 如果你不熟悉notebook代码的训练方式,也可以直接使用封装好的webui在线镜像(含稳定Dreambooth、dreamArtist训练插件,已fix):
9
+ https://www.codewithgpu.com/i/CrazyBoyM/sd_dreambooth_extension_webui/dreambooth-dreamartist-for-webui
10
+
11
+ ## 注意
12
+ 本项目仅供用于学习、测试人工智能技术使用
13
+ 请勿用于训练生成不良或侵权图片内容
14
+
15
+ ## 关于项目
16
+ 在autodl封装的镜像名称为:dreambooth-for-diffusion
17
+ 可在创建实例时直接选择公开的算法镜像使用。
18
+ 在autodl内蒙A区A5000的机器上封装,如遇到问题且无法自行解决的朋友请使用同一环境。
19
+ 白菜写教程时做了尽可能多的测试,但仍然无法确保每一个环节都完全覆盖
20
+ 如有小错误可尝试手动解决,或者访问git项目地址查看最新的README
21
+ 项目地址:https://github.com/CrazyBoyM/dreambooth-for-diffusion
22
+
23
+ 如果遇到问题可到b站主页找该教程对应训练演示的视频:https://space.bilibili.com/291593914
24
+ (因为现在写时视频还没做)
25
+
26
+ ## 强烈建议
27
+ 1.用vscode的ssh功能远程连接到本服务器,训练体验更好,autodl自带的notebook也不错,有文件上传、下载功能。
28
+ 2.(重要)先把/root/目录下dreambooth-for-diffusion文件夹整个移动到/root/autodl-tmp/路径下(数据盘),避免系统盘空间满
29
+
30
+ ## 进入工作文件夹
31
+ ```
32
+ cd /root/autodl-tmp/dreambooth-for-diffusion
33
+ ```
34
+
35
+ ## 转换ckpt检查点文件为diffusers官方权重
36
+ 已经内置了两个基础模型,可以根据自己数据集的特性选择。
37
+ - sd_1-5.ckpt是偏真实风格
38
+ - nd_lastest.ckpt是偏二次元风格
39
+ 开始转换二次元模型:
40
+ ```
41
+ # 该步需要运行大约一分钟
42
+ !python tools/ckpt2diffusers.py \
43
+ --checkpoint_path=./ckpt_models/nd_lastest.ckpt \
44
+ --dump_path=./model \
45
+ --vae_path=./ckpt_models/animevae.pt \
46
+ --original_config_file=./ckpt_models/model.yaml \
47
+ --scheduler_type="ddim"
48
+ ```
49
+ 转换写实风格模型:
50
+ ```
51
+ # 该步需要运行大约一分钟
52
+ !python tools/ckpt2diffusers.py \
53
+ --checkpoint_path=./ckpt_models/sd_1-5.ckpt \
54
+ --dump_path=./model \
55
+ --original_config_file=./ckpt_models/model.yaml \
56
+ --scheduler_type="ddim"
57
+ ```
58
+ 这里后面跟的两个文件分别是你的ckpt文件和转换后的输出路径。
59
+
60
+ ## 转换diffusers官方权重为ckpt检查点文件
61
+ ```
62
+ python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel_half.ckpt --half
63
+ ```
64
+ 如需保存为float16版精度,添加--half参数,权重大小会减半。
65
+
66
+ ## 准备数据集
67
+ 请按照训练任务准备好对应的数据集。
68
+ ### 图像裁剪为512*512
69
+ 我在tools/handle_images.py中提供了一份批量处理的代码用于参考
70
+ 自动center crop图像,并缩放尺寸
71
+ ```
72
+ python tools/handle_images.py ./datasets/test ./datasets/test2 --width=512 --height=512
73
+ ```
74
+ test为未处理的原始图像文件夹,test2为输出处理图像的路径
75
+ 如需处理透明背景png图为黑色/白色底jpg,可以添加--png参数。
76
+
77
+ ### 图像自动标注
78
+ 使用deepdanbooru生成tags label.
79
+ ```
80
+ !python tools/label_images.py --path=./datasets/test2
81
+ ```
82
+ 第二个参数--path为你需要标注的图像文件夹路径
83
+
84
+ 注:如提示deepdanbooru找不到,可自行参考以下仓库进行编译
85
+ https://github.com/KichangKim/DeepDanbooru
86
+
87
+ 我在other文件夹下也提供了一份编译好的版本:
88
+ ```
89
+ pip install other/deepdanbooru-1.0.0-py3-none-any.whl
90
+ ```
91
+
92
+ ## 训练以及常用命令总结
93
+ ### 配置训练环境(可选)
94
+ 如果你不是在封装好的镜像上直接使用,则需要做以下配置:
95
+ ```
96
+ pip install accelerate
97
+ ```
98
+ 运行以下命令,并选择本地运行、NO、NO
99
+ ```
100
+ accelerate config
101
+ ```
102
+
103
+ ### 开始训练
104
+ 请打开train.sh文件,参考其中的具体参数说明。
105
+ 如果需要训练特定人、事物:
106
+ (推荐准备3~5张风格统一、特定对象的图片)
107
+
108
+ ```
109
+ sh train_object.sh
110
+ ```
111
+
112
+ 如果要Finetune训练自己的大模型:
113
+ (推荐准备3000+张图片,包含尽可能的多样性,数据决定训练出的模型质量)
114
+ ```
115
+ sh train_style.sh
116
+ ```
117
+ A5000的训练速度大概8分钟/1000步
118
+
119
+ ### 测试训练效果
120
+ 打开train/test_model.py文件修改其中的model_path和prompt,然后执行:
121
+ ```
122
+ python test_model.py
123
+ ```
124
+
125
+ ### 其他常用命令
126
+ 如需后台任务训练:
127
+ ```
128
+ nohup sh train_style.sh &
129
+ ```
130
+ 推荐晚上这样挂后台跑着,不需要担心连接中断导致的训练停止。
131
+ 白菜个人推荐的省钱训练小妙招:
132
+ ```
133
+ nohup sh back_train.sh &
134
+ ```
135
+ (训��完直接自动关机)
136
+
137
+ 训练日志会输出到nohup.out文件中,可以vscode直接打开或下载查看。
138
+ 查看日志后十行:
139
+ ```
140
+ tail -n 10 nohup.out
141
+ ```
142
+
143
+ 查看当前磁盘占有率:
144
+ (记得清理不要的文件,不然经常容易磁盘几十个g空间满导致模型保存失败!!)
145
+ ```
146
+ df -h
147
+ ```
148
+
149
+ ## 如果你是在其他服务器上执行,没有使用集成环境
150
+ 提示缺少一些包可以自行安装:
151
+ ```
152
+ pip install diffusers
153
+ pip install ftfy
154
+ pip install tensorflow-gpu
155
+ pip install pytorch_lightning
156
+ pip install OmegaConf
157
+ ... 以及其他的一些
158
+ ```
159
+
160
+ ## 学术加速(可选)
161
+ 如果你需要拉取git上一些内容,发现速度很慢,以下内容或许对你有帮助。
162
+ 请根据机器所在区域执行以下命令:
163
+ ```
164
+ 北京A区的实例¶
165
+ export http_proxy=http://100.72.64.19:12798 && export https_proxy=http://100.72.64.19:12798
166
+
167
+ 内蒙A区的实例¶
168
+ export http_proxy=http://192.168.1.174:12798 && export https_proxy=http://192.168.1.174:12798
169
+
170
+ 泉州A区的实例¶
171
+ export http_proxy=http://10.55.146.88:12798 && export https_proxy=http://10.55.146.88:12798
172
+ ```
173
+
174
+ ## xformers(可选)
175
+ 由于A5000实测训练和推理的速度已经很快了,就没有安装。
176
+ 如果你使用的是其他显卡或者实在有需要,可以参考下面的地址进行编译使用:
177
+ https://github.com/facebookresearch/xformers
178
+ (我猜到你可能想要尝试,已经在train/other目录下放了一个提前编译好的版本啦)
179
+ 注:需要升级pytorch版本到1.12.x以上才能安装使用(好懒)(更新:我已经升级好并帮你装好啦~!)
180
+
181
+ ## 升级pytorch版本到1.12.x
182
+ ```
183
+ pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
184
+ ```
185
+
186
+ # 关于autodl的使用心得
187
+
188
+ ## 服务器的数据迁移
189
+ 经常关机后再开机发现机器资源被占用了,这时候你只能另外开一台机器了
190
+ 但是对于已经关机的机器在菜单上有个功能是“跨实例拷贝数据”,
191
+ 可以很方便地同步/root/autodl-tmp文件夹下的内容到其他已开机的机器(所以推荐工作文件都放这)
192
+ (注意,只适用于同一区域的机器之间)
193
+ 数据迁移教程:https://www.autodl.com/docs/migrate_instance/
194
+
195
+ ## 传输文件的方式
196
+ ### 方式一 使用VScode
197
+ 直接从vscode拖动上传、下载文件,速度慢,也最简单。
198
+
199
+ ### 方式二 使用autodl的用户网盘
200
+ 在autodl的网盘界面初始化一个同区域的网盘,然后重启一下服务器实例
201
+ 会发现多了一个文件夹/root/autodl-nas/, 你可以在网页界面进行权重和数据的上传
202
+ 训练完,把生成的权重文件移动到该路径下,就可以去网页上进行下载了。
203
+ (对应网页:https://www.autodl.com/console/netdisk)
204
+ 注意:初始化的网盘一定要和服务器处于同一区域.
205
+
206
+ ### 方式三 使用对象存储
207
+ 有条件的朋友也可以尝试使用cos或oss进行文件中转,速度更快。
208
+ 在train/tools文件夹中我也放置了一份上传到cos的代码供参考,请有经验的朋友自行使用。
209
+
210
+ autodl官网也有一些推荐的方式可以参考,https://www.autodl.com/docs/scp/
211
+
212
+ # 其他内容
213
+ 感谢diffusers、deepdanbooru等开源项目
214
+ 风格训练代码来自nbardy的PR进行修改
215
+ 打tags标签的部分代码来自crosstyan、Nyanko Lepsoni、AUTOMATC1111
216
+ 如果感兴趣欢迎加QQ群探讨交流,455521885
217
+ 封装整理by - 白菜
dreambooth-for-diffusion/back_train.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # 省钱训练:训练正常完成后关机
2
+ sh train_style.sh && shutdown
dreambooth-for-diffusion/ckpt_models/model.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+
19
+ scheduler_config: # 10000 warmup steps
20
+ target: ldm.lr_scheduler.LambdaLinearScheduler
21
+ params:
22
+ warm_up_steps: [ 10000 ]
23
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
24
+ f_start: [ 1.e-6 ]
25
+ f_max: [ 1. ]
26
+ f_min: [ 1. ]
27
+
28
+ unet_config:
29
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
30
+ params:
31
+ image_size: 32 # unused
32
+ in_channels: 4
33
+ out_channels: 4
34
+ model_channels: 320
35
+ attention_resolutions: [ 4, 2, 1 ]
36
+ num_res_blocks: 2
37
+ channel_mult: [ 1, 2, 4, 4 ]
38
+ num_heads: 8
39
+ use_spatial_transformer: True
40
+ transformer_depth: 1
41
+ context_dim: 768
42
+ use_checkpoint: True
43
+ legacy: False
44
+
45
+ first_stage_config:
46
+ target: ldm.models.autoencoder.AutoencoderKL
47
+ params:
48
+ embed_dim: 4
49
+ monitor: val/rec_loss
50
+ ddconfig:
51
+ double_z: true
52
+ z_channels: 4
53
+ resolution: 512
54
+ in_channels: 3
55
+ out_ch: 3
56
+ ch: 128
57
+ ch_mult:
58
+ - 1
59
+ - 2
60
+ - 4
61
+ - 4
62
+ num_res_blocks: 2
63
+ attn_resolutions: []
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
67
+
68
+ cond_stage_config:
69
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
dreambooth-for-diffusion/ckpt_models/put_your_ckpt_models_here.txt ADDED
File without changes
dreambooth-for-diffusion/datasets/put_datasets_here.txt ADDED
File without changes
dreambooth-for-diffusion/other/something others.txt ADDED
File without changes
dreambooth-for-diffusion/test_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import torch
3
+ from diffusers import DDIMScheduler
4
+
5
+ model_path = "./new_model"
6
+ prompt = "a cute girl, blue eyes, brown hair"
7
+ torch.manual_seed(123123123)
8
+
9
+ pipe = StableDiffusionPipeline.from_pretrained(
10
+ model_path,
11
+ torch_dtype=torch.float16,
12
+ scheduler=DDIMScheduler(
13
+ beta_start=0.00085,
14
+ beta_end=0.012,
15
+ beta_schedule="scaled_linear",
16
+ clip_sample=False,
17
+ set_alpha_to_one=True,
18
+ ),
19
+ safety_checker=None
20
+ )
21
+
22
+ # def dummy(images, **kwargs):
23
+ # return images, False
24
+ # pipe.safety_checker = dummy
25
+ pipe = pipe.to("cuda")
26
+ images = pipe(prompt, width=512, height=512, num_inference_steps=30, num_images_per_prompt=3).images
27
+ for i, image in enumerate(images):
28
+ image.save(f"test-{i}.png")
dreambooth-for-diffusion/test_prompts_object.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ a photo of <xxx> dog
2
+ a photo of dog
dreambooth-for-diffusion/test_prompts_style.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ a cute girl, blue eyes, brown hair
2
+ a cute girl, blue eyes, blue hair
3
+ a cute boy, green eyes, brown hair
dreambooth-for-diffusion/tools/ckpt2diffusers.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the LDM checkpoints. """
16
+
17
+ import argparse
18
+ import os
19
+
20
+ import torch
21
+
22
+
23
+ try:
24
+ from omegaconf import OmegaConf
25
+ except ImportError:
26
+ raise ImportError(
27
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
28
+ )
29
+
30
+ from diffusers import (
31
+ AutoencoderKL,
32
+ DDIMScheduler,
33
+ LDMTextToImagePipeline,
34
+ LMSDiscreteScheduler,
35
+ PNDMScheduler,
36
+ StableDiffusionPipeline,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
40
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
41
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
42
+
43
+ script_path = os.path.realpath(__file__)
44
+ default_model_path = os.path.join(os.path.dirname(script_path), "diffusers-models")
45
+
46
+ def shave_segments(path, n_shave_prefix_segments=1):
47
+ """
48
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
49
+ """
50
+ if n_shave_prefix_segments >= 0:
51
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
52
+ else:
53
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
54
+
55
+
56
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
57
+ """
58
+ Updates paths inside resnets to the new naming scheme (local renaming)
59
+ """
60
+ mapping = []
61
+ for old_item in old_list:
62
+ new_item = old_item.replace("in_layers.0", "norm1")
63
+ new_item = new_item.replace("in_layers.2", "conv1")
64
+
65
+ new_item = new_item.replace("out_layers.0", "norm2")
66
+ new_item = new_item.replace("out_layers.3", "conv2")
67
+
68
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
69
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
70
+
71
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
72
+
73
+ mapping.append({"old": old_item, "new": new_item})
74
+
75
+ return mapping
76
+
77
+
78
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
79
+ """
80
+ Updates paths inside resnets to the new naming scheme (local renaming)
81
+ """
82
+ mapping = []
83
+ for old_item in old_list:
84
+ new_item = old_item
85
+
86
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
87
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
88
+
89
+ mapping.append({"old": old_item, "new": new_item})
90
+
91
+ return mapping
92
+
93
+
94
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
95
+ """
96
+ Updates paths inside attentions to the new naming scheme (local renaming)
97
+ """
98
+ mapping = []
99
+ for old_item in old_list:
100
+ new_item = old_item
101
+
102
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
103
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
104
+
105
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
106
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
107
+
108
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
109
+
110
+ mapping.append({"old": old_item, "new": new_item})
111
+
112
+ return mapping
113
+
114
+
115
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
116
+ """
117
+ Updates paths inside attentions to the new naming scheme (local renaming)
118
+ """
119
+ mapping = []
120
+ for old_item in old_list:
121
+ new_item = old_item
122
+
123
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
124
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
125
+
126
+ new_item = new_item.replace("q.weight", "query.weight")
127
+ new_item = new_item.replace("q.bias", "query.bias")
128
+
129
+ new_item = new_item.replace("k.weight", "key.weight")
130
+ new_item = new_item.replace("k.bias", "key.bias")
131
+
132
+ new_item = new_item.replace("v.weight", "value.weight")
133
+ new_item = new_item.replace("v.bias", "value.bias")
134
+
135
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
136
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
137
+
138
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
139
+
140
+ mapping.append({"old": old_item, "new": new_item})
141
+
142
+ return mapping
143
+
144
+
145
+ def assign_to_checkpoint(
146
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
147
+ ):
148
+ """
149
+ This does the final conversion step: take locally converted weights and apply a global renaming
150
+ to them. It splits attention layers, and takes into account additional replacements
151
+ that may arise.
152
+
153
+ Assigns the weights to the new checkpoint.
154
+ """
155
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
156
+
157
+ # Splits the attention layers into three variables.
158
+ if attention_paths_to_split is not None:
159
+ for path, path_map in attention_paths_to_split.items():
160
+ old_tensor = old_checkpoint[path]
161
+ channels = old_tensor.shape[0] // 3
162
+
163
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
164
+
165
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
166
+
167
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
168
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
169
+
170
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
171
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
172
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
173
+
174
+ for path in paths:
175
+ new_path = path["new"]
176
+
177
+ # These have already been assigned
178
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
179
+ continue
180
+
181
+ # Global renaming happens here
182
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
183
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
184
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
185
+
186
+ if additional_replacements is not None:
187
+ for replacement in additional_replacements:
188
+ new_path = new_path.replace(replacement["old"], replacement["new"])
189
+
190
+ # proj_attn.weight has to be converted from conv 1D to linear
191
+ if "proj_attn.weight" in new_path:
192
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
193
+ else:
194
+ checkpoint[new_path] = old_checkpoint[path["old"]]
195
+
196
+
197
+ def conv_attn_to_linear(checkpoint):
198
+ keys = list(checkpoint.keys())
199
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
200
+ for key in keys:
201
+ if ".".join(key.split(".")[-2:]) in attn_keys:
202
+ if checkpoint[key].ndim > 2:
203
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
204
+ elif "proj_attn.weight" in key:
205
+ if checkpoint[key].ndim > 2:
206
+ checkpoint[key] = checkpoint[key][:, :, 0]
207
+
208
+
209
+ def create_unet_diffusers_config(original_config):
210
+ """
211
+ Creates a config for the diffusers based on the config of the LDM model.
212
+ """
213
+ unet_params = original_config.model.params.unet_config.params
214
+
215
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
216
+
217
+ down_block_types = []
218
+ resolution = 1
219
+ for i in range(len(block_out_channels)):
220
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
221
+ down_block_types.append(block_type)
222
+ if i != len(block_out_channels) - 1:
223
+ resolution *= 2
224
+
225
+ up_block_types = []
226
+ for i in range(len(block_out_channels)):
227
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
228
+ up_block_types.append(block_type)
229
+ resolution //= 2
230
+
231
+ config = dict(
232
+ sample_size=unet_params.image_size,
233
+ in_channels=unet_params.in_channels,
234
+ out_channels=unet_params.out_channels,
235
+ down_block_types=tuple(down_block_types),
236
+ up_block_types=tuple(up_block_types),
237
+ block_out_channels=tuple(block_out_channels),
238
+ layers_per_block=unet_params.num_res_blocks,
239
+ cross_attention_dim=unet_params.context_dim,
240
+ attention_head_dim=unet_params.num_heads,
241
+ )
242
+
243
+ return config
244
+
245
+
246
+ def create_vae_diffusers_config(original_config):
247
+ """
248
+ Creates a config for the diffusers based on the config of the LDM model.
249
+ """
250
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
251
+ _ = original_config.model.params.first_stage_config.params.embed_dim
252
+
253
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
254
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
255
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
256
+
257
+ config = dict(
258
+ sample_size=vae_params.resolution,
259
+ in_channels=vae_params.in_channels,
260
+ out_channels=vae_params.out_ch,
261
+ down_block_types=tuple(down_block_types),
262
+ up_block_types=tuple(up_block_types),
263
+ block_out_channels=tuple(block_out_channels),
264
+ latent_channels=vae_params.z_channels,
265
+ layers_per_block=vae_params.num_res_blocks,
266
+ )
267
+ return config
268
+
269
+
270
+ def create_diffusers_schedular(original_config):
271
+ schedular = DDIMScheduler(
272
+ num_train_timesteps=original_config.model.params.timesteps,
273
+ beta_start=original_config.model.params.linear_start,
274
+ beta_end=original_config.model.params.linear_end,
275
+ beta_schedule="scaled_linear",
276
+ )
277
+ return schedular
278
+
279
+
280
+ def create_ldm_bert_config(original_config):
281
+ bert_params = original_config.model.parms.cond_stage_config.params
282
+ config = LDMBertConfig(
283
+ d_model=bert_params.n_embed,
284
+ encoder_layers=bert_params.n_layer,
285
+ encoder_ffn_dim=bert_params.n_embed * 4,
286
+ )
287
+ return config
288
+
289
+
290
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
291
+ """
292
+ Takes a state dict and a config, and returns a converted checkpoint.
293
+ """
294
+
295
+ # extract state_dict for UNet
296
+ unet_state_dict = {}
297
+ keys = list(checkpoint.keys())
298
+
299
+ unet_key = "model.diffusion_model."
300
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
301
+ if sum(k.startswith("model_ema") for k in keys) > 100:
302
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
303
+ if extract_ema:
304
+ print(
305
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
306
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
307
+ )
308
+ for key in keys:
309
+ if key.startswith("model.diffusion_model"):
310
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
311
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
312
+ else:
313
+ print(
314
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
315
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
316
+ )
317
+
318
+ keys = list(checkpoint.keys())
319
+ for key in keys:
320
+ if key.startswith(unet_key):
321
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
322
+
323
+ new_checkpoint = {"time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"],
324
+ "time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"],
325
+ "time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"],
326
+ "time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"],
327
+ "conv_in.weight": unet_state_dict["input_blocks.0.0.weight"],
328
+ "conv_in.bias": unet_state_dict["input_blocks.0.0.bias"],
329
+ "conv_norm_out.weight": unet_state_dict["out.0.weight"],
330
+ "conv_norm_out.bias": unet_state_dict["out.0.bias"],
331
+ "conv_out.weight": unet_state_dict["out.2.weight"],
332
+ "conv_out.bias": unet_state_dict["out.2.bias"]}
333
+
334
+ # Retrieves the keys for the input blocks only
335
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
336
+ input_blocks = {
337
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
338
+ for layer_id in range(num_input_blocks)
339
+ }
340
+
341
+ # Retrieves the keys for the middle blocks only
342
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
343
+ middle_blocks = {
344
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
345
+ for layer_id in range(num_middle_blocks)
346
+ }
347
+
348
+ # Retrieves the keys for the output blocks only
349
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
350
+ output_blocks = {
351
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
352
+ for layer_id in range(num_output_blocks)
353
+ }
354
+
355
+ for i in range(1, num_input_blocks):
356
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
357
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
358
+
359
+ resnets = [
360
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
361
+ ]
362
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
363
+
364
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
365
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
366
+ f"input_blocks.{i}.0.op.weight"
367
+ )
368
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
369
+ f"input_blocks.{i}.0.op.bias"
370
+ )
371
+
372
+ paths = renew_resnet_paths(resnets)
373
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
374
+ assign_to_checkpoint(
375
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
376
+ )
377
+
378
+ if len(attentions):
379
+ paths = renew_attention_paths(attentions)
380
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
381
+ assign_to_checkpoint(
382
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
383
+ )
384
+
385
+ resnet_0 = middle_blocks[0]
386
+ attentions = middle_blocks[1]
387
+ resnet_1 = middle_blocks[2]
388
+
389
+ resnet_0_paths = renew_resnet_paths(resnet_0)
390
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
391
+
392
+ resnet_1_paths = renew_resnet_paths(resnet_1)
393
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
394
+
395
+ attentions_paths = renew_attention_paths(attentions)
396
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
397
+ assign_to_checkpoint(
398
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
399
+ )
400
+
401
+ for i in range(num_output_blocks):
402
+ block_id = i // (config["layers_per_block"] + 1)
403
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
404
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
405
+ output_block_list = {}
406
+
407
+ for layer in output_block_layers:
408
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
409
+ if layer_id in output_block_list:
410
+ output_block_list[layer_id].append(layer_name)
411
+ else:
412
+ output_block_list[layer_id] = [layer_name]
413
+
414
+ if len(output_block_list) > 1:
415
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
416
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
417
+
418
+ resnet_0_paths = renew_resnet_paths(resnets)
419
+ paths = renew_resnet_paths(resnets)
420
+
421
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
422
+ assign_to_checkpoint(
423
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
424
+ )
425
+
426
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
427
+ index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
428
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
429
+ f"output_blocks.{i}.{index}.conv.weight"
430
+ ]
431
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
432
+ f"output_blocks.{i}.{index}.conv.bias"
433
+ ]
434
+
435
+ # Clear attentions as they have been attributed above.
436
+ if len(attentions) == 2:
437
+ attentions = []
438
+
439
+ if len(attentions):
440
+ paths = renew_attention_paths(attentions)
441
+ meta_path = {
442
+ "old": f"output_blocks.{i}.1",
443
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
444
+ }
445
+ assign_to_checkpoint(
446
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
447
+ )
448
+ else:
449
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
450
+ for path in resnet_0_paths:
451
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
452
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
453
+
454
+ new_checkpoint[new_path] = unet_state_dict[old_path]
455
+
456
+ return new_checkpoint
457
+
458
+ def convert_ldm_vae_checkpoint(checkpoint, config):
459
+ # extract state dict for VAE
460
+ vae_state_dict = {}
461
+ vae_key = "first_stage_model."
462
+ keys = list(checkpoint.keys())
463
+ for key in keys:
464
+ if key.startswith(vae_key):
465
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
466
+
467
+ new_checkpoint = {}
468
+
469
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
470
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
471
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
472
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
473
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
474
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
475
+
476
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
477
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
478
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
479
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
480
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
481
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
482
+
483
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
484
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
485
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
486
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
487
+
488
+
489
+ # Retrieves the keys for the encoder down blocks only
490
+ num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'encoder.down' in layer})
491
+ down_blocks = {layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
492
+
493
+ # Retrieves the keys for the decoder up blocks only
494
+ num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'decoder.up' in layer})
495
+ up_blocks = {layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
496
+
497
+
498
+ for i in range(num_down_blocks):
499
+ resnets = [key for key in down_blocks[i] if f'down.{i}' in key and f"down.{i}.downsample" not in key]
500
+
501
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
502
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
503
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
504
+
505
+ paths = renew_vae_resnet_paths(resnets)
506
+ meta_path = {'old': f'down.{i}.block', 'new': f'down_blocks.{i}.resnets'}
507
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
508
+
509
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
510
+ num_mid_res_blocks = 2
511
+ for i in range(1, num_mid_res_blocks + 1):
512
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
513
+
514
+ paths = renew_vae_resnet_paths(resnets)
515
+ meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
516
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
517
+
518
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
519
+ paths = renew_vae_attention_paths(mid_attentions)
520
+ meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
521
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
522
+ conv_attn_to_linear(new_checkpoint)
523
+
524
+ for i in range(num_up_blocks):
525
+ block_id = num_up_blocks - 1 - i
526
+ resnets = [key for key in up_blocks[block_id] if f'up.{block_id}' in key and f"up.{block_id}.upsample" not in key]
527
+
528
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
529
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
530
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
531
+
532
+ paths = renew_vae_resnet_paths(resnets)
533
+ meta_path = {'old': f'up.{block_id}.block', 'new': f'up_blocks.{i}.resnets'}
534
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
535
+
536
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
537
+ num_mid_res_blocks = 2
538
+ for i in range(1, num_mid_res_blocks + 1):
539
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
540
+
541
+ paths = renew_vae_resnet_paths(resnets)
542
+ meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
543
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
544
+
545
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
546
+ paths = renew_vae_attention_paths(mid_attentions)
547
+ meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
548
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
549
+ conv_attn_to_linear(new_checkpoint)
550
+ return new_checkpoint
551
+
552
+
553
+ def convert_ldm_vae(vae_path, config):
554
+ vae_state_dict = torch.load(vae_path)['state_dict']
555
+
556
+ new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
557
+ "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
558
+ "encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"],
559
+ "encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"],
560
+ "encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"],
561
+ "encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"],
562
+ "decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"],
563
+ "decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"],
564
+ "decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"],
565
+ "decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"],
566
+ "decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"],
567
+ "decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"],
568
+ "quant_conv.weight": vae_state_dict["quant_conv.weight"],
569
+ "quant_conv.bias": vae_state_dict["quant_conv.bias"],
570
+ "post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"],
571
+ "post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"]}
572
+
573
+
574
+ # Retrieves the keys for the encoder down blocks only
575
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
576
+ down_blocks = {
577
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
578
+ }
579
+
580
+ # Retrieves the keys for the decoder up blocks only
581
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
582
+ up_blocks = {
583
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
584
+ }
585
+
586
+ for i in range(num_down_blocks):
587
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
588
+
589
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
590
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
591
+ f"encoder.down.{i}.downsample.conv.weight"
592
+ )
593
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
594
+ f"encoder.down.{i}.downsample.conv.bias"
595
+ )
596
+
597
+ paths = renew_vae_resnet_paths(resnets)
598
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
599
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
600
+
601
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
602
+ num_mid_res_blocks = 2
603
+ for i in range(1, num_mid_res_blocks + 1):
604
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
605
+
606
+ paths = renew_vae_resnet_paths(resnets)
607
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
608
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
609
+
610
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
611
+ paths = renew_vae_attention_paths(mid_attentions)
612
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
613
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
614
+ conv_attn_to_linear(new_checkpoint)
615
+
616
+ for i in range(num_up_blocks):
617
+ block_id = num_up_blocks - 1 - i
618
+ resnets = [
619
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
620
+ ]
621
+
622
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
623
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
624
+ f"decoder.up.{block_id}.upsample.conv.weight"
625
+ ]
626
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
627
+ f"decoder.up.{block_id}.upsample.conv.bias"
628
+ ]
629
+
630
+ paths = renew_vae_resnet_paths(resnets)
631
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
632
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
633
+
634
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
635
+ num_mid_res_blocks = 2
636
+ for i in range(1, num_mid_res_blocks + 1):
637
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
638
+
639
+ paths = renew_vae_resnet_paths(resnets)
640
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
641
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
642
+
643
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
644
+ paths = renew_vae_attention_paths(mid_attentions)
645
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
646
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
647
+ conv_attn_to_linear(new_checkpoint)
648
+ return new_checkpoint
649
+
650
+
651
+ def convert_ldm_bert_checkpoint(checkpoint, config):
652
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
653
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
654
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
655
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
656
+
657
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
658
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
659
+
660
+ def _copy_linear(hf_linear, pt_linear):
661
+ hf_linear.weight = pt_linear.weight
662
+ hf_linear.bias = pt_linear.bias
663
+
664
+ def _copy_layer(hf_layer, pt_layer):
665
+ # copy layer norms
666
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
667
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
668
+
669
+ # copy attn
670
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
671
+
672
+ # copy MLP
673
+ pt_mlp = pt_layer[1][1]
674
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
675
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
676
+
677
+ def _copy_layers(hf_layers, pt_layers):
678
+ for i, hf_layer in enumerate(hf_layers):
679
+ if i != 0:
680
+ i += i
681
+ pt_layer = pt_layers[i : i + 2]
682
+ _copy_layer(hf_layer, pt_layer)
683
+
684
+ hf_model = LDMBertModel(config).eval()
685
+
686
+ # copy embeds
687
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
688
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
689
+
690
+ # copy layer norm
691
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
692
+
693
+ # copy hidden layers
694
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
695
+
696
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
697
+
698
+ return hf_model
699
+
700
+
701
+ def convert_ldm_clip_checkpoint(checkpoint):
702
+ if os.path.exists(default_model_path):
703
+ text_model = CLIPTextModel.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
704
+ else:
705
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
706
+
707
+ keys = list(checkpoint.keys())
708
+
709
+ text_model_dict = {}
710
+
711
+ for key in keys:
712
+ if key.startswith("cond_stage_model.transformer"):
713
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
714
+
715
+ text_model.load_state_dict(text_model_dict, strict=False)
716
+
717
+ return text_model
718
+
719
+
720
+ if __name__ == "__main__":
721
+ parser = argparse.ArgumentParser()
722
+
723
+ parser.add_argument(
724
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
725
+ )
726
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
727
+ parser.add_argument(
728
+ "--vae_path", default=None, type=str, help="Path to the vae to convert."
729
+ )
730
+ # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
731
+ parser.add_argument(
732
+ "--original_config_file",
733
+ default=None,
734
+ type=str,
735
+ help="The YAML config file corresponding to the original architecture.",
736
+ )
737
+ parser.add_argument(
738
+ "--scheduler_type",
739
+ default="pndm",
740
+ type=str,
741
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
742
+ )
743
+ parser.add_argument(
744
+ "--extract_ema",
745
+ action="store_true",
746
+ default=False,
747
+ help=(
748
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
749
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
750
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
751
+ ),
752
+ )
753
+
754
+ args = parser.parse_args()
755
+
756
+ if args.original_config_file is None:
757
+ os.system(
758
+ "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
759
+ )
760
+ args.original_config_file = "./v1-inference.yaml"
761
+
762
+ original_config = OmegaConf.load(args.original_config_file)
763
+ checkpoint = torch.load(args.checkpoint_path, map_location="cuda")
764
+ checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
765
+
766
+ num_train_timesteps = original_config.model.params.timesteps
767
+ beta_start = original_config.model.params.linear_start
768
+ beta_end = original_config.model.params.linear_end
769
+ if args.scheduler_type == "pndm":
770
+ scheduler = PNDMScheduler(
771
+ beta_end=beta_end,
772
+ beta_schedule="scaled_linear",
773
+ beta_start=beta_start,
774
+ num_train_timesteps=num_train_timesteps,
775
+ skip_prk_steps=True,
776
+ )
777
+ elif args.scheduler_type == "lms":
778
+ scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
779
+ elif args.scheduler_type == "ddim":
780
+ scheduler = DDIMScheduler(
781
+ beta_start=beta_start,
782
+ beta_end=beta_end,
783
+ beta_schedule="scaled_linear",
784
+ clip_sample=False,
785
+ set_alpha_to_one=False,
786
+ )
787
+ else:
788
+ raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
789
+
790
+ # Convert the UNet2DConditionModel model.
791
+ unet_config = create_unet_diffusers_config(original_config)
792
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
793
+ checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
794
+ )
795
+
796
+ unet = UNet2DConditionModel(**unet_config)
797
+ unet.load_state_dict(converted_unet_checkpoint)
798
+
799
+ # Convert the VAE model.
800
+ if args.vae_path:
801
+ vae_config = create_vae_diffusers_config(original_config)
802
+ converted_vae_checkpoint = convert_ldm_vae(args.vae_path, vae_config)
803
+ else:
804
+ vae_config = create_vae_diffusers_config(original_config)
805
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
806
+
807
+ vae = AutoencoderKL(**vae_config)
808
+ vae.load_state_dict(converted_vae_checkpoint)
809
+
810
+ # Convert the text model.
811
+ text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
812
+ if text_model_type == "FrozenCLIPEmbedder":
813
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
814
+ if os.path.exists(default_model_path):
815
+ tokenizer = CLIPTokenizer.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
816
+ else:
817
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
818
+ #safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
819
+ #feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
820
+ pipe = StableDiffusionPipeline(
821
+ vae=vae,
822
+ text_encoder=text_model,
823
+ tokenizer=tokenizer,
824
+ unet=unet,
825
+ scheduler=scheduler,
826
+ safety_checker=None,
827
+ feature_extractor=None,
828
+ )
829
+ else:
830
+ text_config = create_ldm_bert_config(original_config)
831
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
832
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
833
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
834
+
835
+ pipe.save_pretrained(args.dump_path)
dreambooth-for-diffusion/tools/ckpt2diffusers_old.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the LDM checkpoints. """
16
+
17
+ import argparse, os
18
+ import torch
19
+
20
+ try:
21
+ from omegaconf import OmegaConf
22
+ except ImportError:
23
+ raise ImportError("OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`.")
24
+
25
+ from transformers import BertTokenizerFast, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel
26
+ from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler
27
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel, LDMBertConfig
28
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
29
+
30
+ def shave_segments(path, n_shave_prefix_segments=1):
31
+ """
32
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
33
+ """
34
+ if n_shave_prefix_segments >= 0:
35
+ return '.'.join(path.split('.')[n_shave_prefix_segments:])
36
+ else:
37
+ return '.'.join(path.split('.')[:n_shave_prefix_segments])
38
+
39
+
40
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
41
+ """
42
+ Updates paths inside resnets to the new naming scheme (local renaming)
43
+ """
44
+ mapping = []
45
+ for old_item in old_list:
46
+ new_item = old_item.replace('in_layers.0', 'norm1')
47
+ new_item = new_item.replace('in_layers.2', 'conv1')
48
+
49
+ new_item = new_item.replace('out_layers.0', 'norm2')
50
+ new_item = new_item.replace('out_layers.3', 'conv2')
51
+
52
+ new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
53
+ new_item = new_item.replace('skip_connection', 'conv_shortcut')
54
+
55
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
56
+
57
+ mapping.append({'old': old_item, 'new': new_item})
58
+
59
+ return mapping
60
+
61
+
62
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
63
+ """
64
+ Updates paths inside resnets to the new naming scheme (local renaming)
65
+ """
66
+ mapping = []
67
+ for old_item in old_list:
68
+ new_item = old_item
69
+
70
+ new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({'old': old_item, 'new': new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside attentions to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+
87
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
88
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
89
+
90
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
91
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
92
+
93
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
94
+
95
+ mapping.append({'old': old_item, 'new': new_item})
96
+
97
+ return mapping
98
+
99
+
100
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
101
+ """
102
+ Updates paths inside attentions to the new naming scheme (local renaming)
103
+ """
104
+ mapping = []
105
+ for old_item in old_list:
106
+ new_item = old_item
107
+
108
+ new_item = new_item.replace('norm.weight', 'group_norm.weight')
109
+ new_item = new_item.replace('norm.bias', 'group_norm.bias')
110
+
111
+ new_item = new_item.replace('q.weight', 'query.weight')
112
+ new_item = new_item.replace('q.bias', 'query.bias')
113
+
114
+ new_item = new_item.replace('k.weight', 'key.weight')
115
+ new_item = new_item.replace('k.bias', 'key.bias')
116
+
117
+ new_item = new_item.replace('v.weight', 'value.weight')
118
+ new_item = new_item.replace('v.bias', 'value.bias')
119
+
120
+ new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
121
+ new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
122
+
123
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
124
+
125
+ mapping.append({'old': old_item, 'new': new_item})
126
+
127
+ return mapping
128
+
129
+
130
+ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
131
+ """
132
+ This does the final conversion step: take locally converted weights and apply a global renaming
133
+ to them. It splits attention layers, and takes into account additional replacements
134
+ that may arise.
135
+
136
+ Assigns the weights to the new checkpoint.
137
+ """
138
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
139
+
140
+ # Splits the attention layers into three variables.
141
+ if attention_paths_to_split is not None:
142
+ for path, path_map in attention_paths_to_split.items():
143
+ old_tensor = old_checkpoint[path]
144
+ channels = old_tensor.shape[0] // 3
145
+
146
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
147
+
148
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
149
+
150
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
151
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
152
+
153
+ checkpoint[path_map['query']] = query.reshape(target_shape)
154
+ checkpoint[path_map['key']] = key.reshape(target_shape)
155
+ checkpoint[path_map['value']] = value.reshape(target_shape)
156
+
157
+ for path in paths:
158
+ new_path = path['new']
159
+
160
+ # These have already been assigned
161
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
162
+ continue
163
+
164
+ # Global renaming happens here
165
+ new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0')
166
+ new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0')
167
+ new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1')
168
+
169
+ if additional_replacements is not None:
170
+ for replacement in additional_replacements:
171
+ new_path = new_path.replace(replacement['old'], replacement['new'])
172
+
173
+ # proj_attn.weight has to be converted from conv 1D to linear
174
+ if "proj_attn.weight" in new_path:
175
+ checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
176
+ else:
177
+ checkpoint[new_path] = old_checkpoint[path['old']]
178
+
179
+
180
+ def conv_attn_to_linear(checkpoint):
181
+ keys = list(checkpoint.keys())
182
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
183
+ for key in keys:
184
+ if ".".join(key.split(".")[-2:]) in attn_keys:
185
+ if checkpoint[key].ndim > 2:
186
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
187
+ elif "proj_attn.weight" in key:
188
+ if checkpoint[key].ndim > 2:
189
+ checkpoint[key] = checkpoint[key][:, :, 0]
190
+
191
+
192
+ def create_unet_diffusers_config(original_config):
193
+ """
194
+ Creates a config for the diffusers based on the config of the LDM model.
195
+ """
196
+ unet_params = original_config.model.params.unet_config.params
197
+
198
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
199
+
200
+ down_block_types = []
201
+ resolution = 1
202
+ for i in range(len(block_out_channels)):
203
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
204
+ down_block_types.append(block_type)
205
+ if i != len(block_out_channels) - 1:
206
+ resolution *= 2
207
+
208
+ up_block_types = []
209
+ for i in range(len(block_out_channels)):
210
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
211
+ up_block_types.append(block_type)
212
+ resolution //= 2
213
+
214
+ config = dict(
215
+ sample_size=unet_params.image_size,
216
+ in_channels=unet_params.in_channels,
217
+ out_channels=unet_params.out_channels,
218
+ down_block_types=tuple(down_block_types),
219
+ up_block_types=tuple(up_block_types),
220
+ block_out_channels=tuple(block_out_channels),
221
+ layers_per_block=unet_params.num_res_blocks,
222
+ cross_attention_dim=unet_params.context_dim,
223
+ attention_head_dim=unet_params.num_heads,
224
+ )
225
+
226
+ return config
227
+
228
+
229
+ def create_vae_diffusers_config(original_config):
230
+ """
231
+ Creates a config for the diffusers based on the config of the LDM model.
232
+ """
233
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
234
+ latent_channles = original_config.model.params.first_stage_config.params.embed_dim
235
+
236
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
237
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
238
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
239
+
240
+ config = dict(
241
+ sample_size=vae_params.resolution,
242
+ in_channels=vae_params.in_channels,
243
+ out_channels=vae_params.out_ch,
244
+ down_block_types=tuple(down_block_types),
245
+ up_block_types=tuple(up_block_types),
246
+ block_out_channels=tuple(block_out_channels),
247
+ latent_channels=vae_params.z_channels,
248
+ layers_per_block=vae_params.num_res_blocks,
249
+ )
250
+ return config
251
+
252
+
253
+ def create_diffusers_schedular(original_config):
254
+ schedular = DDIMScheduler(
255
+ num_train_timesteps=original_config.model.params.timesteps,
256
+ beta_start=original_config.model.params.linear_start,
257
+ beta_end=original_config.model.params.linear_end,
258
+ beta_schedule="scaled_linear",
259
+ )
260
+ return schedular
261
+
262
+
263
+ def create_ldm_bert_config(original_config):
264
+ bert_params = original_config.model.parms.cond_stage_config.params
265
+ config = LDMBertConfig(
266
+ d_model=bert_params.n_embed,
267
+ encoder_layers=bert_params.n_layer,
268
+ encoder_ffn_dim=bert_params.n_embed * 4,
269
+ )
270
+ return config
271
+
272
+
273
+ def convert_ldm_unet_checkpoint(checkpoint, config):
274
+ """
275
+ Takes a state dict and a config, and returns a converted checkpoint.
276
+ """
277
+
278
+ # extract state_dict for UNet
279
+ unet_state_dict = {}
280
+ unet_key = "model.diffusion_model."
281
+ keys = list(checkpoint.keys())
282
+ for key in keys:
283
+ if key.startswith(unet_key):
284
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
285
+
286
+ new_checkpoint = {}
287
+
288
+ new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict['time_embed.0.weight']
289
+ new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict['time_embed.0.bias']
290
+ new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict['time_embed.2.weight']
291
+ new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict['time_embed.2.bias']
292
+
293
+ new_checkpoint['conv_in.weight'] = unet_state_dict['input_blocks.0.0.weight']
294
+ new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias']
295
+
296
+ new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight']
297
+ new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias']
298
+ new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight']
299
+ new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias']
300
+
301
+ # Retrieves the keys for the input blocks only
302
+ num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'input_blocks' in layer})
303
+ input_blocks = {layer_id: [key for key in unet_state_dict if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)}
304
+
305
+ # Retrieves the keys for the middle blocks only
306
+ num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'middle_block' in layer})
307
+ middle_blocks = {layer_id: [key for key in unet_state_dict if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)}
308
+
309
+ # Retrieves the keys for the output blocks only
310
+ num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'output_blocks' in layer})
311
+ output_blocks = {layer_id: [key for key in unet_state_dict if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)}
312
+
313
+ for i in range(1, num_input_blocks):
314
+ block_id = (i - 1) // (config['layers_per_block'] + 1)
315
+ layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1)
316
+
317
+ resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key and f'input_blocks.{i}.0.op' not in key]
318
+ attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key]
319
+
320
+ if f'input_blocks.{i}.0.op.weight' in unet_state_dict:
321
+ new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.weight'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight')
322
+ new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.bias'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias')
323
+
324
+ paths = renew_resnet_paths(resnets)
325
+ meta_path = {'old': f'input_blocks.{i}.0', 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}'}
326
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
327
+
328
+ if len(attentions):
329
+ paths = renew_attention_paths(attentions)
330
+ meta_path = {'old': f'input_blocks.{i}.1', 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}'}
331
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
332
+
333
+
334
+ resnet_0 = middle_blocks[0]
335
+ attentions = middle_blocks[1]
336
+ resnet_1 = middle_blocks[2]
337
+
338
+ resnet_0_paths = renew_resnet_paths(resnet_0)
339
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
340
+
341
+ resnet_1_paths = renew_resnet_paths(resnet_1)
342
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
343
+
344
+ attentions_paths = renew_attention_paths(attentions)
345
+ meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'}
346
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
347
+
348
+ for i in range(num_output_blocks):
349
+ block_id = i // (config['layers_per_block'] + 1)
350
+ layer_in_block_id = i % (config['layers_per_block'] + 1)
351
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
352
+ output_block_list = {}
353
+
354
+ for layer in output_block_layers:
355
+ layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1)
356
+ if layer_id in output_block_list:
357
+ output_block_list[layer_id].append(layer_name)
358
+ else:
359
+ output_block_list[layer_id] = [layer_name]
360
+
361
+ if len(output_block_list) > 1:
362
+ resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key]
363
+ attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key]
364
+
365
+ resnet_0_paths = renew_resnet_paths(resnets)
366
+ paths = renew_resnet_paths(resnets)
367
+
368
+ meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
369
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
370
+
371
+ if ['conv.weight', 'conv.bias'] in output_block_list.values():
372
+ index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
373
+ new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight']
374
+ new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias']
375
+
376
+ # Clear attentions as they have been attributed above.
377
+ if len(attentions) == 2:
378
+ attentions = []
379
+
380
+ if len(attentions):
381
+ paths = renew_attention_paths(attentions)
382
+ meta_path = {
383
+ 'old': f'output_blocks.{i}.1',
384
+ 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
385
+ }
386
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
387
+ else:
388
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
389
+ for path in resnet_0_paths:
390
+ old_path = '.'.join(['output_blocks', str(i), path['old']])
391
+ new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
392
+
393
+ new_checkpoint[new_path] = unet_state_dict[old_path]
394
+
395
+ return new_checkpoint
396
+
397
+
398
+ def convert_ldm_vae_checkpoint(checkpoint, config):
399
+ # extract state dict for VAE
400
+ vae_state_dict = {}
401
+ vae_key = "first_stage_model."
402
+ keys = list(checkpoint.keys())
403
+ for key in keys:
404
+ if key.startswith(vae_key):
405
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
406
+
407
+ new_checkpoint = {}
408
+
409
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
410
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
411
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
412
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
413
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
414
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
415
+
416
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
417
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
418
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
419
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
420
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
421
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
422
+
423
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
424
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
425
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
426
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
427
+
428
+
429
+ # Retrieves the keys for the encoder down blocks only
430
+ num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'encoder.down' in layer})
431
+ down_blocks = {layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
432
+
433
+ # Retrieves the keys for the decoder up blocks only
434
+ num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'decoder.up' in layer})
435
+ up_blocks = {layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
436
+
437
+
438
+ for i in range(num_down_blocks):
439
+ resnets = [key for key in down_blocks[i] if f'down.{i}' in key and f"down.{i}.downsample" not in key]
440
+
441
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
442
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
443
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
444
+
445
+ paths = renew_vae_resnet_paths(resnets)
446
+ meta_path = {'old': f'down.{i}.block', 'new': f'down_blocks.{i}.resnets'}
447
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
448
+
449
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
450
+ num_mid_res_blocks = 2
451
+ for i in range(1, num_mid_res_blocks + 1):
452
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
453
+
454
+ paths = renew_vae_resnet_paths(resnets)
455
+ meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
456
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
457
+
458
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
459
+ paths = renew_vae_attention_paths(mid_attentions)
460
+ meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
461
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
462
+ conv_attn_to_linear(new_checkpoint)
463
+
464
+ for i in range(num_up_blocks):
465
+ block_id = num_up_blocks - 1 - i
466
+ resnets = [key for key in up_blocks[block_id] if f'up.{block_id}' in key and f"up.{block_id}.upsample" not in key]
467
+
468
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
469
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
470
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
471
+
472
+ paths = renew_vae_resnet_paths(resnets)
473
+ meta_path = {'old': f'up.{block_id}.block', 'new': f'up_blocks.{i}.resnets'}
474
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
475
+
476
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
477
+ num_mid_res_blocks = 2
478
+ for i in range(1, num_mid_res_blocks + 1):
479
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
480
+
481
+ paths = renew_vae_resnet_paths(resnets)
482
+ meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
483
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
484
+
485
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
486
+ paths = renew_vae_attention_paths(mid_attentions)
487
+ meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
488
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
489
+ conv_attn_to_linear(new_checkpoint)
490
+ return new_checkpoint
491
+
492
+
493
+ def convert_ldm_bert_checkpoint(checkpoint, config):
494
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
495
+
496
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
497
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
498
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
499
+
500
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
501
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
502
+
503
+
504
+ def _copy_linear(hf_linear, pt_linear):
505
+ hf_linear.weight = pt_linear.weight
506
+ hf_linear.bias = pt_linear.bias
507
+
508
+
509
+ def _copy_layer(hf_layer, pt_layer):
510
+ # copy layer norms
511
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
512
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
513
+
514
+ # copy attn
515
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
516
+
517
+ # copy MLP
518
+ pt_mlp = pt_layer[1][1]
519
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
520
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
521
+
522
+
523
+ def _copy_layers(hf_layers, pt_layers):
524
+ for i, hf_layer in enumerate(hf_layers):
525
+ if i != 0: i += i
526
+ pt_layer = pt_layers[i:i+2]
527
+ _copy_layer(hf_layer, pt_layer)
528
+
529
+ hf_model = LDMBertModel(config).eval()
530
+
531
+ # copy embeds
532
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
533
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
534
+
535
+ # copy layer norm
536
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
537
+
538
+ # copy hidden layers
539
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
540
+
541
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
542
+
543
+ return hf_model
544
+
545
+
546
+
547
+ if __name__ == "__main__":
548
+ parser = argparse.ArgumentParser()
549
+
550
+ parser.add_argument(
551
+ "checkpoint_path", default='./model.ckpt', type=str, help="Path to the checkpoint to convert."
552
+ )
553
+
554
+
555
+ parser.add_argument(
556
+ "dump_path", default='./model', type=str, help="Path to the output model."
557
+ )
558
+
559
+ parser.add_argument(
560
+ "--original_config_file",
561
+ default='./ckpt_models/model.yaml',
562
+ type=str,
563
+ required=False,
564
+ help="The YAML config file corresponding to the original architecture.",
565
+ )
566
+
567
+ args = parser.parse_args()
568
+
569
+ original_config = OmegaConf.load(args.original_config_file)
570
+
571
+ checkpoint = torch.load(args.checkpoint_path)["state_dict"]
572
+
573
+ # Convert the UNet2DConditionModel model.
574
+ unet_config = create_unet_diffusers_config(original_config)
575
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
576
+
577
+ unet = UNet2DConditionModel(**unet_config)
578
+ unet.load_state_dict(converted_unet_checkpoint)
579
+
580
+ # Convert the VAE model.
581
+ vae_config = create_vae_diffusers_config(original_config)
582
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
583
+
584
+ vae = AutoencoderKL(**vae_config)
585
+ vae.load_state_dict(converted_vae_checkpoint)
586
+
587
+
588
+
589
+ # Convert the text model.
590
+ text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
591
+
592
+ script_path = os.path.realpath(__file__)
593
+ default_model_path = os.path.join(os.path.dirname(script_path), "diffusers-models")
594
+
595
+ try:
596
+ text_model = CLIPTextModel.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
597
+ tokenizer = CLIPTokenizer.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14"))
598
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(os.path.join(default_model_path, "safety-checker"))
599
+
600
+ except Exception as e:
601
+ print(e)
602
+ print("Could not load the default text model. Auto downloading...")
603
+ if text_model_type == "FrozenCLIPEmbedder":
604
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
605
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
606
+ else:
607
+ # TODO: update the convert function to use the state_dict without the model instance.
608
+ text_config = create_ldm_bert_config(original_config)
609
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
610
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
611
+
612
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker')
613
+
614
+ scheduler = create_diffusers_schedular(original_config)
615
+
616
+ scheduler = create_diffusers_schedular(original_config)
617
+ feature_extractor = CLIPFeatureExtractor()
618
+ pipe = StableDiffusionPipeline(vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor)
619
+ pipe.save_pretrained(args.dump_path)
dreambooth-for-diffusion/tools/ckpt_merge.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ parser = argparse.ArgumentParser(description="Merge two models")
7
+ parser.add_argument("model_0", type=str, help="Path to model 0")
8
+ parser.add_argument("model_1", type=str, help="Path to model 1")
9
+ parser.add_argument("--alpha", type=float, help="Alpha value, optional, defaults to 0.5", default=0.5, required=False)
10
+ parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False)
11
+ parser.add_argument("--device", type=str, help="Device to use, defaults to cpu", default="cpu", required=False)
12
+ parser.add_argument("--without_vae", action="store_true", help="Do not merge VAE", required=False)
13
+
14
+ args = parser.parse_args()
15
+
16
+ device = args.device
17
+ model_0 = torch.load(args.model_0, map_location=device)
18
+ model_1 = torch.load(args.model_1, map_location=device)
19
+ theta_0 = model_0["state_dict"]
20
+ theta_1 = model_1["state_dict"]
21
+ alpha = args.alpha
22
+
23
+ output_file = f'{args.output}-{str(alpha)[2:] + "0"}.ckpt'
24
+
25
+ # check if output file already exists, ask to overwrite
26
+ if os.path.isfile(output_file):
27
+ print("Output file already exists. Overwrite? (y/n)")
28
+ while True:
29
+ overwrite = input()
30
+ if overwrite == "y":
31
+ break
32
+ elif overwrite == "n":
33
+ print("Exiting...")
34
+ exit()
35
+ else:
36
+ print("Please enter y or n")
37
+
38
+
39
+ for key in tqdm(theta_0.keys(), desc="Stage 1/2"):
40
+ # skip VAE model parameters to get better results(tested for anime models)
41
+ # for anime model,with merging VAE model, the result will be worse (dark and blurry)
42
+ if args.without_vae and "first_stage_model" in key:
43
+ continue
44
+
45
+ if "model" in key and key in theta_1:
46
+ theta_0[key] = (1 - alpha) * theta_0[key] + alpha * theta_1[key]
47
+
48
+ for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
49
+ if "model" in key and key not in theta_0:
50
+ theta_0[key] = theta_1[key]
51
+
52
+ print("Saving...")
53
+
54
+ torch.save({"state_dict": theta_0}, output_file)
55
+
56
+ print("Done!")
dreambooth-for-diffusion/tools/ckpt_prune.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sd = torch.load(model_path, map_location="cpu")
2
+ if "state_dict" not in sd:
3
+ pruned_sd = {
4
+ "state_dict": dict(),
5
+ }
6
+ else:
7
+ pruned_sd = dict()
8
+ for k in sd.keys():
9
+ if k != "optimizer_states":
10
+ if "state_dict" not in sd:
11
+ pruned_sd["state_dict"][k] = sd[k]
12
+ else:
13
+ pruned_sd[k] = sd[k]
14
+ torch.save(pruned_sd, "model-pruned.ckpt")
dreambooth-for-diffusion/tools/deepdanbooru-models/put_deepdanbooru_model_here.txt ADDED
File without changes
dreambooth-for-diffusion/tools/diagnose_tensorboard.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Self-diagnosis script for TensorBoard.
16
+
17
+ Instructions: Save this script to your local machine, then execute it in
18
+ the same environment (virtualenv, Conda, etc.) from which you normally
19
+ run TensorBoard. Read the output and follow the directions.
20
+ """
21
+
22
+
23
+ # This script may only depend on the Python standard library. It is not
24
+ # built with Bazel and should not assume any third-party dependencies.
25
+ import dataclasses
26
+ import errno
27
+ import functools
28
+ import hashlib
29
+ import inspect
30
+ import logging
31
+ import os
32
+ import pipes
33
+ import shlex
34
+ import socket
35
+ import subprocess
36
+ import sys
37
+ import tempfile
38
+ import textwrap
39
+ import traceback
40
+
41
+
42
+ # A *check* is a function (of no arguments) that performs a diagnostic,
43
+ # writes log messages, and optionally yields suggestions. Each check
44
+ # runs in isolation; exceptions will be caught and reported.
45
+ CHECKS = []
46
+
47
+
48
+ @dataclasses.dataclass(frozen=True)
49
+ class Suggestion:
50
+ """A suggestion to the end user.
51
+
52
+ Attributes:
53
+ headline: A short description, like "Turn it off and on again". Should be
54
+ imperative with no trailing punctuation. May contain inline Markdown.
55
+ description: A full enumeration of the steps that the user should take to
56
+ accept the suggestion. Within this string, prose should be formatted
57
+ with `reflow`. May contain Markdown.
58
+ """
59
+
60
+ headline: str
61
+ description: str
62
+
63
+
64
+ def check(fn):
65
+ """Decorator to register a function as a check.
66
+
67
+ Checks are run in the order in which they are registered.
68
+
69
+ Args:
70
+ fn: A function that takes no arguments and either returns `None` or
71
+ returns a generator of `Suggestion`s. (The ability to return
72
+ `None` is to work around the awkwardness of defining empty
73
+ generator functions in Python.)
74
+
75
+ Returns:
76
+ A wrapped version of `fn` that returns a generator of `Suggestion`s.
77
+ """
78
+
79
+ @functools.wraps(fn)
80
+ def wrapper():
81
+ result = fn()
82
+ return iter(()) if result is None else result
83
+
84
+ CHECKS.append(wrapper)
85
+ return wrapper
86
+
87
+
88
+ def reflow(paragraph):
89
+ return textwrap.fill(textwrap.dedent(paragraph).strip())
90
+
91
+
92
+ def pip(args):
93
+ """Invoke command-line Pip with the specified args.
94
+
95
+ Returns:
96
+ A bytestring containing the output of Pip.
97
+ """
98
+ # Suppress the Python 2.7 deprecation warning.
99
+ PYTHONWARNINGS_KEY = "PYTHONWARNINGS"
100
+ old_pythonwarnings = os.environ.get(PYTHONWARNINGS_KEY)
101
+ new_pythonwarnings = "%s%s" % (
102
+ "ignore:DEPRECATION",
103
+ ",%s" % old_pythonwarnings if old_pythonwarnings else "",
104
+ )
105
+ command = [sys.executable, "-m", "pip", "--disable-pip-version-check"]
106
+ command.extend(args)
107
+ try:
108
+ os.environ[PYTHONWARNINGS_KEY] = new_pythonwarnings
109
+ return subprocess.check_output(command)
110
+ finally:
111
+ if old_pythonwarnings is None:
112
+ del os.environ[PYTHONWARNINGS_KEY]
113
+ else:
114
+ os.environ[PYTHONWARNINGS_KEY] = old_pythonwarnings
115
+
116
+
117
+ def which(name):
118
+ """Return the path to a binary, or `None` if it's not on the path.
119
+
120
+ Returns:
121
+ A bytestring.
122
+ """
123
+ binary = "where" if os.name == "nt" else "which"
124
+ try:
125
+ return subprocess.check_output([binary, name])
126
+ except subprocess.CalledProcessError:
127
+ return None
128
+
129
+
130
+ def sgetattr(attr, default):
131
+ """Get an attribute off the `socket` module, or use a default."""
132
+ sentinel = object()
133
+ result = getattr(socket, attr, sentinel)
134
+ if result is sentinel:
135
+ print("socket.%s does not exist" % attr)
136
+ return default
137
+ else:
138
+ print("socket.%s = %r" % (attr, result))
139
+ return result
140
+
141
+
142
+ @check
143
+ def autoidentify():
144
+ """Print the Git hash of this version of `diagnose_tensorboard.py`.
145
+
146
+ Given this hash, use `git cat-file blob HASH` to recover the
147
+ relevant version of the script.
148
+ """
149
+ module = sys.modules[__name__]
150
+ try:
151
+ source = inspect.getsource(module).encode("utf-8")
152
+ except TypeError:
153
+ logging.info("diagnose_tensorboard.py source unavailable")
154
+ else:
155
+ # Git inserts a length-prefix before hashing; cf. `git-hash-object`.
156
+ blob = b"blob %d\0%s" % (len(source), source)
157
+ hash = hashlib.sha1(blob).hexdigest()
158
+ logging.info("diagnose_tensorboard.py version %s", hash)
159
+
160
+
161
+ @check
162
+ def general():
163
+ logging.info("sys.version_info: %s", sys.version_info)
164
+ logging.info("os.name: %s", os.name)
165
+ na = type("N/A", (object,), {"__repr__": lambda self: "N/A"})
166
+ logging.info(
167
+ "os.uname(): %r",
168
+ getattr(os, "uname", na)(),
169
+ )
170
+ logging.info(
171
+ "sys.getwindowsversion(): %r",
172
+ getattr(sys, "getwindowsversion", na)(),
173
+ )
174
+
175
+
176
+ @check
177
+ def package_management():
178
+ conda_meta = os.path.join(sys.prefix, "conda-meta")
179
+ logging.info("has conda-meta: %s", os.path.exists(conda_meta))
180
+ logging.info("$VIRTUAL_ENV: %r", os.environ.get("VIRTUAL_ENV"))
181
+
182
+
183
+ @check
184
+ def installed_packages():
185
+ freeze = pip(["freeze", "--all"]).decode("utf-8").splitlines()
186
+ packages = {line.split("==")[0]: line for line in freeze}
187
+ packages_set = frozenset(packages)
188
+
189
+ # For each of the following families, expect exactly one package to be
190
+ # installed.
191
+ expect_unique = [
192
+ frozenset(
193
+ [
194
+ "tensorboard",
195
+ "tb-nightly",
196
+ "tensorflow-tensorboard",
197
+ ]
198
+ ),
199
+ frozenset(
200
+ [
201
+ "tensorflow",
202
+ "tensorflow-gpu",
203
+ "tf-nightly",
204
+ "tf-nightly-2.0-preview",
205
+ "tf-nightly-gpu",
206
+ "tf-nightly-gpu-2.0-preview",
207
+ ]
208
+ ),
209
+ frozenset(
210
+ [
211
+ "tensorflow-estimator",
212
+ "tensorflow-estimator-2.0-preview",
213
+ "tf-estimator-nightly",
214
+ ]
215
+ ),
216
+ ]
217
+ salient_extras = frozenset(["tensorboard-data-server"])
218
+
219
+ found_conflict = False
220
+ for family in expect_unique:
221
+ actual = family & packages_set
222
+ for package in actual:
223
+ logging.info("installed: %s", packages[package])
224
+ if len(actual) == 0:
225
+ logging.warning("no installation among: %s", sorted(family))
226
+ elif len(actual) > 1:
227
+ logging.warning("conflicting installations: %s", sorted(actual))
228
+ found_conflict = True
229
+ for package in sorted(salient_extras & packages_set):
230
+ logging.info("installed: %s", packages[package])
231
+
232
+ if found_conflict:
233
+ preamble = reflow(
234
+ """
235
+ Conflicting package installations found. Depending on the order
236
+ of installations and uninstallations, behavior may be undefined.
237
+ Please uninstall ALL versions of TensorFlow and TensorBoard,
238
+ then reinstall ONLY the desired version of TensorFlow, which
239
+ will transitively pull in the proper version of TensorBoard. (If
240
+ you use TensorBoard without TensorFlow, just reinstall the
241
+ appropriate version of TensorBoard directly.)
242
+ """
243
+ )
244
+ packages_to_uninstall = sorted(
245
+ frozenset().union(*expect_unique) & packages_set
246
+ )
247
+ commands = [
248
+ "pip uninstall %s" % " ".join(packages_to_uninstall),
249
+ "pip install tensorflow # or `tensorflow-gpu`, or `tf-nightly`, ...",
250
+ ]
251
+ message = "%s\n\nNamely:\n\n%s" % (
252
+ preamble,
253
+ "\n".join("\t%s" % c for c in commands),
254
+ )
255
+ yield Suggestion("Fix conflicting installations", message)
256
+
257
+ wit_version = packages.get("tensorboard-plugin-wit")
258
+ if wit_version == "tensorboard-plugin-wit==1.6.0.post2":
259
+ # This is only incompatible with TensorBoard prior to 2.2.0, but
260
+ # we just issue a blanket warning so that we don't have to pull
261
+ # in a `pkg_resources` dep to parse the version.
262
+ preamble = reflow(
263
+ """
264
+ Versions of the What-If Tool (`tensorboard-plugin-wit`)
265
+ prior to 1.6.0.post3 are incompatible with some versions of
266
+ TensorBoard. Please upgrade this package to the latest
267
+ version to resolve any startup errors:
268
+ """
269
+ )
270
+ command = "pip install -U tensorboard-plugin-wit"
271
+ message = "%s\n\n\t%s" % (preamble, command)
272
+ yield Suggestion("Upgrade `tensorboard-plugin-wit`", message)
273
+
274
+
275
+ @check
276
+ def tensorboard_python_version():
277
+ from tensorboard import version
278
+
279
+ logging.info("tensorboard.version.VERSION: %r", version.VERSION)
280
+
281
+
282
+ @check
283
+ def tensorflow_python_version():
284
+ import tensorflow as tf
285
+
286
+ logging.info("tensorflow.__version__: %r", tf.__version__)
287
+ logging.info("tensorflow.__git_version__: %r", tf.__git_version__)
288
+
289
+
290
+ @check
291
+ def tensorboard_data_server_version():
292
+ try:
293
+ import tensorboard_data_server
294
+ except ImportError:
295
+ logging.info("no data server installed")
296
+ return
297
+
298
+ path = tensorboard_data_server.server_binary()
299
+ logging.info("data server binary: %r", path)
300
+ if path is None:
301
+ return
302
+
303
+ try:
304
+ subprocess_output = subprocess.run(
305
+ [path, "--version"],
306
+ capture_output=True,
307
+ check=True,
308
+ )
309
+ except subprocess.CalledProcessError as e:
310
+ logging.info("failed to check binary version: %s", e)
311
+ else:
312
+ logging.info(
313
+ "data server binary version: %s", subprocess_output.stdout.strip()
314
+ )
315
+
316
+
317
+ @check
318
+ def tensorboard_binary_path():
319
+ logging.info("which tensorboard: %r", which("tensorboard"))
320
+
321
+
322
+ @check
323
+ def addrinfos():
324
+ sgetattr("has_ipv6", None)
325
+ family = sgetattr("AF_UNSPEC", 0)
326
+ socktype = sgetattr("SOCK_STREAM", 0)
327
+ protocol = 0
328
+ flags_loopback = sgetattr("AI_ADDRCONFIG", 0)
329
+ flags_wildcard = sgetattr("AI_PASSIVE", 0)
330
+
331
+ hints_loopback = (family, socktype, protocol, flags_loopback)
332
+ infos_loopback = socket.getaddrinfo(None, 0, *hints_loopback)
333
+ print("Loopback flags: %r" % (flags_loopback,))
334
+ print("Loopback infos: %r" % (infos_loopback,))
335
+
336
+ hints_wildcard = (family, socktype, protocol, flags_wildcard)
337
+ infos_wildcard = socket.getaddrinfo(None, 0, *hints_wildcard)
338
+ print("Wildcard flags: %r" % (flags_wildcard,))
339
+ print("Wildcard infos: %r" % (infos_wildcard,))
340
+
341
+
342
+ @check
343
+ def readable_fqdn():
344
+ # May raise `UnicodeDecodeError` for non-ASCII hostnames:
345
+ # https://github.com/tensorflow/tensorboard/issues/682
346
+ try:
347
+ logging.info("socket.getfqdn(): %r", socket.getfqdn())
348
+ except UnicodeDecodeError as e:
349
+ try:
350
+ binary_hostname = subprocess.check_output(["hostname"]).strip()
351
+ except subprocess.CalledProcessError:
352
+ binary_hostname = b"<unavailable>"
353
+ is_non_ascii = not all(
354
+ 0x20
355
+ <= (ord(c) if not isinstance(c, int) else c)
356
+ <= 0x7E # Python 2
357
+ for c in binary_hostname
358
+ )
359
+ if is_non_ascii:
360
+ message = reflow(
361
+ """
362
+ Your computer's hostname, %r, contains bytes outside of the
363
+ printable ASCII range. Some versions of Python have trouble
364
+ working with such names (https://bugs.python.org/issue26227).
365
+ Consider changing to a hostname that only contains printable
366
+ ASCII bytes.
367
+ """
368
+ % (binary_hostname,)
369
+ )
370
+ yield Suggestion("Use an ASCII hostname", message)
371
+ else:
372
+ message = reflow(
373
+ """
374
+ Python can't read your computer's hostname, %r. This can occur
375
+ if the hostname contains non-ASCII bytes
376
+ (https://bugs.python.org/issue26227). Consider changing your
377
+ hostname, rebooting your machine, and rerunning this diagnosis
378
+ script to see if the problem is resolved.
379
+ """
380
+ % (binary_hostname,)
381
+ )
382
+ yield Suggestion("Use a simpler hostname", message)
383
+ raise e
384
+
385
+
386
+ @check
387
+ def stat_tensorboardinfo():
388
+ # We don't use `manager._get_info_dir`, because (a) that requires
389
+ # TensorBoard, and (b) that creates the directory if it doesn't exist.
390
+ path = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
391
+ logging.info("directory: %s", path)
392
+ try:
393
+ stat_result = os.stat(path)
394
+ except OSError as e:
395
+ if e.errno == errno.ENOENT:
396
+ # No problem; this is just fine.
397
+ logging.info(".tensorboard-info directory does not exist")
398
+ return
399
+ else:
400
+ raise
401
+ logging.info("os.stat(...): %r", stat_result)
402
+ logging.info("mode: 0o%o", stat_result.st_mode)
403
+ if stat_result.st_mode & 0o777 != 0o777:
404
+ preamble = reflow(
405
+ """
406
+ The ".tensorboard-info" directory was created by an old version
407
+ of TensorBoard, and its permissions are not set correctly; see
408
+ issue #2010. Change that directory to be world-accessible (may
409
+ require superuser privilege):
410
+ """
411
+ )
412
+ # This error should only appear on Unices, so it's okay to use
413
+ # Unix-specific utilities and shell syntax.
414
+ quote = getattr(shlex, "quote", None) or pipes.quote # Python <3.3
415
+ command = "chmod 777 %s" % quote(path)
416
+ message = "%s\n\n\t%s" % (preamble, command)
417
+ yield Suggestion('Fix permissions on "%s"' % path, message)
418
+
419
+
420
+ @check
421
+ def source_trees_without_genfiles():
422
+ roots = list(sys.path)
423
+ if "" not in roots:
424
+ # Catch problems that would occur in a Python interactive shell
425
+ # (where `""` is prepended to `sys.path`) but not when
426
+ # `diagnose_tensorboard.py` is run as a standalone script.
427
+ roots.insert(0, "")
428
+
429
+ def has_tensorboard(root):
430
+ return os.path.isfile(os.path.join(root, "tensorboard", "__init__.py"))
431
+
432
+ def has_genfiles(root):
433
+ sample_genfile = os.path.join("compat", "proto", "summary_pb2.py")
434
+ return os.path.isfile(os.path.join(root, "tensorboard", sample_genfile))
435
+
436
+ def is_bad(root):
437
+ return has_tensorboard(root) and not has_genfiles(root)
438
+
439
+ tensorboard_roots = [root for root in roots if has_tensorboard(root)]
440
+ bad_roots = [root for root in roots if is_bad(root)]
441
+
442
+ logging.info(
443
+ "tensorboard_roots (%d): %r; bad_roots (%d): %r",
444
+ len(tensorboard_roots),
445
+ tensorboard_roots,
446
+ len(bad_roots),
447
+ bad_roots,
448
+ )
449
+
450
+ if bad_roots:
451
+ if bad_roots == [""]:
452
+ message = reflow(
453
+ """
454
+ Your current directory contains a `tensorboard` Python package
455
+ that does not include generated files. This can happen if your
456
+ current directory includes the TensorBoard source tree (e.g.,
457
+ you are in the TensorBoard Git repository). Consider changing
458
+ to a different directory.
459
+ """
460
+ )
461
+ else:
462
+ preamble = reflow(
463
+ """
464
+ Your Python path contains a `tensorboard` package that does
465
+ not include generated files. This can happen if your current
466
+ directory includes the TensorBoard source tree (e.g., you are
467
+ in the TensorBoard Git repository). The following directories
468
+ from your Python path may be problematic:
469
+ """
470
+ )
471
+ roots = []
472
+ realpaths_seen = set()
473
+ for root in bad_roots:
474
+ label = repr(root) if root else "current directory"
475
+ realpath = os.path.realpath(root)
476
+ if realpath in realpaths_seen:
477
+ # virtualenvs on Ubuntu install to both `lib` and `local/lib`;
478
+ # explicitly call out such duplicates to avoid confusion.
479
+ label += " (duplicate underlying directory)"
480
+ realpaths_seen.add(realpath)
481
+ roots.append(label)
482
+ message = "%s\n\n%s" % (
483
+ preamble,
484
+ "\n".join(" - %s" % s for s in roots),
485
+ )
486
+ yield Suggestion(
487
+ "Avoid `tensorboard` packages without genfiles", message
488
+ )
489
+
490
+
491
+ # Prefer to include this check last, as its output is long.
492
+ @check
493
+ def full_pip_freeze():
494
+ logging.info(
495
+ "pip freeze --all:\n%s", pip(["freeze", "--all"]).decode("utf-8")
496
+ )
497
+
498
+
499
+ def set_up_logging():
500
+ # Manually install handlers to prevent TensorFlow from stomping the
501
+ # default configuration if it's imported:
502
+ # https://github.com/tensorflow/tensorflow/issues/28147
503
+ logger = logging.getLogger()
504
+ logger.setLevel(logging.INFO)
505
+ handler = logging.StreamHandler(sys.stdout)
506
+ handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
507
+ logger.addHandler(handler)
508
+
509
+
510
+ def main():
511
+ set_up_logging()
512
+
513
+ print("### Diagnostics")
514
+ print()
515
+
516
+ print("<details>")
517
+ print("<summary>Diagnostics output</summary>")
518
+ print()
519
+
520
+ markdown_code_fence = "``````" # seems likely to be sufficient
521
+ print(markdown_code_fence)
522
+ suggestions = []
523
+ for (i, check) in enumerate(CHECKS):
524
+ if i > 0:
525
+ print()
526
+ print("--- check: %s" % check.__name__)
527
+ try:
528
+ suggestions.extend(check())
529
+ except Exception:
530
+ traceback.print_exc(file=sys.stdout)
531
+ pass
532
+ print(markdown_code_fence)
533
+ print()
534
+ print("</details>")
535
+
536
+ for suggestion in suggestions:
537
+ print()
538
+ print("### Suggestion: %s" % suggestion.headline)
539
+ print()
540
+ print(suggestion.description)
541
+
542
+ print()
543
+ print("### Next steps")
544
+ print()
545
+ if suggestions:
546
+ print(
547
+ reflow(
548
+ """
549
+ Please try each suggestion enumerated above to determine whether
550
+ it solves your problem. If none of these suggestions works,
551
+ please copy ALL of the above output, including the lines
552
+ containing only backticks, into your GitHub issue or comment. Be
553
+ sure to redact any sensitive information.
554
+ """
555
+ )
556
+ )
557
+ else:
558
+ print(
559
+ reflow(
560
+ """
561
+ No action items identified. Please copy ALL of the above output,
562
+ including the lines containing only backticks, into your GitHub
563
+ issue or comment. Be sure to redact any sensitive information.
564
+ """
565
+ )
566
+ )
567
+
568
+
569
+ if __name__ == "__main__":
570
+ main()
dreambooth-for-diffusion/tools/diffusers2ckpt.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+
8
+ import torch
9
+
10
+
11
+ # =================#
12
+ # UNet Conversion #
13
+ # =================#
14
+
15
+ unet_conversion_map = [
16
+ # (stable-diffusion, HF Diffusers)
17
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
18
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
19
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
20
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
21
+ ("input_blocks.0.0.weight", "conv_in.weight"),
22
+ ("input_blocks.0.0.bias", "conv_in.bias"),
23
+ ("out.0.weight", "conv_norm_out.weight"),
24
+ ("out.0.bias", "conv_norm_out.bias"),
25
+ ("out.2.weight", "conv_out.weight"),
26
+ ("out.2.bias", "conv_out.bias"),
27
+ ]
28
+
29
+ unet_conversion_map_resnet = [
30
+ # (stable-diffusion, HF Diffusers)
31
+ ("in_layers.0", "norm1"),
32
+ ("in_layers.2", "conv1"),
33
+ ("out_layers.0", "norm2"),
34
+ ("out_layers.3", "conv2"),
35
+ ("emb_layers.1", "time_emb_proj"),
36
+ ("skip_connection", "conv_shortcut"),
37
+ ]
38
+
39
+ unet_conversion_map_layer = []
40
+ # hardcoded number of downblocks and resnets/attentions...
41
+ # would need smarter logic for other networks.
42
+ for i in range(4):
43
+ # loop over downblocks/upblocks
44
+
45
+ for j in range(2):
46
+ # loop over resnets/attentions for downblocks
47
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
48
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
49
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
50
+
51
+ if i < 3:
52
+ # no attention layers in down_blocks.3
53
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
54
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
55
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
56
+
57
+ for j in range(3):
58
+ # loop over resnets/attentions for upblocks
59
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
60
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
61
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
62
+
63
+ if i > 0:
64
+ # no attention layers in up_blocks.0
65
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
66
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
67
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
68
+
69
+ if i < 3:
70
+ # no downsample in down_blocks.3
71
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
72
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
73
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
74
+
75
+ # no upsample in up_blocks.3
76
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
77
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
78
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
79
+
80
+ hf_mid_atn_prefix = "mid_block.attentions.0."
81
+ sd_mid_atn_prefix = "middle_block.1."
82
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
83
+
84
+ for j in range(2):
85
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
86
+ sd_mid_res_prefix = f"middle_block.{2*j}."
87
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
88
+
89
+
90
+ def convert_unet_state_dict(unet_state_dict):
91
+ # buyer beware: this is a *brittle* function,
92
+ # and correct output requires that all of these pieces interact in
93
+ # the exact order in which I have arranged them.
94
+ mapping = {k: k for k in unet_state_dict.keys()}
95
+ for sd_name, hf_name in unet_conversion_map:
96
+ mapping[hf_name] = sd_name
97
+ for k, v in mapping.items():
98
+ if "resnets" in k:
99
+ for sd_part, hf_part in unet_conversion_map_resnet:
100
+ v = v.replace(hf_part, sd_part)
101
+ mapping[k] = v
102
+ for k, v in mapping.items():
103
+ for sd_part, hf_part in unet_conversion_map_layer:
104
+ v = v.replace(hf_part, sd_part)
105
+ mapping[k] = v
106
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
107
+ return new_state_dict
108
+
109
+
110
+ # ================#
111
+ # VAE Conversion #
112
+ # ================#
113
+
114
+ vae_conversion_map = [
115
+ # (stable-diffusion, HF Diffusers)
116
+ ("nin_shortcut", "conv_shortcut"),
117
+ ("norm_out", "conv_norm_out"),
118
+ ("mid.attn_1.", "mid_block.attentions.0."),
119
+ ]
120
+
121
+ for i in range(4):
122
+ # down_blocks have two resnets
123
+ for j in range(2):
124
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
125
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
126
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
127
+
128
+ if i < 3:
129
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
130
+ sd_downsample_prefix = f"down.{i}.downsample."
131
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
132
+
133
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
134
+ sd_upsample_prefix = f"up.{3-i}.upsample."
135
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
136
+
137
+ # up_blocks have three resnets
138
+ # also, up blocks in hf are numbered in reverse from sd
139
+ for j in range(3):
140
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
141
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
142
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
143
+
144
+ # this part accounts for mid blocks in both the encoder and the decoder
145
+ for i in range(2):
146
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
147
+ sd_mid_res_prefix = f"mid.block_{i+1}."
148
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
149
+
150
+
151
+ vae_conversion_map_attn = [
152
+ # (stable-diffusion, HF Diffusers)
153
+ ("norm.", "group_norm."),
154
+ ("q.", "query."),
155
+ ("k.", "key."),
156
+ ("v.", "value."),
157
+ ("proj_out.", "proj_attn."),
158
+ ]
159
+
160
+
161
+ def reshape_weight_for_sd(w):
162
+ # convert HF linear weights to SD conv2d weights
163
+ return w.reshape(*w.shape, 1, 1)
164
+
165
+
166
+ def convert_vae_state_dict(vae_state_dict):
167
+ mapping = {k: k for k in vae_state_dict.keys()}
168
+ for k, v in mapping.items():
169
+ for sd_part, hf_part in vae_conversion_map:
170
+ v = v.replace(hf_part, sd_part)
171
+ mapping[k] = v
172
+ for k, v in mapping.items():
173
+ if "attentions" in k:
174
+ for sd_part, hf_part in vae_conversion_map_attn:
175
+ v = v.replace(hf_part, sd_part)
176
+ mapping[k] = v
177
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
178
+ weights_to_convert = ["q", "k", "v", "proj_out"]
179
+ for k, v in new_state_dict.items():
180
+ for weight_name in weights_to_convert:
181
+ if f"mid.attn_1.{weight_name}.weight" in k:
182
+ print(f"Reshaping {k} for SD format")
183
+ new_state_dict[k] = reshape_weight_for_sd(v)
184
+ return new_state_dict
185
+
186
+
187
+ # =========================#
188
+ # Text Encoder Conversion #
189
+ # =========================#
190
+ # pretty much a no-op
191
+
192
+
193
+ def convert_text_enc_state_dict(text_enc_dict):
194
+ return text_enc_dict
195
+
196
+
197
+ if __name__ == "__main__":
198
+ parser = argparse.ArgumentParser()
199
+
200
+ parser.add_argument("model_path", default=None, type=str, help="Path to the model to convert.")
201
+ parser.add_argument("checkpoint_path", default=None, type=str, help="Path to the output model.")
202
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
203
+
204
+ args = parser.parse_args()
205
+
206
+ assert args.model_path is not None, "Must provide a model path!"
207
+
208
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
209
+
210
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
211
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
212
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
213
+
214
+ # Convert the UNet model
215
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
216
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
217
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
218
+
219
+ # Convert the VAE model
220
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
221
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
222
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
223
+
224
+ # Convert the text encoder model
225
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
226
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
227
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228
+
229
+ # Put together new checkpoint
230
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231
+ if args.half:
232
+ state_dict = {k: v.half() for k, v in state_dict.items()}
233
+ state_dict = {"state_dict": state_dict}
234
+ torch.save(state_dict, args.checkpoint_path)
dreambooth-for-diffusion/tools/handle_images.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, cv2, argparse
2
+ import numpy as np
3
+
4
+ # 修改透明背景为白色
5
+ def transparence2white(img):
6
+ sp=img.shape
7
+ width=sp[0]
8
+ height=sp[1]
9
+ for yh in range(height):
10
+ for xw in range(width):
11
+ color_d=img[xw,yh]
12
+ if(color_d[3]==0):
13
+ img[xw,yh]=[255,255,255,255]
14
+ return img
15
+
16
+ # 修改透明背景为黑色
17
+ def transparence2black(img):
18
+ sp = img.shape
19
+ width = sp[0]
20
+ height = sp[1]
21
+ for yh in range(height):
22
+ for xw in range(width):
23
+ color_d = img[xw, yh]
24
+ if (color_d[3] == 0):
25
+ img[xw, yh] = [0, 0, 0, 255]
26
+ return img
27
+
28
+ # 中心裁剪
29
+ def center_crop(img, crop_size):
30
+ h, w = img.shape[:2]
31
+ th, tw = crop_size
32
+ i = int(round((h - th) / 2.))
33
+ j = int(round((w - tw) / 2.))
34
+ return img[i:i + th, j:j + tw]
35
+
36
+ if __name__ == '__main__':
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument("origin_image_path", default=None, type=str, help="Path to the images to convert.")
40
+ parser.add_argument("output_image_path", default=None, type=str, help="Path to the output images.")
41
+ parser.add_argument("--width", default=512, type=int, help="Width of the output images.")
42
+ parser.add_argument("--height", default=512, type=int, help="Height of the output images.")
43
+ parser.add_argument("--png", action="store_true", help="convert the transparent background to white/black.")
44
+
45
+
46
+ args = parser.parse_args()
47
+
48
+ path = args.origin_image_path
49
+ save_path = args.output_image_path
50
+ if not os.path.exists(save_path):
51
+ os.makedirs(save_path)
52
+ else:
53
+ print('The folder already exists, please check the path.')
54
+
55
+ # 只读取png、jpg、jpeg、bmp、webp格式
56
+ allow_suffix = ['png', 'jpg', 'jpeg', 'bmp', 'webp']
57
+ image_list = os.listdir(path)
58
+ image_list = [os.path.join(path, image) for image in image_list if image.split('.')[-1] in allow_suffix]
59
+
60
+ for file, i in zip(image_list, range(1, len(image_list)+1)):
61
+ print('Processing image: %s' % file)
62
+ try:
63
+ img = cv2.imread(file, -1)
64
+
65
+ # 对图像进行center crop, 保证图像的长宽比为1:1, crop_size为图像的较短边
66
+ crop_size = min(img.shape[:2])
67
+ img = center_crop(img, (crop_size, crop_size))
68
+
69
+ # 缩放图像到512*512
70
+ width = args.width
71
+ height = args.height
72
+ img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
73
+
74
+ # 如果是透明图,将透明背景转换为白色或者黑色
75
+ if args.png:
76
+ img = transparence2black(img)
77
+
78
+ cv2.imwrite(os.path.join(save_path, str(i).zfill(4) + ".jpg"), img)
79
+ except Exception as e:
80
+ print(e)
81
+ os.remove(path+file) # 删除无效图片
82
+ print("删除无效图片: " + path+file)
dreambooth-for-diffusion/tools/label_images.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from AUTOMATC1111
2
+ # maybe modified by Nyanko Lepsoni
3
+ # modified by crosstyan
4
+ import os.path
5
+ import re
6
+ import tempfile
7
+ import argparse
8
+ import glob
9
+ import zipfile
10
+ import deepdanbooru as dd
11
+ import tensorflow as tf
12
+ import numpy as np
13
+
14
+ from basicsr.utils.download_util import load_file_from_url
15
+ from PIL import Image
16
+ from tqdm import tqdm
17
+
18
+ re_special = re.compile(r"([\\()])")
19
+
20
+ def get_deepbooru_tags_model(model_path: str):
21
+ if not os.path.exists(os.path.join(model_path, "project.json")):
22
+ is_abs = os.path.isabs(model_path)
23
+ if not is_abs:
24
+ model_path = os.path.abspath(model_path)
25
+
26
+ load_file_from_url(
27
+ r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
28
+ model_path,
29
+ )
30
+ with zipfile.ZipFile(
31
+ os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r"
32
+ ) as zip_ref:
33
+ zip_ref.extractall(model_path)
34
+ os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
35
+
36
+ tags = dd.project.load_tags_from_project(model_path)
37
+ model = dd.project.load_model_from_project(model_path, compile_model=False)
38
+ return model, tags
39
+
40
+
41
+ def get_deepbooru_tags_from_model(
42
+ model,
43
+ tags,
44
+ pil_image,
45
+ threshold,
46
+ alpha_sort=False,
47
+ use_spaces=True,
48
+ use_escape=True,
49
+ include_ranks=False,
50
+ ):
51
+ width = model.input_shape[2]
52
+ height = model.input_shape[1]
53
+ image = np.array(pil_image)
54
+ image = tf.image.resize(
55
+ image,
56
+ size=(height, width),
57
+ method=tf.image.ResizeMethod.AREA,
58
+ preserve_aspect_ratio=True,
59
+ )
60
+ image = image.numpy() # EagerTensor to np.array
61
+ image = dd.image.transform_and_pad_image(image, width, height)
62
+ image = image / 255.0
63
+ image_shape = image.shape
64
+ image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
65
+
66
+ y = model.predict(image)[0]
67
+
68
+ result_dict = {}
69
+
70
+ for i, tag in enumerate(tags):
71
+ result_dict[tag] = y[i]
72
+
73
+ unsorted_tags_in_theshold = []
74
+ result_tags_print = []
75
+ for tag in tags:
76
+ if result_dict[tag] >= threshold:
77
+ if tag.startswith("rating:"):
78
+ continue
79
+ unsorted_tags_in_theshold.append((result_dict[tag], tag))
80
+ result_tags_print.append(f"{result_dict[tag]} {tag}")
81
+
82
+ # sort tags
83
+ result_tags_out = []
84
+ sort_ndx = 0
85
+ if alpha_sort:
86
+ sort_ndx = 1
87
+
88
+ # sort by reverse by likelihood and normal for alpha, and format tag text as requested
89
+ unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
90
+ for weight, tag in unsorted_tags_in_theshold:
91
+ tag_outformat = tag
92
+ if use_spaces:
93
+ tag_outformat = tag_outformat.replace("_", " ")
94
+ if use_escape:
95
+ tag_outformat = re.sub(re_special, r"\\\1", tag_outformat)
96
+ if include_ranks:
97
+ tag_outformat = f"({tag_outformat}:{weight:.3f})"
98
+
99
+ result_tags_out.append(tag_outformat)
100
+
101
+ # print("\n".join(sorted(result_tags_print, reverse=True)))
102
+
103
+ return ", ".join(result_tags_out)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument("--path", type=str, default=".")
109
+ parser.add_argument("--threshold", type=int, default=0.75)
110
+ parser.add_argument("--alpha_sort", type=bool, default=False)
111
+ parser.add_argument("--use_spaces", type=bool, default=True)
112
+ parser.add_argument("--use_escape", type=bool, default=True)
113
+ parser.add_argument("--model_path", type=str, default="")
114
+ parser.add_argument("--include_ranks", type=bool, default=False)
115
+
116
+ args = parser.parse_args()
117
+
118
+ global model_path
119
+ model_path:str
120
+ if args.model_path == "":
121
+ script_path = os.path.realpath(__file__)
122
+ default_model_path = os.path.join(os.path.dirname(script_path), "deepdanbooru-models")
123
+ # print("No model path specified, using default model path: {}".format(default_model_path))
124
+ model_path = default_model_path
125
+ else:
126
+ model_path = args.model_path
127
+
128
+ types = ('*.jpg', '*.png', '*.jpeg', '*.gif', '*.webp', '*.bmp')
129
+ files_grabbed = []
130
+ for files in types:
131
+ files_grabbed.extend(glob.glob(os.path.join(args.path, files)))
132
+ # print(glob.glob(args.path + files))
133
+
134
+ model, tags = get_deepbooru_tags_model(model_path)
135
+ for image_path in tqdm(files_grabbed, desc="Processing"):
136
+ image = Image.open(image_path).convert("RGB")
137
+ prompt = get_deepbooru_tags_from_model(
138
+ model,
139
+ tags,
140
+ image,
141
+ args.threshold,
142
+ alpha_sort=args.alpha_sort,
143
+ use_spaces=args.use_spaces,
144
+ use_escape=args.use_escape,
145
+ include_ranks=args.include_ranks,
146
+ )
147
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
148
+ txt_filename = os.path.join(args.path, f"{image_name}.txt")
149
+ # print(f"writing {txt_filename}: {prompt}")
150
+ with open(txt_filename, 'w') as f:
151
+ f.write(prompt)
152
+
dreambooth-for-diffusion/tools/test_cuda.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import torch
2
+ print(torch.cuda.is_available())
dreambooth-for-diffusion/tools/train_dreambooth.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import itertools
4
+ import math
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch.utils.data import Dataset
13
+
14
+ from accelerate import Accelerator
15
+ from accelerate.logging import get_logger
16
+ from accelerate.utils import set_seed
17
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
18
+ from diffusers.optimization import get_scheduler
19
+ from huggingface_hub import HfFolder, Repository, whoami
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+ from tqdm.auto import tqdm
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ def parse_args(input_args=None):
30
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
31
+ parser.add_argument(
32
+ "--pretrained_model_name_or_path",
33
+ type=str,
34
+ default=None,
35
+ required=True,
36
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
37
+ )
38
+ parser.add_argument(
39
+ "--revision",
40
+ type=str,
41
+ default=None,
42
+ required=False,
43
+ help="Revision of pretrained model identifier from huggingface.co/models.",
44
+ )
45
+ parser.add_argument(
46
+ "--tokenizer_name",
47
+ type=str,
48
+ default=None,
49
+ help="Pretrained tokenizer name or path if not the same as model_name",
50
+ )
51
+ parser.add_argument(
52
+ "--instance_data_dir",
53
+ type=str,
54
+ default=None,
55
+ required=True,
56
+ help="A folder containing the training data of instance images.",
57
+ )
58
+ parser.add_argument(
59
+ "--class_data_dir",
60
+ type=str,
61
+ default=None,
62
+ required=False,
63
+ help="A folder containing the training data of class images.",
64
+ )
65
+ parser.add_argument(
66
+ "--instance_prompt",
67
+ type=str,
68
+ default=None,
69
+ help="The prompt with identifier specifying the instance",
70
+ )
71
+ parser.add_argument(
72
+ "--class_prompt",
73
+ type=str,
74
+ default=None,
75
+ help="The prompt to specify images in the same class as provided instance images.",
76
+ )
77
+ parser.add_argument(
78
+ "--with_prior_preservation",
79
+ default=False,
80
+ action="store_true",
81
+ help="Flag to add prior preservation loss.",
82
+ )
83
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
84
+ parser.add_argument(
85
+ "--num_class_images",
86
+ type=int,
87
+ default=100,
88
+ help=(
89
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
90
+ " sampled with class_prompt."
91
+ ),
92
+ )
93
+ parser.add_argument(
94
+ "--output_dir",
95
+ type=str,
96
+ default="text-inversion-model",
97
+ help="The output directory where the model predictions and checkpoints will be written.",
98
+ )
99
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
100
+ parser.add_argument(
101
+ "--resolution",
102
+ type=int,
103
+ default=512,
104
+ help=(
105
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
106
+ " resolution"
107
+ ),
108
+ )
109
+ parser.add_argument(
110
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
111
+ )
112
+ parser.add_argument(
113
+ "--use_filename_as_label", action="store_true", help="Uses the filename as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance"
114
+ )
115
+ parser.add_argument(
116
+ "--use_txt_as_label", action="store_true", help="Uses the filename.txt file's content as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance"
117
+ )
118
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
119
+ parser.add_argument(
120
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
121
+ )
122
+ parser.add_argument(
123
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
124
+ )
125
+ parser.add_argument("--num_train_epochs", type=int, default=1)
126
+ parser.add_argument(
127
+ "--max_train_steps",
128
+ type=int,
129
+ default=None,
130
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
131
+ )
132
+ parser.add_argument(
133
+ "--gradient_accumulation_steps",
134
+ type=int,
135
+ default=1,
136
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
137
+ )
138
+ parser.add_argument(
139
+ "--gradient_checkpointing",
140
+ action="store_true",
141
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
142
+ )
143
+ parser.add_argument(
144
+ "--learning_rate",
145
+ type=float,
146
+ default=5e-6,
147
+ help="Initial learning rate (after the potential warmup period) to use.",
148
+ )
149
+ parser.add_argument(
150
+ "--scale_lr",
151
+ action="store_true",
152
+ default=False,
153
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
154
+ )
155
+ parser.add_argument(
156
+ "--lr_scheduler",
157
+ type=str,
158
+ default="constant",
159
+ help=(
160
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
161
+ ' "constant", "constant_with_warmup"]'
162
+ ),
163
+ )
164
+ parser.add_argument(
165
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
166
+ )
167
+ parser.add_argument(
168
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
169
+ )
170
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
171
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
172
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
173
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
174
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
175
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
176
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
177
+ parser.add_argument(
178
+ "--hub_model_id",
179
+ type=str,
180
+ default=None,
181
+ help="The name of the repository to keep in sync with the local `output_dir`.",
182
+ )
183
+ parser.add_argument(
184
+ "--logging_dir",
185
+ type=str,
186
+ default="logs",
187
+ help=(
188
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
189
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
190
+ ),
191
+ )
192
+ parser.add_argument(
193
+ "--log_with",
194
+ type=str,
195
+ default="tensorboard",
196
+ choices=["tensorboard", "wandb"]
197
+ )
198
+ parser.add_argument(
199
+ "--mixed_precision",
200
+ type=str,
201
+ default="no",
202
+ choices=["no", "fp16", "bf16"],
203
+ help=(
204
+ "Whether to use mixed precision. Choose"
205
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
206
+ "and an Nvidia Ampere GPU."
207
+ ),
208
+ )
209
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
210
+ parser.add_argument("--save_model_every_n_steps", type=int)
211
+ parser.add_argument("--auto_test_model", action="store_true", help="Whether or not to automatically test the model after saving it")
212
+ parser.add_argument("--test_prompt", type=str, default="A photo of a cat", help="The prompt to use for testing the model.")
213
+ parser.add_argument("--test_prompts_file", type=str, default=None, help="The file containing the prompts to use for testing the model.example: test_prompts.txt, each line is a prompt")
214
+ parser.add_argument("--test_negative_prompt", type=str, default="", help="The negative prompt to use for testing the model.")
215
+ parser.add_argument("--test_seed", type=int, default=42, help="The seed to use for testing the model.")
216
+ parser.add_argument("--test_num_per_prompt", type=int, default=1, help="The number of images to generate per prompt.")
217
+
218
+ if input_args is not None:
219
+ args = parser.parse_args(input_args)
220
+ else:
221
+ args = parser.parse_args()
222
+
223
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
224
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
225
+ args.local_rank = env_local_rank
226
+
227
+ if args.instance_data_dir is None:
228
+ raise ValueError("You must specify a train data directory.")
229
+
230
+ if args.with_prior_preservation:
231
+ if args.class_data_dir is None:
232
+ raise ValueError("You must specify a data directory for class images.")
233
+ if args.class_prompt is None:
234
+ raise ValueError("You must specify prompt for class images.")
235
+
236
+ return args
237
+
238
+ # turns a path into a filename without the extension
239
+ def get_filename(path):
240
+ return path.stem
241
+
242
+ def get_label_from_txt(path):
243
+ txt_path = path.with_suffix(".txt") # get the path to the .txt file
244
+ if txt_path.exists():
245
+ with open(txt_path, "r") as f:
246
+ return f.read()
247
+ else:
248
+ return ""
249
+
250
+ class DreamBoothDataset(Dataset):
251
+ """
252
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
253
+ It pre-processes the images and the tokenizes prompts.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ instance_data_root,
259
+ instance_prompt,
260
+ tokenizer,
261
+ class_data_root=None,
262
+ class_prompt=None,
263
+ size=512,
264
+ center_crop=False,
265
+ use_filename_as_label=False,
266
+ use_txt_as_label=False,
267
+ ):
268
+ self.size = size
269
+ self.center_crop = center_crop
270
+ self.tokenizer = tokenizer
271
+
272
+ self.instance_data_root = Path(instance_data_root)
273
+ if not self.instance_data_root.exists():
274
+ raise ValueError("Instance images root doesn't exists.")
275
+
276
+ self.instance_images_path = list(self.instance_data_root.glob("*.jpg")) + list(self.instance_data_root.glob("*.png"))
277
+ self.num_instance_images = len(self.instance_images_path)
278
+ self.instance_prompt = instance_prompt
279
+ self.use_filename_as_label = use_filename_as_label
280
+ self.use_txt_as_label = use_txt_as_label
281
+ self._length = self.num_instance_images
282
+
283
+ if class_data_root is not None:
284
+ self.class_data_root = Path(class_data_root)
285
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
286
+ self.class_images_path = list(self.class_data_root.glob("*.jpg")) + list(self.class_data_root.glob("*.png"))
287
+ self.num_class_images = len(self.class_images_path)
288
+ self._length = max(self.num_class_images, self.num_instance_images)
289
+ self.class_prompt = class_prompt
290
+ else:
291
+ self.class_data_root = None
292
+
293
+ self.image_transforms = transforms.Compose(
294
+ [
295
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
296
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
297
+ transforms.ToTensor(),
298
+ transforms.Normalize([0.5], [0.5]),
299
+ ]
300
+ )
301
+
302
+ def __len__(self):
303
+ return self._length
304
+
305
+ def __getitem__(self, index):
306
+ example = {}
307
+ path = self.instance_images_path[index % self.num_instance_images]
308
+ prompt = get_filename(path) if self.use_filename_as_label else self.instance_prompt
309
+ prompt = get_label_from_txt(path) if self.use_txt_as_label else prompt
310
+
311
+ print("prompt", prompt)
312
+
313
+ instance_image = Image.open(path)
314
+ if not instance_image.mode == "RGB":
315
+ instance_image = instance_image.convert("RGB")
316
+ example["instance_images"] = self.image_transforms(instance_image)
317
+ example["instance_prompt_ids"] = self.tokenizer(
318
+ prompt,
319
+ padding="do_not_pad",
320
+ truncation=True,
321
+ max_length=self.tokenizer.model_max_length,
322
+ ).input_ids
323
+
324
+ if self.class_data_root:
325
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
326
+ if not class_image.mode == "RGB":
327
+ class_image = class_image.convert("RGB")
328
+ example["class_images"] = self.image_transforms(class_image)
329
+ example["class_prompt_ids"] = self.tokenizer(
330
+ self.class_prompt,
331
+ padding="do_not_pad",
332
+ truncation=True,
333
+ max_length=self.tokenizer.model_max_length,
334
+ ).input_ids
335
+
336
+ return example
337
+
338
+
339
+ class PromptDataset(Dataset):
340
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
341
+
342
+ def __init__(self, prompt, num_samples):
343
+ self.prompt = prompt
344
+ self.num_samples = num_samples
345
+
346
+ def __len__(self):
347
+ return self.num_samples
348
+
349
+ def __getitem__(self, index):
350
+ example = {}
351
+ example["prompt"] = self.prompt
352
+ example["index"] = index
353
+ return example
354
+
355
+
356
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
357
+ if token is None:
358
+ token = HfFolder.get_token()
359
+ if organization is None:
360
+ username = whoami(token)["name"]
361
+ return f"{username}/{model_id}"
362
+ else:
363
+ return f"{organization}/{model_id}"
364
+
365
+ def test_model(folder, args):
366
+ if args.test_prompts_file is not None:
367
+ with open(args.test_prompts_file, "r") as f:
368
+ prompts = f.read().splitlines()
369
+ else:
370
+ prompts = [args.test_prompt]
371
+
372
+ test_path = os.path.join(folder, "test")
373
+ if not os.path.exists(test_path):
374
+ os.makedirs(test_path)
375
+
376
+ print("Testing the model...")
377
+ from diffusers import DDIMScheduler
378
+
379
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
380
+ torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
381
+ pipeline = StableDiffusionPipeline.from_pretrained(
382
+ folder,
383
+ torch_dtype=torch_dtype,
384
+ safety_checker=None,
385
+ load_in_8bit=True,
386
+ scheduler = DDIMScheduler(
387
+ beta_start=0.00085,
388
+ beta_end=0.012,
389
+ beta_schedule="scaled_linear",
390
+ clip_sample=False,
391
+ set_alpha_to_one=False,
392
+ ),
393
+ )
394
+ pipeline.set_progress_bar_config(disable=True)
395
+ pipeline.enable_attention_slicing()
396
+ pipeline = pipeline.to(device)
397
+
398
+ torch.manual_seed(args.test_seed)
399
+ with torch.autocast('cuda'):
400
+ for prompt in prompts:
401
+ print(f"Generating test images for prompt: {prompt}")
402
+ test_images = pipeline(
403
+ prompt=prompt,
404
+ width=512,
405
+ height=512,
406
+ negative_prompt=args.test_negative_prompt,
407
+ num_inference_steps=30,
408
+ num_images_per_prompt=args.test_num_per_prompt,
409
+ ).images
410
+
411
+ for index, image in enumerate(test_images):
412
+ image.save(f"{test_path}/{prompt}_{index}.png")
413
+
414
+ del pipeline
415
+ if torch.cuda.is_available():
416
+ torch.cuda.empty_cache()
417
+
418
+ print(f"Test completed.The examples are saved in {test_path}")
419
+
420
+
421
+ def save_model(accelerator, unet, text_encoder, args, step=None):
422
+ unet = accelerator.unwrap_model(unet)
423
+ text_encoder = accelerator.unwrap_model(text_encoder)
424
+
425
+ if step == None:
426
+ folder = args.output_dir
427
+ else:
428
+ folder = args.output_dir + "-Step-" + str(step)
429
+
430
+ print("Saving Model Checkpoint...")
431
+ print("Directory: " + folder)
432
+
433
+ # Create the pipeline using using the trained modules and save it.
434
+ if accelerator.is_main_process:
435
+ pipeline = StableDiffusionPipeline.from_pretrained(
436
+ args.pretrained_model_name_or_path,
437
+ unet=unet,
438
+ text_encoder=text_encoder,
439
+ revision=args.revision,
440
+ )
441
+ pipeline.save_pretrained(folder)
442
+ del pipeline
443
+ if torch.cuda.is_available():
444
+ torch.cuda.empty_cache()
445
+
446
+ if args.auto_test_model:
447
+ print("Testing Model...")
448
+ test_model(folder, args)
449
+
450
+ if args.push_to_hub:
451
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
452
+
453
+
454
+ def main(args):
455
+ logging_dir = Path(args.logging_dir)
456
+
457
+ accelerator = Accelerator(
458
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
459
+ mixed_precision=args.mixed_precision,
460
+ log_with=args.log_with,
461
+ logging_dir=logging_dir,
462
+ )
463
+
464
+
465
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
466
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
467
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
468
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
469
+ raise ValueError(
470
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
471
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
472
+ )
473
+
474
+ if args.seed is not None:
475
+ set_seed(args.seed)
476
+
477
+ if args.with_prior_preservation:
478
+ class_images_dir = Path(args.class_data_dir)
479
+ if not class_images_dir.exists():
480
+ class_images_dir.mkdir(parents=True)
481
+ cur_class_images = len(list(class_images_dir.iterdir()))
482
+
483
+ if cur_class_images < args.num_class_images:
484
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
485
+ pipeline = StableDiffusionPipeline.from_pretrained(
486
+ args.pretrained_model_name_or_path,
487
+ torch_dtype=torch_dtype,
488
+ safety_checker=None,
489
+ revision=args.revision,
490
+ )
491
+ pipeline.set_progress_bar_config(disable=True)
492
+
493
+ num_new_images = args.num_class_images - cur_class_images
494
+ logger.info(f"Number of class images to sample: {num_new_images}.")
495
+
496
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
497
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
498
+
499
+ sample_dataloader = accelerator.prepare(sample_dataloader)
500
+ pipeline.to(accelerator.device)
501
+
502
+ for example in tqdm(
503
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
504
+ ):
505
+ images = pipeline(example["prompt"]).images
506
+
507
+ for i, image in enumerate(images):
508
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
509
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
510
+ image.save(image_filename)
511
+
512
+ del pipeline
513
+ if torch.cuda.is_available():
514
+ torch.cuda.empty_cache()
515
+
516
+ # Handle the repository creation
517
+ if accelerator.is_main_process:
518
+ if args.push_to_hub:
519
+ if args.hub_model_id is None:
520
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
521
+ else:
522
+ repo_name = args.hub_model_id
523
+ repo = Repository(args.output_dir, clone_from=repo_name)
524
+
525
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
526
+ if "step_*" not in gitignore:
527
+ gitignore.write("step_*\n")
528
+ if "epoch_*" not in gitignore:
529
+ gitignore.write("epoch_*\n")
530
+ elif args.output_dir is not None:
531
+ os.makedirs(args.output_dir, exist_ok=True)
532
+
533
+ # Load the tokenizer
534
+ if args.tokenizer_name:
535
+ tokenizer = CLIPTokenizer.from_pretrained(
536
+ args.tokenizer_name,
537
+ revision=args.revision,
538
+ )
539
+ elif args.pretrained_model_name_or_path:
540
+ tokenizer = CLIPTokenizer.from_pretrained(
541
+ args.pretrained_model_name_or_path,
542
+ subfolder="tokenizer",
543
+ revision=args.revision,
544
+ )
545
+
546
+ # Load models and create wrapper for stable diffusion
547
+ text_encoder = CLIPTextModel.from_pretrained(
548
+ args.pretrained_model_name_or_path,
549
+ subfolder="text_encoder",
550
+ revision=args.revision,
551
+ )
552
+ vae = AutoencoderKL.from_pretrained(
553
+ args.pretrained_model_name_or_path,
554
+ subfolder="vae",
555
+ revision=args.revision,
556
+ )
557
+ unet = UNet2DConditionModel.from_pretrained(
558
+ args.pretrained_model_name_or_path,
559
+ subfolder="unet",
560
+ revision=args.revision,
561
+ )
562
+
563
+ vae.requires_grad_(False)
564
+ if not args.train_text_encoder:
565
+ text_encoder.requires_grad_(False)
566
+
567
+ if args.gradient_checkpointing:
568
+ unet.enable_gradient_checkpointing()
569
+ if args.train_text_encoder:
570
+ text_encoder.gradient_checkpointing_enable()
571
+
572
+ if args.scale_lr:
573
+ args.learning_rate = (
574
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
575
+ )
576
+
577
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
578
+ if args.use_8bit_adam:
579
+ try:
580
+ import bitsandbytes as bnb
581
+ except ImportError:
582
+ raise ImportError(
583
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
584
+ )
585
+
586
+ optimizer_class = bnb.optim.AdamW8bit
587
+ else:
588
+ optimizer_class = torch.optim.AdamW
589
+
590
+ params_to_optimize = (
591
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
592
+ )
593
+ optimizer = optimizer_class(
594
+ params_to_optimize,
595
+ lr=args.learning_rate,
596
+ betas=(args.adam_beta1, args.adam_beta2),
597
+ weight_decay=args.adam_weight_decay,
598
+ eps=args.adam_epsilon,
599
+ )
600
+
601
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
602
+
603
+ train_dataset = DreamBoothDataset(
604
+ instance_data_root=args.instance_data_dir,
605
+ instance_prompt=args.instance_prompt,
606
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
607
+ class_prompt=args.class_prompt,
608
+ tokenizer=tokenizer,
609
+ size=args.resolution,
610
+ center_crop=args.center_crop,
611
+ use_filename_as_label=args.use_filename_as_label,
612
+ use_txt_as_label=args.use_txt_as_label,
613
+ )
614
+
615
+ def collate_fn(examples):
616
+ input_ids = [example["instance_prompt_ids"] for example in examples]
617
+ pixel_values = [example["instance_images"] for example in examples]
618
+
619
+ # Concat class and instance examples for prior preservation.
620
+ # We do this to avoid doing two forward passes.
621
+ if args.with_prior_preservation:
622
+ input_ids += [example["class_prompt_ids"] for example in examples]
623
+ pixel_values += [example["class_images"] for example in examples]
624
+
625
+ pixel_values = torch.stack(pixel_values)
626
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
627
+
628
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
629
+
630
+ batch = {
631
+ "input_ids": input_ids,
632
+ "pixel_values": pixel_values,
633
+ }
634
+ return batch
635
+
636
+ train_dataloader = torch.utils.data.DataLoader(
637
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
638
+ )
639
+
640
+ # Scheduler and math around the number of training steps.
641
+ overrode_max_train_steps = False
642
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
643
+ if args.max_train_steps is None:
644
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
645
+ overrode_max_train_steps = True
646
+
647
+ lr_scheduler = get_scheduler(
648
+ args.lr_scheduler,
649
+ optimizer=optimizer,
650
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
651
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
652
+ )
653
+
654
+ if args.train_text_encoder:
655
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
656
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
657
+ )
658
+ else:
659
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
660
+ unet, optimizer, train_dataloader, lr_scheduler
661
+ )
662
+
663
+ weight_dtype = torch.float32
664
+ if args.mixed_precision == "fp16":
665
+ weight_dtype = torch.float16
666
+ elif args.mixed_precision == "bf16":
667
+ weight_dtype = torch.bfloat16
668
+
669
+ # Move text_encode and vae to gpu.
670
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
671
+ # as these models are only used for inference, keeping weights in full precision is not required.
672
+ vae.to(accelerator.device, dtype=weight_dtype)
673
+ if not args.train_text_encoder:
674
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
675
+
676
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
677
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
678
+ if overrode_max_train_steps:
679
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
680
+ # Afterwards we recalculate our number of training epochs
681
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
682
+
683
+ # We need to initialize the trackers we use, and also store our configuration.
684
+ # The trackers initializes automatically on the main process.
685
+ if accelerator.is_main_process:
686
+ accelerator.init_trackers("dreambooth", config=vars(args))
687
+
688
+ # Train!
689
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
690
+
691
+ logger.info("***** Running training *****")
692
+ logger.info(f" Num examples = {len(train_dataset)}")
693
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
694
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
695
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
696
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
697
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
698
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
699
+ # Only show the progress bar once on each machine.
700
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
701
+ progress_bar.set_description("Steps")
702
+ global_step = 0
703
+
704
+ for epoch in range(args.num_train_epochs):
705
+ unet.train()
706
+ if args.train_text_encoder:
707
+ text_encoder.train()
708
+ for step, batch in enumerate(train_dataloader):
709
+ with accelerator.accumulate(unet):
710
+ # Convert images to latent space
711
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
712
+ latents = latents * 0.18215
713
+
714
+ # Sample noise that we'll add to the latents
715
+ noise = torch.randn_like(latents)
716
+ bsz = latents.shape[0]
717
+ # Sample a random timestep for each image
718
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
719
+ timesteps = timesteps.long()
720
+
721
+ # Add noise to the latents according to the noise magnitude at each timestep
722
+ # (this is the forward diffusion process)
723
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
724
+
725
+ # Get the text embedding for conditioning
726
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
727
+
728
+ # Predict the noise residual
729
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
730
+
731
+ if args.with_prior_preservation:
732
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
733
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
734
+ noise, noise_prior = torch.chunk(noise, 2, dim=0)
735
+
736
+ # Compute instance loss
737
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
738
+
739
+ # Compute prior loss
740
+ prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
741
+
742
+ # Add the prior loss to the instance loss.
743
+ loss = loss + args.prior_loss_weight * prior_loss
744
+ else:
745
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
746
+
747
+ accelerator.backward(loss)
748
+ if accelerator.sync_gradients:
749
+ params_to_clip = (
750
+ itertools.chain(unet.parameters(), text_encoder.parameters())
751
+ if args.train_text_encoder
752
+ else unet.parameters()
753
+ )
754
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
755
+ optimizer.step()
756
+ lr_scheduler.step()
757
+ optimizer.zero_grad()
758
+
759
+ # Checks if the accelerator has performed an optimization step behind the scenes
760
+ if accelerator.sync_gradients:
761
+ progress_bar.update(1)
762
+ global_step += 1
763
+
764
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
765
+ progress_bar.set_postfix(**logs)
766
+ accelerator.log(logs, step=global_step)
767
+
768
+ if global_step >= args.max_train_steps:
769
+ break
770
+
771
+
772
+ if args.save_model_every_n_steps != None and (global_step % args.save_model_every_n_steps) == 0:
773
+ save_model(accelerator, unet, text_encoder, args, global_step)
774
+
775
+ accelerator.wait_for_everyone()
776
+
777
+ save_model(accelerator, unet, text_encoder, args, step=None)
778
+
779
+ accelerator.end_training()
780
+
781
+
782
+ if __name__ == "__main__":
783
+ args = parse_args()
784
+ main(args)
dreambooth-for-diffusion/tools/train_textual_inversion.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ import random
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ # import torch
11
+ import oneflow as torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from torch.utils.data import Dataset
15
+
16
+ import PIL
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import set_seed
20
+ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
21
+ from diffusers.optimization import get_scheduler
22
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
23
+ from huggingface_hub import HfFolder, Repository, whoami
24
+ from PIL import Image
25
+ from torchvision import transforms
26
+ from tqdm.auto import tqdm
27
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args):
34
+ logger.info("Saving embeddings")
35
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
36
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
37
+ torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
38
+
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
42
+ parser.add_argument(
43
+ "--save_steps",
44
+ type=int,
45
+ default=500,
46
+ help="Save learned_embeds.bin every X updates steps.",
47
+ )
48
+ parser.add_argument(
49
+ "--pretrained_model_name_or_path",
50
+ type=str,
51
+ default=None,
52
+ required=True,
53
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
54
+ )
55
+ parser.add_argument(
56
+ "--tokenizer_name",
57
+ type=str,
58
+ default=None,
59
+ help="Pretrained tokenizer name or path if not the same as model_name",
60
+ )
61
+ parser.add_argument(
62
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
63
+ )
64
+ parser.add_argument(
65
+ "--placeholder_token",
66
+ type=str,
67
+ default=None,
68
+ required=True,
69
+ help="A token to use as a placeholder for the concept.",
70
+ )
71
+ parser.add_argument(
72
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
73
+ )
74
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
75
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
76
+ parser.add_argument(
77
+ "--output_dir",
78
+ type=str,
79
+ default="text-inversion-model",
80
+ help="The output directory where the model predictions and checkpoints will be written.",
81
+ )
82
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
83
+ parser.add_argument(
84
+ "--resolution",
85
+ type=int,
86
+ default=512,
87
+ help=(
88
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
89
+ " resolution"
90
+ ),
91
+ )
92
+ parser.add_argument(
93
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
94
+ )
95
+ parser.add_argument(
96
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
97
+ )
98
+ parser.add_argument("--num_train_epochs", type=int, default=100)
99
+ parser.add_argument(
100
+ "--max_train_steps",
101
+ type=int,
102
+ default=5000,
103
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
104
+ )
105
+ parser.add_argument(
106
+ "--gradient_accumulation_steps",
107
+ type=int,
108
+ default=1,
109
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
110
+ )
111
+ parser.add_argument(
112
+ "--learning_rate",
113
+ type=float,
114
+ default=1e-4,
115
+ help="Initial learning rate (after the potential warmup period) to use.",
116
+ )
117
+ parser.add_argument(
118
+ "--scale_lr",
119
+ action="store_true",
120
+ default=True,
121
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
122
+ )
123
+ parser.add_argument(
124
+ "--lr_scheduler",
125
+ type=str,
126
+ default="constant",
127
+ help=(
128
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
129
+ ' "constant", "constant_with_warmup"]'
130
+ ),
131
+ )
132
+ parser.add_argument(
133
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
134
+ )
135
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
136
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
137
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
138
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
139
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
140
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
141
+ parser.add_argument(
142
+ "--hub_model_id",
143
+ type=str,
144
+ default=None,
145
+ help="The name of the repository to keep in sync with the local `output_dir`.",
146
+ )
147
+ parser.add_argument(
148
+ "--logging_dir",
149
+ type=str,
150
+ default="logs",
151
+ help=(
152
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
153
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
154
+ ),
155
+ )
156
+ parser.add_argument(
157
+ "--mixed_precision",
158
+ type=str,
159
+ default="no",
160
+ choices=["no", "fp16", "bf16"],
161
+ help=(
162
+ "Whether to use mixed precision. Choose"
163
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
164
+ "and an Nvidia Ampere GPU."
165
+ ),
166
+ )
167
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
168
+
169
+ args = parser.parse_args()
170
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
171
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
172
+ args.local_rank = env_local_rank
173
+
174
+ if args.train_data_dir is None:
175
+ raise ValueError("You must specify a train data directory.")
176
+
177
+ return args
178
+
179
+
180
+ imagenet_templates_small = [
181
+ "a photo of a {}",
182
+ "a rendering of a {}",
183
+ "a cropped photo of the {}",
184
+ "the photo of a {}",
185
+ "a photo of a clean {}",
186
+ "a photo of a dirty {}",
187
+ "a dark photo of the {}",
188
+ "a photo of my {}",
189
+ "a photo of the cool {}",
190
+ "a close-up photo of a {}",
191
+ "a bright photo of the {}",
192
+ "a cropped photo of a {}",
193
+ "a photo of the {}",
194
+ "a good photo of the {}",
195
+ "a photo of one {}",
196
+ "a close-up photo of the {}",
197
+ "a rendition of the {}",
198
+ "a photo of the clean {}",
199
+ "a rendition of a {}",
200
+ "a photo of a nice {}",
201
+ "a good photo of a {}",
202
+ "a photo of the nice {}",
203
+ "a photo of the small {}",
204
+ "a photo of the weird {}",
205
+ "a photo of the large {}",
206
+ "a photo of a cool {}",
207
+ "a photo of a small {}",
208
+ ]
209
+
210
+ imagenet_style_templates_small = [
211
+ "a painting in the style of {}",
212
+ "a rendering in the style of {}",
213
+ "a cropped painting in the style of {}",
214
+ "the painting in the style of {}",
215
+ "a clean painting in the style of {}",
216
+ "a dirty painting in the style of {}",
217
+ "a dark painting in the style of {}",
218
+ "a picture in the style of {}",
219
+ "a cool painting in the style of {}",
220
+ "a close-up painting in the style of {}",
221
+ "a bright painting in the style of {}",
222
+ "a cropped painting in the style of {}",
223
+ "a good painting in the style of {}",
224
+ "a close-up painting in the style of {}",
225
+ "a rendition in the style of {}",
226
+ "a nice painting in the style of {}",
227
+ "a small painting in the style of {}",
228
+ "a weird painting in the style of {}",
229
+ "a large painting in the style of {}",
230
+ ]
231
+
232
+
233
+ class TextualInversionDataset(Dataset):
234
+ def __init__(
235
+ self,
236
+ data_root,
237
+ tokenizer,
238
+ learnable_property="object", # [object, style]
239
+ size=512,
240
+ repeats=100,
241
+ interpolation="bicubic",
242
+ flip_p=0.5,
243
+ set="train",
244
+ placeholder_token="*",
245
+ center_crop=False,
246
+ ):
247
+ self.data_root = data_root
248
+ self.tokenizer = tokenizer
249
+ self.learnable_property = learnable_property
250
+ self.size = size
251
+ self.placeholder_token = placeholder_token
252
+ self.center_crop = center_crop
253
+ self.flip_p = flip_p
254
+
255
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
256
+
257
+ self.num_images = len(self.image_paths)
258
+ self._length = self.num_images
259
+
260
+ if set == "train":
261
+ self._length = self.num_images * repeats
262
+
263
+ self.interpolation = {
264
+ "linear": PIL.Image.LINEAR,
265
+ "bilinear": PIL.Image.BILINEAR,
266
+ "bicubic": PIL.Image.BICUBIC,
267
+ "lanczos": PIL.Image.LANCZOS,
268
+ }[interpolation]
269
+
270
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
271
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
272
+
273
+ def __len__(self):
274
+ return self._length
275
+
276
+ def __getitem__(self, i):
277
+ example = {}
278
+ image = Image.open(self.image_paths[i % self.num_images])
279
+
280
+ if not image.mode == "RGB":
281
+ image = image.convert("RGB")
282
+
283
+ placeholder_string = self.placeholder_token
284
+ text = random.choice(self.templates).format(placeholder_string)
285
+
286
+ example["input_ids"] = self.tokenizer(
287
+ text,
288
+ padding="max_length",
289
+ truncation=True,
290
+ max_length=self.tokenizer.model_max_length,
291
+ return_tensors="pt",
292
+ ).input_ids[0]
293
+
294
+ # default to score-sde preprocessing
295
+ img = np.array(image).astype(np.uint8)
296
+
297
+ if self.center_crop:
298
+ crop = min(img.shape[0], img.shape[1])
299
+ h, w, = (
300
+ img.shape[0],
301
+ img.shape[1],
302
+ )
303
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
304
+
305
+ image = Image.fromarray(img)
306
+ image = image.resize((self.size, self.size), resample=self.interpolation)
307
+
308
+ image = self.flip_transform(image)
309
+ image = np.array(image).astype(np.uint8)
310
+ image = (image / 127.5 - 1.0).astype(np.float32)
311
+
312
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
313
+ return example
314
+
315
+
316
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
317
+ if token is None:
318
+ token = HfFolder.get_token()
319
+ if organization is None:
320
+ username = whoami(token)["name"]
321
+ return f"{username}/{model_id}"
322
+ else:
323
+ return f"{organization}/{model_id}"
324
+
325
+
326
+ def freeze_params(params):
327
+ for param in params:
328
+ param.requires_grad = False
329
+
330
+
331
+ def main():
332
+ args = parse_args()
333
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
334
+
335
+ accelerator = Accelerator(
336
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
337
+ mixed_precision=args.mixed_precision,
338
+ log_with="tensorboard",
339
+ logging_dir=logging_dir,
340
+ )
341
+
342
+ # If passed along, set the training seed now.
343
+ if args.seed is not None:
344
+ set_seed(args.seed)
345
+
346
+ # Handle the repository creation
347
+ if accelerator.is_main_process:
348
+ if args.push_to_hub:
349
+ if args.hub_model_id is None:
350
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
351
+ else:
352
+ repo_name = args.hub_model_id
353
+ repo = Repository(args.output_dir, clone_from=repo_name)
354
+
355
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
356
+ if "step_*" not in gitignore:
357
+ gitignore.write("step_*\n")
358
+ if "epoch_*" not in gitignore:
359
+ gitignore.write("epoch_*\n")
360
+ elif args.output_dir is not None:
361
+ os.makedirs(args.output_dir, exist_ok=True)
362
+
363
+ # Load the tokenizer and add the placeholder token as a additional special token
364
+ if args.tokenizer_name:
365
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
366
+ elif args.pretrained_model_name_or_path:
367
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
368
+
369
+ # Add the placeholder token in tokenizer
370
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
371
+ if num_added_tokens == 0:
372
+ raise ValueError(
373
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
374
+ " `placeholder_token` that is not already in the tokenizer."
375
+ )
376
+
377
+ # Convert the initializer_token, placeholder_token to ids
378
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
379
+ # Check if initializer_token is a single token or a sequence of tokens
380
+ if len(token_ids) > 1:
381
+ raise ValueError("The initializer token must be a single token.")
382
+
383
+ initializer_token_id = token_ids[0]
384
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
385
+
386
+ # Load models and create wrapper for stable diffusion
387
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
388
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
389
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
390
+
391
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
392
+ text_encoder.resize_token_embeddings(len(tokenizer))
393
+
394
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
395
+ token_embeds = text_encoder.get_input_embeddings().weight.data
396
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
397
+
398
+ # Freeze vae and unet
399
+ freeze_params(vae.parameters())
400
+ freeze_params(unet.parameters())
401
+ # Freeze all parameters except for the token embeddings in text encoder
402
+ params_to_freeze = itertools.chain(
403
+ text_encoder.text_model.encoder.parameters(),
404
+ text_encoder.text_model.final_layer_norm.parameters(),
405
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
406
+ )
407
+ freeze_params(params_to_freeze)
408
+
409
+ if args.scale_lr:
410
+ args.learning_rate = (
411
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
412
+ )
413
+
414
+ # Initialize the optimizer
415
+ optimizer = torch.optim.AdamW(
416
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
417
+ lr=args.learning_rate,
418
+ betas=(args.adam_beta1, args.adam_beta2),
419
+ weight_decay=args.adam_weight_decay,
420
+ eps=args.adam_epsilon,
421
+ )
422
+
423
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
424
+
425
+ train_dataset = TextualInversionDataset(
426
+ data_root=args.train_data_dir,
427
+ tokenizer=tokenizer,
428
+ size=args.resolution,
429
+ placeholder_token=args.placeholder_token,
430
+ repeats=args.repeats,
431
+ learnable_property=args.learnable_property,
432
+ center_crop=args.center_crop,
433
+ set="train",
434
+ )
435
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
436
+
437
+ # Scheduler and math around the number of training steps.
438
+ overrode_max_train_steps = False
439
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
440
+ if args.max_train_steps is None:
441
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
442
+ overrode_max_train_steps = True
443
+
444
+ lr_scheduler = get_scheduler(
445
+ args.lr_scheduler,
446
+ optimizer=optimizer,
447
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
448
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
449
+ )
450
+
451
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
452
+ text_encoder, optimizer, train_dataloader, lr_scheduler
453
+ )
454
+
455
+ # Move vae and unet to device
456
+ vae.to(accelerator.device)
457
+ unet.to(accelerator.device)
458
+
459
+ # Keep vae and unet in eval model as we don't train these
460
+ vae.eval()
461
+ unet.eval()
462
+
463
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
464
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
465
+ if overrode_max_train_steps:
466
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
467
+ # Afterwards we recalculate our number of training epochs
468
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
469
+
470
+ # We need to initialize the trackers we use, and also store our configuration.
471
+ # The trackers initializes automatically on the main process.
472
+ if accelerator.is_main_process:
473
+ accelerator.init_trackers("textual_inversion", config=vars(args))
474
+
475
+ # Train!
476
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
477
+
478
+ logger.info("***** Running training *****")
479
+ logger.info(f" Num examples = {len(train_dataset)}")
480
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
481
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
482
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
483
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
484
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
485
+ # Only show the progress bar once on each machine.
486
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
487
+ progress_bar.set_description("Steps")
488
+ global_step = 0
489
+
490
+ for epoch in range(args.num_train_epochs):
491
+ text_encoder.train()
492
+ for step, batch in enumerate(train_dataloader):
493
+ with accelerator.accumulate(text_encoder):
494
+ # Convert images to latent space
495
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
496
+ latents = latents * 0.18215
497
+
498
+ # Sample noise that we'll add to the latents
499
+ noise = torch.randn(latents.shape).to(latents.device)
500
+ bsz = latents.shape[0]
501
+ # Sample a random timestep for each image
502
+ timesteps = torch.randint(
503
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
504
+ ).long()
505
+
506
+ # Add noise to the latents according to the noise magnitude at each timestep
507
+ # (this is the forward diffusion process)
508
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
509
+
510
+ # Get the text embedding for conditioning
511
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
512
+
513
+ # Predict the noise residual
514
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
515
+
516
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
517
+ accelerator.backward(loss)
518
+
519
+ # Zero out the gradients for all token embeddings except the newly added
520
+ # embeddings for the concept, as we only want to optimize the concept embeddings
521
+ # if accelerator.num_processes > 1:
522
+ # grads = text_encoder.module.get_input_embeddings().weight.grad
523
+ # else:
524
+ # grads = text_encoder.get_input_embeddings().weight.grad
525
+ grads = text_encoder.module.get_input_embeddings().weight.grad
526
+ # Get the index for tokens that we want to zero the grads for
527
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
528
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
529
+
530
+ optimizer.step()
531
+ lr_scheduler.step()
532
+ optimizer.zero_grad()
533
+
534
+ # Checks if the accelerator has performed an optimization step behind the scenes
535
+ if accelerator.sync_gradients:
536
+ progress_bar.update(1)
537
+ global_step += 1
538
+ if global_step % args.save_steps == 0:
539
+ save_progress(text_encoder, placeholder_token_id, accelerator, args)
540
+
541
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
542
+ progress_bar.set_postfix(**logs)
543
+ accelerator.log(logs, step=global_step)
544
+
545
+ if global_step >= args.max_train_steps:
546
+ break
547
+
548
+ accelerator.wait_for_everyone()
549
+
550
+ # Create the pipeline using using the trained modules and save it.
551
+ if accelerator.is_main_process:
552
+ pipeline = StableDiffusionPipeline(
553
+ text_encoder=accelerator.unwrap_model(text_encoder),
554
+ vae=vae,
555
+ unet=unet,
556
+ tokenizer=tokenizer,
557
+ scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
558
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
559
+ feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
560
+ )
561
+ pipeline.save_pretrained(args.output_dir)
562
+ # Also save the newly trained embeddings
563
+ save_progress(text_encoder, placeholder_token_id, accelerator, args)
564
+
565
+ if args.push_to_hub:
566
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
567
+
568
+ accelerator.end_training()
569
+
570
+
571
+ if __name__ == "__main__":
572
+ main()
dreambooth-for-diffusion/tools/upload_cos.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ # by ruochen
3
+ # 需要先执行 pip install -U cos-python-sdk-v5
4
+ from qcloud_cos import CosConfig
5
+ from qcloud_cos import CosS3Client
6
+
7
+ secret_id = 'abc123' # 替换为用户的 secretId
8
+ secret_key = 'abc123' # 替换为用户的 secretKey
9
+ region = 'ap-guangzhou' # 替换为用户的 Region
10
+
11
+ config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
12
+ client = CosS3Client(config)
13
+
14
+ response = client.upload_file(
15
+ Bucket='xxx', # 替换为存储桶名称
16
+ LocalFilePath='../ckpt_models/newModel.ckpt', # 本地文件的路径
17
+ Key='newModel.ckpt', # 上传之后的文件名
18
+ )
19
+ print(response['ETag'])
dreambooth-for-diffusion/train_object.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 用于训练特定物体/人物的方法(只需单一标签)
2
+ export MODEL_NAME="./model"
3
+ export INSTANCE_DIR="./datasets/test2"
4
+ export OUTPUT_DIR="./new_model"
5
+ export CLASS_DIR="./datasets/class" # 用于存放模型生成的先验知识的图片文件夹,请勿改动
6
+ export LOG_DIR="/root/tf-logs"
7
+ export TEST_PROMPTS_FILE="./test_prompts_object.txt"
8
+
9
+ rm -rf $CLASS_DIR/* # 如果你要训练与上次不同的特定物体/人物,需要先清空该文件夹。其他时候可以注释掉这一行(前面加#)
10
+ rm -rf $LOG_DIR/*
11
+
12
+ accelerate launch tools/train_dreambooth.py \
13
+ --train_text_encoder \
14
+ --pretrained_model_name_or_path=$MODEL_NAME \
15
+ --mixed_precision="fp16" \
16
+ --instance_data_dir=$INSTANCE_DIR \
17
+ --instance_prompt="a photo of <xxx> dog" \
18
+ --with_prior_preservation --prior_loss_weight=1.0 \
19
+ --class_prompt="a photo of dog" \
20
+ --class_data_dir=$CLASS_DIR \
21
+ --num_class_images=200 \
22
+ --output_dir=$OUTPUT_DIR \
23
+ --logging_dir=$LOG_DIR \
24
+ --center_crop \
25
+ --resolution=512 \
26
+ --train_batch_size=1 \
27
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
28
+ --use_8bit_adam \
29
+ --learning_rate=2e-6 \
30
+ --lr_scheduler="constant" \
31
+ --lr_warmup_steps=0 \
32
+ --auto_test_model \
33
+ --test_prompts_file=$TEST_PROMPTS_FILE \
34
+ --test_seed=123 \
35
+ --test_num_per_prompt=3 \
36
+ --max_train_steps=1000 \
37
+ --save_model_every_n_steps=500
38
+
39
+ # 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大
40
+ # 不然磁盘很容易中间就满了
41
+
42
+ # 以下是核心参数介绍:
43
+ # 主要的几个
44
+ # --train_text_encoder 训练文本编码器
45
+ # --mixed_precision="fp16" 混合精度训练
46
+ # - center_crop
47
+ # 是否裁剪图片,一般如果你的数据集不是正方形的话,需要裁剪
48
+ # - resolution
49
+ # 图片的分辨率,一般是512,使用该参数会自动缩放输入图像
50
+ # 可以配合center_crop使用,达到裁剪成正方形并缩放到512*512的效果
51
+ # - instance_prompt
52
+ # 如果你希望训练的是特定的人物,使用该参数
53
+ # 如 --instance_prompt="a photo of <xxx> girl"
54
+ # - class_prompt
55
+ # 如果你希望训练的是某个特定的类别,使用该参数可能提升一定的训练效果
56
+ # - use_txt_as_label
57
+ # 是否读取与图片同名的txt文件作为label
58
+ # 如果你要训练的是整个大模型的图像风格,那么可以使用该参数
59
+ # 该选项会忽略instance_prompt参数传入的内容
60
+ # - learning_rate
61
+ # 学习率,一般是2e-6,是训练中需要调整的关键参数
62
+ # 太大会导致模型不收敛,太小的话,训练速度会变慢
63
+ # - lr_scheduler, 可选项有constant, linear, cosine, cosine_with_restarts, cosine_with_hard_restarts
64
+ # 学习率调整策略,一般是constant,即不调整,如果你的数据集很大,可以尝试其他的,但是可能会导致模型不收敛,需要调整学习率
65
+ # - lr_warmup_steps,如果你使用的是constant,那么这个参数可以忽略,
66
+ # 如果使用其他的,那么这个参数可以设置为0,即不使用warmup
67
+ # 也可以设置为其他的值,比如1000,即在前1000个step中,学习率从0慢慢增加到learning_rate的值
68
+ # 一般不需要设置, 除非你的数据集很大,训练收敛很慢
69
+ # - max_train_steps
70
+ # 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值
71
+ # - save_model_every_n_steps
72
+ # 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训
73
+
74
+ # --with_prior_preservation,--prior_loss_weight=1.0,分别是使用先验知识保留和先验损失权重
75
+ # 如果你的数据样本比较少,那么可以使用这两个参数,可以提升训练效果,还可以防止过拟合(即生成的图片与训练的图片相似度过高)
76
+
77
+ # --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt
78
+ # 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数
79
+ # 测试的样本图片会保存在模型输出目录下的test文件夹中
dreambooth-for-diffusion/train_style.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 主要用于训练风格、作画能力(需要每张图片都有对应的标签描述)
2
+ export MODEL_NAME="./model"
3
+ export INSTANCE_DIR="./datasets/test2"
4
+ export OUTPUT_DIR="./new_model"
5
+ export LOG_DIR="/root/tf-logs"
6
+ export TEST_PROMPTS_FILE="./test_prompts_style.txt"
7
+
8
+ rm -rf $LOG_DIR/*
9
+
10
+ accelerate launch tools/train_dreambooth.py \
11
+ --pretrained_model_name_or_path=$MODEL_NAME \
12
+ --mixed_precision="fp16" \
13
+ --instance_data_dir=$INSTANCE_DIR \
14
+ --use_txt_as_label \
15
+ --output_dir=$OUTPUT_DIR \
16
+ --logging_dir=$LOG_DIR \
17
+ --center_crop \
18
+ --resolution=768 \
19
+ --train_batch_size=1 \
20
+ --use_8bit_adam \
21
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
22
+ --learning_rate=2e-6 \
23
+ --lr_scheduler="constant" \
24
+ --lr_warmup_steps=0 \
25
+ --max_train_steps=1000 \
26
+ --save_model_every_n_steps=500 \
27
+ --auto_test_model \
28
+ --test_prompts_file=$TEST_PROMPTS_FILE \
29
+ --test_seed=123 \
30
+ --test_num_per_prompt=3
31
+
32
+ # 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大,不然磁盘容易中间就满了
33
+
34
+ # 以下是核心参数介绍:
35
+ # 主要的几个
36
+ # --train_text_encoder 训练文本编码器
37
+ # --mixed_precision="fp16" 混合精度训练
38
+ # - center_crop
39
+ # 是否裁剪图片,一般如果你的数据集不是正方形的话,需要裁剪
40
+ # - resolution
41
+ # 图片的分辨率,一般是512,使用该参数会自动缩放输入图像
42
+ # 可以配合center_crop使用,达到裁剪成正方形并缩放到512*512的效果
43
+ # - instance_prompt
44
+ # 如果你希望训练的是特定的人物,使用该参数
45
+ # 如 --instance_prompt="a photo of <xxx> girl"
46
+ # - use_txt_as_label
47
+ # 是否读取与图片同名的txt文件作为label
48
+ # 如果你要训练的是整个大模型的图像风格,那么可以使用该参数
49
+ # 该选项会忽略instance_prompt参数传入的内容
50
+ # - learning_rate
51
+ # 学习率,一般是2e-6,是训练中需要调整的关键参数
52
+ # 太大会导致模型不收敛,太小的话,训练速度会变慢
53
+ # - max_train_steps
54
+ # 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值
55
+ # - save_model_every_n_steps
56
+ # 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训
57
+
58
+ # --train_text_encoder # 除了图像生成器,也训练文本编码器
59
+
60
+ # --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt
61
+ # 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数
62
+ # 测试的样本图片会保存在模型输出目录下的test文件夹中
dreambooth-for-diffusion/train_textual_inversion.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 这是另一种finetune模型的方法,名为textual inversion,效果一般,仅内置一份供参考。
2
+ # 提示:该方法训练出的概念编码只能在diffusers使用。暂时不支持在diffusers之外的推理框架使用。(如webui)
3
+ #!/sbin/bash
4
+ export LOG_DIR="/root/tf-logs"
5
+
6
+ accelerate launch ./tools/train_textual_inversion.py \
7
+ --pretrained_model_name_or_path="./model/" \
8
+ --train_data_dir="./datasets/test" \
9
+ --learnable_property="style" \
10
+ --placeholder_token="<xxx-girl>" --initializer_token="girl" \
11
+ --resolution=512 \
12
+ --train_batch_size=1 \
13
+ --gradient_accumulation_steps=4 \
14
+ --learning_rate=5.0e-04 --scale_lr \
15
+ --lr_scheduler="constant" \
16
+ --lr_warmup_steps=0 \
17
+ --save_steps=200 \
18
+ --max_train_steps=3000 \
19
+ --mixed_precision="fp16" \
20
+ --logging_dir=$LOG_DIR \
21
+ --output_dir="output_model"
22
+
23
+ # --learnable_property为style时训练特定风格,为object时训练特定物体/人物。
24
+ # --placeholder_token为训练时的占位符,--initializer_token为训练时的初始化词。
25
+ # --resolution为训练时的分辨率,--train_batch_size为训练时的batch size,--gradient_accumulation_steps为梯度累积步数。
26
+ # --learning_rate为训练时的学习率,--scale_lr为是否对学习率进行缩放,--lr_scheduler为学习率调度器,--lr_warmup_steps为学习率预热步数。
27
+ # --save_steps为保存模型的步数,--max_train_steps为最大训练步数,--mixed_precision为混合精度训练模式。
28
+ # --logging_dir为日志保存路径,--output_dir为模型保存路径。
29
+ # --pretrained_model_name_or_path为预训练模型路径,--train_data_dir为训练数据路径,必须为文件夹,文件夹内为处理后的图片。
dreambooth-for-diffusion/运行.ipynb ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a0b34c19-4215-46f9-9def-65e73629665c",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Dreambooth Stable Diffusion 集成化环境训练\n",
9
+ "如果你是在autodl上的机器可以直接使用封装好的镜像创建实例,开箱即用 \n",
10
+ "如果是本地或者其他服务器上也可以使用,需要手动安装一些pip包\n",
11
+ "\n",
12
+ "## 注意\n",
13
+ "本项目仅供用于学习、测试人工智能技术使用 \n",
14
+ "请勿用于训练生成不良或侵权图片内容\n",
15
+ "\n",
16
+ "## 关于项目\n",
17
+ "在autodl封装的镜像名称为:dreambooth-for-diffusion \n",
18
+ "可在创建实例时直接选择公开的算法镜像使用。 \n",
19
+ "在autodl内蒙A区A5000的机器上封装,如遇到问题且无法自行解决的朋友请使用同一环境。 \n",
20
+ "白菜写教程时做了尽可能多的测试,但仍然无法确保每一个环节都完全覆盖 \n",
21
+ "如有小错误可尝试手动解决,或者访问git项目地址查看最新的README \n",
22
+ "项目地址:https://github.com/CrazyBoyM/dreambooth-for-diffusion\n",
23
+ "\n",
24
+ "## #强烈建议\n",
25
+ "1.用vscode的ssh功能远程连接到本服务器,训练体验更好,autodl自带的notebook也不错,有文件上传、下载功能。 \n",
26
+ "(vscode连接autodl教程:https://www.autodl.com/docs/vscode/ ) \n",
27
+ "### 2.(重要)把train文件夹整个移动到/root/autodl-tmp/路径下进行训练(数据盘),避免系统盘空间满\n",
28
+ "有的机器数据盘也很小,需要自行关注开合适的机器或进行扩容\n",
29
+ "\n",
30
+ "如果遇到问题可到b站主页找该教程对应训练演示的视频:https://space.bilibili.com/291593914\n",
31
+ "(因为现在写时视频还没做 \n",
32
+ "\n",
33
+ "## 服务器的数据迁移\n",
34
+ "经常关机后再开机发现机器资源被占用了,这时候你只能另外开一台机器了 \n",
35
+ "但是对于已经关机的机器在菜单上有个功能是“跨实例拷贝数据”, \n",
36
+ "可以很方便地同步/root/autodl-tmp文件夹下的内容到其他已开机的机器(所以推荐工作文件都放这) \n",
37
+ "(注意,只适用于同一区域的机器之间)\n",
38
+ "数据迁移教程:https://www.autodl.com/docs/migrate_instance/"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "f091e609-bacc-469a-b6cf-bffe331a8944",
44
+ "metadata": {},
45
+ "source": [
46
+ "### 本文件为notebook在线运行版\n",
47
+ "具体详细的教程和参数说明请在根目录下教程.md 文件中查看。 \n",
48
+ "在notebook中执行linux命令,需要前面加个!(感叹号) \n",
49
+ "代码块前如果有个[*],表示正在运行该步骤,并不是卡住了\n"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "id": "3555d8bd-fb3f-4303-8915-ec6fefcc780c",
55
+ "metadata": {},
56
+ "source": [
57
+ "# 笔者前言\n",
58
+ "\n",
59
+ "linux压缩一个文件夹为单个文件包的命令:\n",
60
+ "```\n",
61
+ "!zip xx.zip -r ./xxx\n",
62
+ "```\n",
63
+ "解压一个包到文件夹:\n",
64
+ "```\n",
65
+ "!unzip xx.zip -d xxx\n",
66
+ "```\n",
67
+ "或许你在上传、下载数据集时会用到。\n",
68
+ "\n",
69
+ "其他linux基础命令:https://www.autodl.com/docs/linux/\n",
70
+ "\n",
71
+ "关于文件上传下载的提速可查看官网文档推荐的几种方式:https://www.autodl.com/docs/scp/"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "id": "34cf6ed1-f2b1-4abd-baf6-565ac00567ab",
77
+ "metadata": {},
78
+ "source": [
79
+ "### 首先,进入工作文件夹(记得先把dreambooth-for-diffusion文件夹移动到autodl-tmp目录下)"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "id": "a1249a32-ce15-4b1b-8068-8149ad40588b",
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "%cd /root/autodl-tmp/dreambooth-for-diffusion"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "id": "ccba0e31-f01d-43e5-b474-7d88e0b09bd8",
95
+ "metadata": {},
96
+ "source": [
97
+ "# 准备数据集\n",
98
+ "该部分请参考教程.md文件中的详细内容自行上传并处理你的数据集 \n",
99
+ "dreambooth-for-diffusion/datasets/test中为16张仅供于学习测试的样本数据,便于你了解以下代码的用处 \n"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "id": "470113f6-795a-41f8-a6b3-09f854a4cbc3",
105
+ "metadata": {},
106
+ "source": [
107
+ "## 一键裁剪\n",
108
+ "### 图像批量center crop并处理大小、格式和背景\n",
109
+ "./datasets/test是原始图片数据文件夹,请上传你的图片数据并进行更换 \n",
110
+ "width和height请设置为8的整倍数,并记得修改训练脚本中的参数 \n",
111
+ "(在显存低于20G的设备上请修改使用小于768的分辨率数据去训练,比如512) \n",
112
+ "如果是对透明底的png图处理成纯色底可以加--png参数,具体可以看对应的代码文件"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "10d2bb3d-9002-4d3b-a4be-f5f74a008b9c",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "!python tools/handle_images.py ./datasets/test ./datasets/test2 --width=768 --height=768"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "34efda73-9cb4-4a54-8aac-489ded452a50",
128
+ "metadata": {},
129
+ "source": [
130
+ "## 一键打标签\n",
131
+ "### 图像批量自动标注\n",
132
+ "使用deepdanbooru生成tags标注文件。(仅针对纯二次元类图片效果较好,其他风格请手动标注) \n",
133
+ "./datasets/test2中是需要打标注的图片数据,请按需更换为自己的路径 "
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "8863a53a-4650-4f27-863e-2a70e8b89e11",
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "# 该步根据需要标注文件数量不同,需要运行一段时间(测试6000张图片需要10分钟)\n",
144
+ "!python tools/label_images.py --path=./datasets/test2 "
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "def72b19-9851-400f-8672-48023b3e95fb",
150
+ "metadata": {},
151
+ "source": [
152
+ "## 转换ckpt检查点文件为diffusers官方权重\n",
153
+ "输出的文件在dreambooth-for-diffusion/model下 \n",
154
+ "./ckpt_models/sd_1-5.ckpt需要更换为你自己的权重文件路径 "
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "id": "0582e3c4-e899-4a3b-a468-d49e7775efc6",
160
+ "metadata": {},
161
+ "source": [
162
+ "如需转换写实风格模型:"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": null,
168
+ "id": "05aaf7fd-315f-45b4-9b22-70a46a18424f",
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "# 该步需要运行大约一分钟 \n",
173
+ "!python tools/ckpt2diffusers.py \\\n",
174
+ " --checkpoint_path=./ckpt_models/sd_1-5.ckpt \\\n",
175
+ " --dump_path=./model \\\n",
176
+ " --original_config_file=./ckpt_models/model.yaml \\\n",
177
+ " --scheduler_type=\"ddim\""
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "id": "48c7893a-22db-4ea2-95dc-93fdbd6b5c4b",
183
+ "metadata": {},
184
+ "source": [
185
+ "如需转换二次元风格模型:"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "id": "f7afb70d-7af4-4bd1-804e-40927f1257e2",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "# 该步需要运行大约一分钟 \n",
196
+ "!python tools/ckpt2diffusers.py \\\n",
197
+ " --checkpoint_path=./ckpt_models/nd_lastest.ckpt \\\n",
198
+ " --dump_path=./model \\\n",
199
+ " --vae_path=./ckpt_models/animevae.pt \\\n",
200
+ " --original_config_file=./ckpt_models/model.yaml \\\n",
201
+ " --scheduler_type=\"ddim\""
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "markdown",
206
+ "id": "a1edb9be-1de3-488e-baa3-8f3ab6b8f269",
207
+ "metadata": {},
208
+ "source": [
209
+ "对于需要转换某个特殊模型(7g)并遇到报错的同学,ckpt_models里的nd_lastest.ckpt就是你需要的文件。 \n",
210
+ "如果希望手动转换,我在./tools下放了一份ckpt_prune.py可以参考。"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "markdown",
215
+ "id": "3a3470d3-1691-438c-b8d7-df2cbf885614",
216
+ "metadata": {},
217
+ "source": [
218
+ "# 训练Unet和text encoder\n",
219
+ "以下训练脚本会自动帮你启动tensorboard日志监控进程,入口可参考: https://www.autodl.com/docs/tensorboard/ \n",
220
+ "使用tensorboard面板可以帮助分析loss在不同step的总体下降情况 \n",
221
+ "如果你嫌输出太长,可以在以下命令每一行后加一句 &> log.txt, 会把输出都扔到这个文件中 \n",
222
+ "```\n",
223
+ "!sh train_style.sh &> log.txt\n",
224
+ "```\n",
225
+ "本代码包环境已在A5000、3090测试通过,如果你在某些机器上运行遇到问题可以尝试卸载编译的xformers\n",
226
+ "```\n",
227
+ "!pip uninstall xformers\n",
228
+ "```"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "markdown",
233
+ "id": "98645b45-4cf1-49f8-b2bb-42a5a8771164",
234
+ "metadata": {},
235
+ "source": [
236
+ "### 如果需要训练特定人、事物: \n",
237
+ "(推荐准备3~5张风格统一、特定对象的图片) \n",
238
+ "请打开train_object.sh具体修改里面的参数"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "8b6833e3-8d3f-438a-b45d-0711e9724496",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "# 大约十分钟后才会在tensorboard有日志(因为前十分钟在生成同类别伪图)\n",
249
+ "!sh train_object.sh "
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "id": "594a0352-8bb5-45de-bb19-0028b671569b",
255
+ "metadata": {},
256
+ "source": [
257
+ "### 如果要训练画风: \n",
258
+ "(推荐准备3000+张图片,包含尽可能的多样性,数据决定训练出的模型质量) \n",
259
+ "请打开train_object具体修改里面的参数 \n",
260
+ "实测速度1000步大概8分钟 "
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "id": "442cff33-d264-4096-97e2-0c578229c814",
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "# 正常训练立刻就可以在tensorboard看到日志\n",
271
+ "!sh train_style.sh "
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "id": "3aa1d170-e2d1-4f72-8b0c-b6bfd5f0c318",
277
+ "metadata": {},
278
+ "source": [
279
+ "后台训练法请参考教程.md中的内容"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "markdown",
284
+ "id": "9aece8a8-c9ec-41eb-b6ad-c6c88b6203e1",
285
+ "metadata": {},
286
+ "source": [
287
+ "省钱训练法(训练成功后自动关机,适合步数很大且夜晚训练的场景)"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "id": "52fff58d-1a88-4a59-a961-b13b52812425",
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "!sh back_train.sh"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "markdown",
302
+ "id": "17557280-3a5a-4bde-95c3-f20e1ccffa4d",
303
+ "metadata": {},
304
+ "source": [
305
+ "## 拓展:训练Textual inversion"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "id": "36a543ee-56f8-405a-baaa-b784d96c7d40",
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "!sh train_textual_inversion.sh"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "markdown",
320
+ "id": "f467b2e9-9170-4f19-aea9-7ce0b4e5444e",
321
+ "metadata": {},
322
+ "source": [
323
+ "### 测试训练效果\n",
324
+ "打开dreambooth-for-diffusion/test_model.py文件修改其中的model_path和prompt,然后执行以下测试 \n",
325
+ "会生成一张图片 在左侧test-1、2、3.png"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "b462f33b-48e2-4092-b3de-463025e4ff9e",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "# 大约5~10s \n",
336
+ "!python test_model.py"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "id": "47abb5fd-2f84-4344-a9cf-539b52515971",
342
+ "metadata": {},
343
+ "source": [
344
+ "### 转换diffusers官方权重为ckpt检查点文件\n",
345
+ "输出的文件在dreambooth-for-diffusion/ckpt_models/中,名为newModel.ckpt"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "markdown",
350
+ "id": "5bfe9643-ef1d-42a3-a427-c4904f3a8631",
351
+ "metadata": {},
352
+ "source": [
353
+ "原始保存:"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "id": "2ad27225-10ed-4b3c-9978-bd909404949c",
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "!python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel.ckpt "
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "markdown",
368
+ "id": "b08a5e37-97d3-4c1e-9ba7-e331af23437f",
369
+ "metadata": {},
370
+ "source": [
371
+ "以下代码添加--half 保存float16半精度,权重文件大小会减半(约2g),效果基本一致"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "id": "cba99145-6aab-41b6-a5b7-6e0c4fd96641",
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": [
381
+ "!python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel_half.ckpt --half"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "markdown",
386
+ "id": "d1f98d06-27f3-45b6-85df-c57cda5d6166",
387
+ "metadata": {},
388
+ "source": [
389
+ "下载ckpt文件,去玩吧~"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "markdown",
394
+ "id": "b13f0627-1d0a-4ae2-ab9c-90a605ee4a0e",
395
+ "metadata": {},
396
+ "source": [
397
+ "有问题可以进XDiffusion QQ Group:455521885 "
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "id": "b939a03f-23c9-410d-89be-02e154eeb6b4",
403
+ "metadata": {},
404
+ "source": [
405
+ "### 记得定期清理不需要的中间权重和文件,不然容易导致空间满\n",
406
+ "大部分问题已在教程.md中详细记录,也包含其他非autodl机器手动部署该训练一体化封装代码包的步骤"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": null,
412
+ "id": "3236d62e-fa3d-4826-874e-431f208cfb6d",
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": [
416
+ "# 清理文件的示例\n",
417
+ "!rm -rf ./model* # 删除当前目录model文件/文件夹\n",
418
+ "!rm -rf ./new_* # 删除当前目录所有new_开头的模型文件夹\n",
419
+ "# !rm -rf ./datasets/test2 #删除datasets中的test2数据集 "
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": null,
425
+ "id": "224924ae-2d6d-47d0-aa36-0989a6572bd2",
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": []
429
+ }
430
+ ],
431
+ "metadata": {
432
+ "kernelspec": {
433
+ "display_name": "Python 3 (ipykernel)",
434
+ "language": "python",
435
+ "name": "python3"
436
+ },
437
+ "language_info": {
438
+ "codemirror_mode": {
439
+ "name": "ipython",
440
+ "version": 3
441
+ },
442
+ "file_extension": ".py",
443
+ "mimetype": "text/x-python",
444
+ "name": "python",
445
+ "nbconvert_exporter": "python",
446
+ "pygments_lexer": "ipython3",
447
+ "version": "3.8.10"
448
+ }
449
+ },
450
+ "nbformat": 4,
451
+ "nbformat_minor": 5
452
+ }