--- library_name: gemma_torch license: gemma license_link: https://ai.google.dev/gemma/terms pipeline_tag: text-generation tags: - pytorch extra_gated_heading: Access CodeGemma on Hugging Face extra_gated_prompt: To access CodeGemma on Hugging Face, you’re required to review and agree to Google’s usage license. To do this, please ensure you’re logged-in to Hugging Face and click below. Requests are processed immediately. extra_gated_button_content: Acknowledge license --- # CodeGemma Model Card > [!IMPORTANT] > > This repository corresponds to the CodeGemma 7B checkpoint for use with [Gemma PyTorch](https://github.com/google/gemma_pytorch). If you're looking for the `transformers` implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b. **Model Page**: [CodeGemma](https://ai.google.dev/gemma/docs/codegemma) **Resources and Technical Documentation**: * [Technical Report](https://goo.gle/codegemma) * [Responsible Generative AI Toolkit](https://ai.google.dev/responsible) **Terms of Use**: [Terms](https://www.kaggle.com/models/google/codegemma/license/consent/verify/huggingface?returnModelRepoId=google/codegemma-7b-pytorch) **Authors**: Google # Sample Usage ```python from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b from gemma.model import GemmaForCausalLM from gemma.tokenizer import Tokenizer import contextlib import os import torch VARIANT = "7b" MACHINE_TYPE = "cpu" weights_dir = 'codegemma-7b-pytorch' @contextlib.contextmanager def _set_default_tensor_type(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" torch.set_default_dtype(dtype) yield torch.set_default_dtype(torch.float) model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b() model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model") device = torch.device(MACHINE_TYPE) with _set_default_tensor_type(model_config.get_dtype()): model = GemmaForCausalLM(model_config) ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt') model.load_weights(ckpt_path) model = model.to(device).eval() FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__": sys.exit(0)<|fim_middle|>""" model.generate( FIM_PROMPT, device=device, output_len=100, ) ```