-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgenerate.py
More file actions
250 lines (196 loc) · 7.94 KB
/
generate.py
File metadata and controls
250 lines (196 loc) · 7.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import json
from typing import List
from pathlib import Path
from argparse import ArgumentParser
from tqdm import tqdm
import torch
from transformers import CLIPTextModel
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from utils import fix_seed, flush, chunks
from nb_utils.eval_sets import evaluation_sets
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--pretrained_model_or_path',
default='stabilityai/stable-diffusion-2-base',
help='hf model or path to pipeline',
)
parser.add_argument(
'--exp_dir',
type=lambda s: Path(s).absolute().resolve(),
help='path to experiment directory with all checkpoints',
)
parser.add_argument(
'--ckpt_path',
type=lambda s: Path(s).absolute().resolve(),
help='path to specific checkpoint dir',
)
parser.add_argument(
'--out_dir',
type=lambda s: Path(s).absolute().resolve(),
help='path to where the images will be saved, default is inside the ckpt_path/samples'
)
parser.add_argument(
'--concept',
required=True,
help='placeholder token (e.g. "sks dog") for concept',
)
parser.add_argument(
'--prompt_source',
default='eval_set',
choices=['eval_set', 'json', 'user'],
help='which prompts to use: nb_utils/eval_set; json list; user input',
)
parser.add_argument(
'--eval_set',
choices=evaluation_sets.keys(),
help='[prompt_source==eval_set] prompt evaluation set',
)
parser.add_argument(
'--prompts_json_path',
type=lambda s: Path(s).absolute().resolve(),
help='[prompt_source==json] path to json with a list of prompts',
)
parser.add_argument(
'--prompt_template',
help='[prompt_source==user] prompt template (see eval_sets for format)'
)
parser.add_argument(
'--samples_per_prompt',
type=int,
default=16,
help='how many samples to generate per prompt',
)
parser.add_argument(
'--prompts_per_batch',
type=int,
default=4,
help='how many prompts will be processed in a single batch',
)
parser.add_argument(
'--guidance_scale',
type=float,
default=7.5,
)
parser.add_argument(
'--seed',
type=int,
default=0,
help='seed for reproducible runs',
)
parser.add_argument(
'--version',
type=int,
default=0,
help='internal version, used in path',
)
parser.add_argument(
'--override',
action='store_true',
help='regenerate samples even if already present',
)
args = parser.parse_args()
# Validation
assert (args.exp_dir is None) != (args.ckpt_path is None), "only one of `exp_dir` and `ckpt_path` should be provided"
if args.prompt_source == 'eval_set':
assert args.eval_set is not None, "`eval_set` should be provided if using [prompt_source==eval_set]"
elif args.prompt_source == 'json':
assert args.prompts_json_path is not None, "`prompts_json_path` should be provided if using [prompt_source==json]"
elif args.prompt_source == 'user':
assert args.prompt_template is not None, "`prompt_template` should be provided if using [prompt_source==user]"
else:
raise NotImplementedError("unknown prompt source")
if args.exp_dir is not None:
assert args.out_dir is None, "specifying `out_dir` for experiment is unsupported"
# Misc
print(f"Effective batch_size: {args.prompts_per_batch * args.samples_per_prompt} (= {args.prompts_per_batch} prompts/batch x {args.samples_per_prompt} samples/prompt)")
return args
def get_prompts(args):
templates = []
if args.prompt_source == 'eval_set':
templates = evaluation_sets[args.eval_set]
elif args.prompt_source == 'json':
with open(args.prompts_json_path, 'r') as f:
data = json.load(f)
templates = [sample['prompt'] for sample in data]
elif args.prompt_source == 'user':
templates = [args.prompt_template]
prompts = [tmpl.format(args.concept) for tmpl in templates]
print(f"Total of {len(prompts)} prompts")
return prompts
def load_pipeline(ckpt_dir: Path):
pipeline_kwargs = {
'torch_dtype': torch.float16,
'requires_safety_checker': False,
}
if (ckpt_dir / 'unet').exists():
pipeline_kwargs['unet'] = UNet2DConditionModel.from_pretrained(ckpt_dir / 'unet', torch_dtype=torch.float16)
if (ckpt_dir / 'text_encoder').exists():
pipeline_kwargs['text_encoder'] = CLIPTextModel.from_pretrained(ckpt_dir / 'text_encoder', torch_dtype=torch.float16)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_or_path,
**pipeline_kwargs,
).to('cuda')
if (ckpt_dir.parent / f"learned_embeds-steps-{ckpt_dir.name.split('-')[-1]}.safetensors").exists():
pipeline.load_textual_inversion(ckpt_dir.parent / f"learned_embeds-steps-{ckpt_dir.name.split('-')[-1]}.safetensors")
pipeline.unet.set_attn_processor(AttnProcessor2_0())
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
return pipeline
def str_inference_specs(num_steps: int, guidance_scale: float):
return f"ns{num_steps}_gs{guidance_scale:.1f}"
def split_prompts_into_batches(prompts, samples_dir, args):
batches = []
batch = []
for prompt in prompts:
# Check if the prompt is already generated
out_dir = samples_dir / f"{prompt}"
out_dir.mkdir(parents=True, exist_ok=True)
if not args.override and len(list(out_dir.iterdir())) >= args.samples_per_prompt:
print(f"Prompt `{prompt}` already generated, skipping")
continue
batch.append(prompt)
if len(batch) < args.prompts_per_batch:
continue
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
return batches
def generate_batch(pipeline: StableDiffusionPipeline, prompts: List[str], args):
batch_prompts = [prompt for prompt in prompts for _ in range(args.samples_per_prompt)]
generators = [torch.Generator().manual_seed(args.seed + i) for _ in prompts for i in range(args.samples_per_prompt)]
return pipeline(batch_prompts, guidance_scale=args.guidance_scale, generator=generators).images
def generate_and_save(pipeline: StableDiffusionPipeline, prompts: List[str], samples_dir: Path, args):
images = generate_batch(pipeline, prompts, args)
assert len(images) == len(prompts) * args.samples_per_prompt
for prompt, images_for_one_prompt in zip(prompts, chunks(images, args.samples_per_prompt), strict=True):
out_dir = samples_dir / prompt
for i, img in enumerate(images_for_one_prompt):
img.save(out_dir / f'{i}.png')
def process_checkpoint(ckpt_dir: Path, prompts: List[str], args):
print(f'Processing {ckpt_dir}')
fix_seed(args.seed)
pipeline = load_pipeline(ckpt_dir)
inference_specs = str_inference_specs(50, args.guidance_scale)
samples_dir = args.out_dir or (ckpt_dir / 'samples' / inference_specs / f'version_{args.version}')
prompt_batches = split_prompts_into_batches(prompts, samples_dir, args)
for prompt_batch in tqdm(prompt_batches, desc='batches'):
generate_and_save(pipeline, prompt_batch, samples_dir, args)
# flush()
def generate_all(prompts: List[str], args):
if args.exp_dir is not None:
checkpoints = list(Path(args.exp_dir).glob('checkpoint-*'))
else:
checkpoints = [args.ckpt_path]
print(f"Total of {len(checkpoints)} checkpoints")
for ckpt_dir in checkpoints:
process_checkpoint(ckpt_dir, prompts, args)
flush()
def main(args):
prompts = get_prompts(args)
generate_all(prompts, args)
print('Done')
if __name__ == "__main__":
args = parse_args()
main(args)