merve's picture
merve HF staff
Update link to terms (#2)
b63421a verified
---
library_name: gemma_torch
license: gemma
license_link: https://ai.google.dev/gemma/terms
pipeline_tag: text-generation
tags:
- pytorch
extra_gated_heading: Access CodeGemma on Hugging Face
extra_gated_prompt: To access CodeGemma on Hugging Face, you’re required to review
and agree to Google’s usage license. To do this, please ensure you’re logged-in
to Hugging Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license
---
# CodeGemma Model Card
> [!IMPORTANT]
>
> This repository corresponds to the CodeGemma 7B checkpoint for use with [Gemma PyTorch](https://github.com/google/gemma_pytorch). If you're looking for the `transformers` implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b.
**Model Page**: [CodeGemma](https://ai.google.dev/gemma/docs/codegemma)
**Resources and Technical Documentation**:
* [Technical Report](https://goo.gle/codegemma)
* [Responsible Generative AI Toolkit](https://ai.google.dev/responsible)
**Terms of Use**: [Terms](https://www.kaggle.com/models/google/codegemma/license/consent/verify/huggingface?returnModelRepoId=google/codegemma-7b-pytorch)
**Authors**: Google
# Sample Usage
```python
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
VARIANT = "7b"
MACHINE_TYPE = "cpu"
weights_dir = 'codegemma-7b-pytorch'
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config)
ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt')
model.load_weights(ckpt_path)
model = model.to(device).eval()
FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":
sys.exit(0)<|fim_middle|>"""
model.generate(
FIM_PROMPT,
device=device,
output_len=100,
)
```