BAAI
/

ryanzhangfan commited on
Commit
6f85fd4
·
verified ·
1 Parent(s): 41fdfca

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +86 -3
README.md CHANGED
@@ -1,3 +1,86 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ #### Quickstart
6
+
7
+ ```python
8
+ from PIL import Image
9
+ from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
10
+ from transformers.generation.configuration_utils import GenerationConfig
11
+ from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
12
+ import torch
13
+
14
+ import sys
15
+ sys.path.append(PATH_TO_BAAI_Emu3-Gen_MODEL)
16
+ from processing_emu3 import Emu3Processor
17
+
18
+ EMU_HUB = "BAAI/Emu3-Gen"
19
+ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ EMU_HUB,
23
+ device_map="cuda:0",
24
+ torch_dtype=torch.bfloat16,
25
+ attn_implementation="flash_attention_2",
26
+ trust_remote_code=True,
27
+ )
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True)
30
+ image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
31
+ image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
32
+ processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
33
+
34
+ POSITIVE_PROMPT = " masterpiece, film grained, best quality."
35
+ NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
36
+
37
+ classifier_free_guidance = 3.0
38
+ prompt = "a portrait of young girl."
39
+ prompt += POSITIVE_PROMPT
40
+
41
+ kwargs = dict(
42
+ mode='G',
43
+ ratio="1:1",
44
+ image_area=model.config.image_area,
45
+ return_tensors="pt",
46
+ )
47
+ pos_inputs = processor(text=prompt, **kwargs)
48
+ neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
49
+
50
+ GENERATION_CONFIG = GenerationConfig(
51
+ use_cache=True,
52
+ eos_token_id=model.config.eos_token_id,
53
+ pad_token_id=model.config.pad_token_id,
54
+ max_new_tokens=40960,
55
+ do_sample=True,
56
+ top_k=2048,
57
+ )
58
+
59
+ h, w = pos_inputs.image_size[0]
60
+ constrained_fn = processor.build_prefix_constrained_fn(h, w)
61
+ logits_processor = LogitsProcessorList([
62
+ UnbatchedClassifierFreeGuidanceLogitsProcessor(
63
+ classifier_free_guidance,
64
+ model,
65
+ unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
66
+ ),
67
+ PrefixConstrainedLogitsProcessor(
68
+ constrained_fn ,
69
+ num_beams=1,
70
+ ),
71
+ ])
72
+
73
+ outputs = model.generate(
74
+ pos_inputs.input_ids.to("cuda:0"),
75
+ GENERATION_CONFIG,
76
+ logits_processor=logits_processor
77
+ )
78
+
79
+ mm_list = processor.decode(outputs[0])
80
+ print(mm_list)
81
+ for idx, im in enumerate(mm_list):
82
+ if not isinstance(im, Image.Image):
83
+ continue
84
+ im.save(f"result_{idx}.png")
85
+
86
+ ```