File size: 3,855 Bytes
c98f2aa
 
 
 
ec74248
 
c98f2aa
a67f43b
9006c6b
 
 
 
 
 
 
 
a67f43b
 
 
 
e1b5cb0
719f1ac
 
a67f43b
9006c6b
 
a67f43b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9006c6b
 
a67f43b
 
 
 
 
 
 
 
 
 
 
 
 
9006c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c98f2aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
---
license: apache-2.0
tags:
- yahoo-open-source-software-incubator
pipeline_tag: text-to-image
inference: false
---
# Salient Object-Aware Background Generation using Text-Guided Diffusion Models [![Paper](assets/arxiv.svg)](https://arxiv.org/pdf/2404.10157.pdf)
This repository accompanies our paper, [Salient Object-Aware Background Generation using Text-Guided Diffusion Models](https://arxiv.org/abs/2404.10157), which has been accepted for publication in [CVPR 2024 Generative Models for Computer Vision](https://generative-vision.github.io/workshop-CVPR-24/) workshop.

The paper addresses an issue we call "object expansion" when generating backgrounds for salient objects using inpainting diffusion models.  We show that models such as [Stable Inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) can sometimes arbitrarily expand or distort the salient object, which is undesirable in applications where the object's identity should be preserved, such as e-commerce ads. We provide some examples of object expansion as follows:

<div align="center">
  <img src="assets/fig.jpg">
</div>

# Inference

### Load pipeline
```py
from diffusers import DiffusionPipeline
model_id = "yahoo-inc/photo-background-generation"
pipeline = DiffusionPipeline.from_pretrained(model_id, custom_pipeline=model_id)
pipeline = pipeline.to('cuda')
```

### Load an image and extract its background and foreground
```py
from PIL import Image, ImageOps
import requests
from io import BytesIO
from transparent_background import Remover

def resize_with_padding(img, expected_size):
    img.thumbnail((expected_size[0], expected_size[1]))
    # print(img.size)
    delta_width = expected_size[0] - img.size[0]
    delta_height = expected_size[1] - img.size[1]
    pad_width = delta_width // 2
    pad_height = delta_height // 2
    padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
    return ImageOps.expand(img, padding)

seed = 0
image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/1/16/Granja_comary_Cisne_-_Escalavrado_e_Dedo_De_Deus_ao_fundo_-Teres%C3%B3polis.jpg/2560px-Granja_comary_Cisne_-_Escalavrado_e_Dedo_De_Deus_ao_fundo_-Teres%C3%B3polis.jpg'
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
img = resize_with_padding(img, (512, 512))

# Load background detection model
remover = Remover() # default setting
remover = Remover(mode='base') # nightly release checkpoint

# Get foreground mask
fg_mask = remover.process(img, type='map') # default setting - transparent background
```

### Background generation
```py
seed = 13
mask = ImageOps.invert(fg_mask)
img = resize_with_padding(img, (512, 512))
generator = torch.Generator(device='cuda').manual_seed(seed)
prompt = 'A dark swan in a bedroom'
cond_scale = 1.0
with torch.autocast("cuda"):
    controlnet_image = pipeline(
        prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, generator=generator, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=cond_scale
    ).images[0]
controlnet_image
```
## Citations

If you found our work useful, please consider citing our paper:

```bibtex
@misc{eshratifar2024salient,
      title={Salient Object-Aware Background Generation using Text-Guided Diffusion Models}, 
      author={Amir Erfan Eshratifar and Joao V. B. Soares and Kapil Thadani and Shaunak Mishra and Mikhail Kuznetsov and Yueh-Ning Ku and Paloma de Juan},
      year={2024},
      eprint={2404.10157},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

## Maintainers

- Erfan Eshratifar: [email protected]
- Joao Soares: [email protected]

## License

This project is licensed under the terms of the [Apache 2.0](LICENSE) open source license. Please refer to [LICENSE](LICENSE) for the full terms.