merve's picture
merve HF staff
Update link to terms (#2)
b63421a verified
metadata
library_name: gemma_torch
license: gemma
license_link: https://ai.google.dev/gemma/terms
pipeline_tag: text-generation
tags:
  - pytorch
extra_gated_heading: Access CodeGemma on Hugging Face
extra_gated_prompt: >-
  To access CodeGemma on Hugging Face, you’re required to review and agree to
  Google’s usage license. To do this, please ensure you’re logged-in to Hugging
  Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license

CodeGemma Model Card

This repository corresponds to the CodeGemma 7B checkpoint for use with Gemma PyTorch. If you're looking for the transformers implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b.

Model Page: CodeGemma

Resources and Technical Documentation:

Terms of Use: Terms

Authors: Google

Sample Usage

from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

VARIANT = "7b" 
MACHINE_TYPE = "cpu" 
weights_dir = 'codegemma-7b-pytorch' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":
    sys.exit(0)<|fim_middle|>"""

model.generate(
    FIM_PROMPT,
    device=device,
    output_len=100,
)