-
Notifications
You must be signed in to change notification settings - Fork 4
/
sample_diffusion_ldm.py
567 lines (506 loc) · 19.9 KB
/
sample_diffusion_ldm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
import argparse, os, sys, glob, datetime, yaml
from ldm.models.diffusion.plms import PLMSSampler
import math
from quant.reconstruction_util import RLOSS
from quant.data_generate import generate_cali_data_ldm
from quant.calibration import cali_model, cali_model_multi, load_cali_model
from quant.quant_model import QuantModel
from quant.quant_layer import QMODE, Scaler
import logging
import torch
import time
import numpy as np
from tqdm import trange
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from PIL import Image
import sys
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
from ldm.util import instantiate_from_config
import torch.multiprocessing as mp
logger = logging.getLogger(__name__)
rescale = lambda x: (x + 1.) / 2.
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.) / 2.
x = x.permute(1, 2, 0).numpy()
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def custom_to_np(x):
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
sample = x.detach().cpu()
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
return sample
def logs2pil(logs, keys=["sample"]):
imgs = dict()
for k in logs:
try:
if len(logs[k].shape) == 4:
img = custom_to_pil(logs[k][0, ...])
elif len(logs[k].shape) == 3:
img = custom_to_pil(logs[k])
else:
print(f"Unknown format for key {k}. ")
img = None
except:
img = None
imgs[k] = img
return imgs
@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
verbose=True,
make_prog_row=False):
if not make_prog_row:
return model.p_sample_loop(None, shape,
return_intermediates=return_intermediates, verbose=verbose)
else:
return model.progressive_denoising(
None, shape, verbose=True
)
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0, use_correction=False
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
if not use_correction:
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False)
else:
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,
correct=True, slope=torch.load('./data/slope.pt'), bias=torch.load('./data/bias.pt'),
residual_error=torch.load('./data/residual_error.pt'))
return samples, intermediates
@torch.no_grad()
def convsample_dpm(model, steps, shape, eta=1.0
):
dpm = DPMSolverSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = dpm.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
return samples, intermediates
@torch.no_grad()
def convsample_plms(model, steps, shape, eta=1.0):
plms = PLMSSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = plms.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0, dpm=False, use_correction=False, plms=False):
log = dict()
shape = [batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size]
# with model.ema_scope("Plotting"):
t0 = time.time()
if vanilla:
sample, progrow = convsample(model, shape,
make_prog_row=True)
elif dpm:
logger.info(f'Using DPM sampling with {custom_steps} sampling steps and eta={eta}')
sample, intermediates = convsample_dpm(model, steps=custom_steps, shape=shape,
eta=eta)
elif plms:
logger.info(f'Using PLMS sampling with {custom_steps} sampling steps and eta={eta}')
sample, intermediates = convsample_plms(model, steps=custom_steps, shape=shape,
eta=eta)
else:
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
eta=eta, use_correction=use_correction)
t1 = time.time()
x_sample = model.decode_first_stage(sample)
log["sample"] = x_sample
log["time"] = t1 - t0
log['throughput'] = sample.shape[0] / (t1 - t0)
logger.info(f'Throughput for this batch: {log["throughput"]}')
return log
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None,
n_samples=50000, nplog=None, dpm=False, use_correction=False, plms=False):
if vanilla:
logger.info(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
else:
logger.info(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
tstart = time.time()
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
# path = logdir
if model.cond_stage_model is None:
all_images = []
logger.info(f"Running unconditional sampling for {n_samples} samples")
for _ in trange(math.ceil(n_samples / batch_size), desc="Sampling Batches (unconditional)"):
if (_ + 1) * batch_size > n_samples:
assert _ == math.ceil(n_samples / batch_size) - 1
batch_size = n_samples - _ * batch_size
logs = make_convolutional_sample(model, batch_size=batch_size,
vanilla=vanilla, custom_steps=custom_steps,
eta=eta, dpm=dpm, use_correction=use_correction, plms=plms)
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
all_images.extend([custom_to_np(logs["sample"])])
if n_saved >= n_samples:
logger.info(f'Finish after generating {n_saved} samples')
break
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
shape_str = "x".join([str(x) for x in all_img.shape])
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
np.savez(nppath, all_img)
else:
raise NotImplementedError('Currently only sampling for unconditional models supported.')
logger.info(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
for k in logs:
if k == key:
batch = logs[key]
if np_path is None:
for x in batch:
img = custom_to_pil(x)
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
img.save(imgpath)
n_saved += 1
else:
npbatch = custom_to_np(batch)
shape_str = "x".join([str(x) for x in npbatch.shape])
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
np.savez(nppath, npbatch)
n_saved += npbatch.shape[0]
return n_saved
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-n",
"--n_samples",
type=int,
nargs="?",
help="number of samples to draw",
default=50000
)
parser.add_argument(
"-e",
"--eta",
type=float,
nargs="?",
help="eta for ddim sampling (0.0 yields deterministic sampling)",
default=1.0
)
parser.add_argument(
"-v",
"--vanilla_sample",
default=False,
action='store_true',
help="vanilla sampling (default option is DDIM sampling)?",
)
parser.add_argument(
"--seed",
type=int,
# default=42,
required=True,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
nargs="?",
help="extra logdir",
default="none"
)
parser.add_argument(
"-c",
"--custom_steps",
type=int,
nargs="?",
help="number of steps for ddim and fastdpm sampling",
default=50
)
parser.add_argument(
"--plms",
action='store_true',
help="use plms sampler",
)
parser.add_argument(
"--batch_size",
type=int,
nargs="?",
help="the bs",
default=10
)
# linear quantization configs
parser.add_argument(
"--ptq", action="store_true", help="apply post-training quantization"
)
parser.add_argument(
"--wq",
type=int,
default=8,
help="int bit for weight quantization",
)
parser.add_argument(
"--aq",
type=int,
default=8,
help="int bit for activation quantization",
)
# qdiff specific configs
parser.add_argument(
"--cali_ckpt", type=str,
help="path for calibrated model ckpt"
)
parser.add_argument(
"--softmax_a_bit",type=int, default=8,
help="attn softmax activation bit"
)
parser.add_argument(
"--dpm", action="store_true",
help="use dpm solver for sampling"
)
parser.add_argument(
"--verbose", action="store_true",
help="print out info like quantized model arch"
)
parser.add_argument(
"--cali",
action="store_true",
help="whether to calibrate the model"
)
parser.add_argument(
"--cali_save_path",
type=str,
default="cali_ckpt/quant_ddim.pth",
help="path to save the calibrated ckpt"
)
parser.add_argument(
"--interval_length",
type=int,
default=1,
help="calibration interval length"
)
parser.add_argument(
'--use_aq',
action='store_true',
help='whether to use activation quantization'
)
# multi-gpu configs
parser.add_argument('--multi_gpu', action='store_true', help='use multiple gpus')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:3367', type=str, help='')
parser.add_argument('--dist-backend', default='nccl', type=str, help='')
parser.add_argument('--rank', default=0, type=int, help='')
parser.add_argument('--world_size', default=1, type=int, help='')
return parser
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.cuda()
model.eval()
return model
def load_model(config, ckpt, gpu, eval_mode):
if ckpt:
logger.info(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"])
return model, global_step
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
sys.path.append(os.getcwd())
command = " ".join(sys.argv)
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
# fix random seed
seed_everything(opt.seed)
ngpus_per_node = torch.cuda.device_count()
opt.world_size = ngpus_per_node * opt.world_size
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
# paths = opt.resume.split("/")
try:
logdir = '/'.join(opt.resume.split('/')[:-1])
# idx = len(paths)-paths[::-1].index("logs")+1
print(f'Logdir is {logdir}')
except ValueError:
paths = opt.resume.split("/")
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "model.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
opt.base = base_configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
gpu = True
eval_mode = True
if opt.logdir != "none":
locallog = logdir.split(os.sep)[-1]
if locallog == "": locallog = logdir.split(os.sep)[-2]
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
logdir = os.path.join(opt.logdir, locallog)
logdir = os.path.join(logdir, "samples", now)
os.makedirs(logdir)
log_path = os.path.join(logdir, "run.log")
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO,
handlers=[
logging.FileHandler(log_path),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
print(config)
logger.info(75 * "=")
logger.info(f"Host {os.uname()[1]}")
logger.info("logging to:")
imglogdir = os.path.join(logdir, "img")
numpylogdir = os.path.join(logdir, "numpy")
os.makedirs(imglogdir)
os.makedirs(numpylogdir)
logger.info(logdir)
logger.info(75 * "=")
model, global_step = load_model(config, ckpt, gpu, eval_mode)
logger.info(f"global step: {global_step}")
logger.info("Switched to EMA weights")
model.model_ema.store(model.model.parameters())
model.model_ema.copy_to(model.model)
p = [QMODE.NORMAL.value]
p.append(QMODE.QDIFF.value)
opt.q_mode = p
opt.asym = True
opt.running_stat = True
wq_params = {"bits": opt.wq,
"channel_wise": True,
"scaler": Scaler.MSE if opt.cali else Scaler.MINMAX}
aq_params = {"bits": opt.aq,
"channel_wise": False,
"scaler": Scaler.MSE if opt.cali else Scaler.MINMAX,
"leaf_param": opt.use_aq}
if opt.ptq:
if not opt.cali:
setattr(model.model.diffusion_model, "split", True)
qnn = QuantModel(model=model.model.diffusion_model,
wq_params=wq_params,
aq_params=aq_params,
cali=False,
softmax_a_bit=opt.softmax_a_bit,
aq_mode=opt.q_mode)
qnn.to('cuda')
qnn.eval()
image_size = config.model.params.image_size
channels = config.model.params.channels
cali_data = (torch.randn(1, channels, image_size, image_size), torch.randint(0, 1000, (1,)))
load_cali_model(qnn, cali_data, use_aq=opt.use_aq, path=opt.cali_ckpt)
model.model.diffusion_model = qnn
if opt.use_aq:
cali_ckpt = torch.load(opt.cali_ckpt)
tot = len(list(cali_ckpt.keys())) - 1
model.model.tot = 1000 // tot
model.model.t_max = tot - 1
model.model.ckpt = cali_ckpt
model.model.iter = 0
else:
logger.info("Generating calibration data...")
shape = [config.model.params.channels, config.model.params.image_size, config.model.params.image_size]
cali_data = generate_cali_data_ldm(model=model,
T=opt.custom_steps,
c=1,
batch_size=256,
shape=shape,
vanilla=opt.vanilla_sample,
dpm=opt.dpm,
plms=opt.plms,
eta=opt.eta)
a_cali_data = cali_data
tmp = []
for i in range(0, opt.custom_steps, opt.interval_length):
tmp.append([cali_data[0][i * 256: (i + 1) * 256], cali_data[1][i * 256: (i + 1) * 256]])
w_cali_data = [torch.cat([x[0] for x in tmp]), torch.cat([x[1] for x in tmp])]
logger.info("Calibration data generated.")
torch.cuda.empty_cache()
setattr(model.model.diffusion_model, "split", True)
if opt.multi_gpu:
kwargs = dict(iters=20000,
batch_size=32,
w=0.01,
asym=opt.asym,
warmup=0.2,
opt_mode=RLOSS.MSE,
wq_params=wq_params,
aq_params=aq_params,
softmax_a_bit=opt.softmax_a_bit,
aq_mode=opt.q_mode,
multi_gpu=ngpus_per_node > 1)
mp.spawn(cali_model_multi, args=(opt.dist_backend,
opt.world_size,
opt.dist_url,
opt.rank,
ngpus_per_node,
model.model,
opt.use_aq,
opt.cali_save_path,
w_cali_data,
a_cali_data,
256,
opt.running_stat,
kwargs), nprocs=ngpus_per_node)
else:
qnn = QuantModel(model=model.model.diffusion_model,
wq_params=wq_params,
aq_params=aq_params,
softmax_a_bit=opt.softmax_a_bit,
aq_mode=opt.q_mode)
kwargs = dict(w_cali_data=w_cali_data,
a_cali_data=a_cali_data,
iters=20000,
batch_size=32,
w=0.01,
asym=opt.asym,
warmup=0.2,
opt_mode=RLOSS.MSE,
multi_gpu=False)
qnn.to('cuda')
qnn.eval()
cali_model(qnn=qnn,
use_aq=opt.use_aq,
path=opt.cali_save_path,
running_stat=opt.running_stat,
interval=256,
**kwargs)
exit(0)
# write config out
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)
with open(sampling_file, 'a+') as f:
yaml.dump(sampling_conf, f, default_flow_style=False)
if opt.verbose:
print(sampling_conf)
logger.info("first_stage_model")
logger.info(model.first_stage_model)
logger.info("UNet model")
logger.info(model.model)
run(model, imglogdir, eta=opt.eta,
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir, dpm=opt.dpm, use_correction=QMODE.PTQD.value in opt.q_mode, plms=opt.plms)
logger.info("done.")