本文主要介绍了如何对卫星影像数据进行处理、构建huggingface diffuser格式的数据集并使用diffuser训练controlnet,从而实现使用全球语义图控制生成全球卫星图。
训练数据使用了MODIS数据集的xxxxx卫星图与xxxxx语义图。进行了一定的预处理后裁剪成512x512的小图,在SD1.5版本的controlnet上进行训练。(详细步骤还得等黄子扬同学回忆一下看看能不能加上来)
从已经获取一系列source(控制图像)和target(目标图像)样本对开始记录。 需要的数据集结构如下:
txtdataset_name: ├─dataset_name.py ├─prompt.jsonl ├─source │ ├─0_s.png │ ├─1_s.png │ ├─2_s.png │ ├─3_s.png │ └─... └─target ├─0_t.png ├─1_t.png ├─2_t.png ├─3_t.png └─...
其中的prompt.jsonl的具体内容如下,每条记录对应一个样本对。
{"source": "数据集文件夹内的控制图像相对路径", "target": "数据集文件夹内的目标图像相对路径", "prompt": "{所需要生成的目标图像的描述提示}"} {"source": "source/0_s.png", "target": "target/_t.png", "prompt": "a remote sensing image of earth. Containing Closed Shrublands,Open Shrublands,Savannas,Grasslands,Permanent Wetlands,Croplands,Urban and Built-up Lands,Barren,Water Bodies,"}
最后一个文件是和dataset文件夹同名的dataset_name.py,其主要作用是重写controlnet内的数据读取接口。
python--dataset_name.py:
import pandas as pd
import datasets
import os
_VERSION = datasets.Version("0.0.1")
_DESCRIPTION = "TODO"
_HOMEPAGE = "TODO"
_LICENSE = "TODO"
_CITATION = "TODO"
_FEATURES = datasets.Features(
{
"target": datasets.Image(),#!!首先需要改的是这三行,用来指定你的prompt中三个列分别对应的类型
"source": datasets.Image(),#!!首先需要改的是这三行,用来指定你的prompt中三个列分别对应的类型
"prompt": datasets.Value("string"),#!!首先需要改的是这三行,用来指定你的prompt中三个列分别对应的类型
},
)
_DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
class MyData(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [_DEFAULT_CONFIG]
DEFAULT_CONFIG_NAME = "default"
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=_FEATURES,
supervised_keys=None,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
#!其次是下面三个路径,分别代表:
# metadata_path -> prompt.jsonl的路径
# images_dir -> target目标图像的路径,因为我们在prompt里已经记录了target/0.png, 所以这里只需要写上一级的dataset_name文件夹的路径
# conditioning_images_dir -> source控制图像的路径,因为我们在prompt里已经记录了source/0.png, 所以这里只需要写上一级的dataset_name文件夹的路径
metadata_path = "/home/disk/hzy/repos/diffusers/examples/controlnet/gpcvrs_small/small.jsonl"
images_dir = "/home/disk/hzy/repos/diffusers/examples/controlnet/gpcvrs_small"
conditioning_images_dir = "/home/disk/hzy/repos/diffusers/examples/controlnet/gpcvrs_small"
dl_manager.download_and_extract
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"metadata_path": metadata_path,
"images_dir": images_dir,
"conditioning_images_dir": conditioning_images_dir,
},
),
]
def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
metadata = pd.read_json(metadata_path, lines=True)
for _, row in metadata.iterrows():
text = row["prompt"]#!这里改为我们的提示词列名称,即"prompt"
image_path = row["target"]#!这里改为我们的目标图像列名称,即"target"
image_path = os.path.join(images_dir, image_path)
image = open(image_path, "rb").read()
conditioning_image_path = row["source"]#!这里改为我们的控制图像列名称,即"source"
conditioning_image_path = os.path.join(
conditioning_images_dir, row["source"]#!这里改为我们的控制图像列名称,即"source"
)
conditioning_image = open(conditioning_image_path, "rb").read()
#最后迭代返回的是两个元素:
# 1、该样本对目标文件路径(此处即为row["target"])
# 2、一个包含下列元素的字典,按照描述修改即可,返回的字典的键名对应开头(line45~47)指定的三个名字:
# "prompt":prompt_str,
# "target":{
# "path": image_path,(目标图像路径)
# "bytes": image,(目标图像路径)
# }
# "source":{
# "path": conditioning_image_path,(控制图像文件路径)
# "bytes": conditioning_image,(控制图像本身)
# }
yield row["target"], {
"prompt": text,
"target": {
"path": image_path,
"bytes": image,
},
"source": {
"path": conditioning_image_path,
"bytes": conditioning_image,
},
}
在完成dataset_name.py的修改后,只需要修改huggingface 给出的controlnet的训练脚本即可:
bash#训练脚本train.sh修改示例:
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="{输出的模型路径}"
accelerate launch --multi_gpu train_controlnet.py \
--mixed_precision="fp16" \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--train_data_dir="{你的dataset文件路径}" \
--resolution=512 \
--learning_rate=1e-5 \
--train_batch_size=4 \
--num_train_epochs=6 \
--caption_column="prompt" \ #这里修改caption_column为提示词列名称prompt
--image_column="target" \ #这里修改image_column为目标图像列名称target
--conditioning_image_column="source" \ #这里修改conditioning_image_column为目标图像列名称source
在diffuser的example/controlnet目录下运行train.sh脚本可以直接开始训练。
本文作者:insomnia
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!