metadata
library_name: gemma_torch
license: gemma
license_link: https://ai.google.dev/gemma/terms
pipeline_tag: text-generation
tags:
- pytorch
extra_gated_heading: Access CodeGemma on Hugging Face
extra_gated_prompt: >-
To access CodeGemma on Hugging Face, you’re required to review and agree to
Google’s usage license. To do this, please ensure you’re logged-in to Hugging
Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license
CodeGemma Model Card
This repository corresponds to the CodeGemma 7B checkpoint for use with Gemma PyTorch. If you're looking for the
transformers
implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b.
Model Page: CodeGemma
Resources and Technical Documentation:
Terms of Use: Terms
Authors: Google
Sample Usage
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
VARIANT = "7b"
MACHINE_TYPE = "cpu"
weights_dir = 'codegemma-7b-pytorch'
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config)
ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt')
model.load_weights(ckpt_path)
model = model.to(device).eval()
FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":
sys.exit(0)<|fim_middle|>"""
model.generate(
FIM_PROMPT,
device=device,
output_len=100,
)