Spaces:
Build error
Build error
Commit
·
973977c
1
Parent(s):
8e2f608
update
Browse files- app.py +12 -8
- ldm/models/diffusion/sync_dreamer.py +3 -7
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import os
|
|
| 8 |
import fire
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
|
|
|
|
| 11 |
from ldm.util import add_margin, instantiate_from_config
|
| 12 |
from sam_utils import sam_init, sam_out_nosave
|
| 13 |
|
|
@@ -19,12 +20,12 @@ _DESCRIPTION = '''
|
|
| 19 |
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo="></a>
|
| 20 |
<a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
|
| 21 |
</div>
|
| 22 |
-
Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
|
| 23 |
|
| 24 |
-
Procedure:
|
| 25 |
-
**Step 0**. Upload an image or select an example. ==> The foreground is masked out by SAM.
|
| 26 |
-
**Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized.
|
| 27 |
-
**Step 2**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. (This costs about 2 min.)
|
| 28 |
To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/liuyuan-pal/SyncDreamer).
|
| 29 |
'''
|
| 30 |
_USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example shown in the left)."
|
|
@@ -74,8 +75,9 @@ def resize_inputs(image_input, crop_size):
|
|
| 74 |
results = add_margin(ref_img_, size=256)
|
| 75 |
return results
|
| 76 |
|
| 77 |
-
def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
|
| 78 |
if deployed:
|
|
|
|
| 79 |
seed=int(seed)
|
| 80 |
torch.random.manual_seed(seed)
|
| 81 |
np.random.seed(seed)
|
|
@@ -97,7 +99,8 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
|
|
| 97 |
data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
|
| 98 |
|
| 99 |
if deployed:
|
| 100 |
-
|
|
|
|
| 101 |
else:
|
| 102 |
x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
|
| 103 |
|
|
@@ -219,6 +222,7 @@ def run_demo():
|
|
| 219 |
with gr.Accordion('Advanced options', open=False):
|
| 220 |
cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
| 221 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
|
|
|
|
| 222 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
| 223 |
seed = gr.Number(6033, label='Random seed', interactive=True)
|
| 224 |
run_btn = gr.Button('Run generation', variant='primary', interactive=True)
|
|
@@ -235,7 +239,7 @@ def run_demo():
|
|
| 235 |
crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
| 236 |
.success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
| 237 |
|
| 238 |
-
run_btn.click(partial(generate, model), inputs=[batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
|
| 239 |
.success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
|
| 240 |
|
| 241 |
demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
|
|
|
|
| 8 |
import fire
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
|
| 11 |
+
from ldm.models.diffusion.sync_dreamer import SyncDDIMSampler, SyncMultiviewDiffusion
|
| 12 |
from ldm.util import add_margin, instantiate_from_config
|
| 13 |
from sam_utils import sam_init, sam_out_nosave
|
| 14 |
|
|
|
|
| 20 |
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo="></a>
|
| 21 |
<a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
|
| 22 |
</div>
|
| 23 |
+
Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss </br>
|
| 24 |
|
| 25 |
+
Procedure: </br>
|
| 26 |
+
**Step 0**. Upload an image or select an example. ==> The foreground is masked out by SAM. </br>
|
| 27 |
+
**Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized. </br>
|
| 28 |
+
**Step 2**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. (This costs about 2 min.) </br>
|
| 29 |
To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/liuyuan-pal/SyncDreamer).
|
| 30 |
'''
|
| 31 |
_USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example shown in the left)."
|
|
|
|
| 75 |
results = add_margin(ref_img_, size=256)
|
| 76 |
return results
|
| 77 |
|
| 78 |
+
def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
|
| 79 |
if deployed:
|
| 80 |
+
assert isinstance(model, SyncMultiviewDiffusion)
|
| 81 |
seed=int(seed)
|
| 82 |
torch.random.manual_seed(seed)
|
| 83 |
np.random.seed(seed)
|
|
|
|
| 99 |
data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
|
| 100 |
|
| 101 |
if deployed:
|
| 102 |
+
sampler = SyncDDIMSampler(model, sample_steps)
|
| 103 |
+
x_sample = model.sample(sampler, data, cfg_scale, batch_view_num)
|
| 104 |
else:
|
| 105 |
x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
|
| 106 |
|
|
|
|
| 222 |
with gr.Accordion('Advanced options', open=False):
|
| 223 |
cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
| 224 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
|
| 225 |
+
sample_steps = gr.Slider(40, 400, 200, step=10, label='Sample steps', interactive=True)
|
| 226 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
| 227 |
seed = gr.Number(6033, label='Random seed', interactive=True)
|
| 228 |
run_btn = gr.Button('Run generation', variant='primary', interactive=True)
|
|
|
|
| 239 |
crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
| 240 |
.success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
| 241 |
|
| 242 |
+
run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
|
| 243 |
.success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
|
| 244 |
|
| 245 |
demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
|
ldm/models/diffusion/sync_dreamer.py
CHANGED
|
@@ -468,13 +468,9 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
| 468 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
| 469 |
return x_noisy, noise
|
| 470 |
|
| 471 |
-
def sample(self, batch, cfg_scale, batch_view_num,
|
| 472 |
-
return_inter_results=False, inter_interval=50, inter_view_interval=2):
|
| 473 |
_, clip_embed, input_info = self.prepare(batch)
|
| 474 |
-
|
| 475 |
-
x_sample, inter = self.ddim.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
|
| 476 |
-
else:
|
| 477 |
-
raise NotImplementedError
|
| 478 |
|
| 479 |
N = x_sample.shape[1]
|
| 480 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
|
@@ -540,7 +536,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
| 540 |
return [opt], scheduler
|
| 541 |
|
| 542 |
class SyncDDIMSampler:
|
| 543 |
-
def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0
|
| 544 |
super().__init__()
|
| 545 |
self.model = model
|
| 546 |
self.ddpm_num_timesteps = model.num_timesteps
|
|
|
|
| 468 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
| 469 |
return x_noisy, noise
|
| 470 |
|
| 471 |
+
def sample(self, sampler, batch, cfg_scale, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2):
|
|
|
|
| 472 |
_, clip_embed, input_info = self.prepare(batch)
|
| 473 |
+
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
N = x_sample.shape[1]
|
| 476 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
|
|
|
| 536 |
return [opt], scheduler
|
| 537 |
|
| 538 |
class SyncDDIMSampler:
|
| 539 |
+
def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=1.0, latent_size=32):
|
| 540 |
super().__init__()
|
| 541 |
self.model = model
|
| 542 |
self.ddpm_num_timesteps = model.num_timesteps
|