编辑
2024-04-14
AIGC
00
请注意,本文编写于 282 天前,最后修改于 282 天前,其中某些信息可能已经过时。

目录

摘要
数据介绍
数据集的构建
训练

摘要

本文主要介绍了如何对卫星影像数据进行处理、构建huggingface diffuser格式的数据集并使用diffuser训练controlnet,从而实现使用全球语义图控制生成全球卫星图。

数据介绍

训练数据使用了MODIS数据集的xxxxx卫星图与xxxxx语义图。进行了一定的预处理后裁剪成512x512的小图,在SD1.5版本的controlnet上进行训练。(详细步骤还得等黄子扬同学回忆一下看看能不能加上来)

数据集的构建

从已经获取一系列source(控制图像)和target(目标图像)样本对开始记录。 需要的数据集结构如下:

txt
dataset_name: ├─dataset_name.py ├─prompt.jsonl ├─source │ ├─0_s.png │ ├─1_s.png │ ├─2_s.png │ ├─3_s.png │ └─... └─target ├─0_t.png ├─1_t.png ├─2_t.png ├─3_t.png └─...

其中的prompt.jsonl的具体内容如下,每条记录对应一个样本对。

{"source": "数据集文件夹内的控制图像相对路径", "target": "数据集文件夹内的目标图像相对路径", "prompt": "{所需要生成的目标图像的描述提示}"} {"source": "source/0_s.png", "target": "target/_t.png", "prompt": "a remote sensing image of earth. Containing Closed Shrublands,Open Shrublands,Savannas,Grasslands,Permanent Wetlands,Croplands,Urban and Built-up Lands,Barren,Water Bodies,"}

最后一个文件是和dataset文件夹同名的dataset_name.py,其主要作用是重写controlnet内的数据读取接口。

python
--dataset_name.py: import pandas as pd import datasets import os _VERSION = datasets.Version("0.0.1") _DESCRIPTION = "TODO" _HOMEPAGE = "TODO" _LICENSE = "TODO" _CITATION = "TODO" _FEATURES = datasets.Features( { "target": datasets.Image(),#!!首先需要改的是这三行,用来指定你的prompt中三个列分别对应的类型 "source": datasets.Image(),#!!首先需要改的是这三行,用来指定你的prompt中三个列分别对应的类型 "prompt": datasets.Value("string"),#!!首先需要改的是这三行,用来指定你的prompt中三个列分别对应的类型 }, ) _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION) class MyData(datasets.GeneratorBasedBuilder): BUILDER_CONFIGS = [_DEFAULT_CONFIG] DEFAULT_CONFIG_NAME = "default" def _info(self): return datasets.DatasetInfo( description=_DESCRIPTION, features=_FEATURES, supervised_keys=None, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION, ) def _split_generators(self, dl_manager): #!其次是下面三个路径,分别代表: # metadata_path -> prompt.jsonl的路径 # images_dir -> target目标图像的路径,因为我们在prompt里已经记录了target/0.png, 所以这里只需要写上一级的dataset_name文件夹的路径 # conditioning_images_dir -> source控制图像的路径,因为我们在prompt里已经记录了source/0.png, 所以这里只需要写上一级的dataset_name文件夹的路径 metadata_path = "/home/disk/hzy/repos/diffusers/examples/controlnet/gpcvrs_small/small.jsonl" images_dir = "/home/disk/hzy/repos/diffusers/examples/controlnet/gpcvrs_small" conditioning_images_dir = "/home/disk/hzy/repos/diffusers/examples/controlnet/gpcvrs_small" dl_manager.download_and_extract return [ datasets.SplitGenerator( name=datasets.Split.TRAIN, # These kwargs will be passed to _generate_examples gen_kwargs={ "metadata_path": metadata_path, "images_dir": images_dir, "conditioning_images_dir": conditioning_images_dir, }, ), ] def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir): metadata = pd.read_json(metadata_path, lines=True) for _, row in metadata.iterrows(): text = row["prompt"]#!这里改为我们的提示词列名称,即"prompt" image_path = row["target"]#!这里改为我们的目标图像列名称,即"target" image_path = os.path.join(images_dir, image_path) image = open(image_path, "rb").read() conditioning_image_path = row["source"]#!这里改为我们的控制图像列名称,即"source" conditioning_image_path = os.path.join( conditioning_images_dir, row["source"]#!这里改为我们的控制图像列名称,即"source" ) conditioning_image = open(conditioning_image_path, "rb").read() #最后迭代返回的是两个元素: # 1、该样本对目标文件路径(此处即为row["target"]) # 2、一个包含下列元素的字典,按照描述修改即可,返回的字典的键名对应开头(line45~47)指定的三个名字: # "prompt":prompt_str, # "target":{ # "path": image_path,(目标图像路径) # "bytes": image,(目标图像路径) # } # "source":{ # "path": conditioning_image_path,(控制图像文件路径) # "bytes": conditioning_image,(控制图像本身) # } yield row["target"], { "prompt": text, "target": { "path": image_path, "bytes": image, }, "source": { "path": conditioning_image_path, "bytes": conditioning_image, }, }

训练

在完成dataset_name.py的修改后,只需要修改huggingface 给出的controlnet的训练脚本即可:

bash
#训练脚本train.sh修改示例: export MODEL_DIR="runwayml/stable-diffusion-v1-5" export OUTPUT_DIR="{输出的模型路径}" accelerate launch --multi_gpu train_controlnet.py \ --mixed_precision="fp16" \ --pretrained_model_name_or_path=$MODEL_DIR \ --output_dir=$OUTPUT_DIR \ --train_data_dir="{你的dataset文件路径}" \ --resolution=512 \ --learning_rate=1e-5 \ --train_batch_size=4 \ --num_train_epochs=6 \ --caption_column="prompt" \ #这里修改caption_column为提示词列名称prompt --image_column="target" \ #这里修改image_column为目标图像列名称target --conditioning_image_column="source" \ #这里修改conditioning_image_column为目标图像列名称source

在diffuser的example/controlnet目录下运行train.sh脚本可以直接开始训练。

本文作者:insomnia

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!