Cerule - A Tiny Mighty Vision Model
Based on Google's - Gemma-2b + SigLIP
ββββββββββββββββββββββ βββ ββββββ ββββββββ
βββββββββββββββββββββββββββ ββββββ ββββββββ
βββ ββββββ βββββββββββ ββββββ ββββββ
βββ ββββββ βββββββββββ ββββββ ββββββ
βββββββββββββββββββ ββββββββββββββββββββββββββββ
ββββββββββββββββββ βββ βββββββ ββββββββββββββββ
We train and release "Cerule", a tiny yet powerful Vision Lanuage Model based on the newly released Google's Gemma-2b and Google's SigLIP.
We utilise highly efficient data selection techniques with:
- Pretraining stage : 650K images (A LAION 2M Subset)
- Finetuning stage : 695K images (SVIT-mix-665K modified for finetuning(Dataset SOON!))
The training setup was 4xA100's 80GB
and took ~6 hours to pretrain and ~13 hours to finetune. We modify and adapt the training code from LLaVA.
π¨ Training code, Data and more details to release soon!
Training and Inference:
We will release the training code in some time.
Inference:
Please note that running the inference code at this stage may result in errors. The proper code for training and inference shall be released soon! Before running the snippet, you need to install the following dependencies:
pip install torch transformers accelerate pillow
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')
torch.set_default_device('cuda') # or 'cpu'
model = AutoModelForCausalLM.from_pretrained(
'Tensoic/Cerule',
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
'Tensoic/Cerule',
trust_remote_code=True)
# text prompt
prompt = 'Who are these charecters?'
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
image = Image.open('examples/mario.png')
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
# generate
output_ids = model.generate(
input_ids,
images=image_tensor,
max_new_tokens=100,
use_cache=False)[0] #keep use_cache=False or else it might run into some torch dim error
print(tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=False).strip())
License
Apache 2.0? Maybe... idk
- Downloads last month
- 0
Inference API (serverless) does not yet support model repos that contain custom code.