|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
pipeline_tag: image-to-text |
|
tags: |
|
- mplug-owl |
|
--- |
|
|
|
# Usage |
|
## Get the latest codebase from Github |
|
```Bash |
|
git clone https://github.com/X-PLUG/mPLUG-Owl.git |
|
``` |
|
|
|
## Model initialization |
|
```Python |
|
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration |
|
from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer |
|
from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor |
|
|
|
pretrained_ckpt = 'MAGAer13/mplug-owl-llama-7b' |
|
model = MplugOwlForConditionalGeneration.from_pretrained( |
|
pretrained_ckpt, |
|
torch_dtype=torch.bfloat16, |
|
) |
|
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) |
|
tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt) |
|
processor = MplugOwlProcessor(image_processor, tokenizer) |
|
``` |
|
|
|
## Model inference |
|
Prepare model inputs. |
|
```Python |
|
# We use a human/AI template to organize the context as a multi-turn conversation. |
|
# <image> denotes an image placehold. |
|
prompts = [ |
|
'''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. |
|
Human: <image> |
|
Human: Explain why this meme is funny. |
|
AI: '''] |
|
|
|
# The image paths should be placed in the image_list and kept in the same order as in the prompts. |
|
# We support urls, local file paths and base64 string. You can custom the pre-process of images by modifying the mplug_owl.modeling_mplug_owl.ImageProcessor |
|
image_list = ['https://xxx.com/image.jpg'] |
|
``` |
|
|
|
Get response. |
|
```Python |
|
# generate kwargs (the same in transformers) can be passed in the do_generate() |
|
generate_kwargs = { |
|
'do_sample': True, |
|
'top_k': 5, |
|
'max_length': 512 |
|
} |
|
from PIL import Image |
|
images = [Image.open(_) for _ in image_list] |
|
inputs = processor(text=prompts, images=images, return_tensors='pt') |
|
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
res = model.generate(**inputs, **generate_kwargs) |
|
sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) |
|
print(sentence) |
|
``` |